Add code
This commit is contained in:
25
ifield/modules/siren.py
Normal file
25
ifield/modules/siren.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from math import sqrt
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
class Sine(nn.Module):
|
||||
def __init__(self, omega_0: float):
|
||||
super().__init__()
|
||||
self.omega_0 = omega_0
|
||||
|
||||
def forward(self, input):
|
||||
if self.omega_0 == 1:
|
||||
return torch.sin(input)
|
||||
else:
|
||||
return torch.sin(input * self.omega_0)
|
||||
|
||||
|
||||
def init_weights_(module: nn.Linear, omega_0: float, is_first: bool = True):
|
||||
assert isinstance(module, nn.Linear), module
|
||||
with torch.no_grad():
|
||||
mag = (
|
||||
1 / module.in_features
|
||||
if is_first else
|
||||
sqrt(6 / module.in_features) / omega_0
|
||||
)
|
||||
module.weight.uniform_(-mag, mag)
|
||||
Reference in New Issue
Block a user