Files
marf/ifield/viewer/ray_field.py
2025-01-09 15:43:11 +01:00

793 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from ..data.common.scan import SingleViewUVScan
import mesh_to_sdf.scan as sdf_scan
from ..models import intersection_fields
from ..utils import geometry, helpers
from ..utils.operators import diff
from .common import InteractiveViewer
from matplotlib import cm
import matplotlib.colors as mcolors
from concurrent.futures import ThreadPoolExecutor
from textwrap import dedent
from typing import Hashable, Optional, Callable
from munch import Munch
import functools
import itertools
import numpy as np
import random
from pathlib import Path
import shutil
import subprocess
import torch
from trimesh import Trimesh
import trimesh.transformations as T
class ModelViewer(InteractiveViewer):
lambertian_color = (1.0, 1.0, 1.0)
max_cols = 200
max_cols = 32
def __init__(self,
model : intersection_fields.IntersectionFieldAutoDecoderModel,
start_uid : Hashable,
skyward : str = "+Z",
mesh_gt_getter: Callable[[Hashable], Trimesh] | None = None,
*a, **kw):
self.model = model
self.model.eval()
self.current_uid = self._prev_uid = start_uid
self.all_uids = list(model.keys())
self.mesh_gt_getter = mesh_gt_getter
self.current_gt_mesh: tuple[Hashable, Trimesh] = (None, None)
self.display_mode_normals = self.vizmodes_normals .index("medial" if self.model.hparams.output_mode == "medial_sphere" else "analytical")
self.display_mode_shading = self.vizmodes_shading .index("lambertian")
self.display_mode_centroid = self.vizmodes_centroids.index("best-centroids-colored")
self.display_mode_spheres = self.vizmodes_spheres .index(None)
self.display_mode_variation = 0
self.display_sphere_map_bg = True
self.atom_radius_offset = 0
self.atom_index_solo = None
self.export_medial_surface_mesh = False
self.light_angle1 = 0
self.light_angle2 = 0
self.obj_rot = {
"-X": torch.tensor(T.rotation_matrix(angle= np.pi/2, direction=(0, 1, 0))[:3, :3], **model.device_and_dtype).T,
"+X": torch.tensor(T.rotation_matrix(angle=-np.pi/2, direction=(0, 1, 0))[:3, :3], **model.device_and_dtype).T,
"-Y": torch.tensor(T.rotation_matrix(angle= np.pi/2, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
"+Y": torch.tensor(T.rotation_matrix(angle=-np.pi/2, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
"-Z": torch.tensor(T.rotation_matrix(angle= np.pi, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
"+Z": torch.eye(3, **model.device_and_dtype),
}[str(skyward).upper()]
self.obj_rot_inv = torch.linalg.inv(self.obj_rot)
super().__init__(*a, **kw)
vizmodes_normals = (
"medial",
"analytical",
"ground_truth",
)
vizmodes_shading = (
None, # just atoms or medial axis
"colored-lambertian",
"lambertian",
"shade-best-radii",
"shade-all-radii",
"translucent",
"normal",
"centroid-grad-norm", # backprop
"anisotropic", # backprop
"curvature", # backprop
"glass",
"double-glass",
)
vizmodes_centroids = (
None,
"best-centroids",
"all-centroids",
"best-centroids-colored",
"all-centroids-colored",
"miss-centroids-colored",
"all-miss-centroids-colored",
)
vizmodes_spheres = (
None,
"intersecting-sphere",
"intersecting-sphere-colored",
"best-sphere",
"best-sphere-colored",
"all-spheres-colored",
)
def get_display_mode(self) -> tuple[str, str, Optional[str], Optional[str]]:
MARF = self.model.hparams.output_mode == "medial_sphere"
if isinstance(self.display_mode_normals, str): self.display_mode_normals = self.vizmodes_shading .index(self.display_mode_normals)
if isinstance(self.display_mode_shading, str): self.display_mode_shading = self.vizmodes_shading .index(self.display_mode_shading)
if isinstance(self.display_mode_centroid, str): self.display_mode_centroid = self.vizmodes_centroids.index(self.display_mode_centroid)
if isinstance(self.display_mode_spheres, str): self.display_mode_spheres = self.vizmodes_spheres .index(self.display_mode_spheres)
out = (
self.vizmodes_normals [self.display_mode_normals % len(self.vizmodes_normals)],
self.vizmodes_shading [self.display_mode_shading % len(self.vizmodes_shading)],
self.vizmodes_centroids[self.display_mode_centroid % len(self.vizmodes_centroids)] if MARF else None,
self.vizmodes_spheres [self.display_mode_spheres % len(self.vizmodes_spheres)] if MARF else None,
)
self.set_caption(" & ".join(i for i in out if i is not None))
return out
@property
def cam_state(self):
return super().cam_state | {
"light_angle1" : self.light_angle1,
"light_angle2" : self.light_angle2,
}
@cam_state.setter
def cam_state(self, new_state):
InteractiveViewer.cam_state.fset(self, new_state)
self.light_angle1 = new_state.get("light_angle1", self.light_angle1)
self.light_angle2 = new_state.get("light_angle2", self.light_angle2)
def get_current_conditioning(self) -> Optional[torch.Tensor]:
if not self.model.is_conditioned:
return None
prev_uid = self._prev_uid # to determine if target has changed
next_z = self.model[prev_uid].detach() # interpolation target
prev_z = getattr(self, "_prev_z", next_z) # interpolation source
epoch = getattr(self, "_prev_epoch", 0) # interpolation factor
if not self.is_headless:
now = self.t
t = (now - epoch) / 1 # 1 second
else:
now = self.frame_idx
t = (now - epoch) / self.fps # 1 second
assert t >= 0
if t < 1:
next_z = next_z*t + prev_z*(1-t)
if prev_uid != self.current_uid:
self._prev_uid = self.current_uid
self._prev_z = next_z
self._prev_epoch = now
return next_z
def get_current_ground_truth(self) -> Trimesh | None:
if self.mesh_gt_getter is None:
return None
uid, mesh = self.current_gt_mesh
try:
if uid != self.current_uid:
print("Loading ground truth mesh...")
mesh = self.mesh_gt_getter(self.current_uid)
self.current_gt_mesh = self.current_uid, mesh
except NotImplementedError:
self.current_gt_mesh = self.current_uid, None
return None
return mesh
def handle_keys_pressed(self, pressed):
td = super().handle_keys_pressed(pressed)
mod = pressed[self.constants.K_LALT] or pressed[self.constants.K_RALT]
if not mod and pressed[self.constants.K_f]: self.light_angle1 -= td * 0.5
if not mod and pressed[self.constants.K_g]: self.light_angle1 += td * 0.5
if mod and pressed[self.constants.K_f]: self.light_angle2 += td * 0.5
if mod and pressed[self.constants.K_g]: self.light_angle2 -= td * 0.5
return td
def handle_key_down(self, key, keys_pressed):
super().handle_key_down(key, keys_pressed)
shift = keys_pressed[self.constants.K_LSHIFT] or keys_pressed[self.constants.K_RSHIFT]
if key == self.constants.K_o:
i = self.all_uids.index(self.current_uid)
i = (i - 1) % len(self.all_uids)
self.current_uid = self.all_uids[i]
print(self.current_uid)
if key == self.constants.K_p:
i = self.all_uids.index(self.current_uid)
i = (i + 1) % len(self.all_uids)
self.current_uid = self.all_uids[i]
print(self.current_uid)
if key == self.constants.K_SPACE:
self.display_sphere_map_bg = {
True : 255,
255 : 0,
0 : True,
}.get(self.display_sphere_map_bg, True)
if key == self.constants.K_u: self.export_medial_surface_mesh = True
if key == self.constants.K_x: self.display_mode_normals += -1 if shift else 1
if key == self.constants.K_c: self.display_mode_shading += -1 if shift else 1
if key == self.constants.K_v: self.display_mode_centroid += -1 if shift else 1
if key == self.constants.K_b: self.display_mode_spheres += -1 if shift else 1
if key == self.constants.K_e: self.display_mode_variation+= -1 if shift else 1
if key == self.constants.K_c: self.display_mode_variation = 0
if key == self.constants.K_0: self.atom_index_solo = None
if key == self.constants.K_1: self.atom_index_solo = 0 if self.atom_index_solo != 0 else None
if key == self.constants.K_2: self.atom_index_solo = 1 if self.atom_index_solo != 1 else None
if key == self.constants.K_3: self.atom_index_solo = 2 if self.atom_index_solo != 2 else None
if key == self.constants.K_4: self.atom_index_solo = 3 if self.atom_index_solo != 3 else None
if key == self.constants.K_5: self.atom_index_solo = 4 if self.atom_index_solo != 4 else None
if key == self.constants.K_6: self.atom_index_solo = 5 if self.atom_index_solo != 5 else None
if key == self.constants.K_7: self.atom_index_solo = 6 if self.atom_index_solo != 6 else None
if key == self.constants.K_8: self.atom_index_solo = 7 if self.atom_index_solo != 7 else None
if key == self.constants.K_9: self.atom_index_solo = self.atom_index_solo + (-1 if shift else 1) if self.atom_index_solo is not None else 0
def handle_mouse_button_down(self, pos, button, keys_pressed):
super().handle_mouse_button_down(pos, button, keys_pressed)
if button in (1, 3):
self.display_mode_spheres += 1 if button == 1 else -1
def handle_mousewheel(self, flipped, x, y, keys_pressed):
shift = keys_pressed[self.constants.K_LSHIFT] or keys_pressed[self.constants.K_RSHIFT]
if not shift:
super().handle_mousewheel(flipped, x, y, keys_pressed)
else:
self.atom_radius_offset += 0.005 * y
print()
print("atom_radius_offset:", self.atom_radius_offset)
def setup(self):
if not self.is_headless:
print(dedent("""
WASD + PG Up/Down - translate
ARROWS - rotate
(SHIFT+) C - Next/(Prev) shading mode
(SHIFT+) V - Next/(Prev) centroids mode
(SHIFT+) B - Next/(Prev) sphere mode
Mouse L/ R - Next/ Prev sphere mode
(SHIFT+) E - Next/(Prev) variation (for quick experimentation within a shading mode)
SHIFT + Scroll - Offset atom radius
ALT + Scroll - Modify FoV (_true_ zoom)
Mouse Scroll - Translate in/out ("zoom", moves camera to/from to point of focus)
Alt+PG Up/Down - Translate in/out ("zoom", moves camera to/from to point of focus)
F / G - rotate light left / right
ALT+ F / G - rotate light up / down
CTRL / SHIFT - faster/slower rotation
O / P - prev/next object
1-9 - solo atom
0 - show all atoms
+ / - - decrease/increase pixel scale
R - rotate continuously
H / SHIFT+H / CTRL+H - load/save/print camera state
Enter - reset camera state
Y - save screenshot
U - save mesh of centroids
Space - cycle sphere map background
Q - quit
""").strip())
fname = Path(__file__).parent.resolve() / "assets/texturify_pano-1-4.jpg"
self.load_sphere_map(fname)
if self.model.hparams.output_mode == "medial_sphere":
@self.model.net.register_forward_hook
def atom_offset_radius_and_solo(model, input, output):
slice = (..., [i+3 for i in range(0, output.shape[-1], 4)])
output[slice] += self.atom_radius_offset * output[slice].sign()
if self.atom_index_solo is not None:
x = self.atom_index_solo * 4
x = x % output.shape[-1]
output = output[..., list(range(x, x+4))]
return output
self._atom_offset_radius_and_solo_hook = atom_offset_radius_and_solo
def teardown(self):
if hasattr(self, "_atom_offset_radius_and_solo_hook"):
self._atom_offset_radius_and_solo_hook.remove()
del self._atom_offset_radius_and_solo_hook
@torch.no_grad()
def render_frame(self, pixel_view: np.ndarray): # (W, H, 3) dtype=uint8
MARF = self.model.hparams.output_mode == "medial_sphere"
PRIF = self.model.hparams.output_mode == "orthogonal_plane"
assert (MARF or PRIF) and MARF != PRIF
device_and_dtype = self.model.device_and_dtype
device = self.model.device
dtype = self.model.dtype
(
vizmode_normals,
vizmode_shading,
vizmode_centroids,
vizmode_spheres,
) = self.get_display_mode()
dirs, origins = self.raydirs_and_cam
origins = origins.detach().clone().to(**device_and_dtype)
dirs = dirs .detach().clone().to(**device_and_dtype)
if vizmode_normals != "ground_truth" or self.get_current_ground_truth() is None:
# enable grad or not
do_jac = PRIF or vizmode_normals == "analytical"
do_jac_medial = MARF and "centroid-grad-norm" in (vizmode_shading or "")
do_shape_operator = "anisotropic" in (vizmode_shading or "") or "curvature" in (vizmode_shading or "")
do_grad = do_jac or do_jac_medial or do_shape_operator
if do_grad:
origins = origins.broadcast_to(dirs.shape)
self.model.eval()
latent = self.get_current_conditioning()
if self.max_cols is None or self.max_cols > dirs.shape[0]:
chunks = [slice(None)]
else:
chunks = [slice(col, col+self.max_cols) for col in range(0, dirs.shape[0], self.max_cols)]
forward_chunks = []
for chunk in chunks:
self.model.zero_grad()
origins_chunk = origins[chunk if origins.ndim != 1 else slice(None)] @ self.obj_rot
dirs_chunk = dirs [chunk] @ self.obj_rot
if do_grad:
origins_chunk.requires_grad = dirs_chunk.requires_grad = True
@forward_chunks.append
@(lambda f: f(origins_chunk, dirs_chunk))
@torch.set_grad_enabled(do_grad)
def forward_chunk(origins, dirs) -> Munch:
if PRIF:
intersections, is_intersecting = self.model(dict(origins=origins, dirs=dirs), z=latent, normalize_origins=True)
is_intersecting = is_intersecting > 0.5
elif MARF:
(
depths, silhouettes, intersections,
intersection_normals, is_intersecting,
sphere_centers, sphere_radii,
atom_indices,
all_intersections, all_intersection_normals, all_depths, all_silhouettes, all_is_intersecting,
all_sphere_centers, all_sphere_radii,
) = self.model.forward(dict(origins=origins, dirs=dirs), z=latent,
intersections_only = False,
return_all_atoms = True,
)
if do_jac:
jac = diff.jacobian(intersections, origins, detach=not do_shape_operator)
intersection_normals = self.model.compute_normals_from_intersection_origin_jacobian(jac, dirs.detach())
if do_jac_medial:
sphere_centers_jac = diff.jacobian(sphere_centers, origins, detach=True)
if do_shape_operator:
hess = diff.jacobian(intersection_normals, origins, detach=True)[is_intersecting, :, :]
N = intersection_normals.detach()[is_intersecting, :]
TM = (torch.eye(3, device=device) - N[..., None, :]*N[..., :, None]) # projection onto tangent plane
# shape operator, i.e. total derivative of the surface normal w.r.t. the tangent space
shape_operator = hess @ TM
return Munch((k, v.detach()) for k, v in locals().items() if isinstance(v, torch.Tensor))
intersections = torch.cat([chunk.intersections for chunk in forward_chunks], dim=0)
is_intersecting = torch.cat([chunk.is_intersecting for chunk in forward_chunks], dim=0)
intersection_normals = torch.cat([chunk.intersection_normals for chunk in forward_chunks], dim=0)
if MARF:
all_sphere_centers = torch.cat([chunk.all_sphere_centers for chunk in forward_chunks], dim=0)
all_sphere_radii = torch.cat([chunk.all_sphere_radii for chunk in forward_chunks], dim=0)
atom_indices = torch.cat([chunk.atom_indices for chunk in forward_chunks], dim=0)
silhouettes = torch.cat([chunk.silhouettes for chunk in forward_chunks], dim=0)
sphere_centers = torch.cat([chunk.sphere_centers for chunk in forward_chunks], dim=0)
sphere_radii = torch.cat([chunk.sphere_radii for chunk in forward_chunks], dim=0)
if do_jac_medial:
sphere_centers_jac = torch.cat([chunk.sphere_centers_jac for chunk in forward_chunks], dim=0)
if do_shape_operator:
shape_operator = torch.cat([chunk.shape_operator for chunk in forward_chunks], dim=0)
n_atoms = all_sphere_centers.shape[-2] if MARF else 1
intersections = intersections @ self.obj_rot_inv
intersection_normals = intersection_normals @ self.obj_rot_inv
sphere_centers = sphere_centers @ self.obj_rot_inv if sphere_centers is not None else None
all_sphere_centers = all_sphere_centers @ self.obj_rot_inv if all_sphere_centers is not None else None
else: # render ground truth mesh
# HACK: we use a thread to not break the pygame opengl context
with ThreadPoolExecutor(max_workers=1) as p:
scan = p.submit(sdf_scan.Scan, self.get_current_ground_truth(),
camera_transform = self.cam2world.numpy(),
resolution = self.res[1],
calculate_normals = True,
fov = self.cam_fov_y,
z_near = 0.001,
z_far = 50,
no_flip_backfaced_normals = True
).result()
n_atoms, MARF, PRIF = 1, False, True
is_intersecting = torch.zeros(self.res, dtype=bool)
is_intersecting[ (self.res[0]-self.res[1]) // 2 : (self.res[0]-self.res[1]) // 2 + self.res[1], : ] = torch.tensor(scan.depth_buffer != 0, dtype=bool)
intersections = torch.zeros((*is_intersecting.shape, 3), dtype=dtype)
intersection_normals = torch.zeros((*is_intersecting.shape, 3), dtype=dtype)
intersections [is_intersecting] = torch.tensor(scan.points, dtype=dtype)
intersection_normals[is_intersecting] = torch.tensor(scan.normals, dtype=dtype)
is_intersecting = is_intersecting .flip(1).to(device)
intersections = intersections .flip(1).to(device)
intersection_normals = intersection_normals.flip(1).to(device)
mask = is_intersecting.cpu()
mx, my = self.mouse_position
w, h = dirs.shape[:2]
# fill white
if self.display_sphere_map_bg == True:
self.blit_sphere_map_mask(pixel_view)
else:
pixel_view[:] = self.display_sphere_map_bg
# draw to buffer
to_cam = -dirs.detach()
# light direction
extra = np.pi if vizmode_shading == "translucent" else 0
LM = torch.tensor(T.rotation_matrix(angle=self.light_angle2, direction=(0, 1, 0))[:3, :3], dtype=dtype)
LM = torch.tensor(T.rotation_matrix(angle=self.light_angle1 + extra, direction=(1, 0, 0))[:3, :3], dtype=dtype) @ LM
to_light = (self.cam2world[:3, :3] @ LM @ torch.tensor((1, 1, 3), dtype=dtype)).to(device)[None, :]
to_light = to_light / to_light.norm(dim=-1, keepdim=True)
# used to color different atom candidates
color_set = tuple(map(helpers.hex2tuple,
itertools.chain(
mcolors.TABLEAU_COLORS.values(),
#list(mcolors.TABLEAU_COLORS.values())[::-1],
#['#f8481c', '#c20078', '#35530a', '#010844', '#a8ff04'],
mcolors.XKCD_COLORS.values(),
)
))
color_per_atom = (*zip(*zip(range(n_atoms), itertools.cycle(color_set))),)[1]
# shade hits
if vizmode_shading is None:
pass
elif vizmode_shading == "colored-lambertian":
if n_atoms > 1:
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[is_intersecting].T,)]
else:
color = torch.tensor(color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)], device=device)
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None]
pixel_view[mask, :] = (color *
torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_cam[is_intersecting, :],
)[..., None]).int().cpu()
pixel_view[mask, :] = (
255 * lambertian.clamp(0, 1).pow(32) +
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.clamp(0, 1).pow(32))
).cpu()
elif vizmode_shading == "lambertian":
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None].clamp(0, 1)
if self.lambertian_color == (1.0, 1.0, 1.0):
pixel_view[mask, :] = (255 * lambertian).cpu()
else:
color = 255*torch.tensor(self.lambertian_color, device=device)
pixel_view[mask, :] = (color *
torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_cam[is_intersecting, :],
)[..., None]).int().cpu()
pixel_view[mask, :] = (
255 * lambertian.clamp(0, 1).pow(32) +
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.clamp(0, 1).pow(32))
).cpu()
elif vizmode_shading == "translucent" and MARF:
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None].abs().clamp(0, 1)
distortion = 0.08
power = 16
ambient = 0
thickness = sphere_radii[is_intersecting].detach()
if self.display_mode_variation % 2:
thickness = thickness.mean()
color1 = torch.tensor((1, 0.5, 0.5), **device_and_dtype) # subsurface
color2 = torch.tensor((0, 1, 1), **device_and_dtype) # diffuse
l = to_light + intersection_normals[is_intersecting, :] * distortion
d = (to_cam[is_intersecting, :] * -l).sum(dim=-1).clamp(0, None).pow(power)
f = (d + ambient) * (1/(0.05 + thickness))
pixel_view[((dirs * to_light).sum(dim=-1) > 0.99).cpu(), :] = 255 # draw light source
pixel_view[mask, :] = (255 * (
color2 * (0.05 + lambertian*0.15) +
color1 * 0.3 * f[..., None]
).clamp(0, 1)).cpu()
elif vizmode_shading == "anisotropic" and vizmode_normals != "ground_truth":
eigvals, eigvecs = torch.linalg.eig(shape_operator.mT) # slow, complex output, not sorted
eigvals, indices = eigvals.abs().sort(dim=-1)
eigvecs = (eigvecs.abs() * eigvecs.real.sign()).take_along_dim(indices[..., None, :], dim=-1)
eigvecs = eigvecs.mT
s = self.display_mode_variation % 5
if s in (0, 1):
# try to keep these below 0.2:
if s == 0: a1, a2 = 0.05, 0.3
if s == 1: a1, a2 = 0.3, 0.05
# == Ward anisotropic specular reflectance ==
# G.J. Ward, Measuring and modeling anisotropic reflection, in:
# Proceedings of the 19th Annual Conference on Computer Graphics and
# Interactive Techniques, 1992: pp. 265272.
eigvecs /= eigvecs.norm(dim=-1, keepdim=True)
N = intersection_normals[is_intersecting, :]
H = to_cam[is_intersecting, :] + to_light
H = H / H.norm(dim=-1, keepdim=True)
specular = (1/(4*torch.pi * a1*a2 * torch.sqrt((
(N * to_cam[is_intersecting, :]).sum(dim=-1) *
(N * to_light ).sum(dim=-1)
)))) * torch.exp(
-2 * (
((H * eigvecs[..., 2, :]).sum(dim=-1) / a1).pow(2)
+
((H * eigvecs[..., 1, :]).sum(dim=-1) / a2).pow(2)
) / (
1 + (N * H).sum(dim=-1)
)
)
specular = specular.clamp(0, None).nan_to_num(0, 0, 0)
lambertian = torch.einsum("id,id->i", N, to_light ).clamp(0, None)
color1 = 0.4 * torch.tensor((1, 1, 1), **device_and_dtype) # specular
color2 = 0.4 * torch.tensor((0, 1, 1), **device_and_dtype) # diffuse
pixel_view[mask, :] = (255 * (
color1 * specular [..., None] +
color2 * lambertian[..., None]
).clamp(0, 1)).int().cpu()
if s == 2:
pixel_view[mask, :] = (255 * (
eigvecs[..., 2, :].abs().clamp(0, 1) # orientation only
)).int().cpu()
elif s == 3:
pixel_view[mask, :] = (255 * (
eigvecs[..., 1, :].abs().clamp(0, 1) # orientation only
)).int().cpu()
elif s == 4:
pixel_view[mask, :] = (255 * (
eigvecs[..., 0, :].abs().clamp(0, 1) # orientation only
)).int().cpu()
elif vizmode_shading == "shade-best-radii" and MARF:
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None]
radii = sphere_radii[is_intersecting]
radii = radii - 0.04
radii = radii / 0.4
colors = cm.plasma(radii.clamp(0, 1).cpu())[..., :3]
pixel_view[mask, :] = 255 * (
lambertian.pow(32).clamp(0, 1).cpu().numpy() +
colors * (lambertian + 0.25).clamp(0, 1).cpu().numpy() * (1-lambertian.pow(32).clamp(0, 1)).cpu().numpy()
)
elif vizmode_shading == "shade-all-radii" and MARF:
radii = sphere_radii[is_intersecting][..., None]
radii /= radii.max()
if n_atoms > 1:
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[is_intersecting].T,)]
else:
color = torch.tensor(color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)], device=device)
pixel_view[mask, :] = (color * radii).int().cpu()
elif vizmode_shading == "normal":
normal = intersection_normals[is_intersecting, :]
pixel_view[mask, :] = (255 * (normal * 0.5 + 0.5) ).int().cpu()
elif vizmode_shading == "curvature" and vizmode_normals != "ground_truth":
eigvals = torch.linalg.eigvals(shape_operator.mT) # complex output, not sorted
# we sort them by absolute magnitude, not the real component
_, indices = (eigvals.abs() * eigvals.real.sign()).sort(dim=-1)
eigvals = eigvals.real.take_along_dim(indices, dim=-1)
s = self.display_mode_variation % (6 if MARF else 5)
if s==0: out = (eigvals[..., [0, 2]].mean(dim=-1, keepdim=True) / 25).tanh() # mean curvature
if s==1: out = (eigvals[..., [0, 2]].prod(dim=-1, keepdim=True) / 25).tanh() # gaussian curvature
if s==2: out = (eigvals[..., [2]] / 25).tanh() # maximum principal curvature - k1
if s==3: out = (eigvals[..., [1]] / 25).tanh() # some curvature
if s==4: out = (eigvals[..., [0]] / 25).tanh() # minimum principal curvature - k2
if s==5: out = ((sphere_radii[is_intersecting][..., None].detach() - 1 / eigvals[..., [2]].clamp(1e-8, None)) * 5).tanh().clamp(0, None)
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None]
pixel_view[mask, :] = (255 * (lambertian+0.5).clamp(0, 1) * torch.cat((
1+out.clamp(-1, 0),
1-out.abs(),
1-out.clamp(0, 1),
), dim=-1)).int().cpu()
elif vizmode_shading == "centroid-grad-norm" and MARF:
asd = sphere_centers_jac[is_intersecting, :, :].norm(dim=-2).mean(dim=-1, keepdim=True)
asd -= asd.min()
asd /= asd.max()
pixel_view[mask, :] = (255 * asd).cpu()
elif "glass" in vizmode_shading:
normals = intersection_normals[is_intersecting, :]
to_cam_ = to_cam [is_intersecting, :]
# "Empiricial Approximation" of fresnel
# https://developer.download.nvidia.com/CgTutorial/cg_tutorial_chapter07.html via
# http://kylehalladay.com/blog/tutorial/2014/02/18/Fresnel-Shaders-From-The-Ground-Up.html
cos = torch.einsum("id,id->i", normals, to_cam_ )[..., None]
bias, scale, power = 0, 4, 3
fresnel = (bias + scale*(1-cos)**power).clamp(0, 1)
#reflection
reflection = -to_cam_ - 2*(-cos)*normals
#refraction
r = 1 / 1.5 # refractive index, air -> glass
refraction = -r*to_cam_ + (r*cos - (1-r**2*(1-cos**2)).sqrt()) * normals
exit_point = intersections[is_intersecting, :]
# reflect the refraction over the plane defined by the refraction direction and the sphere center, resulting in the second refraction
if vizmode_shading == "double-glass" and MARF:
cos2 = torch.einsum("id,id->i", refraction, -to_cam_ )[..., None]
pn = -to_cam_ - cos2*refraction
pn /= pn.norm(dim=-1, keepdim=True)
refraction = -to_cam_ - 2*torch.einsum("id,id->i", pn, -to_cam_ )[..., None]*pn
exit_point -= sphere_centers[is_intersecting, :]
exit_point = exit_point - 2*torch.einsum("id,id->i", pn, exit_point )[..., None]*pn
exit_point += sphere_centers[is_intersecting, :]
fresnel = np.asanyarray(fresnel.cpu())
pixel_view[mask, :] \
= self.lookup_sphere_map_dirs(reflection, intersections[is_intersecting, :]) * fresnel \
+ self.lookup_sphere_map_dirs(refraction, exit_point) * (1-fresnel)
else: # flat
pixel_view[mask, :] = 80
if not MARF: return
# overlay medial atoms
if vizmode_spheres is not None:
# show miss distance in red
s = silhouettes.detach()[~is_intersecting].clamp(0, 1)
s /= s.max()
pixel_view[~mask, 1] = (s * 255).cpu()
pixel_view[:, 2] = pixel_view[:, 1]
mouse_hits = 0 <= mx < w and 0 <= my < h and mask[mx, my]
draw_intersecting = "intersecting-sphere" in vizmode_spheres
draw_best = "best-sphere" in vizmode_spheres
draw_color = "-sphere-colored" in vizmode_spheres
draw_all = "all-spheres-colored" in vizmode_spheres
def get_nears():
if draw_all:
projected, near, far, is_intersecting = geometry.ray_sphere_intersect(
torch.tensor(origins),
torch.tensor(dirs[..., None, :]),
sphere_centers = all_sphere_centers[mx, my][None, None, ...],
sphere_radii = all_sphere_radii [mx, my][None, None, ...],
allow_nans = False,
return_parts = True,
)
depths = (near - origins).norm(dim=-1)
atom_indices_ = torch.where(is_intersecting, depths.detach(), depths.detach()+100).argmin(dim=-1, keepdim=True)
is_intersecting = is_intersecting.any(dim=-1)
projected = None
near = near.take_along_dim(atom_indices_[..., None], -2).squeeze(-2)
far = None
sphere_centers_ = all_sphere_centers[mx, my][None, None, ...].take_along_dim(atom_indices_[..., None], -2).squeeze(-2)
normals = near[is_intersecting, :] - sphere_centers_[is_intersecting, :]
normals /= torch.linalg.norm(normals, dim=-1)[..., None]
color = torch.tensor(color_per_atom, device=device)[(*atom_indices_[is_intersecting].T,)]
yield color, projected, near, far, is_intersecting, normals
if (mouse_hits and draw_intersecting) or draw_best:
projected, near, far, is_intersecting = geometry.ray_sphere_intersect(
torch.tensor(origins),
torch.tensor(dirs),
# unit-sphere by default
sphere_centers = sphere_centers[mx, my][None, None, ...],
sphere_radii = sphere_radii [mx, my][None, None, ...],
return_parts = True,
)
normals = near[is_intersecting, :] - sphere_centers[mx, my][None, ...]
normals /= torch.linalg.norm(normals, dim=-1)[..., None]
color = (255, 255, 255) if not draw_color else color_per_atom[atom_indices[mx, my]]
yield torch.tensor(color, device=device), projected, near, far, is_intersecting, normals
# draw sphere with lambertian shading
for color, projected, near, far, is_intersecting_2, normals in get_nears():
lambertian = torch.einsum("...id,...id->...i", normals, to_light )[..., None]
pixel_view[is_intersecting_2.cpu(), :] = (
255*lambertian.pow(32).clamp(0, 1) +
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.pow(32).clamp(0, 1))
).cpu()
# overlay points / sphere centers
if vizmode_centroids is not None:
cam2world_inv = torch.tensor(self.cam2world_inv, **device_and_dtype)
intrinsics = torch.tensor(self.intrinsics, **device_and_dtype)
def get_coords():
miss_centroid = "miss-centroids" in vizmode_centroids
mask = is_intersecting if not miss_centroid else ~is_intersecting
if vizmode_centroids in ("all-centroids-colored", "all-miss-centroids-colored"):
# we use temporal dithering to the show all overlapping centers
for color, atom_index in sorted(zip(itertools.chain(color_set), range(n_atoms)), key=lambda x: random.random()):
yield color, all_sphere_centers[..., atom_index, :][mask], mask
elif "all-centroids" in vizmode_centroids:
yield (80, 150, 80), all_sphere_centers[mask].reshape(-1, 3), mask # [:, 3]
if "centroids-colored" in vizmode_centroids:
if n_atoms == 1:
color = color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)]
else:
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[mask].T,)].cpu()
else:
color = (0, 0, 0)
yield color, sphere_centers[mask], mask
for i, (color, coords, coord_mask) in enumerate(get_coords()):
if self.export_medial_surface_mesh:
fname = self.mk_dump_fname("ply", uid=i)
p = torch.zeros_like(sphere_centers)
c = torch.zeros_like(sphere_centers)
p[coord_mask, :] = coords
c[coord_mask, :] = torch.tensor(color, device=p.device) / 255
SingleViewUVScan(
hits = ( mask).numpy(),
miss = (~mask).numpy(),
points = p.cpu().numpy(),
colors = c.cpu().numpy(),
normals=None, distances=None, cam_pos=None,
cam_mat4=None, proj_mat4=None, transforms=None,
).to_mesh().export(str(fname), file_type="ply")
print("dumped", fname)
if shutil.which("f3d"):
subprocess.Popen(["f3d", "-gsy", "--up=+z", "--bg-color=1,1,1", fname], close_fds=True)
coords = torch.cat((coords, torch.ones((*coords.shape[:-1], 1), **device_and_dtype)), dim=-1)
coords = torch.einsum("...ij,...kj->...ki", cam2world_inv, coords)[..., :3]
coords = geometry.project(coords[..., 0], coords[..., 1], coords[..., 2], intrinsics)
in_view = functools.reduce(torch.mul, (
coords[:, 0] < pixel_view.shape[1],
coords[:, 0] >= 0,
coords[:, 1] < pixel_view.shape[0],
coords[:, 1] >= 0,
)).cpu()
coords = coords[in_view, :]
if not isinstance(color, tuple):
color = color[in_view, :]
pixel_view[(*coords[..., [1, 0]].int().T.cpu(),)] = color
self.export_medial_surface_mesh = False