from typing import List, Set, Tuple

from neuralogic.core.constructs.relation import BaseRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.db.converter import Converter
from import helpers

CREATE OR REPLACE FUNCTION {name}({params}) RETURNS {return_type} AS $$
$$ LANGUAGE {language} {volatility};

    "mul": "neuralogic_std.mul",
    "sum": "neuralogic_std.sum",
    "tanh": "neuralogic_std.tanh",
    "sigmoid": "neuralogic_std.sigmoid",
    "relu": "neuralogic_std.relu",
    "identity": "",
    "avg": "AVG",
    "max": "MAX",
    "min": "MIN",

[docs] class PostgresConverter(Converter):
[docs] @staticmethod def get_function( name: str, params: List[str], return_type: List[str], body: str, language: str = "SQL", volatility: str = "STABLE", ) -> str: return FUNCTION_TEMPLATE.format( name=name, params=",".join(params), return_type=f"Table({','.join(return_type)})", body=body, language=language, volatility=volatility, )
[docs] @staticmethod def get_empty_function( name: str, params: List[str], return_type: List[str], language: str = "SQL", volatility: str = "STABLE", ) -> str: if len(return_type) == 0: raise NotImplementedError select = ["1"] for _ in return_type[1:]: select.append("'1'") return PostgresConverter.get_function( name, params, return_type, f"SELECT {','.join(select)}", language, volatility )
[docs] def get_helpers(self, functions: Set[str]) -> str: used_functions = {fun for fun in functions} used_functions.add("mul") used_functions.add("sum") function_sources = ["CREATE SCHEMA IF NOT EXISTS neuralogic_std;" "CREATE SCHEMA IF NOT EXISTS neuralogic;"] for fun in used_functions: helper_fun = helpers[fun] if helper_fun is not None: function_sources.append(helper_fun.strip()) return "\n".join(function_sources)
[docs] def get_fact_sql_function(self, relation: BaseRelation, index: int, weight_indices: List[int], weights) -> str: """Generate a SQL function for a ground fact""" value = 1 if weight_indices[0] is None else weights[weight_indices[0]] parameters = [f"p{i} TEXT" for i in range(len(relation.terms))] return_type = ["value NUMERIC", *(f"t{i} TEXT" for i in range(len(relation.terms)))] select = [f"{value} as value", *(f"'{term}' as t{i}" for i, term in enumerate(relation.terms))] name = f"neuralogic._{}_{relation.predicate.arity}_{index}" body = f"SELECT {','.join(select)}" condition = " AND ".join(f"(p{i} = '{term}' OR p{i} IS NULL)" for i, term in enumerate(relation.terms)) if condition: body = f"{body} WHERE {condition}" return self.get_function(name, parameters, return_type, body, volatility="IMMUTABLE")
[docs] def get_relation_interface_sql_function(self, relation: str, arity: int) -> Tuple[str, str]: """Return the SQL function that should by used by the end users""" function_parameters = [f"p{i}" for i in range(arity)] params = [f"{name} TEXT" for name in function_parameters] name = f"neuralogic.{relation}" return_type = ["value NUMERIC", *(f"t{i} TEXT" for i in range(arity))] function_name = f"neuralogic._{relation}_{arity}({','.join('NULL' for _ in range(arity))})" tmp_table_name = f"__neuralogic___tmp_{relation}_{arity}" conditions = " AND ".join(f"out.t{i} LIKE COALESCE(p{i}, '%')" for i in range(arity)) if conditions: return_select = f"RETURN QUERY SELECT * FROM {tmp_table_name} as out WHERE {conditions}" else: return_select = f"RETURN QUERY SELECT * FROM {tmp_table_name} as out" body = ( f"BEGIN CREATE TEMPORARY TABLE IF NOT EXISTS {tmp_table_name} ON COMMIT DROP AS " f"SELECT * FROM {function_name}; {return_select}; RETURN; END;" ) return ( self.get_empty_function(name, params, return_type, "SQL", "VOLATILE"), self.get_function(name, params, return_type, body, "plpgsql", "VOLATILE"), )
[docs] def get_rule_aggregation_function( self, name: str, arity: int, number_of_rules: int, activation: str, aggregation: str, is_fact: bool = False ) -> str: """ Generete SQL function which aggregates rule functions (something like the aggregation neuron) """ function_parameters = [f"p{i}" for i in range(arity)] inner_select = [] for index in range(number_of_rules): function_name = f"neuralogic._{name}_{arity}_{index}({','.join(function_parameters)}) as s{index}" selects = [f"s{index}.value as value"] selects.extend(f"s{index}.t{i}" for i in range(arity)) inner_select.append(f"SELECT {', '.join(selects)} FROM {function_name}") select = ["SUM(out.value) as value"] if is_fact else [f"{FUNCTION_MAP[activation]}(SUM(out.value)) as value"] select.extend(f"out.t{i}" for i in range(arity)) group_by_clause = f" GROUP BY {', '.join('out.t' + str(v) for v in range(arity))}" select_from = " UNION ALL ".join(inner_select) return_type = ["value NUMERIC", *(f"t{i} TEXT" for i in range(arity))] name = f"neuralogic._{name}_{arity}" body = f"SELECT {', '.join(select)} FROM ({select_from}) AS out{'' if arity == 0 else group_by_clause}" return self.get_function(name, [f"{name} TEXT" for name in function_parameters], return_type, body)
[docs] def get_rule_sql_function( self, rule: Rule, index: int, activation: str, aggregation: str, weight_indices: List[int], weights ) -> str: """Return the SQL function of one rule""" function_parameters = [f"p{i}" for i in range(rule.head.predicate.arity)] if weight_indices[0] is None: select = [f"{FUNCTION_MAP[aggregation]}({FUNCTION_MAP[activation]}(out.value)) as value"] else: select = [ f"{FUNCTION_MAP['mul']}({weights[weight_indices[0]]}, " f"{FUNCTION_MAP[aggregation]}({FUNCTION_MAP[activation]}(out.value))) as value" ] vars_mapping = {} where = [] vars_body_mapping = {} join_vars_mapping = {} inner_select = set() inner_selected_terms = set() inner_value_select = [] from_clause: List[str] = [] for term_idx, term in enumerate(rule.head.terms): if Converter._is_var(term): term_name = f"t{term_idx}" vars_mapping[str(term)] = term_name vars_body_mapping[str(term)] = f"p{term_idx}" select.append(f"out.{term_name}") else: select.append(f"'{term}' as t{term_idx}") for t_index, (relation, weight_id) in enumerate(zip(rule.body, weight_indices[1:])): join_on = [] relation_mapping = self.table_mappings.get(str(relation.predicate).replace("*", "_"), None) if not relation.predicate.hidden: if relation_mapping is None: value = "value" else: value_column = relation_mapping.value_column value = None if value_column is None else f"{value_column}::NUMERIC" selected_value = "1" if value is None else f"s{t_index}.{value}" if relation.function is not None: function = str(relation.function).lower() selected_value = f"{FUNCTION_MAP[function]}({selected_value})" self._used_functions.add(function) if weight_id is None: inner_value_select.append(selected_value) else: inner_value_select.append(f"{FUNCTION_MAP['mul']}({weights[weight_id]}, {selected_value})") if len(inner_value_select) == 2: inner_value_select = [f"{FUNCTION_MAP['sum']}({', '.join(inner_value_select)})"] from_function_parameters = [] for term_idx, term in enumerate(relation.terms): if relation_mapping is None: field = f"t{term_idx}" else: field = f"{relation_mapping.term_columns[term_idx]}::TEXT" if not self._is_var(term): if relation_mapping is None: from_function_parameters.append(f"'{term}'") else: where.append(f"s{t_index}.{field} = '{str(term)}'") continue from_function_parameters.append("NULL") if str(term) in vars_mapping and str(term) not in inner_selected_terms: inner_select.add(f"s{t_index}.{field} AS {vars_mapping[str(term)]}") inner_selected_terms.add(str(term)) if str(term) in vars_body_mapping: where.append(f"s{t_index}.{field} LIKE COALESCE({vars_body_mapping[str(term)]}, '%')") if str(term) in join_vars_mapping: join_on.append(f"{vars_body_mapping[str(term)]} = s{t_index}.{field}") else: vars_body_mapping[str(term)] = f"s{t_index}.{field}" join_vars_mapping[str(term)] = f"s{t_index}.{field}" continue join_vars_mapping[str(term)] = f"s{t_index}.{field}" vars_body_mapping[str(term)] = f"s{t_index}.{field}" if relation_mapping is None: params = ",".join(from_function_parameters) function_name = f"neuralogic.{}({params})" else: function_name = relation_mapping.table_name if not from_clause: from_clause.append(f"{function_name} AS s{t_index}") else: from_clause.append( f"{function_name} AS s{t_index} ON {'1 = 1' if not join_on else ' AND '.join(join_on)}" ) where_clause = f" WHERE {' AND '.join(where)}" group_by_clause = f" GROUP BY {', '.join('out.' + v for v in vars_mapping.values())}" if len(inner_value_select) == 1: inner_select.add(f"{inner_value_select[0]} as value") from_str = f"{' INNER JOIN '.join(from_clause)}" from_str = f"SELECT {', '.join(inner_select)} FROM {from_str}{'' if not where else where_clause}" return_type = ["value NUMERIC", *(f"t{i} TEXT" for i in range(len(rule.head.terms)))] name = f"neuralogic._{}_{rule.head.predicate.arity}_{index}" body = f"SELECT {', '.join(select)} FROM ({from_str}) AS out{'' if not vars_mapping else group_by_clause}" return self.get_function(name, [f"{name} TEXT" for name in function_parameters], return_type, body)