160 lines
6.0 KiB
Python
160 lines
6.0 KiB
Python
from abc import ABC, abstractmethod
|
||
from torch import nn, Tensor
|
||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||
from typing import Hashable, Union, Optional, KeysView, ValuesView, ItemsView, Any, Sequence
|
||
import torch
|
||
|
||
|
||
class RequiresConditioner(nn.Module, ABC): # mixin
|
||
|
||
@property
|
||
@abstractmethod
|
||
def n_latent_features(self) -> int:
|
||
"This should provide the width of the conditioning feature vector"
|
||
...
|
||
|
||
@property
|
||
@abstractmethod
|
||
def latent_embeddings_init_std(self) -> float:
|
||
"This should provide the standard deviation to initialize the latent features with. DeepSDF uses 0.01."
|
||
...
|
||
|
||
@property
|
||
@abstractmethod
|
||
def latent_embeddings() -> Optional[Tensor]:
|
||
"""This property should return a tensor cotnaining all stored embeddings, for use in computing auto-decoder losses"""
|
||
...
|
||
|
||
@abstractmethod
|
||
def encode(self, batch: Any, batch_idx: int, optimizer_idx: int) -> Tensor:
|
||
"This should, given a training batch, return the encoded conditioning vector"
|
||
...
|
||
|
||
|
||
class AutoDecoderModuleMixin(RequiresConditioner, ABC):
|
||
"""
|
||
Populates dunder methods making it behave as a mapping.
|
||
The mapping indexes into a stored set of learnable embedding vectors.
|
||
|
||
Based on the auto-decoder architecture of
|
||
J.J. Park, P. Florence, J. Straub, R. Newcombe, S. Lovegrove, DeepSDF:
|
||
Learning Continuous Signed Distance Functions for Shape Representation, in:
|
||
2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
|
||
IEEE, Long Beach, CA, USA, 2019: pp. 165–174.
|
||
https://doi.org/10.1109/CVPR.2019.00025.
|
||
"""
|
||
|
||
_autodecoder_mapping: dict[Hashable, int]
|
||
autodecoder_embeddings: nn.Parameter
|
||
|
||
def __init__(self, *a, **kw):
|
||
super().__init__(*a, **kw)
|
||
|
||
@self._register_load_state_dict_pre_hook
|
||
def hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||
if f"{prefix}_autodecoder_mapping" in state_dict:
|
||
state_dict[f"{prefix}{_EXTRA_STATE_KEY_SUFFIX}"] = state_dict.pop(f"{prefix}_autodecoder_mapping")
|
||
|
||
class ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing(nn.UninitializedParameter):
|
||
def copy_(self, other):
|
||
self.materialize(other.shape, other.device, other.dtype)
|
||
return self.copy_(other)
|
||
self.autodecoder_embeddings = ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing()
|
||
|
||
# nn.Module interface
|
||
|
||
def get_extra_state(self):
|
||
return {
|
||
"ad_uids": getattr(self, "_autodecoder_mapping", {}),
|
||
}
|
||
|
||
def set_extra_state(self, obj):
|
||
if "ad_uids" not in obj: # backward compat
|
||
self._autodecoder_mapping = obj
|
||
else:
|
||
self._autodecoder_mapping = obj["ad_uids"]
|
||
|
||
# RequiresConditioner interface
|
||
|
||
@property
|
||
def latent_embeddings(self) -> Tensor:
|
||
return self.autodecoder_embeddings
|
||
|
||
# my interface
|
||
|
||
def set_observation_ids(self, z_uids: set[Hashable]):
|
||
assert self.latent_embeddings_init_std is not None, f"{self.__module__}.{self.__class__.__qualname__}.latent_embeddings_init_std"
|
||
assert self.n_latent_features is not None, f"{self.__module__}.{self.__class__.__qualname__}.n_latent_features"
|
||
assert self.latent_embeddings_init_std > 0, self.latent_embeddings_init_std
|
||
assert self.n_latent_features > 0, self.n_latent_features
|
||
|
||
self._autodecoder_mapping = {
|
||
k: i
|
||
for i, k in enumerate(sorted(set(z_uids)))
|
||
}
|
||
|
||
if not len(z_uids) == len(self._autodecoder_mapping):
|
||
raise ValueError(f"Observation identifiers are not unique! {z_uids = }")
|
||
|
||
self.autodecoder_embeddings = nn.Parameter(
|
||
torch.Tensor(len(self._autodecoder_mapping), self.n_latent_features)
|
||
.normal_(mean=0, std=self.latent_embeddings_init_std)
|
||
.to(self.device, self.dtype)
|
||
)
|
||
|
||
def add_key(self, z_uid: Hashable, z: Optional[Tensor] = None):
|
||
if z_uid in self._autodecoder_mapping:
|
||
raise ValueError(f"Observation identifier {z_uid!r} not unique!")
|
||
|
||
self._autodecoder_mapping[z_uid] = len(self._autodecoder_mapping)
|
||
self.autodecoder_embeddings
|
||
raise NotImplementedError
|
||
|
||
def __delitem__(self, z_uid: Hashable):
|
||
i = self._autodecoder_mapping.pop(z_uid)
|
||
for k, v in list(self._autodecoder_mapping.items()):
|
||
if v > i:
|
||
self._autodecoder_mapping[k] -= 1
|
||
|
||
with torch.no_grad():
|
||
self.autodecoder_embeddings = nn.Parameter(torch.cat((
|
||
self.autodecoder_embeddings.detach()[:i, :],
|
||
self.autodecoder_embeddings.detach()[i+1:, :],
|
||
), dim=0))
|
||
|
||
def __contains__(self, z_uid: Hashable) -> bool:
|
||
return z_uid in self._autodecoder_mapping
|
||
|
||
def __getitem__(self, z_uids: Union[Hashable, Sequence[Hashable]]) -> Tensor:
|
||
if isinstance(z_uids, tuple) or isinstance(z_uids, list):
|
||
key = tuple(map(self._autodecoder_mapping.__getitem__, z_uids))
|
||
else:
|
||
key = self._autodecoder_mapping[z_uids]
|
||
return self.autodecoder_embeddings[key, :]
|
||
|
||
def __iter__(self):
|
||
return self._autodecoder_mapping.keys()
|
||
|
||
def keys(self) -> KeysView[Hashable]:
|
||
"""
|
||
lists the identifiers of each code
|
||
"""
|
||
return self._autodecoder_mapping.keys()
|
||
|
||
def values(self) -> ValuesView[Tensor]:
|
||
return list(self.autodecoder_embeddings)
|
||
|
||
def items(self) -> ItemsView[Hashable, Tensor]:
|
||
"""
|
||
lists all the learned codes / latent vectors with their identifiers as keys
|
||
"""
|
||
return {
|
||
k : self.autodecoder_embeddings[i]
|
||
for k, i in self._autodecoder_mapping.items()
|
||
}.items()
|
||
|
||
class EncoderModuleMixin(RequiresConditioner, ABC):
|
||
@property
|
||
def latent_embeddings(self) -> None:
|
||
return None
|