import json
from typing import Dict, Sized
import jpype
from neuralogic import is_initialized, initialize
from neuralogic.core import BuiltDataset
from neuralogic.core.constructs.java_objects import ValueFactory
from neuralogic.nn.base import AbstractNeuraLogic
from neuralogic.core.settings import SettingsProxy
[docs]
class NeuraLogic(AbstractNeuraLogic):
def __init__(self, model, dataset_builder, template, settings: SettingsProxy):
super().__init__(dataset_builder, template, settings)
if not is_initialized():
initialize()
python_strategy = jpype.JClass(
"cz.cvut.fel.ida.neural.networks.computation.training.strategies.PythonTrainingStrategy"
)
self.do_train = True
self.need_sync = False
self.value_factory = ValueFactory()
optimizer = self.settings.optimizer.initialize()
lr_decay = self.settings.optimizer.get_lr_decay()
self.neural_model = model
self.strategy = python_strategy(settings.settings, model, optimizer, lr_decay)
self.samples_len = 0
self.number_format = self.settings.settings_class.superDetailedNumberFormat
@jpype.JImplements(
jpype.JClass("cz.cvut.fel.ida.neural.networks.computation.iteration.actions.PythonHookHandler")
)
class HookHandler:
def __init__(self, module: "NeuraLogic"):
self.module = module
@jpype.JOverride
def handleHook(self, hook, value):
self.module.run_hook(hook, json.loads(value))
self.hook_handler = HookHandler(self)
self.reset_parameters()
[docs]
def reset_parameters(self):
self.strategy.resetParameters()
[docs]
def train(self):
self.do_train = True
[docs]
def test(self):
self.do_train = False
[docs]
def set_training_samples(self, samples):
self.samples_len = len(samples)
self.strategy.setSamples(jpype.java.util.ArrayList(samples))
def __call__(self, dataset=None, train: bool = None, epochs: int = 1):
self.hooks_set = len(self.hooks) != 0
if isinstance(dataset, BuiltDataset):
samples = dataset.samples
batch_size = dataset.batch_size
else:
samples = dataset
batch_size = 1
if self.hooks_set:
self.strategy.setHooks(set(self.hooks.keys()), self.hook_handler)
if train is not None:
self.do_train = train
if samples is None:
results = self.strategy.learnSamples(epochs, batch_size)
deserialized_results = json.loads(str(results))
return deserialized_results, self.samples_len
if not isinstance(samples, Sized):
if self.do_train:
result = self.strategy.learnSample(samples.java_sample)
return json.loads(str(result)), 1
return json.loads(str(self.strategy.evaluateSample(samples.java_sample)))
sample_array = jpype.java.util.ArrayList([sample.java_sample for sample in samples])
if self.do_train:
results = self.strategy.learnSamples(sample_array, epochs, batch_size)
return json.loads(str(results)), len(samples)
results = self.strategy.evaluateSamples(sample_array, batch_size)
return json.loads(str(results))
[docs]
def backprop(self, sample, gradient):
trainer = self.strategy.getTrainer()
_, gradient_value = self.value_factory.get_value(gradient)
backpropagation = trainer.getBackpropagation()
weight_updater = backpropagation.backpropagate(sample.java_sample, gradient_value)
state_index = backpropagation.backproper
return state_index, weight_updater
[docs]
def state_dict(self) -> Dict:
weights = self.neural_model.getAllWeights()
weights_dict = {}
weight_names = {}
for weight in weights:
if weight.isLearnable:
weights_dict[weight.index] = ValueFactory.from_java(weight.value, SettingsProxy.number_format())
weight_names[weight.index] = str(weight.name)
return {
"weights": weights_dict,
"weight_names": weight_names,
}
[docs]
def load_state_dict(self, state_dict: Dict):
self.sync_template(state_dict, self.neural_model.getAllWeights())