This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

25
ifield/modules/siren.py Normal file
View File

@@ -0,0 +1,25 @@
from math import sqrt
from torch import nn
import torch
class Sine(nn.Module):
def __init__(self, omega_0: float):
super().__init__()
self.omega_0 = omega_0
def forward(self, input):
if self.omega_0 == 1:
return torch.sin(input)
else:
return torch.sin(input * self.omega_0)
def init_weights_(module: nn.Linear, omega_0: float, is_first: bool = True):
assert isinstance(module, nn.Linear), module
with torch.no_grad():
mag = (
1 / module.in_features
if is_first else
sqrt(6 / module.in_features) / omega_0
)
module.weight.uniform_(-mag, mag)