198 lines
6.4 KiB
Python
198 lines
6.4 KiB
Python
from torch import Tensor
|
|
from torch.nn import functional as F
|
|
from typing import Optional, Literal
|
|
import torch
|
|
from .helpers import compose
|
|
|
|
|
|
def get_ray_origins(cam2world: Tensor):
|
|
return cam2world[..., :3, 3]
|
|
|
|
def camera_uv_to_rays(
|
|
cam2world : Tensor,
|
|
uv : Tensor,
|
|
intrinsics : Tensor,
|
|
) -> tuple[Tensor, Tensor]:
|
|
"""
|
|
Computes rays and origins from batched cam2world & intrinsics matrices, as well as pixel coordinates
|
|
cam2world: (..., 4, 4)
|
|
intrinsics: (..., 3, 3)
|
|
uv: (..., n, 2)
|
|
"""
|
|
ray_dirs = get_ray_directions(uv, cam2world=cam2world, intrinsics=intrinsics)
|
|
ray_origins = get_ray_origins(cam2world)
|
|
ray_origins = ray_origins[..., None, :].expand([*uv.shape[:-1], 3])
|
|
return ray_origins, ray_dirs
|
|
|
|
RayEmbedding = Literal[
|
|
"plucker", # LFN
|
|
"perp_foot", # PRIF
|
|
"both",
|
|
]
|
|
|
|
@compose(torch.cat, dim=-1)
|
|
@compose(tuple)
|
|
def ray_input_embedding(ray_origins: Tensor, ray_dirs: Tensor, mode: RayEmbedding = "plucker", normalize_dirs=False, is_training=False):
|
|
"""
|
|
Computes the plucker coordinates / perpendicular foot from ray origins and directions, appending it to direction
|
|
"""
|
|
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
|
|
f"{ray_dirs.shape = }, {ray_origins.shape = }"
|
|
|
|
if normalize_dirs:
|
|
ray_dirs = ray_dirs / ray_dirs.norm(dim=-1, keepdim=True)
|
|
|
|
yield ray_dirs
|
|
|
|
do_moment = mode in ("plucker", "both")
|
|
do_perp_feet = mode in ("perp_foot", "both")
|
|
assert do_moment or do_perp_feet
|
|
|
|
moment = torch.cross(ray_origins, ray_dirs, dim=-1)
|
|
if do_moment:
|
|
yield moment
|
|
|
|
if do_perp_feet:
|
|
perp_feet = torch.cross(ray_dirs, moment, dim=-1)
|
|
yield perp_feet
|
|
|
|
def ray_input_embedding_length(mode: RayEmbedding = "plucker") -> int:
|
|
do_moment = mode in ("plucker", "both")
|
|
do_perp_feet = mode in ("perp_foot", "both")
|
|
assert do_moment or do_perp_feet
|
|
|
|
out = 3 # ray_dirs
|
|
if do_moment:
|
|
out += 3 # moment
|
|
if do_perp_feet:
|
|
out += 3 # perp foot
|
|
return out
|
|
|
|
def parse_intrinsics(intrinsics, return_dict=False):
|
|
fx = intrinsics[..., 0, 0:1]
|
|
fy = intrinsics[..., 1, 1:2]
|
|
cx = intrinsics[..., 0, 2:3]
|
|
cy = intrinsics[..., 1, 2:3]
|
|
if return_dict:
|
|
return {"fx": fx, "fy": fy, "cx": cx, "cy": cy}
|
|
else:
|
|
return fx, fy, cx, cy
|
|
|
|
def expand_as(x, y):
|
|
if len(x.shape) == len(y.shape):
|
|
return x
|
|
|
|
for i in range(len(y.shape) - len(x.shape)):
|
|
x = x.unsqueeze(-1)
|
|
|
|
return x
|
|
|
|
def lift(x, y, z, intrinsics, homogeneous=False):
|
|
"""
|
|
|
|
:param self:
|
|
:param x: Shape (batch_size, num_points)
|
|
:param y:
|
|
:param z:
|
|
:param intrinsics:
|
|
:return:
|
|
"""
|
|
fx, fy, cx, cy = parse_intrinsics(intrinsics)
|
|
|
|
x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
|
|
y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
|
|
|
|
if homogeneous:
|
|
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
|
|
else:
|
|
return torch.stack((x_lift, y_lift, z), dim=-1)
|
|
|
|
def project(x, y, z, intrinsics):
|
|
"""
|
|
|
|
:param self:
|
|
:param x: Shape (batch_size, num_points)
|
|
:param y:
|
|
:param z:
|
|
:param intrinsics:
|
|
:return:
|
|
"""
|
|
fx, fy, cx, cy = parse_intrinsics(intrinsics)
|
|
|
|
x_proj = expand_as(fx, x) * x / z + expand_as(cx, x)
|
|
y_proj = expand_as(fy, y) * y / z + expand_as(cy, y)
|
|
|
|
return torch.stack((x_proj, y_proj, z), dim=-1)
|
|
|
|
def world_from_xy_depth(xy, depth, cam2world, intrinsics):
|
|
batch_size, *_ = cam2world.shape
|
|
|
|
x_cam = xy[..., 0]
|
|
y_cam = xy[..., 1]
|
|
z_cam = depth
|
|
|
|
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True)
|
|
world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[..., :3]
|
|
|
|
return world_coords
|
|
|
|
def project_point_on_ray(projection_point, ray_origin, ray_dir):
|
|
dot = torch.einsum("...j,...j", projection_point-ray_origin, ray_dir)
|
|
return ray_origin + dot[..., None] * ray_dir
|
|
|
|
def get_ray_directions(
|
|
xy : Tensor, # (..., N, 2)
|
|
cam2world : Tensor, # (..., 4, 4)
|
|
intrinsics : Tensor, # (..., 3, 3)
|
|
):
|
|
z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
|
|
pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3)
|
|
|
|
cam_pos = cam2world[..., :3, 3]
|
|
ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
|
|
ray_dirs = F.normalize(ray_dirs, dim=-1)
|
|
return ray_dirs
|
|
|
|
def ray_sphere_intersect(
|
|
ray_origins : Tensor, # (..., 3)
|
|
ray_dirs : Tensor, # (..., 3)
|
|
sphere_centers : Optional[Tensor] = None, # (..., 3)
|
|
sphere_radii : Optional[Tensor] = 1, # (...)
|
|
*,
|
|
return_parts : bool = False,
|
|
allow_nans : bool = True,
|
|
improve_miss_grads : bool = False,
|
|
) -> tuple[Tensor, ...]:
|
|
if improve_miss_grads: assert not allow_nans, "improve_miss_grads does not work with allow_nans"
|
|
if sphere_centers is None:
|
|
ray_origins_centered = ray_origins #- torch.zeros_like(ray_origins)
|
|
else:
|
|
ray_origins_centered = ray_origins - sphere_centers
|
|
|
|
ray_dir_dot_origins = (ray_dirs * ray_origins_centered).sum(dim=-1, keepdim=True)
|
|
discriminants2 = ray_dir_dot_origins**2 - ((ray_origins_centered * ray_origins_centered).sum(dim=-1) - sphere_radii**2)[..., None]
|
|
if not allow_nans or return_parts:
|
|
is_intersecting = discriminants2 > 0
|
|
if allow_nans:
|
|
discriminants = torch.sqrt(discriminants2)
|
|
else:
|
|
discriminants = torch.sqrt(torch.where(is_intersecting, discriminants2,
|
|
discriminants2 - discriminants2.detach() + 0.001
|
|
if improve_miss_grads else
|
|
torch.zeros_like(discriminants2)
|
|
))
|
|
assert not discriminants.detach().isnan().any() # slow, use optimizations!
|
|
|
|
if not return_parts:
|
|
return (
|
|
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
|
|
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
|
|
)
|
|
else:
|
|
return (
|
|
ray_origins + ray_dirs * (- ray_dir_dot_origins),
|
|
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
|
|
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
|
|
is_intersecting.squeeze(-1),
|
|
)
|