Add code
This commit is contained in:
590
ifield/utils/loss.py
Normal file
590
ifield/utils/loss.py
Normal file
@@ -0,0 +1,590 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
from functools import wraps
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.artist import Artist
|
||||
from tabulate import tabulate
|
||||
from torch import nn
|
||||
from typing import Optional, TypeVar, Union
|
||||
import inspect
|
||||
import math
|
||||
import pytorch_lightning as pl
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
|
||||
HParamSchedule = TypeVar("HParamSchedule", bound="HParamScheduleBase")
|
||||
Schedulable = Union[HParamSchedule, int, float, str]
|
||||
|
||||
class HParamScheduleBase(ABC):
|
||||
_subclasses = {} # shared reference intended
|
||||
def __init_subclass__(cls):
|
||||
if not cls.__name__.startswith("_"):
|
||||
cls._subclasses[cls.__name__] = cls
|
||||
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=None)
|
||||
_param_name : Optional[str] = field(init=False, repr=False, default=None)
|
||||
_expr : Optional[str] = field(init=False, repr=False, default=None)
|
||||
|
||||
def get(self, module: nn.Module, *, trainer: Optional[pl.Trainer] = None) -> float:
|
||||
if module.training:
|
||||
if trainer is None:
|
||||
trainer = module.trainer # this assumes `module` is a pl.LightningModule
|
||||
value = self.get_train_value(
|
||||
epoch = trainer.current_epoch + (trainer.fit_loop.epoch_loop.batch_progress.current.processed / trainer.num_training_batches),
|
||||
)
|
||||
if trainer.logger is not None and self._param_name is not None and self.__class__ is not Const and trainer.global_step % 15 == 0:
|
||||
trainer.logger.log_metrics({
|
||||
f"HParamSchedule/{self._param_name}": value,
|
||||
}, step=trainer.global_step)
|
||||
return value
|
||||
else:
|
||||
return self.get_eval_value()
|
||||
|
||||
def _gen_data(self, n_epochs, steps_per_epoch=1000):
|
||||
global_steps = 0
|
||||
for epoch in range(n_epochs):
|
||||
for step in range(steps_per_epoch):
|
||||
yield (
|
||||
epoch + step/steps_per_epoch,
|
||||
self.get_train_value(epoch + step/steps_per_epoch),
|
||||
)
|
||||
global_steps += steps_per_epoch
|
||||
|
||||
def plot(self, *a, ax: Optional[plt.Axes] = None, **kw) -> Artist:
|
||||
if ax is None: ax = plt.gca()
|
||||
out = ax.plot(*zip(*self._gen_data(*a, **kw)), label=self._expr)
|
||||
ax.set_title(self._param_name)
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Value")
|
||||
ax.legend()
|
||||
return out
|
||||
|
||||
def assert_positive(self, *a, **kw):
|
||||
for epoch, val in self._gen_data(*a, **kw):
|
||||
assert val >= 0, f"{epoch=}, {val=}"
|
||||
|
||||
@abstractmethod
|
||||
def get_eval_value(self) -> float:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
...
|
||||
|
||||
def __add__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "+":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __radd__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "+":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __sub__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rsub__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "*":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmul__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "*":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "@":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "@":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "/":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rtruediv__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "/":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "//":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rfloordiv__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "//":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __mod__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "%":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmod__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "%":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __pow__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "**":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rpow__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "**":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __lshift__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<<":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rlshift__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<<":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __rshift__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">>":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rrshift__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">>":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __and__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "&":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rand__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "&":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __xor__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "^":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rxor__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "^":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __or__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "|":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __ror__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "|":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">=":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<=":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __neg__(self):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(0, self)
|
||||
return NotImplemented
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def parse_dsl(config: Schedulable, name=None) -> HParamSchedule:
|
||||
if isinstance(config, HParamScheduleBase):
|
||||
return config
|
||||
elif isinstance(config, str):
|
||||
out = eval(config, {"__builtins__": {}, "lg": math.log10}, HParamScheduleBase._subclasses)
|
||||
if not isinstance(out, HParamScheduleBase):
|
||||
out = Const(out)
|
||||
else:
|
||||
out = Const(config)
|
||||
out._expr = config
|
||||
out._param_name = name
|
||||
return out
|
||||
|
||||
|
||||
# decorator
|
||||
def ensure_schedulables(func):
|
||||
signature = inspect.signature(func)
|
||||
module_name = func.__qualname__.removesuffix(".__init__")
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*a, **kw):
|
||||
bound_args = signature.bind(*a, **kw)
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
type_origin = typing.get_origin(param.annotation)
|
||||
type_args = typing.get_args (param.annotation)
|
||||
|
||||
if type_origin is HParamSchedule or (type_origin is Union and (HParamSchedule in type_args or HParamScheduleBase in type_args)):
|
||||
if param_name in bound_args.arguments:
|
||||
bound_args.arguments[param_name] = parse_dsl(bound_args.arguments[param_name], name=f"{module_name}.{param_name}")
|
||||
elif param.default is not param.empty:
|
||||
bound_args.arguments[param_name] = parse_dsl(param.default, name=f"{module_name}.{param_name}")
|
||||
|
||||
return func(
|
||||
*bound_args.args,
|
||||
**bound_args.kwargs,
|
||||
)
|
||||
return wrapper
|
||||
|
||||
# https://easings.net/
|
||||
|
||||
@dataclass
|
||||
class _InfixBase(HParamScheduleBase):
|
||||
l : Union[HParamSchedule, int, float]
|
||||
r : Union[HParamSchedule, int, float]
|
||||
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self._operation(
|
||||
self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) else self.l,
|
||||
self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) else self.r,
|
||||
)
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self._operation(
|
||||
self.l.get_train_value(epoch) if isinstance(self.l, HParamScheduleBase) else self.l,
|
||||
self.r.get_train_value(epoch) if isinstance(self.r, HParamScheduleBase) else self.r,
|
||||
)
|
||||
|
||||
def __bool__(self):
|
||||
if self.is_const:
|
||||
return bool(self.get_eval_value())
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return (self.l.is_const if isinstance(self.l, HParamScheduleBase) else True) \
|
||||
and (self.r.is_const if isinstance(self.r, HParamScheduleBase) else True)
|
||||
|
||||
@dataclass
|
||||
class Add(_InfixBase):
|
||||
""" adds the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="+")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l + r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sub(_InfixBase):
|
||||
""" subtracts the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="-")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l - r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Prod(_InfixBase):
|
||||
""" multiplies the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="*")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l * r
|
||||
@property
|
||||
def is_const(self) -> bool: # propagate identity
|
||||
l = self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) and self.l.is_const else self.l
|
||||
r = self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) and self.r.is_const else self.r
|
||||
return l == 0 or r == 0 or super().is_const
|
||||
|
||||
|
||||
@dataclass
|
||||
class Div(_InfixBase):
|
||||
""" divides the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="/")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l / r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pow(_InfixBase):
|
||||
""" raises the results of one schedule to the other """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="**")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l ** r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gt(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=">")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l > r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lt(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="<")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l < r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ge(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=">=")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l >= r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Le(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="<=")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l <= r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Const(HParamScheduleBase):
|
||||
""" A way to ensure .get(...) exists """
|
||||
|
||||
c : Union[int, float]
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.c
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.c
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.get_eval_value())
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return True
|
||||
|
||||
@dataclass
|
||||
class Step(HParamScheduleBase):
|
||||
""" steps from 0 to 1 at `epoch` """
|
||||
|
||||
epoch : float
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return 1 if epoch >= self.epoch else 0
|
||||
|
||||
@dataclass
|
||||
class Linear(HParamScheduleBase):
|
||||
""" linear from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
if self.n_epochs <= 0: return 1
|
||||
return min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
|
||||
@dataclass
|
||||
class EaseSin(HParamScheduleBase): # effectively 1-CosineAnnealing
|
||||
""" sinusoidal ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
return -(math.cos(math.pi * x) - 1) / 2
|
||||
|
||||
@dataclass
|
||||
class EaseExp(HParamScheduleBase):
|
||||
""" exponential ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
if (epoch-self.offset) < 0:
|
||||
return 0
|
||||
if (epoch-self.offset) > self.n_epochs:
|
||||
return 1
|
||||
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
return (
|
||||
2**(20*x-10) / 2
|
||||
if x < 0.5 else
|
||||
(2 - 2**(-20*x+10)) / 2
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Steps(HParamScheduleBase):
|
||||
""" Starts at 1, multiply by gamma every n epochs. Models StepLR in pytorch """
|
||||
step_size: float
|
||||
gamma: float = 0.1
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.gamma**int(epoch / self.step_size)
|
||||
|
||||
@dataclass
|
||||
class MultiStep(HParamScheduleBase):
|
||||
""" Starts at 1, multiply by gamma every milstone epoch. Models MultiStepLR in pytorch """
|
||||
milestones: list[float]
|
||||
gamma: float = 0.1
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
for i, m in list(enumerate(self.milestones))[::-1]:
|
||||
if epoch >= m:
|
||||
return self.gamma**(i+1)
|
||||
return 1
|
||||
|
||||
@dataclass
|
||||
class Epoch(HParamScheduleBase):
|
||||
""" The current epoch, starting at 0 """
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 0
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return epoch
|
||||
|
||||
@dataclass
|
||||
class Offset(HParamScheduleBase):
|
||||
""" Offsets the epoch for the subexpression, clamped above 0. Positive offsets makes it happen later """
|
||||
expr : Union[HParamSchedule, int, float]
|
||||
offset : float
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.expr.get_train_value(max(epoch - self.offset, 0)) if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
|
||||
@dataclass
|
||||
class Mod(HParamScheduleBase):
|
||||
""" The epoch in the subexptression is subject to a modulus. Use for warm restarts """
|
||||
|
||||
modulus : float
|
||||
expr : Union[HParamSchedule, int, float]
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.expr.get_train_value(epoch % self.modulus) if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
|
||||
|
||||
def main():
|
||||
import sys, rich.pretty
|
||||
if not sys.argv[2:]:
|
||||
print(f"Usage: {sys.argv[0]} n_epochs 'expression'")
|
||||
print("Available operations:")
|
||||
def mk_ops():
|
||||
for name, cls in HParamScheduleBase._subclasses.items():
|
||||
if isinstance(cls._infix, str):
|
||||
yield (cls._infix, f"(infix) {cls.__doc__.strip()}")
|
||||
else:
|
||||
yield (
|
||||
f"""{name}({', '.join(
|
||||
i.name
|
||||
if i.default is MISSING else
|
||||
f"{i.name}={i.default!r}"
|
||||
for i in fields(cls)
|
||||
)})""",
|
||||
cls.__doc__.strip(),
|
||||
)
|
||||
rich.print(tabulate(sorted(mk_ops()), tablefmt="plain"))
|
||||
else:
|
||||
n_epochs = int(sys.argv[1])
|
||||
schedules = [parse_dsl(arg, name="cli arg") for arg in sys.argv[2:]]
|
||||
ax = plt.gca()
|
||||
print("[")
|
||||
for schedule in schedules:
|
||||
rich.print(f" {schedule}, # {schedule.is_const = }")
|
||||
schedule.plot(n_epochs, ax=ax)
|
||||
print("]")
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user