This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

22
ifield/modules/dtype.py Normal file
View File

@@ -0,0 +1,22 @@
import pytorch_lightning as pl
class DtypeMixin:
def __init_subclass__(cls):
assert issubclass(cls, pl.LightningModule), \
f"{cls.__name__!r} is not a subclass of 'pytorch_lightning.LightningModule'!"
@property
def device_and_dtype(self) -> dict:
"""
Examples:
```
torch.tensor(1337, **self.device_and_dtype)
some_tensor.to(**self.device_and_dtype)
```
"""
return {
"dtype": self.dtype,
"device": self.device,
}