Files
marf/ifield/models/conditioning.py
2025-01-09 15:43:11 +01:00

160 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from abc import ABC, abstractmethod
from torch import nn, Tensor
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
from typing import Hashable, Union, Optional, KeysView, ValuesView, ItemsView, Any, Sequence
import torch
class RequiresConditioner(nn.Module, ABC): # mixin
@property
@abstractmethod
def n_latent_features(self) -> int:
"This should provide the width of the conditioning feature vector"
...
@property
@abstractmethod
def latent_embeddings_init_std(self) -> float:
"This should provide the standard deviation to initialize the latent features with. DeepSDF uses 0.01."
...
@property
@abstractmethod
def latent_embeddings() -> Optional[Tensor]:
"""This property should return a tensor cotnaining all stored embeddings, for use in computing auto-decoder losses"""
...
@abstractmethod
def encode(self, batch: Any, batch_idx: int, optimizer_idx: int) -> Tensor:
"This should, given a training batch, return the encoded conditioning vector"
...
class AutoDecoderModuleMixin(RequiresConditioner, ABC):
"""
Populates dunder methods making it behave as a mapping.
The mapping indexes into a stored set of learnable embedding vectors.
Based on the auto-decoder architecture of
J.J. Park, P. Florence, J. Straub, R. Newcombe, S. Lovegrove, DeepSDF:
Learning Continuous Signed Distance Functions for Shape Representation, in:
2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
IEEE, Long Beach, CA, USA, 2019: pp. 165174.
https://doi.org/10.1109/CVPR.2019.00025.
"""
_autodecoder_mapping: dict[Hashable, int]
autodecoder_embeddings: nn.Parameter
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
@self._register_load_state_dict_pre_hook
def hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if f"{prefix}_autodecoder_mapping" in state_dict:
state_dict[f"{prefix}{_EXTRA_STATE_KEY_SUFFIX}"] = state_dict.pop(f"{prefix}_autodecoder_mapping")
class ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing(nn.UninitializedParameter):
def copy_(self, other):
self.materialize(other.shape, other.device, other.dtype)
return self.copy_(other)
self.autodecoder_embeddings = ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing()
# nn.Module interface
def get_extra_state(self):
return {
"ad_uids": getattr(self, "_autodecoder_mapping", {}),
}
def set_extra_state(self, obj):
if "ad_uids" not in obj: # backward compat
self._autodecoder_mapping = obj
else:
self._autodecoder_mapping = obj["ad_uids"]
# RequiresConditioner interface
@property
def latent_embeddings(self) -> Tensor:
return self.autodecoder_embeddings
# my interface
def set_observation_ids(self, z_uids: set[Hashable]):
assert self.latent_embeddings_init_std is not None, f"{self.__module__}.{self.__class__.__qualname__}.latent_embeddings_init_std"
assert self.n_latent_features is not None, f"{self.__module__}.{self.__class__.__qualname__}.n_latent_features"
assert self.latent_embeddings_init_std > 0, self.latent_embeddings_init_std
assert self.n_latent_features > 0, self.n_latent_features
self._autodecoder_mapping = {
k: i
for i, k in enumerate(sorted(set(z_uids)))
}
if not len(z_uids) == len(self._autodecoder_mapping):
raise ValueError(f"Observation identifiers are not unique! {z_uids = }")
self.autodecoder_embeddings = nn.Parameter(
torch.Tensor(len(self._autodecoder_mapping), self.n_latent_features)
.normal_(mean=0, std=self.latent_embeddings_init_std)
.to(self.device, self.dtype)
)
def add_key(self, z_uid: Hashable, z: Optional[Tensor] = None):
if z_uid in self._autodecoder_mapping:
raise ValueError(f"Observation identifier {z_uid!r} not unique!")
self._autodecoder_mapping[z_uid] = len(self._autodecoder_mapping)
self.autodecoder_embeddings
raise NotImplementedError
def __delitem__(self, z_uid: Hashable):
i = self._autodecoder_mapping.pop(z_uid)
for k, v in list(self._autodecoder_mapping.items()):
if v > i:
self._autodecoder_mapping[k] -= 1
with torch.no_grad():
self.autodecoder_embeddings = nn.Parameter(torch.cat((
self.autodecoder_embeddings.detach()[:i, :],
self.autodecoder_embeddings.detach()[i+1:, :],
), dim=0))
def __contains__(self, z_uid: Hashable) -> bool:
return z_uid in self._autodecoder_mapping
def __getitem__(self, z_uids: Union[Hashable, Sequence[Hashable]]) -> Tensor:
if isinstance(z_uids, tuple) or isinstance(z_uids, list):
key = tuple(map(self._autodecoder_mapping.__getitem__, z_uids))
else:
key = self._autodecoder_mapping[z_uids]
return self.autodecoder_embeddings[key, :]
def __iter__(self):
return self._autodecoder_mapping.keys()
def keys(self) -> KeysView[Hashable]:
"""
lists the identifiers of each code
"""
return self._autodecoder_mapping.keys()
def values(self) -> ValuesView[Tensor]:
return list(self.autodecoder_embeddings)
def items(self) -> ItemsView[Hashable, Tensor]:
"""
lists all the learned codes / latent vectors with their identifiers as keys
"""
return {
k : self.autodecoder_embeddings[i]
for k, i in self._autodecoder_mapping.items()
}.items()
class EncoderModuleMixin(RequiresConditioner, ABC):
@property
def latent_embeddings(self) -> None:
return None