This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

View File

@@ -0,0 +1,159 @@
from abc import ABC, abstractmethod
from torch import nn, Tensor
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
from typing import Hashable, Union, Optional, KeysView, ValuesView, ItemsView, Any, Sequence
import torch
class RequiresConditioner(nn.Module, ABC): # mixin
@property
@abstractmethod
def n_latent_features(self) -> int:
"This should provide the width of the conditioning feature vector"
...
@property
@abstractmethod
def latent_embeddings_init_std(self) -> float:
"This should provide the standard deviation to initialize the latent features with. DeepSDF uses 0.01."
...
@property
@abstractmethod
def latent_embeddings() -> Optional[Tensor]:
"""This property should return a tensor cotnaining all stored embeddings, for use in computing auto-decoder losses"""
...
@abstractmethod
def encode(self, batch: Any, batch_idx: int, optimizer_idx: int) -> Tensor:
"This should, given a training batch, return the encoded conditioning vector"
...
class AutoDecoderModuleMixin(RequiresConditioner, ABC):
"""
Populates dunder methods making it behave as a mapping.
The mapping indexes into a stored set of learnable embedding vectors.
Based on the auto-decoder architecture of
J.J. Park, P. Florence, J. Straub, R. Newcombe, S. Lovegrove, DeepSDF:
Learning Continuous Signed Distance Functions for Shape Representation, in:
2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
IEEE, Long Beach, CA, USA, 2019: pp. 165174.
https://doi.org/10.1109/CVPR.2019.00025.
"""
_autodecoder_mapping: dict[Hashable, int]
autodecoder_embeddings: nn.Parameter
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
@self._register_load_state_dict_pre_hook
def hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if f"{prefix}_autodecoder_mapping" in state_dict:
state_dict[f"{prefix}{_EXTRA_STATE_KEY_SUFFIX}"] = state_dict.pop(f"{prefix}_autodecoder_mapping")
class ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing(nn.UninitializedParameter):
def copy_(self, other):
self.materialize(other.shape, other.device, other.dtype)
return self.copy_(other)
self.autodecoder_embeddings = ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing()
# nn.Module interface
def get_extra_state(self):
return {
"ad_uids": getattr(self, "_autodecoder_mapping", {}),
}
def set_extra_state(self, obj):
if "ad_uids" not in obj: # backward compat
self._autodecoder_mapping = obj
else:
self._autodecoder_mapping = obj["ad_uids"]
# RequiresConditioner interface
@property
def latent_embeddings(self) -> Tensor:
return self.autodecoder_embeddings
# my interface
def set_observation_ids(self, z_uids: set[Hashable]):
assert self.latent_embeddings_init_std is not None, f"{self.__module__}.{self.__class__.__qualname__}.latent_embeddings_init_std"
assert self.n_latent_features is not None, f"{self.__module__}.{self.__class__.__qualname__}.n_latent_features"
assert self.latent_embeddings_init_std > 0, self.latent_embeddings_init_std
assert self.n_latent_features > 0, self.n_latent_features
self._autodecoder_mapping = {
k: i
for i, k in enumerate(sorted(set(z_uids)))
}
if not len(z_uids) == len(self._autodecoder_mapping):
raise ValueError(f"Observation identifiers are not unique! {z_uids = }")
self.autodecoder_embeddings = nn.Parameter(
torch.Tensor(len(self._autodecoder_mapping), self.n_latent_features)
.normal_(mean=0, std=self.latent_embeddings_init_std)
.to(self.device, self.dtype)
)
def add_key(self, z_uid: Hashable, z: Optional[Tensor] = None):
if z_uid in self._autodecoder_mapping:
raise ValueError(f"Observation identifier {z_uid!r} not unique!")
self._autodecoder_mapping[z_uid] = len(self._autodecoder_mapping)
self.autodecoder_embeddings
raise NotImplementedError
def __delitem__(self, z_uid: Hashable):
i = self._autodecoder_mapping.pop(z_uid)
for k, v in list(self._autodecoder_mapping.items()):
if v > i:
self._autodecoder_mapping[k] -= 1
with torch.no_grad():
self.autodecoder_embeddings = nn.Parameter(torch.cat((
self.autodecoder_embeddings.detach()[:i, :],
self.autodecoder_embeddings.detach()[i+1:, :],
), dim=0))
def __contains__(self, z_uid: Hashable) -> bool:
return z_uid in self._autodecoder_mapping
def __getitem__(self, z_uids: Union[Hashable, Sequence[Hashable]]) -> Tensor:
if isinstance(z_uids, tuple) or isinstance(z_uids, list):
key = tuple(map(self._autodecoder_mapping.__getitem__, z_uids))
else:
key = self._autodecoder_mapping[z_uids]
return self.autodecoder_embeddings[key, :]
def __iter__(self):
return self._autodecoder_mapping.keys()
def keys(self) -> KeysView[Hashable]:
"""
lists the identifiers of each code
"""
return self._autodecoder_mapping.keys()
def values(self) -> ValuesView[Tensor]:
return list(self.autodecoder_embeddings)
def items(self) -> ItemsView[Hashable, Tensor]:
"""
lists all the learned codes / latent vectors with their identifiers as keys
"""
return {
k : self.autodecoder_embeddings[i]
for k, i in self._autodecoder_mapping.items()
}.items()
class EncoderModuleMixin(RequiresConditioner, ABC):
@property
def latent_embeddings(self) -> None:
return None