425 lines
16 KiB
Python
425 lines
16 KiB
Python
from . import siren
|
|
from .. import param
|
|
from ..utils.helpers import compose, run_length_encode, MetaModuleProxy
|
|
from collections import OrderedDict
|
|
from pytorch_lightning.core.mixins import HyperparametersMixin
|
|
from torch import nn, Tensor
|
|
from torch.nn.utils.weight_norm import WeightNorm
|
|
from torchmeta.modules import MetaModule, MetaSequential
|
|
from typing import Iterable, Literal, Optional, Union, Callable
|
|
import itertools
|
|
import math
|
|
import torch
|
|
|
|
__doc__ = """
|
|
`fc` is short for "Fully Connected"
|
|
"""
|
|
|
|
def broadcast_tensors_except(*tensors: Tensor, dim: int) -> list[Tensor]:
|
|
if dim == -1:
|
|
shapes = [ i.shape[:dim] for i in tensors ]
|
|
else:
|
|
shapes = [ (*i.shape[:dim], i.shape[dim+1:]) for i in tensors ]
|
|
target_shape = list(torch.broadcast_shapes(*shapes))
|
|
if dim == -1:
|
|
target_shape.append(-1)
|
|
elif dim < 0:
|
|
target_shape.insert(dim+1, -1)
|
|
else:
|
|
target_shape.insert(dim, -1)
|
|
|
|
return [ i.broadcast_to(target_shape) for i in tensors ]
|
|
|
|
|
|
EPS = 1e-8
|
|
|
|
Nonlinearity = Literal[
|
|
None,
|
|
"relu",
|
|
"leaky_relu",
|
|
"silu",
|
|
"softplus",
|
|
"elu",
|
|
"selu",
|
|
"sine",
|
|
"sigmoid",
|
|
"tanh",
|
|
]
|
|
|
|
Normalization = Literal[
|
|
None,
|
|
"batchnorm",
|
|
"batchnorm_na",
|
|
"layernorm",
|
|
"layernorm_na",
|
|
"weightnorm",
|
|
]
|
|
|
|
class ReprHyperparametersMixin(HyperparametersMixin):
|
|
def extra_repr(self):
|
|
this = ", ".join(f"{k}={v!r}" for k, v in self.hparams.items())
|
|
rest = super().extra_repr()
|
|
if rest:
|
|
return f"{this}, {rest}"
|
|
else:
|
|
return this
|
|
|
|
class MultilineReprHyperparametersMixin(HyperparametersMixin):
|
|
def extra_repr(self):
|
|
items = [f"{k}={v!r}" for k, v in self.hparams.items()]
|
|
this = "\n".join(
|
|
", ".join(filter(bool, i)) + ","
|
|
for i in itertools.zip_longest(items[0::3], items[1::3], items[2::3])
|
|
)
|
|
rest = super().extra_repr()
|
|
if rest:
|
|
return f"{this}, {rest}"
|
|
else:
|
|
return this
|
|
|
|
|
|
class BatchLinear(nn.Linear):
|
|
"""
|
|
A linear (meta-)layer that can deal with batched weight matrices and biases,
|
|
as for instance output by a hypernetwork.
|
|
"""
|
|
__doc__ = nn.Linear.__doc__
|
|
_meta_forward_pre_hooks = None
|
|
|
|
def register_forward_pre_hook(self, hook: Callable) -> torch.utils.hooks.RemovableHandle:
|
|
if not isinstance(hook, WeightNorm):
|
|
return super().register_forward_pre_hook(hook)
|
|
|
|
if self._meta_forward_pre_hooks is None:
|
|
self._meta_forward_pre_hooks = OrderedDict()
|
|
|
|
handle = torch.utils.hooks.RemovableHandle(self._meta_forward_pre_hooks)
|
|
self._meta_forward_pre_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def forward(self, input: Tensor, params: Optional[dict[str, Tensor]]=None):
|
|
if params is None or not isinstance(self, MetaModule):
|
|
params = OrderedDict(self.named_parameters())
|
|
if self._meta_forward_pre_hooks is not None:
|
|
proxy = MetaModuleProxy(self, params)
|
|
for hook in self._meta_forward_pre_hooks.values():
|
|
hook(proxy, [input])
|
|
|
|
weight = params["weight"]
|
|
bias = params.get("bias", None)
|
|
|
|
# transpose weights
|
|
weight = weight.permute(*range(len(weight.shape) - 2), -1, -2) # does not jit
|
|
|
|
output = input.unsqueeze(-2).matmul(weight).squeeze(-2)
|
|
|
|
if bias is not None:
|
|
output = output + bias
|
|
|
|
return output
|
|
|
|
|
|
class MetaBatchLinear(BatchLinear, MetaModule):
|
|
pass
|
|
|
|
|
|
class CallbackConcatLayer(nn.Module):
|
|
"A tricky way to enable skip connections in sequentials models"
|
|
def __init__(self, tensor_getter: Callable[[], tuple[Tensor, ...]]):
|
|
super().__init__()
|
|
self.tensor_getter = tensor_getter
|
|
|
|
def forward(self, x):
|
|
ys = self.tensor_getter()
|
|
return torch.cat(broadcast_tensors_except(x, *ys, dim=-1), dim=-1)
|
|
|
|
|
|
class ResidualSkipConnectionEndLayer(nn.Module):
|
|
"""
|
|
Residual skip connections that can be added to a nn.Sequential
|
|
"""
|
|
|
|
class ResidualSkipConnectionStartLayer(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._stored_tensor = None
|
|
|
|
def forward(self, x):
|
|
assert self._stored_tensor is None
|
|
self._stored_tensor = x
|
|
return x
|
|
|
|
def get(self):
|
|
assert self._stored_tensor is not None
|
|
x = self._stored_tensor
|
|
self._stored_tensor = None
|
|
return x
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._stored_tensor = None
|
|
self._start = self.ResidualSkipConnectionStartLayer()
|
|
|
|
def forward(self, x):
|
|
skip = self._start.get()
|
|
return x + skip
|
|
|
|
@property
|
|
def start(self) -> ResidualSkipConnectionStartLayer:
|
|
return self._start
|
|
|
|
@property
|
|
def end(self) -> "ResidualSkipConnectionEndLayer":
|
|
return self
|
|
|
|
|
|
ResidualMode = Literal[
|
|
None,
|
|
"identity",
|
|
]
|
|
|
|
class FCLayer(MultilineReprHyperparametersMixin, MetaSequential):
|
|
"""
|
|
A single fully connected (FC) layer
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_features : int,
|
|
out_features : int,
|
|
*,
|
|
nonlinearity : Nonlinearity = "relu",
|
|
normalization : Normalization = None,
|
|
is_first : bool = False, # used for SIREN initialization
|
|
is_final : bool = False, # used for fan_out init
|
|
dropout_prob : float = 0.0,
|
|
negative_slope : float = 0.01, # only for nonlinearity="leaky_relu", default is normally 0.01
|
|
omega_0 : float = 30, # only for nonlinearity="sine"
|
|
residual_mode : ResidualMode = None,
|
|
_no_meta : bool = False, # set to true in hypernetworks
|
|
**_
|
|
):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
# improve repr
|
|
if nonlinearity != "leaky_relu":
|
|
self.hparams.pop("negative_slope")
|
|
if nonlinearity != "sine":
|
|
self.hparams.pop("omega_0")
|
|
|
|
Linear = nn.Linear if _no_meta else MetaBatchLinear
|
|
|
|
def make_layer() -> Iterable[nn.Module]:
|
|
# residual start
|
|
if residual_mode is not None:
|
|
residual_layer = ResidualSkipConnectionEndLayer()
|
|
yield "res_a", residual_layer.start
|
|
|
|
linear = Linear(in_features, out_features)
|
|
|
|
# initialize
|
|
if nonlinearity in {"relu", "leaky_relu", "silu", "softplus"}:
|
|
nn.init.kaiming_uniform_(linear.weight, a=negative_slope, nonlinearity=nonlinearity, mode="fan_in" if not is_final else "fan_out")
|
|
elif nonlinearity == "elu":
|
|
nn.init.normal_(linear.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(linear.weight.size(-1)))
|
|
elif nonlinearity == "selu":
|
|
nn.init.normal_(linear.weight, std=1 / math.sqrt(linear.weight.size(-1)))
|
|
elif nonlinearity == "sine":
|
|
siren.init_weights_(linear, omega_0, is_first)
|
|
elif nonlinearity in {"sigmoid", "tanh"}:
|
|
nn.init.xavier_normal_(linear.weight)
|
|
elif nonlinearity is None:
|
|
pass # this is effectively uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
|
else:
|
|
raise NotImplementedError(nonlinearity)
|
|
|
|
# linear + normalize
|
|
if normalization is None:
|
|
yield "linear", linear
|
|
elif normalization == "batchnorm":
|
|
yield "linear", linear
|
|
yield "norm", nn.BatchNorm1d(out_features, affine=True)
|
|
elif normalization == "batchnorm_na":
|
|
yield "linear", linear
|
|
yield "norm", nn.BatchNorm1d(out_features, affine=False)
|
|
elif normalization == "layernorm":
|
|
yield "linear", linear
|
|
yield "norm", nn.LayerNorm([out_features], elementwise_affine=True)
|
|
elif normalization == "layernorm_na":
|
|
yield "linear", linear
|
|
yield "norm", nn.LayerNorm([out_features], elementwise_affine=False)
|
|
elif normalization == "weightnorm":
|
|
yield "linear", nn.utils.weight_norm(linear)
|
|
else:
|
|
raise NotImplementedError(normalization)
|
|
|
|
# activation
|
|
inplace = False
|
|
if nonlinearity is None : pass
|
|
elif nonlinearity == "relu" : yield nonlinearity, nn.ReLU(inplace=inplace)
|
|
elif nonlinearity == "leaky_relu" : yield nonlinearity, nn.LeakyReLU(negative_slope=negative_slope, inplace=inplace)
|
|
elif nonlinearity == "silu" : yield nonlinearity, nn.SiLU(inplace=inplace)
|
|
elif nonlinearity == "softplus" : yield nonlinearity, nn.Softplus()
|
|
elif nonlinearity == "elu" : yield nonlinearity, nn.ELU(inplace=inplace)
|
|
elif nonlinearity == "selu" : yield nonlinearity, nn.SELU(inplace=inplace)
|
|
elif nonlinearity == "sine" : yield nonlinearity, siren.Sine(omega_0)
|
|
elif nonlinearity == "sigmoid" : yield nonlinearity, nn.Sigmoid()
|
|
elif nonlinearity == "tanh" : yield nonlinearity, nn.Tanh()
|
|
else : raise NotImplementedError(f"{nonlinearity=}")
|
|
|
|
# dropout
|
|
if dropout_prob > 0:
|
|
if nonlinearity == "selu":
|
|
yield "adropout", nn.AlphaDropout(p=dropout_prob)
|
|
else:
|
|
yield "dropout", nn.Dropout(p=dropout_prob)
|
|
|
|
# residual end
|
|
if residual_mode is not None:
|
|
yield "res_b", residual_layer.end
|
|
|
|
for name, module in make_layer():
|
|
self.add_module(name.replace("-", "_"), module)
|
|
|
|
@property
|
|
def nonlinearity(self) -> Optional[nn.Module]:
|
|
"alias to the activation function submodule"
|
|
if self.hparams.nonlinearity is None:
|
|
return None
|
|
return getattr(self, self.hparams.nonlinearity.replace("-", "_"))
|
|
|
|
def initialize_weights():
|
|
raise NotImplementedError
|
|
|
|
|
|
class FCBlock(MultilineReprHyperparametersMixin, MetaSequential):
|
|
"""
|
|
A block of FC layers
|
|
"""
|
|
def __init__(self,
|
|
in_features : int,
|
|
hidden_features : int,
|
|
hidden_layers : int,
|
|
out_features : int,
|
|
normalization : Normalization = None,
|
|
nonlinearity : Nonlinearity = "relu",
|
|
dropout_prob : float = 0.0,
|
|
outermost_linear : bool = True, # whether last linear is nonlinear
|
|
latent_features : Optional[int] = None,
|
|
concat_skipped_layers : Union[list[int], bool] = [],
|
|
concat_conditioned_layers : Union[list[int], bool] = [],
|
|
**kw,
|
|
):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
if isinstance(concat_skipped_layers, bool):
|
|
concat_skipped_layers = list(range(hidden_layers+2)) if concat_skipped_layers else []
|
|
if isinstance(concat_conditioned_layers, bool):
|
|
concat_conditioned_layers = list(range(hidden_layers+2)) if concat_conditioned_layers else []
|
|
if len(concat_conditioned_layers) != 0 and latent_features is None:
|
|
raise ValueError("Layers marked to be conditioned without known number of latent features")
|
|
concat_skipped_layers = [i if i >= 0 else hidden_layers+2-abs(i) for i in concat_skipped_layers]
|
|
concat_conditioned_layers = [i if i >= 0 else hidden_layers+2-abs(i) for i in concat_conditioned_layers]
|
|
self._concat_x_layers: frozenset[int] = frozenset(concat_skipped_layers)
|
|
self._concat_z_layers: frozenset[int] = frozenset(concat_conditioned_layers)
|
|
if len(self._concat_x_layers) != len(concat_skipped_layers):
|
|
raise ValueError(f"Duplicates found in {concat_skipped_layers = }")
|
|
if len(self._concat_z_layers) != len(concat_conditioned_layers):
|
|
raise ValueError(f"Duplicates found in {concat_conditioned_layers = }")
|
|
if not all(isinstance(i, int) for i in self._concat_x_layers):
|
|
raise TypeError(f"Expected only integers in {concat_skipped_layers = }")
|
|
if not all(isinstance(i, int) for i in self._concat_z_layers):
|
|
raise TypeError(f"Expected only integers in {concat_conditioned_layers = }")
|
|
|
|
def make_layers() -> Iterable[nn.Module]:
|
|
def make_concat_layer(*idxs: int) -> int:
|
|
x_condition_this_layer = any(idx in self._concat_x_layers for idx in idxs)
|
|
z_condition_this_layer = any(idx in self._concat_z_layers for idx in idxs)
|
|
if x_condition_this_layer and z_condition_this_layer:
|
|
yield CallbackConcatLayer(lambda: (self._current_x, self._current_z))
|
|
elif x_condition_this_layer:
|
|
yield CallbackConcatLayer(lambda: (self._current_x,))
|
|
elif z_condition_this_layer:
|
|
yield CallbackConcatLayer(lambda: (self._current_z,))
|
|
|
|
return in_features*x_condition_this_layer + (latent_features or 0)*z_condition_this_layer
|
|
|
|
added = yield from make_concat_layer(0)
|
|
|
|
yield FCLayer(
|
|
in_features = in_features + added,
|
|
out_features = hidden_features,
|
|
nonlinearity = nonlinearity,
|
|
normalization = normalization,
|
|
dropout_prob = dropout_prob,
|
|
is_first = True,
|
|
is_final = False,
|
|
**kw,
|
|
)
|
|
|
|
for i in range(hidden_layers):
|
|
added = yield from make_concat_layer(i+1)
|
|
|
|
yield FCLayer(
|
|
in_features = hidden_features + added,
|
|
out_features = hidden_features,
|
|
nonlinearity = nonlinearity,
|
|
normalization = normalization,
|
|
dropout_prob = dropout_prob,
|
|
is_first = False,
|
|
is_final = False,
|
|
**kw,
|
|
)
|
|
|
|
added = yield from make_concat_layer(hidden_layers+1)
|
|
|
|
nl = nonlinearity
|
|
|
|
yield FCLayer(
|
|
in_features = hidden_features + added,
|
|
out_features = out_features,
|
|
nonlinearity = None if outermost_linear else nl,
|
|
normalization = None if outermost_linear else normalization,
|
|
dropout_prob = 0.0 if outermost_linear else dropout_prob,
|
|
is_first = False,
|
|
is_final = True,
|
|
**kw,
|
|
)
|
|
|
|
for i, module in enumerate(make_layers()):
|
|
self.add_module(str(i), module)
|
|
|
|
@property
|
|
def is_conditioned(self) -> bool:
|
|
"Whether z is used or not"
|
|
return bool(self._concat_z_layers)
|
|
|
|
@classmethod
|
|
@compose("\n".join)
|
|
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
|
@compose(" ".join)
|
|
def as_jexpr(values: Union[list[int]]):
|
|
yield "{{"
|
|
for val, count in run_length_encode(values):
|
|
yield f"[{val!r}]*{count!r}"
|
|
yield "}}"
|
|
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list)
|
|
yield param.make_jinja_template(FCLayer, top_level=False, exclude_list=exclude_list | {
|
|
"in_features",
|
|
"out_features",
|
|
"nonlinearity",
|
|
"normalization",
|
|
"dropout_prob",
|
|
"is_first",
|
|
"is_final",
|
|
})
|
|
|
|
def forward(self, input: Tensor, z: Optional[Tensor] = None, *, params: Optional[dict[str, Tensor]]=None):
|
|
assert not self.is_conditioned or z is not None
|
|
if z is not None and z.ndim < input.ndim:
|
|
z = z[(*(None,)*(input.ndim - z.ndim), ...)]
|
|
self._current_x = input
|
|
self._current_z = z
|
|
return super().forward(input, params=params)
|