Add code
This commit is contained in:
22
ifield/modules/dtype.py
Normal file
22
ifield/modules/dtype.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class DtypeMixin:
|
||||
def __init_subclass__(cls):
|
||||
assert issubclass(cls, pl.LightningModule), \
|
||||
f"{cls.__name__!r} is not a subclass of 'pytorch_lightning.LightningModule'!"
|
||||
|
||||
@property
|
||||
def device_and_dtype(self) -> dict:
|
||||
"""
|
||||
Examples:
|
||||
```
|
||||
torch.tensor(1337, **self.device_and_dtype)
|
||||
some_tensor.to(**self.device_and_dtype)
|
||||
```
|
||||
"""
|
||||
|
||||
return {
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
Reference in New Issue
Block a user