This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

View File

@@ -0,0 +1,96 @@
import torch
from torch.autograd import grad
def hessian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
"""
hessian of y wrt x
y: shape (..., Y)
x: shape (..., X)
return: shape (..., Y, X, X)
"""
assert x.requires_grad
assert y.grad_fn
grad_y = torch.ones_like(y[..., 0]).to(y.device) # reuse -> less memory
hess = torch.stack([
# calculate hessian on y for each x value
torch.stack(
gradients(
*(dydx[..., j] for j in range(x.shape[-1])),
wrt=x,
grad_outputs=[grad_y]*x.shape[-1],
detach=detach,
),
dim = -2,
)
# calculate dydx over batches for each feature value of y
for dydx in gradients(*(y[..., i] for i in range(y.shape[-1])), wrt=x)
], dim=-3)
if check:
assert hess.isnan().any()
return hess
def laplace(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return divergence(*gradients(y, wrt=x), x)
def divergence(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
assert x.requires_grad
assert y.grad_fn
return sum(
grad(
y[..., i],
x,
torch.ones_like(y[..., i]),
create_graph=True
)[0][..., i:i+1]
for i in range(y.shape[-1])
)
def gradients(*ys, wrt, grad_outputs=None, detach=False) -> list[torch.Tensor]:
assert wrt.requires_grad
assert all(y.grad_fn for y in ys)
if grad_outputs is None:
grad_outputs = [torch.ones_like(y) for y in ys]
grads = (
grad(
[y],
[wrt],
grad_outputs=y_grad,
create_graph=True,
)[0]
for y, y_grad in zip(ys, grad_outputs)
)
if detach:
grads = map(torch.detach, grads)
return [*grads]
def jacobian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
"""
jacobian of `y` w.r.t. `x`
y: shape (..., Y)
x: shape (..., X)
return: shape (..., Y, X)
"""
assert x.requires_grad
assert y.grad_fn
y_grad = torch.ones_like(y[..., 0])
jac = torch.stack(
gradients(
*(y[..., i] for i in range(y.shape[-1])),
wrt=x,
grad_outputs=[y_grad]*x.shape[-1],
detach=detach,
),
dim=-2,
)
if check:
assert jac.isnan().any()
return jac