Source code for neuralogic.nn.module.general.transformer
from typing import Optional
from neuralogic.core.constructs.function import Transformation
from neuralogic.core.constructs.factories import R
from neuralogic.nn.module.module import Module
from neuralogic.nn.module.general.mlp import MLP
from neuralogic.nn.module.general.attention import MultiheadAttention
[docs]
class Transformer(Module):
r"""
A transformer module based on `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_.
Parameters
----------
input_dim : int
The number of expected features.
num_heads : int
The number of heads in the multi-head attention module.
dim_feedforward : int
The dimension of the feedforward network.
output_name : str
Output (head) predicate name of the module.
src_name : str
The name of the predicate of the input to the encoder.
tgt_name : str
The name of the predicate of the input to the decoder.
src_mask_name : str, optional
The name of the predicate of the encoder input mask. Default: ``None``
tgt_mask_name : str, optional
The name of the predicate of the decoder input mask. Default: ``None``
memory_mask_name : str, optional
The name of the predicate of the encoder output mask. Default: ``None``
arity : int
Arity of the input and output predicate. Default: ``1``
"""
def __init__(
self,
input_dim: int,
num_heads: int,
dim_feedforward: int,
output_name: str,
src_name: str,
tgt_name: str,
src_mask_name: Optional[str] = None,
tgt_mask_name: Optional[str] = None,
memory_mask_name: Optional[str] = None,
arity: int = 1,
):
self.input_dim = input_dim
self.num_heads = num_heads
self.dim_feedforward = dim_feedforward
self.output_name = output_name
self.src_name = src_name
self.tgt_name = tgt_name
self.src_mask_name = src_mask_name
self.tgt_mask_name = tgt_mask_name
self.memory_mask_name = memory_mask_name
self.arity = arity
def __call__(self):
encoder = TransformerEncoder(
self.input_dim,
self.num_heads,
self.dim_feedforward,
f"{self.output_name}__encoder",
self.src_name,
self.src_mask_name,
self.arity,
)
decoder = TransformerDecoder(
self.input_dim,
self.num_heads,
self.dim_feedforward,
self.output_name,
self.tgt_name,
f"{self.output_name}__encoder",
self.tgt_mask_name,
self.memory_mask_name,
self.arity,
)
return [
*encoder(),
*decoder(),
]
class EncoderBlock(Module):
def __init__(
self,
input_dim: int,
num_heads: int,
dim_feedforward: int,
output_name: str,
query_name: str,
key_name: str,
value_name: str,
mask_name: Optional[str] = None,
arity: int = 1,
mlp: bool = True,
):
self.input_dim = input_dim
self.num_heads = num_heads
self.dim_feedforward = dim_feedforward
self.output_name = output_name
self.query_name = query_name
self.key_name = key_name
self.value_name = value_name
self.mask_name = mask_name
self.arity = arity
self.mlp = mlp
def __call__(self):
terms = [f"X{i}" for i in range(self.arity)]
attn_name = f"{self.output_name}__mhattn"
norm_name = f"{self.output_name}__norm"
mlp_name = f"{self.output_name}__mlp"
output_rel = R.get(self.output_name)
dim = self.input_dim
data_name = self.query_name
attention = MultiheadAttention(
dim,
self.num_heads,
attn_name,
self.query_name,
self.key_name,
self.value_name,
mask_name=self.mask_name,
arity=self.arity,
)
if self.mlp:
dims = [dim, self.dim_feedforward, dim]
mlp = MLP(dims, mlp_name, norm_name, [Transformation.RELU, Transformation.IDENTITY], self.arity)
return [
*attention(),
(R.get(norm_name)(terms) <= (R.get(attn_name)(terms), R.get(data_name)(terms))) | [Transformation.NORM],
R.get(norm_name) / self.arity | [Transformation.IDENTITY],
*mlp(),
(output_rel(terms) <= (R.get(norm_name)(terms), R.get(mlp_name)(terms))) | [Transformation.NORM],
output_rel / self.arity | [Transformation.IDENTITY],
]
return [
*attention(),
(output_rel(terms) <= (R.get(attn_name)(terms), R.get(data_name)(terms))) | [Transformation.NORM],
output_rel / self.arity | [Transformation.IDENTITY],
]
[docs]
class TransformerEncoder(EncoderBlock):
r"""
A transformer encoder module based on `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_.
Parameters
----------
input_dim : int
The number of expected features.
num_heads : int
The number of heads in the multi-head attention module.
dim_feedforward : int
The dimension of the feedforward network.
output_name : str
Output (head) predicate name of the module.
input_name : str
The name of the predicate of the input sequence.
mask_name : str, optional
The name of the predicate of the input sequence mask. Default: ``None``
arity : int
Arity of the input and output predicate. Default: ``1``
"""
def __init__(
self,
input_dim: int,
num_heads: int,
dim_feedforward: int,
output_name: str,
input_name: str,
mask_name: Optional[str] = None,
arity: int = 1,
):
super().__init__(
input_dim,
num_heads,
dim_feedforward,
output_name,
input_name,
input_name,
input_name,
mask_name,
arity,
True,
)
[docs]
class TransformerDecoder(Module):
r"""
A transformer decoder module based on `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_.
Parameters
----------
input_dim : int
The number of expected features.
num_heads : int
The number of heads in the multi-head attention module.
dim_feedforward : int
The dimension of the feedforward network.
output_name : str
Output (head) predicate name of the module.
input_name : str
The name of the predicate of the input sequence.
input_name : str
The name of the input encoder.
mask_name : str, optional
The name of the predicate of the decoder input sequence mask. Default: ``None``
memory_mask_name : str, optional
The name of the predicate of the encoder output mask. Default: ``None``
arity : int
Arity of the input and output predicate. Default: ``1``
"""
def __init__(
self,
input_dim: int,
num_heads: int,
dim_feedforward: int,
output_name: str,
input_name: str,
encoder_name: str,
mask_name: Optional[str] = None,
memory_mask_name: Optional[str] = None,
arity: int = 1,
):
self.input_dim = input_dim
self.num_heads = num_heads
self.dim_feedforward = dim_feedforward
self.output_name = output_name
self.input_name = input_name
self.encoder_name = encoder_name
self.mask_name = mask_name
self.memory_mask_name = memory_mask_name
self.arity = arity
def __call__(self):
data_name = self.input_name
dim = self.input_dim
tmp_encoder_out = f"{self.output_name}__decoder"
encoder_name = self.encoder_name
mlp_dim = self.dim_feedforward
enc_block_one = EncoderBlock(
dim,
self.num_heads,
mlp_dim,
tmp_encoder_out,
data_name,
data_name,
data_name,
self.mask_name,
self.arity,
False,
)
enc_block_two = EncoderBlock(
dim,
self.num_heads,
mlp_dim,
self.output_name,
tmp_encoder_out,
encoder_name,
encoder_name,
self.memory_mask_name,
self.arity,
)
return [
*enc_block_one(),
*enc_block_two(),
]