from typing import Dict, Any
[docs]
class InitializerNames:
UNIFORM = "UNIFORM"
NORMAL = "NORMAL"
CONSTANT = "CONSTANT"
LONGTAIL = "LONGTAIL"
GLOROT = "GLOROT"
HE = "HE"
[docs]
class Initializer:
[docs]
def is_simple(self) -> bool:
return True
[docs]
def get_settings(self) -> Dict[str, Any]:
return {"initializer": str(self)}
[docs]
class Normal(Initializer):
r"""Initializes learnable parameters with random samples from a normal (Gaussian) distribution"""
def __str__(self):
return InitializerNames.NORMAL
[docs]
class Constant(Initializer):
r"""Initializes learnable parameters with the ``value``.
Parameters
----------
value : float
Value to fill weights with. Default: ``0.1``
"""
def __init__(self, value: float = 0.1):
self.value = value
[docs]
def get_settings(self) -> Dict[str, Any]:
return {
"initializer": str(self),
"initializer_const": self.value,
}
def __str__(self):
return InitializerNames.CONSTANT
[docs]
class Longtail(Initializer):
"""Initializes learnable parameters with random samples from a long tail distribution"""
def __str__(self):
return InitializerNames.LONGTAIL
[docs]
class Glorot(Initializer):
r"""Initializes learnable parameters with samples from a uniform distribution (from the interval
``[-scale / 2, scale / 2]``) using the Glorot method.
Parameters
----------
scale : float
Scale of a uniform distribution interval ``[-scale / 2, scale / 2]``. Default: ``2``
"""
def __init__(self, scale: float = 2):
self.scale = scale
[docs]
def is_simple(self) -> bool:
return False
[docs]
def get_settings(self) -> Dict[str, Any]:
return {
"initializer": str(self),
"initializer_uniform_scale": self.scale,
}
def __str__(self):
return InitializerNames.GLOROT
[docs]
class He(Initializer):
r"""Initializes learnable parameters with samples from a uniform distribution (from the interval
``[-scale / 2, scale / 2]``) using the He method.
Parameters
----------
scale : float
Scale of a uniform distribution interval ``[-scale / 2, scale / 2]``. Default: ``2``
"""
def __init__(self, scale: float = 2):
self.scale = scale
[docs]
def is_simple(self) -> bool:
return False
[docs]
def get_settings(self) -> Dict[str, Any]:
return {
"initializer": str(self),
"initializer_uniform_scale": self.scale,
}
def __str__(self):
return InitializerNames.HE
__all__ = ["Normal", "Uniform", "Constant", "Longtail", "Glorot", "He", "Initializer", "InitializerNames"]