This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

0
ifield/utils/__init__.py Normal file
View File

197
ifield/utils/geometry.py Normal file
View File

@@ -0,0 +1,197 @@
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, Literal
import torch
from .helpers import compose
def get_ray_origins(cam2world: Tensor):
return cam2world[..., :3, 3]
def camera_uv_to_rays(
cam2world : Tensor,
uv : Tensor,
intrinsics : Tensor,
) -> tuple[Tensor, Tensor]:
"""
Computes rays and origins from batched cam2world & intrinsics matrices, as well as pixel coordinates
cam2world: (..., 4, 4)
intrinsics: (..., 3, 3)
uv: (..., n, 2)
"""
ray_dirs = get_ray_directions(uv, cam2world=cam2world, intrinsics=intrinsics)
ray_origins = get_ray_origins(cam2world)
ray_origins = ray_origins[..., None, :].expand([*uv.shape[:-1], 3])
return ray_origins, ray_dirs
RayEmbedding = Literal[
"plucker", # LFN
"perp_foot", # PRIF
"both",
]
@compose(torch.cat, dim=-1)
@compose(tuple)
def ray_input_embedding(ray_origins: Tensor, ray_dirs: Tensor, mode: RayEmbedding = "plucker", normalize_dirs=False, is_training=False):
"""
Computes the plucker coordinates / perpendicular foot from ray origins and directions, appending it to direction
"""
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
f"{ray_dirs.shape = }, {ray_origins.shape = }"
if normalize_dirs:
ray_dirs = ray_dirs / ray_dirs.norm(dim=-1, keepdim=True)
yield ray_dirs
do_moment = mode in ("plucker", "both")
do_perp_feet = mode in ("perp_foot", "both")
assert do_moment or do_perp_feet
moment = torch.cross(ray_origins, ray_dirs, dim=-1)
if do_moment:
yield moment
if do_perp_feet:
perp_feet = torch.cross(ray_dirs, moment, dim=-1)
yield perp_feet
def ray_input_embedding_length(mode: RayEmbedding = "plucker") -> int:
do_moment = mode in ("plucker", "both")
do_perp_feet = mode in ("perp_foot", "both")
assert do_moment or do_perp_feet
out = 3 # ray_dirs
if do_moment:
out += 3 # moment
if do_perp_feet:
out += 3 # perp foot
return out
def parse_intrinsics(intrinsics, return_dict=False):
fx = intrinsics[..., 0, 0:1]
fy = intrinsics[..., 1, 1:2]
cx = intrinsics[..., 0, 2:3]
cy = intrinsics[..., 1, 2:3]
if return_dict:
return {"fx": fx, "fy": fy, "cx": cx, "cy": cy}
else:
return fx, fy, cx, cy
def expand_as(x, y):
if len(x.shape) == len(y.shape):
return x
for i in range(len(y.shape) - len(x.shape)):
x = x.unsqueeze(-1)
return x
def lift(x, y, z, intrinsics, homogeneous=False):
"""
:param self:
:param x: Shape (batch_size, num_points)
:param y:
:param z:
:param intrinsics:
:return:
"""
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
if homogeneous:
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
else:
return torch.stack((x_lift, y_lift, z), dim=-1)
def project(x, y, z, intrinsics):
"""
:param self:
:param x: Shape (batch_size, num_points)
:param y:
:param z:
:param intrinsics:
:return:
"""
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_proj = expand_as(fx, x) * x / z + expand_as(cx, x)
y_proj = expand_as(fy, y) * y / z + expand_as(cy, y)
return torch.stack((x_proj, y_proj, z), dim=-1)
def world_from_xy_depth(xy, depth, cam2world, intrinsics):
batch_size, *_ = cam2world.shape
x_cam = xy[..., 0]
y_cam = xy[..., 1]
z_cam = depth
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True)
world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[..., :3]
return world_coords
def project_point_on_ray(projection_point, ray_origin, ray_dir):
dot = torch.einsum("...j,...j", projection_point-ray_origin, ray_dir)
return ray_origin + dot[..., None] * ray_dir
def get_ray_directions(
xy : Tensor, # (..., N, 2)
cam2world : Tensor, # (..., 4, 4)
intrinsics : Tensor, # (..., 3, 3)
):
z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3)
cam_pos = cam2world[..., :3, 3]
ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
ray_dirs = F.normalize(ray_dirs, dim=-1)
return ray_dirs
def ray_sphere_intersect(
ray_origins : Tensor, # (..., 3)
ray_dirs : Tensor, # (..., 3)
sphere_centers : Optional[Tensor] = None, # (..., 3)
sphere_radii : Optional[Tensor] = 1, # (...)
*,
return_parts : bool = False,
allow_nans : bool = True,
improve_miss_grads : bool = False,
) -> tuple[Tensor, ...]:
if improve_miss_grads: assert not allow_nans, "improve_miss_grads does not work with allow_nans"
if sphere_centers is None:
ray_origins_centered = ray_origins #- torch.zeros_like(ray_origins)
else:
ray_origins_centered = ray_origins - sphere_centers
ray_dir_dot_origins = (ray_dirs * ray_origins_centered).sum(dim=-1, keepdim=True)
discriminants2 = ray_dir_dot_origins**2 - ((ray_origins_centered * ray_origins_centered).sum(dim=-1) - sphere_radii**2)[..., None]
if not allow_nans or return_parts:
is_intersecting = discriminants2 > 0
if allow_nans:
discriminants = torch.sqrt(discriminants2)
else:
discriminants = torch.sqrt(torch.where(is_intersecting, discriminants2,
discriminants2 - discriminants2.detach() + 0.001
if improve_miss_grads else
torch.zeros_like(discriminants2)
))
assert not discriminants.detach().isnan().any() # slow, use optimizations!
if not return_parts:
return (
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
)
else:
return (
ray_origins + ray_dirs * (- ray_dir_dot_origins),
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
is_intersecting.squeeze(-1),
)

205
ifield/utils/helpers.py Normal file
View File

@@ -0,0 +1,205 @@
from functools import wraps, reduce, partial
from itertools import zip_longest, groupby
from pathlib import Path
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
import collections
import operator
import re
Numeric = Union[int, float, complex]
T = TypeVar("T")
S = TypeVar("S")
# decorator
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
def wrapper(inner_func: Callable[..., S]):
@wraps(inner_func)
def wrapped(*a, **kw):
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
return wrapped
return wrapper
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
def wrapper(inner_func: Callable[..., S]):
@wraps(inner_func)
def wrapped(*a, **kw):
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
return wrapped
return wrapper
# itertools
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
if initial is not None:
return reduce(operator.mul, numbers, initial)
else:
return reduce(operator.mul, numbers)
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
return (
(x, len(y))
for x, y in groupby(data)
)
# text conversion
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
parts = (
part.lower()
for part in re.split(r'(?=[A-Z])', text)
if part
)
if join_abbreviations:
parts = list(parts)
if len(parts) > 1:
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
if len(a) == len(b) == 1:
parts[i] = parts[i] + parts.pop(i+1)
return sep.join(parts)
def snake_to_camel_case(text: str) -> str:
return "".join(
part.captialize()
for part in text.split("_")
if part
)
# textwrap
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
sub = (len(data) + 1) // n_columns
return reduce(partial(columnize, sep=sep),
(
columnize(
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
)
for n in range(n_columns)
)
)
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
left = left .split("\n")
right = right.split("\n")
width = max(map(len, left)) if left else 0
return "\n".join(
f"{prefix}{a.ljust(width)}{sep}{b}"
if b else
f"{prefix}{a}"
for a, b in zip_longest(left, right, fillvalue="")
)
# pathlib
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
if isinstance(path, str):
path = Path(path)
if parent is None:
parent = Path.cwd()
try:
return path.relative_to(parent)
except ValueError:
pass
try:
return ".." / path.relative_to(parent.parent)
except ValueError:
pass
return path
# dictionaries
def update_recursive(target: dict, source: dict):
""" Update two config dictionaries recursively. """
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = type(target)()
update_recursive(target[k], v)
else:
target[k] = v
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
if isinstance(val, collections.abc.Mapping):
return {
k: map_tree(func, subval)
for k, subval in val.items()
}
elif isinstance(val, tuple):
return tuple(
map_tree(func, subval)
for subval in val
)
elif isinstance(val, list):
return [
map_tree(func, subval)
for subval in val
]
else:
return func(val)
def flatten_tree(val, *, sep=".", prefix=None):
if isinstance(val, collections.abc.Mapping):
return {
k: v
for subkey, subval in val.items()
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
}
elif isinstance(val, tuple) or isinstance(val, list):
return {
k: v
for index, subval in enumerate(val)
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
}
elif prefix:
return {prefix: val}
else:
return val
# conversions
def hex2tuple(data: str) -> tuple[int]:
data = data.removeprefix("#")
return (*(
int(data[i:i+2], 16)
for i in range(0, len(data), 2)
),)
# repr shims
class CustomRepr:
def __init__(self, repr_str: str):
self.repr_str = repr_str
def __str__(self):
return self.repr_str
def __repr__(self):
return self.repr_str
# Meta Params Module proxy
class MetaModuleProxy:
def __init__(self, module, params):
self._module = module
self._params = params
def __getattr__(self, name):
params = super().__getattribute__("_params")
if name in params:
return params[name]
else:
return getattr(super().__getattribute__("_module"), name)
def __setattr__(self, name, value):
if name not in ("_params", "_module"):
super().__getattribute__("_params")[name] = value
else:
super().__setattr__(name, value)

590
ifield/utils/loss.py Normal file
View File

@@ -0,0 +1,590 @@
from abc import abstractmethod, ABC
from dataclasses import dataclass, field, fields, MISSING
from functools import wraps
from matplotlib import pyplot as plt
from matplotlib.artist import Artist
from tabulate import tabulate
from torch import nn
from typing import Optional, TypeVar, Union
import inspect
import math
import pytorch_lightning as pl
import typing
import warnings
HParamSchedule = TypeVar("HParamSchedule", bound="HParamScheduleBase")
Schedulable = Union[HParamSchedule, int, float, str]
class HParamScheduleBase(ABC):
_subclasses = {} # shared reference intended
def __init_subclass__(cls):
if not cls.__name__.startswith("_"):
cls._subclasses[cls.__name__] = cls
_infix : Optional[str] = field(init=False, repr=False, default=None)
_param_name : Optional[str] = field(init=False, repr=False, default=None)
_expr : Optional[str] = field(init=False, repr=False, default=None)
def get(self, module: nn.Module, *, trainer: Optional[pl.Trainer] = None) -> float:
if module.training:
if trainer is None:
trainer = module.trainer # this assumes `module` is a pl.LightningModule
value = self.get_train_value(
epoch = trainer.current_epoch + (trainer.fit_loop.epoch_loop.batch_progress.current.processed / trainer.num_training_batches),
)
if trainer.logger is not None and self._param_name is not None and self.__class__ is not Const and trainer.global_step % 15 == 0:
trainer.logger.log_metrics({
f"HParamSchedule/{self._param_name}": value,
}, step=trainer.global_step)
return value
else:
return self.get_eval_value()
def _gen_data(self, n_epochs, steps_per_epoch=1000):
global_steps = 0
for epoch in range(n_epochs):
for step in range(steps_per_epoch):
yield (
epoch + step/steps_per_epoch,
self.get_train_value(epoch + step/steps_per_epoch),
)
global_steps += steps_per_epoch
def plot(self, *a, ax: Optional[plt.Axes] = None, **kw) -> Artist:
if ax is None: ax = plt.gca()
out = ax.plot(*zip(*self._gen_data(*a, **kw)), label=self._expr)
ax.set_title(self._param_name)
ax.set_xlabel("Epoch")
ax.set_ylabel("Value")
ax.legend()
return out
def assert_positive(self, *a, **kw):
for epoch, val in self._gen_data(*a, **kw):
assert val >= 0, f"{epoch=}, {val=}"
@abstractmethod
def get_eval_value(self) -> float:
...
@abstractmethod
def get_train_value(self, epoch: float) -> float:
...
def __add__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "+":
return cls(self, rhs)
return NotImplemented
def __radd__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "+":
return cls(lhs, self)
return NotImplemented
def __sub__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(self, rhs)
return NotImplemented
def __rsub__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(lhs, self)
return NotImplemented
def __mul__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "*":
return cls(self, rhs)
return NotImplemented
def __rmul__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "*":
return cls(lhs, self)
return NotImplemented
def __matmul__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "@":
return cls(self, rhs)
return NotImplemented
def __rmatmul__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "@":
return cls(lhs, self)
return NotImplemented
def __truediv__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "/":
return cls(self, rhs)
return NotImplemented
def __rtruediv__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "/":
return cls(lhs, self)
return NotImplemented
def __floordiv__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "//":
return cls(self, rhs)
return NotImplemented
def __rfloordiv__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "//":
return cls(lhs, self)
return NotImplemented
def __mod__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "%":
return cls(self, rhs)
return NotImplemented
def __rmod__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "%":
return cls(lhs, self)
return NotImplemented
def __pow__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "**":
return cls(self, rhs)
return NotImplemented
def __rpow__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "**":
return cls(lhs, self)
return NotImplemented
def __lshift__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<<":
return cls(self, rhs)
return NotImplemented
def __rlshift__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "<<":
return cls(lhs, self)
return NotImplemented
def __rshift__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">>":
return cls(self, rhs)
return NotImplemented
def __rrshift__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == ">>":
return cls(lhs, self)
return NotImplemented
def __and__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "&":
return cls(self, rhs)
return NotImplemented
def __rand__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "&":
return cls(lhs, self)
return NotImplemented
def __xor__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "^":
return cls(self, rhs)
return NotImplemented
def __rxor__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "^":
return cls(lhs, self)
return NotImplemented
def __or__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "|":
return cls(self, rhs)
return NotImplemented
def __ror__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "|":
return cls(lhs, self)
return NotImplemented
def __ge__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">=":
return cls(self, rhs)
return NotImplemented
def __gt__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">":
return cls(self, rhs)
return NotImplemented
def __le__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<=":
return cls(self, rhs)
return NotImplemented
def __lt__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<":
return cls(self, rhs)
return NotImplemented
def __bool__(self):
return True
def __neg__(self):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(0, self)
return NotImplemented
@property
def is_const(self) -> bool:
return False
def parse_dsl(config: Schedulable, name=None) -> HParamSchedule:
if isinstance(config, HParamScheduleBase):
return config
elif isinstance(config, str):
out = eval(config, {"__builtins__": {}, "lg": math.log10}, HParamScheduleBase._subclasses)
if not isinstance(out, HParamScheduleBase):
out = Const(out)
else:
out = Const(config)
out._expr = config
out._param_name = name
return out
# decorator
def ensure_schedulables(func):
signature = inspect.signature(func)
module_name = func.__qualname__.removesuffix(".__init__")
@wraps(func)
def wrapper(*a, **kw):
bound_args = signature.bind(*a, **kw)
for param_name, param in signature.parameters.items():
type_origin = typing.get_origin(param.annotation)
type_args = typing.get_args (param.annotation)
if type_origin is HParamSchedule or (type_origin is Union and (HParamSchedule in type_args or HParamScheduleBase in type_args)):
if param_name in bound_args.arguments:
bound_args.arguments[param_name] = parse_dsl(bound_args.arguments[param_name], name=f"{module_name}.{param_name}")
elif param.default is not param.empty:
bound_args.arguments[param_name] = parse_dsl(param.default, name=f"{module_name}.{param_name}")
return func(
*bound_args.args,
**bound_args.kwargs,
)
return wrapper
# https://easings.net/
@dataclass
class _InfixBase(HParamScheduleBase):
l : Union[HParamSchedule, int, float]
r : Union[HParamSchedule, int, float]
def _operation(self, l: float, r: float) -> float:
raise NotImplementedError
def get_eval_value(self) -> float:
return self._operation(
self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) else self.l,
self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) else self.r,
)
def get_train_value(self, epoch: float) -> float:
return self._operation(
self.l.get_train_value(epoch) if isinstance(self.l, HParamScheduleBase) else self.l,
self.r.get_train_value(epoch) if isinstance(self.r, HParamScheduleBase) else self.r,
)
def __bool__(self):
if self.is_const:
return bool(self.get_eval_value())
else:
return True
@property
def is_const(self) -> bool:
return (self.l.is_const if isinstance(self.l, HParamScheduleBase) else True) \
and (self.r.is_const if isinstance(self.r, HParamScheduleBase) else True)
@dataclass
class Add(_InfixBase):
""" adds the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="+")
def _operation(self, l: float, r: float) -> float:
return l + r
@dataclass
class Sub(_InfixBase):
""" subtracts the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="-")
def _operation(self, l: float, r: float) -> float:
return l - r
@dataclass
class Prod(_InfixBase):
""" multiplies the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="*")
def _operation(self, l: float, r: float) -> float:
return l * r
@property
def is_const(self) -> bool: # propagate identity
l = self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) and self.l.is_const else self.l
r = self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) and self.r.is_const else self.r
return l == 0 or r == 0 or super().is_const
@dataclass
class Div(_InfixBase):
""" divides the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="/")
def _operation(self, l: float, r: float) -> float:
return l / r
@dataclass
class Pow(_InfixBase):
""" raises the results of one schedule to the other """
_infix : Optional[str] = field(init=False, repr=False, default="**")
def _operation(self, l: float, r: float) -> float:
return l ** r
@dataclass
class Gt(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default=">")
def _operation(self, l: float, r: float) -> float:
return l > r
@dataclass
class Lt(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="<")
def _operation(self, l: float, r: float) -> float:
return l < r
@dataclass
class Ge(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default=">=")
def _operation(self, l: float, r: float) -> float:
return l >= r
@dataclass
class Le(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="<=")
def _operation(self, l: float, r: float) -> float:
return l <= r
@dataclass
class Const(HParamScheduleBase):
""" A way to ensure .get(...) exists """
c : Union[int, float]
def get_eval_value(self) -> float:
return self.c
def get_train_value(self, epoch: float) -> float:
return self.c
def __bool__(self):
return bool(self.get_eval_value())
@property
def is_const(self) -> bool:
return True
@dataclass
class Step(HParamScheduleBase):
""" steps from 0 to 1 at `epoch` """
epoch : float
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
return 1 if epoch >= self.epoch else 0
@dataclass
class Linear(HParamScheduleBase):
""" linear from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
if self.n_epochs <= 0: return 1
return min(max(epoch - self.offset, 0) / self.n_epochs, 1)
@dataclass
class EaseSin(HParamScheduleBase): # effectively 1-CosineAnnealing
""" sinusoidal ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
return -(math.cos(math.pi * x) - 1) / 2
@dataclass
class EaseExp(HParamScheduleBase):
""" exponential ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
if (epoch-self.offset) < 0:
return 0
if (epoch-self.offset) > self.n_epochs:
return 1
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
return (
2**(20*x-10) / 2
if x < 0.5 else
(2 - 2**(-20*x+10)) / 2
)
@dataclass
class Steps(HParamScheduleBase):
""" Starts at 1, multiply by gamma every n epochs. Models StepLR in pytorch """
step_size: float
gamma: float = 0.1
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
return self.gamma**int(epoch / self.step_size)
@dataclass
class MultiStep(HParamScheduleBase):
""" Starts at 1, multiply by gamma every milstone epoch. Models MultiStepLR in pytorch """
milestones: list[float]
gamma: float = 0.1
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
for i, m in list(enumerate(self.milestones))[::-1]:
if epoch >= m:
return self.gamma**(i+1)
return 1
@dataclass
class Epoch(HParamScheduleBase):
""" The current epoch, starting at 0 """
def get_eval_value(self) -> float:
return 0
def get_train_value(self, epoch: float) -> float:
return epoch
@dataclass
class Offset(HParamScheduleBase):
""" Offsets the epoch for the subexpression, clamped above 0. Positive offsets makes it happen later """
expr : Union[HParamSchedule, int, float]
offset : float
def get_eval_value(self) -> float:
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
def get_train_value(self, epoch: float) -> float:
return self.expr.get_train_value(max(epoch - self.offset, 0)) if isinstance(self.expr, HParamScheduleBase) else self.expr
@dataclass
class Mod(HParamScheduleBase):
""" The epoch in the subexptression is subject to a modulus. Use for warm restarts """
modulus : float
expr : Union[HParamSchedule, int, float]
def get_eval_value(self) -> float:
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
def get_train_value(self, epoch: float) -> float:
return self.expr.get_train_value(epoch % self.modulus) if isinstance(self.expr, HParamScheduleBase) else self.expr
def main():
import sys, rich.pretty
if not sys.argv[2:]:
print(f"Usage: {sys.argv[0]} n_epochs 'expression'")
print("Available operations:")
def mk_ops():
for name, cls in HParamScheduleBase._subclasses.items():
if isinstance(cls._infix, str):
yield (cls._infix, f"(infix) {cls.__doc__.strip()}")
else:
yield (
f"""{name}({', '.join(
i.name
if i.default is MISSING else
f"{i.name}={i.default!r}"
for i in fields(cls)
)})""",
cls.__doc__.strip(),
)
rich.print(tabulate(sorted(mk_ops()), tablefmt="plain"))
else:
n_epochs = int(sys.argv[1])
schedules = [parse_dsl(arg, name="cli arg") for arg in sys.argv[2:]]
ax = plt.gca()
print("[")
for schedule in schedules:
rich.print(f" {schedule}, # {schedule.is_const = }")
schedule.plot(n_epochs, ax=ax)
print("]")
plt.show()
if __name__ == "__main__":
main()

View File

View File

@@ -0,0 +1,96 @@
import torch
from torch.autograd import grad
def hessian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
"""
hessian of y wrt x
y: shape (..., Y)
x: shape (..., X)
return: shape (..., Y, X, X)
"""
assert x.requires_grad
assert y.grad_fn
grad_y = torch.ones_like(y[..., 0]).to(y.device) # reuse -> less memory
hess = torch.stack([
# calculate hessian on y for each x value
torch.stack(
gradients(
*(dydx[..., j] for j in range(x.shape[-1])),
wrt=x,
grad_outputs=[grad_y]*x.shape[-1],
detach=detach,
),
dim = -2,
)
# calculate dydx over batches for each feature value of y
for dydx in gradients(*(y[..., i] for i in range(y.shape[-1])), wrt=x)
], dim=-3)
if check:
assert hess.isnan().any()
return hess
def laplace(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return divergence(*gradients(y, wrt=x), x)
def divergence(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
assert x.requires_grad
assert y.grad_fn
return sum(
grad(
y[..., i],
x,
torch.ones_like(y[..., i]),
create_graph=True
)[0][..., i:i+1]
for i in range(y.shape[-1])
)
def gradients(*ys, wrt, grad_outputs=None, detach=False) -> list[torch.Tensor]:
assert wrt.requires_grad
assert all(y.grad_fn for y in ys)
if grad_outputs is None:
grad_outputs = [torch.ones_like(y) for y in ys]
grads = (
grad(
[y],
[wrt],
grad_outputs=y_grad,
create_graph=True,
)[0]
for y, y_grad in zip(ys, grad_outputs)
)
if detach:
grads = map(torch.detach, grads)
return [*grads]
def jacobian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
"""
jacobian of `y` w.r.t. `x`
y: shape (..., Y)
x: shape (..., X)
return: shape (..., Y, X)
"""
assert x.requires_grad
assert y.grad_fn
y_grad = torch.ones_like(y[..., 0])
jac = torch.stack(
gradients(
*(y[..., i] for i in range(y.shape[-1])),
wrt=x,
grad_outputs=[y_grad]*x.shape[-1],
detach=detach,
),
dim=-2,
)
if check:
assert jac.isnan().any()
return jac