Source code for neuralogic.core.constructs.relation

from typing import Iterable, Union

import numpy as np

from neuralogic.core.constructs.function.enum import Combination, Transformation
from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.predicate import Predicate
from neuralogic.core.constructs import rule, factories
from neuralogic.core.constructs.function.function import TransformationFunction, CombinationFunction


[docs] class BaseRelation: __slots__ = "predicate", "function", "terms", "negated" def __init__( self, predicate: Predicate, terms=None, function: Union[TransformationFunction, CombinationFunction] = None, negated: bool = False, ): self.predicate = predicate self.function = function self.negated = negated self.terms = [] if not isinstance(terms, Iterable) or isinstance(terms, str): terms = [terms] for term in terms: if term is None: continue if isinstance(term, list): self.terms.extend(term) else: self.terms.append(term) def __neg__(self) -> "BaseRelation": return self.attach_activation_function(Transformation.REVERSE) def __invert__(self) -> "BaseRelation": if self.function is not None: raise ValueError(f"Cannot negate relation {self} with attached function.") predicate = Predicate(self.predicate.name, self.predicate.arity, True, self.predicate.special) relation = BaseRelation(predicate, self.terms, self.function, not self.negated) return relation @property def T(self) -> "BaseRelation": return self.attach_activation_function(Transformation.TRANSP)
[docs] def attach_activation_function(self, function: Union[TransformationFunction, CombinationFunction]): if self.negated: raise ValueError(f"Cannot attach function to negated relation {self}") relation = self.__copy__() relation.function = function return relation
def __truediv__(self, other): if not isinstance(other, int) or self.predicate.arity != 0 or other < 0: raise NotImplementedError name, hidden, special = self.predicate.name, self.predicate.hidden, self.predicate.special return factories.AtomFactory.Predicate.get_predicate(name, other, hidden, special) def __call__(self, *args) -> "BaseRelation": if self.terms: raise Exception if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], str): terms = list(args[0]) else: terms = list(args) arity = len(terms) name, hidden, special = self.predicate.name, self.predicate.hidden, self.predicate.special predicate = factories.AtomFactory.Predicate.get_predicate(name, arity, hidden, special) return BaseRelation(predicate, terms, self.function, self.negated) def __getitem__(self, item) -> "WeightedRelation": if self.predicate.hidden or self.predicate.special: raise ValueError(f"Special/Hidden relation {self} cannot have learnable parameters.") return WeightedRelation(item, self.predicate, False, self.terms, self.function) def __le__(self, other: Union[Iterable["BaseRelation"], "BaseRelation"]) -> rule.Rule: return rule.Rule(self, other)
[docs] def to_str(self, end=False) -> str: end = "." if end else "" if self.terms: terms = ", ".join([str(term) for term in self.terms]) if self.negated: return f"!{self.predicate.to_str()}({terms}){end}" if self.function: literal = f"{self.predicate.to_str()}({terms})" return f"{self.function.wrap(literal)}{end}" return f"{self.predicate.to_str()}({terms}){end}" if self.negated: return f"!{self.predicate.to_str()}{end}" if self.function: return f"{self.function.wrap(self.predicate.to_str())}{end}" return f"{self.predicate.to_str()}{end}"
def __str__(self) -> str: return self.to_str(True) def __repr__(self) -> str: return self.__str__() def __copy__(self): relation = BaseRelation.__new__(BaseRelation) relation.function = self.function relation.terms = self.terms relation.predicate = self.predicate relation.negated = self.negated return relation def __and__(self, other) -> rule.RuleBody: if isinstance(other, BaseRelation): return rule.RuleBody(self, other) raise NotImplementedError def __add__(self, other): return FContainer((self, other), Combination.SUM) def __mul__(self, other): return FContainer((self, other), Combination.ELPRODUCT) def __matmul__(self, other): return FContainer((self, other), Combination.PRODUCT)
[docs] class WeightedRelation(BaseRelation): __slots__ = "weight", "weight_name", "is_fixed" def __init__( self, weight, predicate: Predicate, fixed=False, terms=None, function: Union[TransformationFunction, CombinationFunction] = None, ): super().__init__(predicate, terms, function, False) self.weight = weight self.weight_name = None self.is_fixed = fixed if isinstance(weight, slice): self.weight_name = str(weight.start) self.weight = weight.stop elif isinstance(weight, tuple) and isinstance(weight[0], slice): self.weight_name = str(weight[0].start) self.weight = (weight[0].stop, *weight[1:]) if isinstance(weight, np.ndarray): self.weight = weight.tolist()
[docs] def fixed(self) -> "WeightedRelation": if self.is_fixed: raise Exception(f"Weighted relation {self} is already fixed") return WeightedRelation(self.weight, self.predicate, True, self.terms, self.function)
[docs] def to_str(self, end=False): if isinstance(self.weight, tuple): weight = f"{{{', '.join(str(w) for w in self.weight)}}}" else: weight = str(self.weight) if self.weight_name: weight = f"${self.weight_name}={weight}" if self.is_fixed: return f"<{weight}> {super().to_str(end)}" return f"{weight} {super().to_str(end)}"
def __str__(self) -> str: return self.to_str(True) def __repr__(self) -> str: return self.__str__() def __call__(self, *args) -> BaseRelation: raise NotImplementedError(f"Cannot assign terms to weighted relation {self.predicate}") def __getitem__(self, item) -> "WeightedRelation": raise NotImplementedError(f"Cannot assign weight to weighted relation {self.predicate}")
[docs] def attach_activation_function(self, function: Union[Transformation, Combination]): raise NotImplementedError( f"Cannot attach a function to weighted relation {self}. Attach the function before adding weights." )
@property def T(self) -> "WeightedRelation": raise NotImplementedError( f"Cannot transpose weighted relation {self} Apply the transposition before adding weights." ) def __invert__(self) -> "WeightedRelation": raise NotImplementedError(f"Weighted relations ({self}) cannot be negated.") def __neg__(self) -> "WeightedRelation": raise NotImplementedError( f"Cannot negate weighted relation {self} Apply the reverse function before adding weights." ) def __copy__(self): relation = WeightedRelation.__new__(WeightedRelation) relation.predicate = self.predicate relation.function = self.function relation.terms = self.terms relation.weight = self.weight relation.is_fixed = self.is_fixed relation.negated = self.negated return relation