from typing import Union, List, Optional, Set, Dict, Any, Callable, Iterable
import jpype
from neuralogic import is_initialized, initialize
from neuralogic.core.builder import Builder, DatasetBuilder
from neuralogic.core.constructs.relation import BaseRelation, WeightedRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.core.constructs.predicate import PredicateMetadata
from neuralogic.core.constructs.java_objects import JavaFactory
from neuralogic.core.settings import SettingsProxy, Settings
from neuralogic.nn.module.module import Module
from neuralogic.utils.visualize import draw_model
TemplateEntries = Union[BaseRelation, WeightedRelation, Rule]
[docs]
class Template:
def __init__(
self,
*,
template_file: Optional[str] = None,
):
self.template: List[TemplateEntries] = []
self.template_file = template_file
self.hooks: Dict[str, Set] = {}
[docs]
def add_hook(self, relation: Union[BaseRelation, str], callback: Callable[[Any], None]) -> None:
"""Hooks the callable to be called with the relation's value as an argument when the value of
the relation is being calculated.
:param relation:
:param callback:
:return:
"""
name = str(relation)
if isinstance(relation, BaseRelation):
name = name[:-1]
if name not in self.hooks:
self.hooks[name] = {callback}
else:
self.hooks[name].add(callback)
[docs]
def remove_hook(self, relation: Union[BaseRelation, str], callback):
"""Removes the callable from the relation's hooks
:param relation:
:param callback:
:return:
"""
name = str(relation)
if isinstance(relation, BaseRelation):
name = name[:-1]
if name not in self.hooks:
return
self.hooks[name].discard(callback)
[docs]
def add_rule(self, rule) -> None:
"""Adds one rule to the template
:param rule:
:return:
"""
self.add_rules([rule])
[docs]
def add_rules(self, rules: List):
"""Adds multiple rules to the template
:param rules:
:return:
"""
self.template.extend(rules)
[docs]
def add_module(self, module: Module):
"""Expands the module into rules and adds them into the template
:param module:
:return:
"""
self.add_rules(module())
[docs]
def get_parsed_template(self, settings: SettingsProxy, java_factory: JavaFactory):
if not is_initialized():
initialize()
if self.template_file is not None:
return Builder(settings).build_template_from_file(settings, self.template_file)
predicate_metadata = []
weighted_rules = []
valued_facts = []
for rule in self.template:
if isinstance(rule, PredicateMetadata):
predicate_metadata.append(java_factory.get_predicate_metadata_pair(rule))
elif isinstance(rule, Rule):
weighted_rules.append(java_factory.get_rule(rule))
elif isinstance(rule, (WeightedRelation, BaseRelation)):
valued_facts.append(java_factory.get_valued_fact(rule, java_factory.get_variable_factory()))
parsed_template = jpype.JClass("cz.cvut.fel.ida.logic.constructs.template.types.ParsedTemplate")
template = parsed_template(jpype.java.util.ArrayList(weighted_rules), jpype.java.util.ArrayList(valued_facts))
template.weightsMetadata = (jpype.java.util.List) @ jpype.java.util.ArrayList([])
template.predicatesMetadata = jpype.java.util.ArrayList(predicate_metadata)
metadata_processor = jpype.JClass("cz.cvut.fel.ida.logic.constructs.template.transforming.MetadataProcessor")
metadata_processor = metadata_processor(settings.settings)
metadata_processor.processMetadata(template)
template.inferTemplateFacts()
return template
[docs]
def remove_duplicates(self):
"""Remove duplicates from the template"""
entries = set()
deduplicated_template: List[TemplateEntries] = []
for entry in self.template:
entry_str = str(entry)
if entry_str in entries:
continue
entries.add(entry_str)
deduplicated_template.append(entry)
self.template = deduplicated_template
[docs]
def build(self, settings: Settings):
from neuralogic.nn import get_neuralogic_layer
java_factory = JavaFactory()
settings_proxy = settings.create_proxy()
parsed_template = self.get_parsed_template(settings_proxy, java_factory)
model = Builder(settings_proxy).build_model(parsed_template, settings_proxy)
return get_neuralogic_layer()(model, DatasetBuilder(parsed_template, java_factory), self, settings_proxy)
[docs]
def draw(
self,
filename: Optional[str] = None,
show=True,
img_type="png",
value_detail: int = 0,
graphviz_path: Optional[str] = None,
model=None,
*args,
**kwargs,
):
if model is None:
model = self.build(Settings())
return draw_model(model, filename, show, img_type, value_detail, graphviz_path, *args, **kwargs)
def __str__(self) -> str:
return "\n".join(str(r) for r in self.template)
def __repr__(self) -> str:
return self.__str__()
def __iadd__(self, other) -> "Template":
if isinstance(other, Iterable):
self.template.extend(other)
elif isinstance(other, Module):
self.template.extend(other())
else:
self.template.append(other)
return self
def __getitem__(self, item) -> TemplateEntries:
return self.template[item]
def __delitem__(self, key):
self.template.pop(key)
def __setitem__(self, key, value):
if isinstance(value, (Iterable, Module)):
raise NotImplementedError
self.template[key] = value
def __copy__(self) -> "Template":
temp = Template()
temp.template_file = self.template_file
temp.template = self.template
return temp
[docs]
def clone(self) -> "Template":
return self.__copy__()