Add code
This commit is contained in:
96
ifield/utils/operators/diff.py
Normal file
96
ifield/utils/operators/diff.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
|
||||
|
||||
def hessian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
|
||||
"""
|
||||
hessian of y wrt x
|
||||
y: shape (..., Y)
|
||||
x: shape (..., X)
|
||||
return: shape (..., Y, X, X)
|
||||
"""
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
|
||||
grad_y = torch.ones_like(y[..., 0]).to(y.device) # reuse -> less memory
|
||||
|
||||
hess = torch.stack([
|
||||
# calculate hessian on y for each x value
|
||||
torch.stack(
|
||||
gradients(
|
||||
*(dydx[..., j] for j in range(x.shape[-1])),
|
||||
wrt=x,
|
||||
grad_outputs=[grad_y]*x.shape[-1],
|
||||
detach=detach,
|
||||
),
|
||||
dim = -2,
|
||||
)
|
||||
# calculate dydx over batches for each feature value of y
|
||||
for dydx in gradients(*(y[..., i] for i in range(y.shape[-1])), wrt=x)
|
||||
], dim=-3)
|
||||
|
||||
if check:
|
||||
assert hess.isnan().any()
|
||||
return hess
|
||||
|
||||
def laplace(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
return divergence(*gradients(y, wrt=x), x)
|
||||
|
||||
def divergence(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
return sum(
|
||||
grad(
|
||||
y[..., i],
|
||||
x,
|
||||
torch.ones_like(y[..., i]),
|
||||
create_graph=True
|
||||
)[0][..., i:i+1]
|
||||
for i in range(y.shape[-1])
|
||||
)
|
||||
|
||||
def gradients(*ys, wrt, grad_outputs=None, detach=False) -> list[torch.Tensor]:
|
||||
assert wrt.requires_grad
|
||||
assert all(y.grad_fn for y in ys)
|
||||
if grad_outputs is None:
|
||||
grad_outputs = [torch.ones_like(y) for y in ys]
|
||||
|
||||
grads = (
|
||||
grad(
|
||||
[y],
|
||||
[wrt],
|
||||
grad_outputs=y_grad,
|
||||
create_graph=True,
|
||||
)[0]
|
||||
for y, y_grad in zip(ys, grad_outputs)
|
||||
)
|
||||
if detach:
|
||||
grads = map(torch.detach, grads)
|
||||
|
||||
return [*grads]
|
||||
|
||||
def jacobian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
|
||||
"""
|
||||
jacobian of `y` w.r.t. `x`
|
||||
|
||||
y: shape (..., Y)
|
||||
x: shape (..., X)
|
||||
return: shape (..., Y, X)
|
||||
"""
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
|
||||
y_grad = torch.ones_like(y[..., 0])
|
||||
jac = torch.stack(
|
||||
gradients(
|
||||
*(y[..., i] for i in range(y.shape[-1])),
|
||||
wrt=x,
|
||||
grad_outputs=[y_grad]*x.shape[-1],
|
||||
detach=detach,
|
||||
),
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
if check:
|
||||
assert jac.isnan().any()
|
||||
return jac
|
||||
Reference in New Issue
Block a user