Add code
This commit is contained in:
0
ifield/utils/__init__.py
Normal file
0
ifield/utils/__init__.py
Normal file
197
ifield/utils/geometry.py
Normal file
197
ifield/utils/geometry.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, Literal
|
||||
import torch
|
||||
from .helpers import compose
|
||||
|
||||
|
||||
def get_ray_origins(cam2world: Tensor):
|
||||
return cam2world[..., :3, 3]
|
||||
|
||||
def camera_uv_to_rays(
|
||||
cam2world : Tensor,
|
||||
uv : Tensor,
|
||||
intrinsics : Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Computes rays and origins from batched cam2world & intrinsics matrices, as well as pixel coordinates
|
||||
cam2world: (..., 4, 4)
|
||||
intrinsics: (..., 3, 3)
|
||||
uv: (..., n, 2)
|
||||
"""
|
||||
ray_dirs = get_ray_directions(uv, cam2world=cam2world, intrinsics=intrinsics)
|
||||
ray_origins = get_ray_origins(cam2world)
|
||||
ray_origins = ray_origins[..., None, :].expand([*uv.shape[:-1], 3])
|
||||
return ray_origins, ray_dirs
|
||||
|
||||
RayEmbedding = Literal[
|
||||
"plucker", # LFN
|
||||
"perp_foot", # PRIF
|
||||
"both",
|
||||
]
|
||||
|
||||
@compose(torch.cat, dim=-1)
|
||||
@compose(tuple)
|
||||
def ray_input_embedding(ray_origins: Tensor, ray_dirs: Tensor, mode: RayEmbedding = "plucker", normalize_dirs=False, is_training=False):
|
||||
"""
|
||||
Computes the plucker coordinates / perpendicular foot from ray origins and directions, appending it to direction
|
||||
"""
|
||||
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
|
||||
f"{ray_dirs.shape = }, {ray_origins.shape = }"
|
||||
|
||||
if normalize_dirs:
|
||||
ray_dirs = ray_dirs / ray_dirs.norm(dim=-1, keepdim=True)
|
||||
|
||||
yield ray_dirs
|
||||
|
||||
do_moment = mode in ("plucker", "both")
|
||||
do_perp_feet = mode in ("perp_foot", "both")
|
||||
assert do_moment or do_perp_feet
|
||||
|
||||
moment = torch.cross(ray_origins, ray_dirs, dim=-1)
|
||||
if do_moment:
|
||||
yield moment
|
||||
|
||||
if do_perp_feet:
|
||||
perp_feet = torch.cross(ray_dirs, moment, dim=-1)
|
||||
yield perp_feet
|
||||
|
||||
def ray_input_embedding_length(mode: RayEmbedding = "plucker") -> int:
|
||||
do_moment = mode in ("plucker", "both")
|
||||
do_perp_feet = mode in ("perp_foot", "both")
|
||||
assert do_moment or do_perp_feet
|
||||
|
||||
out = 3 # ray_dirs
|
||||
if do_moment:
|
||||
out += 3 # moment
|
||||
if do_perp_feet:
|
||||
out += 3 # perp foot
|
||||
return out
|
||||
|
||||
def parse_intrinsics(intrinsics, return_dict=False):
|
||||
fx = intrinsics[..., 0, 0:1]
|
||||
fy = intrinsics[..., 1, 1:2]
|
||||
cx = intrinsics[..., 0, 2:3]
|
||||
cy = intrinsics[..., 1, 2:3]
|
||||
if return_dict:
|
||||
return {"fx": fx, "fy": fy, "cx": cx, "cy": cy}
|
||||
else:
|
||||
return fx, fy, cx, cy
|
||||
|
||||
def expand_as(x, y):
|
||||
if len(x.shape) == len(y.shape):
|
||||
return x
|
||||
|
||||
for i in range(len(y.shape) - len(x.shape)):
|
||||
x = x.unsqueeze(-1)
|
||||
|
||||
return x
|
||||
|
||||
def lift(x, y, z, intrinsics, homogeneous=False):
|
||||
"""
|
||||
|
||||
:param self:
|
||||
:param x: Shape (batch_size, num_points)
|
||||
:param y:
|
||||
:param z:
|
||||
:param intrinsics:
|
||||
:return:
|
||||
"""
|
||||
fx, fy, cx, cy = parse_intrinsics(intrinsics)
|
||||
|
||||
x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
|
||||
y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
|
||||
|
||||
if homogeneous:
|
||||
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
|
||||
else:
|
||||
return torch.stack((x_lift, y_lift, z), dim=-1)
|
||||
|
||||
def project(x, y, z, intrinsics):
|
||||
"""
|
||||
|
||||
:param self:
|
||||
:param x: Shape (batch_size, num_points)
|
||||
:param y:
|
||||
:param z:
|
||||
:param intrinsics:
|
||||
:return:
|
||||
"""
|
||||
fx, fy, cx, cy = parse_intrinsics(intrinsics)
|
||||
|
||||
x_proj = expand_as(fx, x) * x / z + expand_as(cx, x)
|
||||
y_proj = expand_as(fy, y) * y / z + expand_as(cy, y)
|
||||
|
||||
return torch.stack((x_proj, y_proj, z), dim=-1)
|
||||
|
||||
def world_from_xy_depth(xy, depth, cam2world, intrinsics):
|
||||
batch_size, *_ = cam2world.shape
|
||||
|
||||
x_cam = xy[..., 0]
|
||||
y_cam = xy[..., 1]
|
||||
z_cam = depth
|
||||
|
||||
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True)
|
||||
world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[..., :3]
|
||||
|
||||
return world_coords
|
||||
|
||||
def project_point_on_ray(projection_point, ray_origin, ray_dir):
|
||||
dot = torch.einsum("...j,...j", projection_point-ray_origin, ray_dir)
|
||||
return ray_origin + dot[..., None] * ray_dir
|
||||
|
||||
def get_ray_directions(
|
||||
xy : Tensor, # (..., N, 2)
|
||||
cam2world : Tensor, # (..., 4, 4)
|
||||
intrinsics : Tensor, # (..., 3, 3)
|
||||
):
|
||||
z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
|
||||
pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3)
|
||||
|
||||
cam_pos = cam2world[..., :3, 3]
|
||||
ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
|
||||
ray_dirs = F.normalize(ray_dirs, dim=-1)
|
||||
return ray_dirs
|
||||
|
||||
def ray_sphere_intersect(
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
sphere_centers : Optional[Tensor] = None, # (..., 3)
|
||||
sphere_radii : Optional[Tensor] = 1, # (...)
|
||||
*,
|
||||
return_parts : bool = False,
|
||||
allow_nans : bool = True,
|
||||
improve_miss_grads : bool = False,
|
||||
) -> tuple[Tensor, ...]:
|
||||
if improve_miss_grads: assert not allow_nans, "improve_miss_grads does not work with allow_nans"
|
||||
if sphere_centers is None:
|
||||
ray_origins_centered = ray_origins #- torch.zeros_like(ray_origins)
|
||||
else:
|
||||
ray_origins_centered = ray_origins - sphere_centers
|
||||
|
||||
ray_dir_dot_origins = (ray_dirs * ray_origins_centered).sum(dim=-1, keepdim=True)
|
||||
discriminants2 = ray_dir_dot_origins**2 - ((ray_origins_centered * ray_origins_centered).sum(dim=-1) - sphere_radii**2)[..., None]
|
||||
if not allow_nans or return_parts:
|
||||
is_intersecting = discriminants2 > 0
|
||||
if allow_nans:
|
||||
discriminants = torch.sqrt(discriminants2)
|
||||
else:
|
||||
discriminants = torch.sqrt(torch.where(is_intersecting, discriminants2,
|
||||
discriminants2 - discriminants2.detach() + 0.001
|
||||
if improve_miss_grads else
|
||||
torch.zeros_like(discriminants2)
|
||||
))
|
||||
assert not discriminants.detach().isnan().any() # slow, use optimizations!
|
||||
|
||||
if not return_parts:
|
||||
return (
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins),
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
|
||||
is_intersecting.squeeze(-1),
|
||||
)
|
||||
205
ifield/utils/helpers.py
Normal file
205
ifield/utils/helpers.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from functools import wraps, reduce, partial
|
||||
from itertools import zip_longest, groupby
|
||||
from pathlib import Path
|
||||
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
|
||||
import collections
|
||||
import operator
|
||||
import re
|
||||
|
||||
Numeric = Union[int, float, complex]
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
# decorator
|
||||
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
||||
def wrapper(inner_func: Callable[..., S]):
|
||||
@wraps(inner_func)
|
||||
def wrapped(*a, **kw):
|
||||
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
||||
def wrapper(inner_func: Callable[..., S]):
|
||||
@wraps(inner_func)
|
||||
def wrapped(*a, **kw):
|
||||
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
|
||||
# itertools
|
||||
|
||||
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
|
||||
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
|
||||
|
||||
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return reduce(operator.mul, numbers)
|
||||
|
||||
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
|
||||
return (
|
||||
(x, len(y))
|
||||
for x, y in groupby(data)
|
||||
)
|
||||
|
||||
|
||||
# text conversion
|
||||
|
||||
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
||||
parts = (
|
||||
part.lower()
|
||||
for part in re.split(r'(?=[A-Z])', text)
|
||||
if part
|
||||
)
|
||||
if join_abbreviations:
|
||||
parts = list(parts)
|
||||
if len(parts) > 1:
|
||||
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
||||
if len(a) == len(b) == 1:
|
||||
parts[i] = parts[i] + parts.pop(i+1)
|
||||
return sep.join(parts)
|
||||
|
||||
def snake_to_camel_case(text: str) -> str:
|
||||
return "".join(
|
||||
part.captialize()
|
||||
for part in text.split("_")
|
||||
if part
|
||||
)
|
||||
|
||||
|
||||
# textwrap
|
||||
|
||||
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
|
||||
sub = (len(data) + 1) // n_columns
|
||||
return reduce(partial(columnize, sep=sep),
|
||||
(
|
||||
columnize(
|
||||
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
|
||||
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
|
||||
)
|
||||
for n in range(n_columns)
|
||||
)
|
||||
)
|
||||
|
||||
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
|
||||
left = left .split("\n")
|
||||
right = right.split("\n")
|
||||
width = max(map(len, left)) if left else 0
|
||||
return "\n".join(
|
||||
f"{prefix}{a.ljust(width)}{sep}{b}"
|
||||
if b else
|
||||
f"{prefix}{a}"
|
||||
for a, b in zip_longest(left, right, fillvalue="")
|
||||
)
|
||||
|
||||
|
||||
# pathlib
|
||||
|
||||
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if parent is None:
|
||||
parent = Path.cwd()
|
||||
try:
|
||||
return path.relative_to(parent)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return ".." / path.relative_to(parent.parent)
|
||||
except ValueError:
|
||||
pass
|
||||
return path
|
||||
|
||||
|
||||
# dictionaries
|
||||
|
||||
def update_recursive(target: dict, source: dict):
|
||||
""" Update two config dictionaries recursively. """
|
||||
for k, v in source.items():
|
||||
if isinstance(v, dict):
|
||||
if k not in target:
|
||||
target[k] = type(target)()
|
||||
update_recursive(target[k], v)
|
||||
else:
|
||||
target[k] = v
|
||||
|
||||
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
|
||||
if isinstance(val, collections.abc.Mapping):
|
||||
return {
|
||||
k: map_tree(func, subval)
|
||||
for k, subval in val.items()
|
||||
}
|
||||
elif isinstance(val, tuple):
|
||||
return tuple(
|
||||
map_tree(func, subval)
|
||||
for subval in val
|
||||
)
|
||||
elif isinstance(val, list):
|
||||
return [
|
||||
map_tree(func, subval)
|
||||
for subval in val
|
||||
]
|
||||
else:
|
||||
return func(val)
|
||||
|
||||
def flatten_tree(val, *, sep=".", prefix=None):
|
||||
if isinstance(val, collections.abc.Mapping):
|
||||
return {
|
||||
k: v
|
||||
for subkey, subval in val.items()
|
||||
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
|
||||
}
|
||||
elif isinstance(val, tuple) or isinstance(val, list):
|
||||
return {
|
||||
k: v
|
||||
for index, subval in enumerate(val)
|
||||
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
|
||||
}
|
||||
elif prefix:
|
||||
return {prefix: val}
|
||||
else:
|
||||
return val
|
||||
|
||||
# conversions
|
||||
|
||||
def hex2tuple(data: str) -> tuple[int]:
|
||||
data = data.removeprefix("#")
|
||||
return (*(
|
||||
int(data[i:i+2], 16)
|
||||
for i in range(0, len(data), 2)
|
||||
),)
|
||||
|
||||
|
||||
# repr shims
|
||||
|
||||
class CustomRepr:
|
||||
def __init__(self, repr_str: str):
|
||||
self.repr_str = repr_str
|
||||
def __str__(self):
|
||||
return self.repr_str
|
||||
def __repr__(self):
|
||||
return self.repr_str
|
||||
|
||||
|
||||
# Meta Params Module proxy
|
||||
|
||||
class MetaModuleProxy:
|
||||
def __init__(self, module, params):
|
||||
self._module = module
|
||||
self._params = params
|
||||
|
||||
def __getattr__(self, name):
|
||||
params = super().__getattribute__("_params")
|
||||
if name in params:
|
||||
return params[name]
|
||||
else:
|
||||
return getattr(super().__getattribute__("_module"), name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name not in ("_params", "_module"):
|
||||
super().__getattribute__("_params")[name] = value
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
590
ifield/utils/loss.py
Normal file
590
ifield/utils/loss.py
Normal file
@@ -0,0 +1,590 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
from functools import wraps
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.artist import Artist
|
||||
from tabulate import tabulate
|
||||
from torch import nn
|
||||
from typing import Optional, TypeVar, Union
|
||||
import inspect
|
||||
import math
|
||||
import pytorch_lightning as pl
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
|
||||
HParamSchedule = TypeVar("HParamSchedule", bound="HParamScheduleBase")
|
||||
Schedulable = Union[HParamSchedule, int, float, str]
|
||||
|
||||
class HParamScheduleBase(ABC):
|
||||
_subclasses = {} # shared reference intended
|
||||
def __init_subclass__(cls):
|
||||
if not cls.__name__.startswith("_"):
|
||||
cls._subclasses[cls.__name__] = cls
|
||||
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=None)
|
||||
_param_name : Optional[str] = field(init=False, repr=False, default=None)
|
||||
_expr : Optional[str] = field(init=False, repr=False, default=None)
|
||||
|
||||
def get(self, module: nn.Module, *, trainer: Optional[pl.Trainer] = None) -> float:
|
||||
if module.training:
|
||||
if trainer is None:
|
||||
trainer = module.trainer # this assumes `module` is a pl.LightningModule
|
||||
value = self.get_train_value(
|
||||
epoch = trainer.current_epoch + (trainer.fit_loop.epoch_loop.batch_progress.current.processed / trainer.num_training_batches),
|
||||
)
|
||||
if trainer.logger is not None and self._param_name is not None and self.__class__ is not Const and trainer.global_step % 15 == 0:
|
||||
trainer.logger.log_metrics({
|
||||
f"HParamSchedule/{self._param_name}": value,
|
||||
}, step=trainer.global_step)
|
||||
return value
|
||||
else:
|
||||
return self.get_eval_value()
|
||||
|
||||
def _gen_data(self, n_epochs, steps_per_epoch=1000):
|
||||
global_steps = 0
|
||||
for epoch in range(n_epochs):
|
||||
for step in range(steps_per_epoch):
|
||||
yield (
|
||||
epoch + step/steps_per_epoch,
|
||||
self.get_train_value(epoch + step/steps_per_epoch),
|
||||
)
|
||||
global_steps += steps_per_epoch
|
||||
|
||||
def plot(self, *a, ax: Optional[plt.Axes] = None, **kw) -> Artist:
|
||||
if ax is None: ax = plt.gca()
|
||||
out = ax.plot(*zip(*self._gen_data(*a, **kw)), label=self._expr)
|
||||
ax.set_title(self._param_name)
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Value")
|
||||
ax.legend()
|
||||
return out
|
||||
|
||||
def assert_positive(self, *a, **kw):
|
||||
for epoch, val in self._gen_data(*a, **kw):
|
||||
assert val >= 0, f"{epoch=}, {val=}"
|
||||
|
||||
@abstractmethod
|
||||
def get_eval_value(self) -> float:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
...
|
||||
|
||||
def __add__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "+":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __radd__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "+":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __sub__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rsub__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "*":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmul__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "*":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "@":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "@":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "/":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rtruediv__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "/":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "//":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rfloordiv__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "//":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __mod__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "%":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmod__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "%":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __pow__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "**":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rpow__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "**":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __lshift__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<<":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rlshift__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<<":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __rshift__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">>":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rrshift__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">>":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __and__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "&":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rand__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "&":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __xor__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "^":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rxor__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "^":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __or__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "|":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __ror__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "|":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">=":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<=":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __neg__(self):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(0, self)
|
||||
return NotImplemented
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def parse_dsl(config: Schedulable, name=None) -> HParamSchedule:
|
||||
if isinstance(config, HParamScheduleBase):
|
||||
return config
|
||||
elif isinstance(config, str):
|
||||
out = eval(config, {"__builtins__": {}, "lg": math.log10}, HParamScheduleBase._subclasses)
|
||||
if not isinstance(out, HParamScheduleBase):
|
||||
out = Const(out)
|
||||
else:
|
||||
out = Const(config)
|
||||
out._expr = config
|
||||
out._param_name = name
|
||||
return out
|
||||
|
||||
|
||||
# decorator
|
||||
def ensure_schedulables(func):
|
||||
signature = inspect.signature(func)
|
||||
module_name = func.__qualname__.removesuffix(".__init__")
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*a, **kw):
|
||||
bound_args = signature.bind(*a, **kw)
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
type_origin = typing.get_origin(param.annotation)
|
||||
type_args = typing.get_args (param.annotation)
|
||||
|
||||
if type_origin is HParamSchedule or (type_origin is Union and (HParamSchedule in type_args or HParamScheduleBase in type_args)):
|
||||
if param_name in bound_args.arguments:
|
||||
bound_args.arguments[param_name] = parse_dsl(bound_args.arguments[param_name], name=f"{module_name}.{param_name}")
|
||||
elif param.default is not param.empty:
|
||||
bound_args.arguments[param_name] = parse_dsl(param.default, name=f"{module_name}.{param_name}")
|
||||
|
||||
return func(
|
||||
*bound_args.args,
|
||||
**bound_args.kwargs,
|
||||
)
|
||||
return wrapper
|
||||
|
||||
# https://easings.net/
|
||||
|
||||
@dataclass
|
||||
class _InfixBase(HParamScheduleBase):
|
||||
l : Union[HParamSchedule, int, float]
|
||||
r : Union[HParamSchedule, int, float]
|
||||
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self._operation(
|
||||
self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) else self.l,
|
||||
self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) else self.r,
|
||||
)
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self._operation(
|
||||
self.l.get_train_value(epoch) if isinstance(self.l, HParamScheduleBase) else self.l,
|
||||
self.r.get_train_value(epoch) if isinstance(self.r, HParamScheduleBase) else self.r,
|
||||
)
|
||||
|
||||
def __bool__(self):
|
||||
if self.is_const:
|
||||
return bool(self.get_eval_value())
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return (self.l.is_const if isinstance(self.l, HParamScheduleBase) else True) \
|
||||
and (self.r.is_const if isinstance(self.r, HParamScheduleBase) else True)
|
||||
|
||||
@dataclass
|
||||
class Add(_InfixBase):
|
||||
""" adds the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="+")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l + r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sub(_InfixBase):
|
||||
""" subtracts the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="-")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l - r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Prod(_InfixBase):
|
||||
""" multiplies the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="*")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l * r
|
||||
@property
|
||||
def is_const(self) -> bool: # propagate identity
|
||||
l = self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) and self.l.is_const else self.l
|
||||
r = self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) and self.r.is_const else self.r
|
||||
return l == 0 or r == 0 or super().is_const
|
||||
|
||||
|
||||
@dataclass
|
||||
class Div(_InfixBase):
|
||||
""" divides the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="/")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l / r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pow(_InfixBase):
|
||||
""" raises the results of one schedule to the other """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="**")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l ** r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gt(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=">")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l > r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lt(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="<")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l < r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ge(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=">=")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l >= r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Le(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="<=")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l <= r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Const(HParamScheduleBase):
|
||||
""" A way to ensure .get(...) exists """
|
||||
|
||||
c : Union[int, float]
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.c
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.c
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.get_eval_value())
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return True
|
||||
|
||||
@dataclass
|
||||
class Step(HParamScheduleBase):
|
||||
""" steps from 0 to 1 at `epoch` """
|
||||
|
||||
epoch : float
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return 1 if epoch >= self.epoch else 0
|
||||
|
||||
@dataclass
|
||||
class Linear(HParamScheduleBase):
|
||||
""" linear from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
if self.n_epochs <= 0: return 1
|
||||
return min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
|
||||
@dataclass
|
||||
class EaseSin(HParamScheduleBase): # effectively 1-CosineAnnealing
|
||||
""" sinusoidal ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
return -(math.cos(math.pi * x) - 1) / 2
|
||||
|
||||
@dataclass
|
||||
class EaseExp(HParamScheduleBase):
|
||||
""" exponential ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
if (epoch-self.offset) < 0:
|
||||
return 0
|
||||
if (epoch-self.offset) > self.n_epochs:
|
||||
return 1
|
||||
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
return (
|
||||
2**(20*x-10) / 2
|
||||
if x < 0.5 else
|
||||
(2 - 2**(-20*x+10)) / 2
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Steps(HParamScheduleBase):
|
||||
""" Starts at 1, multiply by gamma every n epochs. Models StepLR in pytorch """
|
||||
step_size: float
|
||||
gamma: float = 0.1
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.gamma**int(epoch / self.step_size)
|
||||
|
||||
@dataclass
|
||||
class MultiStep(HParamScheduleBase):
|
||||
""" Starts at 1, multiply by gamma every milstone epoch. Models MultiStepLR in pytorch """
|
||||
milestones: list[float]
|
||||
gamma: float = 0.1
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
for i, m in list(enumerate(self.milestones))[::-1]:
|
||||
if epoch >= m:
|
||||
return self.gamma**(i+1)
|
||||
return 1
|
||||
|
||||
@dataclass
|
||||
class Epoch(HParamScheduleBase):
|
||||
""" The current epoch, starting at 0 """
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 0
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return epoch
|
||||
|
||||
@dataclass
|
||||
class Offset(HParamScheduleBase):
|
||||
""" Offsets the epoch for the subexpression, clamped above 0. Positive offsets makes it happen later """
|
||||
expr : Union[HParamSchedule, int, float]
|
||||
offset : float
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.expr.get_train_value(max(epoch - self.offset, 0)) if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
|
||||
@dataclass
|
||||
class Mod(HParamScheduleBase):
|
||||
""" The epoch in the subexptression is subject to a modulus. Use for warm restarts """
|
||||
|
||||
modulus : float
|
||||
expr : Union[HParamSchedule, int, float]
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.expr.get_train_value(epoch % self.modulus) if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
|
||||
|
||||
def main():
|
||||
import sys, rich.pretty
|
||||
if not sys.argv[2:]:
|
||||
print(f"Usage: {sys.argv[0]} n_epochs 'expression'")
|
||||
print("Available operations:")
|
||||
def mk_ops():
|
||||
for name, cls in HParamScheduleBase._subclasses.items():
|
||||
if isinstance(cls._infix, str):
|
||||
yield (cls._infix, f"(infix) {cls.__doc__.strip()}")
|
||||
else:
|
||||
yield (
|
||||
f"""{name}({', '.join(
|
||||
i.name
|
||||
if i.default is MISSING else
|
||||
f"{i.name}={i.default!r}"
|
||||
for i in fields(cls)
|
||||
)})""",
|
||||
cls.__doc__.strip(),
|
||||
)
|
||||
rich.print(tabulate(sorted(mk_ops()), tablefmt="plain"))
|
||||
else:
|
||||
n_epochs = int(sys.argv[1])
|
||||
schedules = [parse_dsl(arg, name="cli arg") for arg in sys.argv[2:]]
|
||||
ax = plt.gca()
|
||||
print("[")
|
||||
for schedule in schedules:
|
||||
rich.print(f" {schedule}, # {schedule.is_const = }")
|
||||
schedule.plot(n_epochs, ax=ax)
|
||||
print("]")
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
ifield/utils/operators/__init__.py
Normal file
0
ifield/utils/operators/__init__.py
Normal file
96
ifield/utils/operators/diff.py
Normal file
96
ifield/utils/operators/diff.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
|
||||
|
||||
def hessian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
|
||||
"""
|
||||
hessian of y wrt x
|
||||
y: shape (..., Y)
|
||||
x: shape (..., X)
|
||||
return: shape (..., Y, X, X)
|
||||
"""
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
|
||||
grad_y = torch.ones_like(y[..., 0]).to(y.device) # reuse -> less memory
|
||||
|
||||
hess = torch.stack([
|
||||
# calculate hessian on y for each x value
|
||||
torch.stack(
|
||||
gradients(
|
||||
*(dydx[..., j] for j in range(x.shape[-1])),
|
||||
wrt=x,
|
||||
grad_outputs=[grad_y]*x.shape[-1],
|
||||
detach=detach,
|
||||
),
|
||||
dim = -2,
|
||||
)
|
||||
# calculate dydx over batches for each feature value of y
|
||||
for dydx in gradients(*(y[..., i] for i in range(y.shape[-1])), wrt=x)
|
||||
], dim=-3)
|
||||
|
||||
if check:
|
||||
assert hess.isnan().any()
|
||||
return hess
|
||||
|
||||
def laplace(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
return divergence(*gradients(y, wrt=x), x)
|
||||
|
||||
def divergence(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
return sum(
|
||||
grad(
|
||||
y[..., i],
|
||||
x,
|
||||
torch.ones_like(y[..., i]),
|
||||
create_graph=True
|
||||
)[0][..., i:i+1]
|
||||
for i in range(y.shape[-1])
|
||||
)
|
||||
|
||||
def gradients(*ys, wrt, grad_outputs=None, detach=False) -> list[torch.Tensor]:
|
||||
assert wrt.requires_grad
|
||||
assert all(y.grad_fn for y in ys)
|
||||
if grad_outputs is None:
|
||||
grad_outputs = [torch.ones_like(y) for y in ys]
|
||||
|
||||
grads = (
|
||||
grad(
|
||||
[y],
|
||||
[wrt],
|
||||
grad_outputs=y_grad,
|
||||
create_graph=True,
|
||||
)[0]
|
||||
for y, y_grad in zip(ys, grad_outputs)
|
||||
)
|
||||
if detach:
|
||||
grads = map(torch.detach, grads)
|
||||
|
||||
return [*grads]
|
||||
|
||||
def jacobian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
|
||||
"""
|
||||
jacobian of `y` w.r.t. `x`
|
||||
|
||||
y: shape (..., Y)
|
||||
x: shape (..., X)
|
||||
return: shape (..., Y, X)
|
||||
"""
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
|
||||
y_grad = torch.ones_like(y[..., 0])
|
||||
jac = torch.stack(
|
||||
gradients(
|
||||
*(y[..., i] for i in range(y.shape[-1])),
|
||||
wrt=x,
|
||||
grad_outputs=[y_grad]*x.shape[-1],
|
||||
detach=detach,
|
||||
),
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
if check:
|
||||
assert jac.isnan().any()
|
||||
return jac
|
||||
Reference in New Issue
Block a user