Files
marf/ifield/utils/loss.py
2025-01-09 15:43:11 +01:00

591 lines
18 KiB
Python

from abc import abstractmethod, ABC
from dataclasses import dataclass, field, fields, MISSING
from functools import wraps
from matplotlib import pyplot as plt
from matplotlib.artist import Artist
from tabulate import tabulate
from torch import nn
from typing import Optional, TypeVar, Union
import inspect
import math
import pytorch_lightning as pl
import typing
import warnings
HParamSchedule = TypeVar("HParamSchedule", bound="HParamScheduleBase")
Schedulable = Union[HParamSchedule, int, float, str]
class HParamScheduleBase(ABC):
_subclasses = {} # shared reference intended
def __init_subclass__(cls):
if not cls.__name__.startswith("_"):
cls._subclasses[cls.__name__] = cls
_infix : Optional[str] = field(init=False, repr=False, default=None)
_param_name : Optional[str] = field(init=False, repr=False, default=None)
_expr : Optional[str] = field(init=False, repr=False, default=None)
def get(self, module: nn.Module, *, trainer: Optional[pl.Trainer] = None) -> float:
if module.training:
if trainer is None:
trainer = module.trainer # this assumes `module` is a pl.LightningModule
value = self.get_train_value(
epoch = trainer.current_epoch + (trainer.fit_loop.epoch_loop.batch_progress.current.processed / trainer.num_training_batches),
)
if trainer.logger is not None and self._param_name is not None and self.__class__ is not Const and trainer.global_step % 15 == 0:
trainer.logger.log_metrics({
f"HParamSchedule/{self._param_name}": value,
}, step=trainer.global_step)
return value
else:
return self.get_eval_value()
def _gen_data(self, n_epochs, steps_per_epoch=1000):
global_steps = 0
for epoch in range(n_epochs):
for step in range(steps_per_epoch):
yield (
epoch + step/steps_per_epoch,
self.get_train_value(epoch + step/steps_per_epoch),
)
global_steps += steps_per_epoch
def plot(self, *a, ax: Optional[plt.Axes] = None, **kw) -> Artist:
if ax is None: ax = plt.gca()
out = ax.plot(*zip(*self._gen_data(*a, **kw)), label=self._expr)
ax.set_title(self._param_name)
ax.set_xlabel("Epoch")
ax.set_ylabel("Value")
ax.legend()
return out
def assert_positive(self, *a, **kw):
for epoch, val in self._gen_data(*a, **kw):
assert val >= 0, f"{epoch=}, {val=}"
@abstractmethod
def get_eval_value(self) -> float:
...
@abstractmethod
def get_train_value(self, epoch: float) -> float:
...
def __add__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "+":
return cls(self, rhs)
return NotImplemented
def __radd__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "+":
return cls(lhs, self)
return NotImplemented
def __sub__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(self, rhs)
return NotImplemented
def __rsub__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(lhs, self)
return NotImplemented
def __mul__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "*":
return cls(self, rhs)
return NotImplemented
def __rmul__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "*":
return cls(lhs, self)
return NotImplemented
def __matmul__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "@":
return cls(self, rhs)
return NotImplemented
def __rmatmul__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "@":
return cls(lhs, self)
return NotImplemented
def __truediv__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "/":
return cls(self, rhs)
return NotImplemented
def __rtruediv__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "/":
return cls(lhs, self)
return NotImplemented
def __floordiv__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "//":
return cls(self, rhs)
return NotImplemented
def __rfloordiv__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "//":
return cls(lhs, self)
return NotImplemented
def __mod__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "%":
return cls(self, rhs)
return NotImplemented
def __rmod__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "%":
return cls(lhs, self)
return NotImplemented
def __pow__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "**":
return cls(self, rhs)
return NotImplemented
def __rpow__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "**":
return cls(lhs, self)
return NotImplemented
def __lshift__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<<":
return cls(self, rhs)
return NotImplemented
def __rlshift__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "<<":
return cls(lhs, self)
return NotImplemented
def __rshift__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">>":
return cls(self, rhs)
return NotImplemented
def __rrshift__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == ">>":
return cls(lhs, self)
return NotImplemented
def __and__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "&":
return cls(self, rhs)
return NotImplemented
def __rand__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "&":
return cls(lhs, self)
return NotImplemented
def __xor__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "^":
return cls(self, rhs)
return NotImplemented
def __rxor__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "^":
return cls(lhs, self)
return NotImplemented
def __or__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "|":
return cls(self, rhs)
return NotImplemented
def __ror__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "|":
return cls(lhs, self)
return NotImplemented
def __ge__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">=":
return cls(self, rhs)
return NotImplemented
def __gt__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">":
return cls(self, rhs)
return NotImplemented
def __le__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<=":
return cls(self, rhs)
return NotImplemented
def __lt__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<":
return cls(self, rhs)
return NotImplemented
def __bool__(self):
return True
def __neg__(self):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(0, self)
return NotImplemented
@property
def is_const(self) -> bool:
return False
def parse_dsl(config: Schedulable, name=None) -> HParamSchedule:
if isinstance(config, HParamScheduleBase):
return config
elif isinstance(config, str):
out = eval(config, {"__builtins__": {}, "lg": math.log10}, HParamScheduleBase._subclasses)
if not isinstance(out, HParamScheduleBase):
out = Const(out)
else:
out = Const(config)
out._expr = config
out._param_name = name
return out
# decorator
def ensure_schedulables(func):
signature = inspect.signature(func)
module_name = func.__qualname__.removesuffix(".__init__")
@wraps(func)
def wrapper(*a, **kw):
bound_args = signature.bind(*a, **kw)
for param_name, param in signature.parameters.items():
type_origin = typing.get_origin(param.annotation)
type_args = typing.get_args (param.annotation)
if type_origin is HParamSchedule or (type_origin is Union and (HParamSchedule in type_args or HParamScheduleBase in type_args)):
if param_name in bound_args.arguments:
bound_args.arguments[param_name] = parse_dsl(bound_args.arguments[param_name], name=f"{module_name}.{param_name}")
elif param.default is not param.empty:
bound_args.arguments[param_name] = parse_dsl(param.default, name=f"{module_name}.{param_name}")
return func(
*bound_args.args,
**bound_args.kwargs,
)
return wrapper
# https://easings.net/
@dataclass
class _InfixBase(HParamScheduleBase):
l : Union[HParamSchedule, int, float]
r : Union[HParamSchedule, int, float]
def _operation(self, l: float, r: float) -> float:
raise NotImplementedError
def get_eval_value(self) -> float:
return self._operation(
self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) else self.l,
self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) else self.r,
)
def get_train_value(self, epoch: float) -> float:
return self._operation(
self.l.get_train_value(epoch) if isinstance(self.l, HParamScheduleBase) else self.l,
self.r.get_train_value(epoch) if isinstance(self.r, HParamScheduleBase) else self.r,
)
def __bool__(self):
if self.is_const:
return bool(self.get_eval_value())
else:
return True
@property
def is_const(self) -> bool:
return (self.l.is_const if isinstance(self.l, HParamScheduleBase) else True) \
and (self.r.is_const if isinstance(self.r, HParamScheduleBase) else True)
@dataclass
class Add(_InfixBase):
""" adds the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="+")
def _operation(self, l: float, r: float) -> float:
return l + r
@dataclass
class Sub(_InfixBase):
""" subtracts the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="-")
def _operation(self, l: float, r: float) -> float:
return l - r
@dataclass
class Prod(_InfixBase):
""" multiplies the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="*")
def _operation(self, l: float, r: float) -> float:
return l * r
@property
def is_const(self) -> bool: # propagate identity
l = self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) and self.l.is_const else self.l
r = self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) and self.r.is_const else self.r
return l == 0 or r == 0 or super().is_const
@dataclass
class Div(_InfixBase):
""" divides the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="/")
def _operation(self, l: float, r: float) -> float:
return l / r
@dataclass
class Pow(_InfixBase):
""" raises the results of one schedule to the other """
_infix : Optional[str] = field(init=False, repr=False, default="**")
def _operation(self, l: float, r: float) -> float:
return l ** r
@dataclass
class Gt(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default=">")
def _operation(self, l: float, r: float) -> float:
return l > r
@dataclass
class Lt(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="<")
def _operation(self, l: float, r: float) -> float:
return l < r
@dataclass
class Ge(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default=">=")
def _operation(self, l: float, r: float) -> float:
return l >= r
@dataclass
class Le(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="<=")
def _operation(self, l: float, r: float) -> float:
return l <= r
@dataclass
class Const(HParamScheduleBase):
""" A way to ensure .get(...) exists """
c : Union[int, float]
def get_eval_value(self) -> float:
return self.c
def get_train_value(self, epoch: float) -> float:
return self.c
def __bool__(self):
return bool(self.get_eval_value())
@property
def is_const(self) -> bool:
return True
@dataclass
class Step(HParamScheduleBase):
""" steps from 0 to 1 at `epoch` """
epoch : float
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
return 1 if epoch >= self.epoch else 0
@dataclass
class Linear(HParamScheduleBase):
""" linear from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
if self.n_epochs <= 0: return 1
return min(max(epoch - self.offset, 0) / self.n_epochs, 1)
@dataclass
class EaseSin(HParamScheduleBase): # effectively 1-CosineAnnealing
""" sinusoidal ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
return -(math.cos(math.pi * x) - 1) / 2
@dataclass
class EaseExp(HParamScheduleBase):
""" exponential ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
if (epoch-self.offset) < 0:
return 0
if (epoch-self.offset) > self.n_epochs:
return 1
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
return (
2**(20*x-10) / 2
if x < 0.5 else
(2 - 2**(-20*x+10)) / 2
)
@dataclass
class Steps(HParamScheduleBase):
""" Starts at 1, multiply by gamma every n epochs. Models StepLR in pytorch """
step_size: float
gamma: float = 0.1
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
return self.gamma**int(epoch / self.step_size)
@dataclass
class MultiStep(HParamScheduleBase):
""" Starts at 1, multiply by gamma every milstone epoch. Models MultiStepLR in pytorch """
milestones: list[float]
gamma: float = 0.1
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
for i, m in list(enumerate(self.milestones))[::-1]:
if epoch >= m:
return self.gamma**(i+1)
return 1
@dataclass
class Epoch(HParamScheduleBase):
""" The current epoch, starting at 0 """
def get_eval_value(self) -> float:
return 0
def get_train_value(self, epoch: float) -> float:
return epoch
@dataclass
class Offset(HParamScheduleBase):
""" Offsets the epoch for the subexpression, clamped above 0. Positive offsets makes it happen later """
expr : Union[HParamSchedule, int, float]
offset : float
def get_eval_value(self) -> float:
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
def get_train_value(self, epoch: float) -> float:
return self.expr.get_train_value(max(epoch - self.offset, 0)) if isinstance(self.expr, HParamScheduleBase) else self.expr
@dataclass
class Mod(HParamScheduleBase):
""" The epoch in the subexptression is subject to a modulus. Use for warm restarts """
modulus : float
expr : Union[HParamSchedule, int, float]
def get_eval_value(self) -> float:
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
def get_train_value(self, epoch: float) -> float:
return self.expr.get_train_value(epoch % self.modulus) if isinstance(self.expr, HParamScheduleBase) else self.expr
def main():
import sys, rich.pretty
if not sys.argv[2:]:
print(f"Usage: {sys.argv[0]} n_epochs 'expression'")
print("Available operations:")
def mk_ops():
for name, cls in HParamScheduleBase._subclasses.items():
if isinstance(cls._infix, str):
yield (cls._infix, f"(infix) {cls.__doc__.strip()}")
else:
yield (
f"""{name}({', '.join(
i.name
if i.default is MISSING else
f"{i.name}={i.default!r}"
for i in fields(cls)
)})""",
cls.__doc__.strip(),
)
rich.print(tabulate(sorted(mk_ops()), tablefmt="plain"))
else:
n_epochs = int(sys.argv[1])
schedules = [parse_dsl(arg, name="cli arg") for arg in sys.argv[2:]]
ax = plt.gca()
print("[")
for schedule in schedules:
rich.print(f" {schedule}, # {schedule.is_const = }")
schedule.plot(n_epochs, ax=ax)
print("]")
plt.show()
if __name__ == "__main__":
main()