Source code for neuralogic.db.converter

import dataclasses
from collections import defaultdict
from typing import List, Dict, Tuple, Set, Optional

from neuralogic.core.constructs.relation import BaseRelation, WeightedRelation

from neuralogic.core import Aggregation, Settings, Metadata
from neuralogic.core.constructs.predicate import PredicateMetadata
from neuralogic.core.constructs.rule import Rule


[docs] @dataclasses.dataclass class TableMapping: relation_name: str table_name: str term_columns: List[str] value_column: Optional[str] = None
[docs] class Converter: def __init__(self, model, table_mappings: List[TableMapping], settings: Settings): self.table_mappings: Dict[str, TableMapping] = { f"{mapping.relation_name}/{len(mapping.term_columns)}": mapping for mapping in table_mappings } self.model = model self.settings = settings self._used_functions: Set[str] = set() self.sql_source = None self.std_functions = None def _process_template_entries(self) -> Tuple[defaultdict, Dict[str, Metadata]]: template = self.model.source_template weight_index = 0 batched_relations: defaultdict[str, defaultdict[int, List]] = defaultdict(lambda: defaultdict(list)) predicates_metadata = {} for entry in template: if isinstance(entry, Rule): weight_indices: List[Optional[int]] = [] if isinstance(entry.head, WeightedRelation) and entry.head.weight is not None: weight_indices.append(weight_index) weight_index += 1 else: weight_indices.append(None) for body_relation in entry.body: if isinstance(body_relation, WeightedRelation) and body_relation.weight is not None: weight_indices.append(weight_index) weight_index += 1 else: weight_indices.append(None) batched_relations[entry.head.predicate.name][entry.head.predicate.arity].append((entry, weight_indices)) elif isinstance(entry, BaseRelation): if isinstance(entry, WeightedRelation) and entry.weight is not None: weight_indices = [weight_index] weight_index += 1 else: weight_indices = [None] batched_relations[entry.predicate.name][entry.predicate.arity].append((entry, weight_indices)) elif isinstance(entry, PredicateMetadata): predicates_metadata[str(entry.predicate)] = entry.metadata else: raise NotImplementedError("Template can contain only relations or predicate metadata!") return batched_relations, predicates_metadata def _convert(self): weights = self.model.state_dict()["weights"] rule_default_activation = str(self.settings.rule_transformation).lower() relation_default_activation = str(self.settings.relation_transformation).lower() default_aggregation = str(Aggregation.AVG).lower() batched_relations, predicates_metadata = self._process_template_entries() sql_source_headers = [] sql_source = [] for name, arities in batched_relations.items(): for arity, relations_by_arity in arities.items(): if f"{name}/{arity}" in self.table_mappings: raise Exception is_fact = False for index, (relation, weight_indices) in enumerate(relations_by_arity): if isinstance(relation, Rule): act, agg = rule_default_activation, default_aggregation if relation.metadata is not None: if relation.metadata.transformation is not None: act = str(relation.metadata.transformation).lower() if relation.metadata.aggregation is not None: agg = str(relation.metadata.aggregation).lower() self._used_functions.add(act) self._used_functions.add(agg) sql_func = self.get_rule_sql_function(relation, index, act, agg, weight_indices, weights) else: is_fact = True sql_func = self.get_fact_sql_function(relation, index, weight_indices, weights) sql_source.append(sql_func) activation = relation_default_activation aggregation = default_aggregation predicate_metadata = predicates_metadata.get(f"{name}/{arity}", None) if predicate_metadata is not None: if predicate_metadata.transformation is not None: activation = str(predicate_metadata.transformation).lower() if predicate_metadata.aggregation is not None: aggregation = str(predicate_metadata.aggregation).lower() self._used_functions.add(activation) self._used_functions.add(aggregation) sql_func = self.get_rule_aggregation_function( name, arity, len(relations_by_arity), activation, aggregation, is_fact, ) sql_source.append(sql_func) sql_funcs = self.get_relation_interface_sql_function(name, arity) sql_source_headers.append(sql_funcs[0]) sql_source.append(sql_funcs[1]) self.std_functions = self.get_helpers(self._used_functions) sql_source_headers.extend(sql_source) self.sql_source = "\n".join(sql_source_headers)
[docs] def get_relation_interface_sql_function(self, relation: str, arity: int) -> Tuple[str, str]: raise NotImplementedError
[docs] def get_rule_sql_function( self, rule: Rule, index: int, activation: str, aggregation: str, weight_indices: List[int], weights ) -> str: raise NotImplementedError
[docs] def get_fact_sql_function(self, relation: BaseRelation, index: int, weight_indices: List[int], weights) -> str: raise NotImplementedError
[docs] def get_rule_aggregation_function( self, name: str, arity: int, number_of_rules: int, activation: str, aggregation: str, is_fact: bool = False ) -> str: raise NotImplementedError
[docs] def get_helpers(self, functions: Set[str]) -> str: raise NotImplementedError
[docs] def get_std_functions(self) -> str: if self.sql_source is None: self._convert() if self.sql_source is None: return "" return self.std_functions
[docs] def to_sql(self) -> str: if self.sql_source is None: self._convert() if self.sql_source is None: return "" return self.sql_source
@staticmethod def _is_var(term) -> bool: """Helper check if term is a variable or constant""" return str(term)[0].isupper()