Source code for neuralogic.dataset.db

import csv
import io
from typing import Optional, List, Union, Callable

from neuralogic.core.constructs.relation import BaseRelation, WeightedRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.dataset.logic import Dataset
from neuralogic.dataset.csv import CSVDataset, CSVFile, Mode
from neuralogic.dataset.base import ConvertibleDataset

DatasetEntries = Union[BaseRelation, WeightedRelation, Rule]


[docs] class DBSource: __slots__ = ( "relation_name", "table_name", "term_columns", "value_column", "default_value", "value_mapper", "skip_rows", "n_rows", "replace_empty_column", "sep", ) def __init__( self, relation_name: str, table_name: str, term_columns: List[str], value_column: Optional[str] = None, default_value: Union[float, int] = 1.0, value_mapper: Optional[Callable] = None, skip_rows: int = 0, n_rows: Optional[int] = None, replace_empty_column: Union[str, float, int] = 0, sep=",", ): self.table_name = table_name self.relation_name = relation_name self.sep = sep self.value_column = value_column self.default_value = default_value self.value_mapper = value_mapper self.term_columns = term_columns self.skip_rows = skip_rows self.n_rows = n_rows self.replace_empty_column = replace_empty_column if len(term_columns) == 0: raise NotImplementedError("Cannot create DBSource with zero terms") def to_csv(self, cursor) -> CSVFile: source = io.StringIO() columns = [term for term in self.term_columns] term_columns = list(range(len(columns))) value_column = None if self.value_column is not None: columns.append(self.value_column) value_column = len(columns) - 1 if hasattr(cursor, "copy_to"): cursor.copy_to(source, self.table_name, sep=self.sep, null="", columns=columns) else: cursor.execute(f"SELECT {','.join(columns)} FROM {self.table_name}") results = cursor.fetchall() csv_writer = csv.writer(source, lineterminator="\n") csv_writer.writerows(results) source.seek(0) return CSVFile( self.relation_name, source, self.sep, value_column, self.default_value, self.value_mapper, term_columns, False, self.skip_rows, self.n_rows, self.replace_empty_column, )
[docs] class DBDataset(ConvertibleDataset): def __init__( self, connection, db_sources: Union[List[DBSource], DBSource], queries_db_source: Optional[DBSource] = None, mode: Mode = Mode.ONE_EXAMPLE, ): self.connection = connection self.db_sources = [db_sources] if isinstance(db_sources, DBSource) else db_sources self.queries_db_source = queries_db_source self.mode = mode def add_db_source(self, db_source: DBSource): self.db_sources.append(db_source) def set_queries(self, db_source: DBSource): self.queries_db_source = db_source def to_dataset(self) -> Dataset: with self.connection.cursor() as cur: csv_files = [db_source.to_csv(cur) for db_source in self.db_sources] csv_queries = None if self.queries_db_source is None else self.queries_db_source.to_csv(cur) csv_dataset = CSVDataset(csv_files, csv_queries, self.mode) return csv_dataset.to_dataset()