Files
marf/ifield/utils/geometry.py
2025-01-09 15:43:11 +01:00

198 lines
6.4 KiB
Python

from torch import Tensor
from torch.nn import functional as F
from typing import Optional, Literal
import torch
from .helpers import compose
def get_ray_origins(cam2world: Tensor):
return cam2world[..., :3, 3]
def camera_uv_to_rays(
cam2world : Tensor,
uv : Tensor,
intrinsics : Tensor,
) -> tuple[Tensor, Tensor]:
"""
Computes rays and origins from batched cam2world & intrinsics matrices, as well as pixel coordinates
cam2world: (..., 4, 4)
intrinsics: (..., 3, 3)
uv: (..., n, 2)
"""
ray_dirs = get_ray_directions(uv, cam2world=cam2world, intrinsics=intrinsics)
ray_origins = get_ray_origins(cam2world)
ray_origins = ray_origins[..., None, :].expand([*uv.shape[:-1], 3])
return ray_origins, ray_dirs
RayEmbedding = Literal[
"plucker", # LFN
"perp_foot", # PRIF
"both",
]
@compose(torch.cat, dim=-1)
@compose(tuple)
def ray_input_embedding(ray_origins: Tensor, ray_dirs: Tensor, mode: RayEmbedding = "plucker", normalize_dirs=False, is_training=False):
"""
Computes the plucker coordinates / perpendicular foot from ray origins and directions, appending it to direction
"""
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
f"{ray_dirs.shape = }, {ray_origins.shape = }"
if normalize_dirs:
ray_dirs = ray_dirs / ray_dirs.norm(dim=-1, keepdim=True)
yield ray_dirs
do_moment = mode in ("plucker", "both")
do_perp_feet = mode in ("perp_foot", "both")
assert do_moment or do_perp_feet
moment = torch.cross(ray_origins, ray_dirs, dim=-1)
if do_moment:
yield moment
if do_perp_feet:
perp_feet = torch.cross(ray_dirs, moment, dim=-1)
yield perp_feet
def ray_input_embedding_length(mode: RayEmbedding = "plucker") -> int:
do_moment = mode in ("plucker", "both")
do_perp_feet = mode in ("perp_foot", "both")
assert do_moment or do_perp_feet
out = 3 # ray_dirs
if do_moment:
out += 3 # moment
if do_perp_feet:
out += 3 # perp foot
return out
def parse_intrinsics(intrinsics, return_dict=False):
fx = intrinsics[..., 0, 0:1]
fy = intrinsics[..., 1, 1:2]
cx = intrinsics[..., 0, 2:3]
cy = intrinsics[..., 1, 2:3]
if return_dict:
return {"fx": fx, "fy": fy, "cx": cx, "cy": cy}
else:
return fx, fy, cx, cy
def expand_as(x, y):
if len(x.shape) == len(y.shape):
return x
for i in range(len(y.shape) - len(x.shape)):
x = x.unsqueeze(-1)
return x
def lift(x, y, z, intrinsics, homogeneous=False):
"""
:param self:
:param x: Shape (batch_size, num_points)
:param y:
:param z:
:param intrinsics:
:return:
"""
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
if homogeneous:
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
else:
return torch.stack((x_lift, y_lift, z), dim=-1)
def project(x, y, z, intrinsics):
"""
:param self:
:param x: Shape (batch_size, num_points)
:param y:
:param z:
:param intrinsics:
:return:
"""
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_proj = expand_as(fx, x) * x / z + expand_as(cx, x)
y_proj = expand_as(fy, y) * y / z + expand_as(cy, y)
return torch.stack((x_proj, y_proj, z), dim=-1)
def world_from_xy_depth(xy, depth, cam2world, intrinsics):
batch_size, *_ = cam2world.shape
x_cam = xy[..., 0]
y_cam = xy[..., 1]
z_cam = depth
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True)
world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[..., :3]
return world_coords
def project_point_on_ray(projection_point, ray_origin, ray_dir):
dot = torch.einsum("...j,...j", projection_point-ray_origin, ray_dir)
return ray_origin + dot[..., None] * ray_dir
def get_ray_directions(
xy : Tensor, # (..., N, 2)
cam2world : Tensor, # (..., 4, 4)
intrinsics : Tensor, # (..., 3, 3)
):
z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3)
cam_pos = cam2world[..., :3, 3]
ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
ray_dirs = F.normalize(ray_dirs, dim=-1)
return ray_dirs
def ray_sphere_intersect(
ray_origins : Tensor, # (..., 3)
ray_dirs : Tensor, # (..., 3)
sphere_centers : Optional[Tensor] = None, # (..., 3)
sphere_radii : Optional[Tensor] = 1, # (...)
*,
return_parts : bool = False,
allow_nans : bool = True,
improve_miss_grads : bool = False,
) -> tuple[Tensor, ...]:
if improve_miss_grads: assert not allow_nans, "improve_miss_grads does not work with allow_nans"
if sphere_centers is None:
ray_origins_centered = ray_origins #- torch.zeros_like(ray_origins)
else:
ray_origins_centered = ray_origins - sphere_centers
ray_dir_dot_origins = (ray_dirs * ray_origins_centered).sum(dim=-1, keepdim=True)
discriminants2 = ray_dir_dot_origins**2 - ((ray_origins_centered * ray_origins_centered).sum(dim=-1) - sphere_radii**2)[..., None]
if not allow_nans or return_parts:
is_intersecting = discriminants2 > 0
if allow_nans:
discriminants = torch.sqrt(discriminants2)
else:
discriminants = torch.sqrt(torch.where(is_intersecting, discriminants2,
discriminants2 - discriminants2.detach() + 0.001
if improve_miss_grads else
torch.zeros_like(discriminants2)
))
assert not discriminants.detach().isnan().any() # slow, use optimizations!
if not return_parts:
return (
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
)
else:
return (
ray_origins + ray_dirs * (- ray_dir_dot_origins),
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
is_intersecting.squeeze(-1),
)