431 lines
17 KiB
Python
431 lines
17 KiB
Python
from ..utils import geometry
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from pytorch3d.transforms import euler_angles_to_matrix
|
|
from tqdm import tqdm
|
|
from typing import Sequence, Callable, TypedDict
|
|
import imageio
|
|
import shlex
|
|
import json
|
|
import numpy as np
|
|
import os
|
|
import time
|
|
import torch
|
|
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
|
|
import pygame
|
|
|
|
IVec2 = tuple[int, int]
|
|
IVec3 = tuple[int, int, int]
|
|
Vec2 = tuple[float|int, float|int]
|
|
Vec3 = tuple[float|int, float|int, float|int]
|
|
|
|
class CamState(TypedDict, total=False):
|
|
distance : float
|
|
pos_x : float
|
|
pos_y : float
|
|
pos_z : float
|
|
rot_x : float
|
|
rot_y : float
|
|
fov_y : float
|
|
|
|
|
|
|
|
class InteractiveViewer(ABC):
|
|
constants = pygame.constants # saves an import
|
|
|
|
# realtime
|
|
t : float # time since start
|
|
td : float # time delta since last frame
|
|
|
|
# offline
|
|
is_headless : bool
|
|
fps : int
|
|
frame_idx : int
|
|
|
|
fill_color = (255, 255, 255)
|
|
|
|
def __init__(self, name: str, res: IVec2 = (640, 480), scale: int= 1, screenshot_dir: Path = "."):
|
|
self.name = name
|
|
self.res = res
|
|
self.scale = scale
|
|
self.screenshot_dir = Path(screenshot_dir)
|
|
|
|
self.is_headless = False
|
|
|
|
self.cam_distance = 2.0
|
|
self.cam_pos_x = 0.0 # look-at and rotation pivot
|
|
self.cam_pos_y = 0.0 # look-at and rotation pivot
|
|
self.cam_pos_z = 0.0 # look-at and rotation pivot
|
|
self.cam_rot_x = 0.5 * torch.pi # radians
|
|
self.cam_rot_y = -0.5 * torch.pi # radians
|
|
self.cam_fov_y = 60.0 / 180.0 * 3.1415 # radians
|
|
self.keep_rotating = False
|
|
self.initial_camera_state = self.cam_state
|
|
self.fps_cap = None
|
|
|
|
@property
|
|
def cam_state(self) -> CamState:
|
|
return dict(
|
|
distance = self.cam_distance,
|
|
pos_x = self.cam_pos_x,
|
|
pos_y = self.cam_pos_y,
|
|
pos_z = self.cam_pos_z,
|
|
rot_x = self.cam_rot_x,
|
|
rot_y = self.cam_rot_y,
|
|
fov_y = self.cam_fov_y,
|
|
)
|
|
|
|
@cam_state.setter
|
|
def cam_state(self, new_state: CamState):
|
|
self.cam_distance = new_state.get("distance", self.cam_distance)
|
|
self.cam_pos_x = new_state.get("pos_x", self.cam_pos_x)
|
|
self.cam_pos_y = new_state.get("pos_y", self.cam_pos_y)
|
|
self.cam_pos_z = new_state.get("pos_z", self.cam_pos_z)
|
|
self.cam_rot_x = new_state.get("rot_x", self.cam_rot_x)
|
|
self.cam_rot_y = new_state.get("rot_y", self.cam_rot_y)
|
|
self.cam_fov_y = new_state.get("fov_y", self.cam_fov_y)
|
|
|
|
@property
|
|
def scaled_res(self) -> IVec2:
|
|
return (
|
|
self.res[0] * self.scale,
|
|
self.res[1] * self.scale,
|
|
)
|
|
|
|
def setup(self):
|
|
pass
|
|
|
|
def teardown(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def render_frame(self, pixel_view: np.ndarray): # (W, H, 3) dtype=uint8
|
|
...
|
|
|
|
def handle_key_up(self, key: int, keys_pressed: Sequence[bool]):
|
|
pass
|
|
|
|
def handle_key_down(self, key: int, keys_pressed: Sequence[bool]):
|
|
mod = keys_pressed[pygame.K_LSHIFT] or keys_pressed[pygame.K_RSHIFT]
|
|
mod2 = keys_pressed[pygame.K_LCTRL] or keys_pressed[pygame.K_RCTRL]
|
|
if key == pygame.K_r:
|
|
self.keep_rotating = True
|
|
self.cam_rot_x += self.td
|
|
if key == pygame.K_MINUS:
|
|
self.scale += 1
|
|
if __debug__: print()
|
|
print(f"== Scale = {self.scale} ==")
|
|
if key == pygame.K_PLUS and self.scale > 1:
|
|
self.scale -= 1
|
|
if __debug__: print()
|
|
print(f"== Scale = {self.scale} ==")
|
|
if key == pygame.K_RETURN:
|
|
self.cam_state = self.initial_camera_state
|
|
if key == pygame.K_h:
|
|
if mod2:
|
|
print(shlex.quote(json.dumps(self.cam_state)))
|
|
elif mod:
|
|
with (self.screenshot_dir / "camera.json").open("w") as f:
|
|
json.dump(self.cam_state, f)
|
|
print("Wrote", self.screenshot_dir / "camera.json")
|
|
else:
|
|
with (self.screenshot_dir / "camera.json").open("r") as f:
|
|
self.cam_state = json.load(f)
|
|
print("Read", self.screenshot_dir / "camera.json")
|
|
|
|
def handle_keys_pressed(self, pressed: Sequence[bool]) -> float:
|
|
mod1 = pressed[pygame.K_LCTRL] or pressed[pygame.K_RCTRL]
|
|
mod2 = pressed[pygame.K_LSHIFT] or pressed[pygame.K_RSHIFT]
|
|
mod3 = pressed[pygame.K_LALT] or pressed[pygame.K_RALT]
|
|
td = self.td * (0.5 if mod2 else (6 if mod1 else 2))
|
|
|
|
if pressed[pygame.K_UP]: self.cam_rot_y += td
|
|
if pressed[pygame.K_DOWN]: self.cam_rot_y -= td
|
|
if pressed[pygame.K_LEFT]: self.cam_rot_x += td
|
|
if pressed[pygame.K_RIGHT]: self.cam_rot_x -= td
|
|
if pressed[pygame.K_PAGEUP] and mod3: self.cam_distance -= td
|
|
if pressed[pygame.K_PAGEDOWN] and mod3: self.cam_distance += td
|
|
|
|
if any(pressed[i] for i in [pygame.K_UP, pygame.K_DOWN, pygame.K_LEFT, pygame.K_RIGHT]):
|
|
self.keep_rotating = False
|
|
if self.keep_rotating: self.cam_rot_x += self.td * 0.25
|
|
|
|
if pressed[pygame.K_w]: self.cam_pos_x -= td * np.cos(-self.cam_rot_x)
|
|
if pressed[pygame.K_w]: self.cam_pos_y += td * np.sin(-self.cam_rot_x)
|
|
if pressed[pygame.K_s]: self.cam_pos_x += td * np.cos(-self.cam_rot_x)
|
|
if pressed[pygame.K_s]: self.cam_pos_y -= td * np.sin(-self.cam_rot_x)
|
|
if pressed[pygame.K_a]: self.cam_pos_x += td * np.sin(self.cam_rot_x)
|
|
if pressed[pygame.K_a]: self.cam_pos_y -= td * np.cos(self.cam_rot_x)
|
|
if pressed[pygame.K_d]: self.cam_pos_x -= td * np.sin(self.cam_rot_x)
|
|
if pressed[pygame.K_d]: self.cam_pos_y += td * np.cos(self.cam_rot_x)
|
|
if pressed[pygame.K_PAGEUP] and not mod3: self.cam_pos_z -= td
|
|
if pressed[pygame.K_PAGEDOWN] and not mod3: self.cam_pos_z += td
|
|
|
|
return td
|
|
|
|
def handle_mouse_button_up(self, pos: IVec2, button: int, keys_pressed: Sequence[bool]):
|
|
pass
|
|
|
|
def handle_mouse_button_down(self, pos: IVec2, button: int, keys_pressed: Sequence[bool]):
|
|
pass
|
|
|
|
def handle_mouse_motion(self, pos: IVec2, rel: IVec2, buttons: Sequence[bool], keys_pressed: Sequence[bool]):
|
|
pass
|
|
|
|
def handle_mousewheel(self, flipped: bool, x: int, y: int, keys_pressed: Sequence[bool]):
|
|
if keys_pressed[pygame.K_LALT] or keys_pressed[pygame.K_RALT]:
|
|
self.cam_fov_y -= y * 0.015
|
|
else:
|
|
self.cam_distance -= y * 0.2
|
|
|
|
_current_caption = None
|
|
def set_caption(self, title: str, *a, **kw):
|
|
if self._current_caption != title and not self.is_headless:
|
|
print(f"set_caption: {title!r}")
|
|
self._current_caption = title
|
|
return pygame.display.set_caption(title, *a, **kw)
|
|
|
|
@property
|
|
def mouse_position(self) -> IVec2:
|
|
mx, my = pygame.mouse.get_pos() if not self.is_headless else (0, 0)
|
|
return (
|
|
mx // self.scale,
|
|
my // self.scale,
|
|
)
|
|
|
|
@property
|
|
def uvs(self) -> torch.Tensor: # (w, h, 2) dtype=float32
|
|
res = tuple(self.res)
|
|
if not getattr(self, "_uvs_res", None) == res:
|
|
U, V = torch.meshgrid(
|
|
torch.arange(self.res[1]).to(torch.float32),
|
|
torch.arange(self.res[0]).to(torch.float32),
|
|
indexing="xy",
|
|
)
|
|
self._uvs_res, self._uvs = res, torch.stack((U, V), dim=-1)
|
|
return self._uvs
|
|
|
|
@property
|
|
def cam2world(self) -> torch.Tensor: # (4, 4) dtype=float32
|
|
if getattr(self, "_cam2world_cam_rot_y", None) is not self.cam_rot_y \
|
|
or getattr(self, "_cam2world_cam_rot_x", None) is not self.cam_rot_x \
|
|
or getattr(self, "_cam2world_cam_pos_x", None) is not self.cam_pos_x \
|
|
or getattr(self, "_cam2world_cam_pos_y", None) is not self.cam_pos_y \
|
|
or getattr(self, "_cam2world_cam_pos_z", None) is not self.cam_pos_z \
|
|
or getattr(self, "_cam2world_cam_distance", None) is not self.cam_distance:
|
|
self._cam2world_cam_rot_y = self.cam_rot_y
|
|
self._cam2world_cam_rot_x = self.cam_rot_x
|
|
self._cam2world_cam_pos_x = self.cam_pos_x
|
|
self._cam2world_cam_pos_y = self.cam_pos_y
|
|
self._cam2world_cam_pos_z = self.cam_pos_z
|
|
self._cam2world_cam_distance = self.cam_distance
|
|
|
|
a = torch.eye(4)
|
|
a[2, 3] = self.cam_distance
|
|
b = torch.eye(4)
|
|
b[:3, :3] = euler_angles_to_matrix(torch.tensor((self.cam_rot_x, self.cam_rot_y, 0)), "ZYX")
|
|
b[0:3, 3] -= torch.tensor(( self.cam_pos_x, self.cam_pos_y, self.cam_pos_z, ))
|
|
self._cam2world = b @ a
|
|
|
|
self._cam2world_inv = None
|
|
return self._cam2world
|
|
|
|
@property
|
|
def cam2world_inv(self) -> torch.Tensor: # (4, 4) dtype=float32
|
|
if getattr(self, "_cam2world_inv", None) is None:
|
|
self._cam2world_inv = torch.linalg.inv(self._cam2world)
|
|
return self._cam2world_inv
|
|
|
|
@property
|
|
def intrinsics(self) -> torch.Tensor: # (3, 3) dtype=float32
|
|
if getattr(self, "_intrinsics_res", None) is not self.res \
|
|
or getattr(self, "_intrinsics_cam_fov_y", None) is not self.cam_fov_y:
|
|
self._intrinsics_res = res = self.res
|
|
self._intrinsics_cam_fov_y = cam_fov_y = self.cam_fov_y
|
|
|
|
self._intrinsics = torch.eye(3)
|
|
p = torch.sin(torch.tensor(cam_fov_y / 2))
|
|
s = (res[1] / 2)
|
|
self._intrinsics[0, 0] = s/p # fx - focal length x
|
|
self._intrinsics[1, 1] = s/p # fy - focal length y
|
|
self._intrinsics[0, 2] = (res[1] - 1) / 2 # cx - optical center x
|
|
self._intrinsics[1, 2] = (res[0] - 1) / 2 # cy - optical center y
|
|
return self._intrinsics
|
|
|
|
@property
|
|
def raydirs_and_cam(self) -> tuple[torch.Tensor, torch.Tensor]: # (w, h, 3) and (3) dtype=float32
|
|
if getattr(self, "_raydirs_and_cam_cam2world", None) is not self.cam2world \
|
|
or getattr(self, "_raydirs_and_cam_intrinsics", None) is not self.intrinsics \
|
|
or getattr(self, "_raydirs_and_cam_uvs", None) is not self.uvs:
|
|
self._raydirs_and_cam_cam2world = cam2world = self.cam2world
|
|
self._raydirs_and_cam_intrinsics = intrinsics = self.intrinsics
|
|
self._raydirs_and_cam_uvs = uvs = self.uvs
|
|
|
|
#cam_pos = (cam2world @ torch.tensor([0, 0, 0, 1], dtype=torch.float32))[:3]
|
|
cam_pos = cam2world[:3, -1]
|
|
|
|
dirs = -geometry.get_ray_directions(uvs, cam2world[None, ...], intrinsics[None, ...]).squeeze(-1)
|
|
|
|
self._raydirs_and_cam = (dirs, cam_pos)
|
|
return (
|
|
self._raydirs_and_cam[0],
|
|
self._raydirs_and_cam[1],
|
|
)
|
|
|
|
def run(self):
|
|
self.is_headless = False
|
|
pygame.display.init() # we do not use the mixer, which often hangs on quit
|
|
try:
|
|
window = pygame.display.set_mode(self.scaled_res, flags=pygame.RESIZABLE)
|
|
buffer = pygame.surface.Surface(self.res)
|
|
|
|
window.fill(self.fill_color)
|
|
buffer.fill(self.fill_color)
|
|
pygame.display.flip()
|
|
|
|
pixel_view = pygame.surfarray.pixels3d(buffer) # (W, H, 3)
|
|
|
|
current_scale = self.scale
|
|
def remake_window_buffer(window_size: IVec2):
|
|
nonlocal buffer, pixel_view, current_scale
|
|
self.res = (
|
|
window_size[0] // self.scale,
|
|
window_size[1] // self.scale,
|
|
)
|
|
buffer = pygame.surface.Surface(self.res)
|
|
pixel_view = pygame.surfarray.pixels3d(buffer)
|
|
current_scale = self.scale
|
|
|
|
print()
|
|
|
|
self.setup()
|
|
|
|
is_running = True
|
|
clock = pygame.time.Clock()
|
|
epoch = t_prev = time.time()
|
|
self.frame_idx = -1
|
|
while is_running:
|
|
self.frame_idx += 1
|
|
if not self.fps_cap is None: clock.tick(self.fps_cap)
|
|
t = time.time()
|
|
self.td = t - t_prev
|
|
t_prev = t
|
|
self.t = t - epoch
|
|
print("\rFPS:", 1/self.td, " "*10, end="")
|
|
|
|
self.render_frame(pixel_view)
|
|
|
|
pygame.transform.scale(buffer, window.get_size(), window)
|
|
pygame.display.flip()
|
|
|
|
keys_pressed = pygame.key.get_pressed()
|
|
self.handle_keys_pressed(keys_pressed)
|
|
|
|
for event in pygame.event.get():
|
|
if event.type == pygame.VIDEORESIZE:
|
|
print()
|
|
print("== resize window ==")
|
|
remake_window_buffer(event.size)
|
|
elif event.type == pygame.QUIT:
|
|
is_running = False
|
|
elif event.type == pygame.KEYUP:
|
|
self.handle_key_up(event.key, keys_pressed)
|
|
elif event.type == pygame.KEYDOWN:
|
|
self.handle_key_down(event.key, keys_pressed)
|
|
if event.key == pygame.K_q:
|
|
is_running = False
|
|
elif event.key == pygame.K_y:
|
|
fname = self.mk_dump_fname("png")
|
|
fname.parent.mkdir(parents=True, exist_ok=True)
|
|
pygame.image.save(buffer.copy(), fname)
|
|
print()
|
|
print("Saved", fname)
|
|
elif event.type == pygame.MOUSEBUTTONUP:
|
|
self.handle_mouse_button_up(event.pos, event.button, keys_pressed)
|
|
elif event.type == pygame.MOUSEBUTTONDOWN:
|
|
self.handle_mouse_button_down(event.pos, event.button, keys_pressed)
|
|
elif event.type == pygame.MOUSEMOTION:
|
|
self.handle_mouse_motion(event.pos, event.rel, event.buttons, keys_pressed)
|
|
elif event.type == pygame.MOUSEWHEEL:
|
|
self.handle_mousewheel(event.flipped, event.x, event.y, keys_pressed)
|
|
|
|
if current_scale != self.scale:
|
|
remake_window_buffer(window.get_size())
|
|
|
|
finally:
|
|
self.teardown()
|
|
print()
|
|
pygame.quit()
|
|
|
|
def render_headless(self, output_path: str, *, n_frames: int, fps: int, state_callback: Callable[["InteractiveViewer", int], None] | None, resolution=None, bitrate=None, **kw):
|
|
self.is_headless = True
|
|
self.fps = fps
|
|
|
|
buffer = pygame.surface.Surface(self.res if resolution is None else resolution)
|
|
pixel_view = pygame.surfarray.pixels3d(buffer) # (W, H, 3)
|
|
|
|
def do():
|
|
try:
|
|
self.setup()
|
|
for frame in tqdm(range(n_frames), **kw, disable=n_frames==1):
|
|
self.frame_idx = frame
|
|
if state_callback is not None:
|
|
state_callback(self, frame)
|
|
|
|
self.render_frame(pixel_view)
|
|
|
|
yield pixel_view.copy().swapaxes(0,1)
|
|
finally:
|
|
self.teardown()
|
|
|
|
output_path = Path(output_path)
|
|
if output_path.suffix == ".png":
|
|
if n_frames > 1 and "%" not in output_path.name: raise ValueError
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
for i, framebuffer in enumerate(do()):
|
|
with imageio.get_writer(output_path.parent / output_path.name.replace("%", f"{i:04}")) as writer:
|
|
writer.append_data(framebuffer)
|
|
else: # ffmpeg - https://imageio.readthedocs.io/en/v2.9.0/format_ffmpeg.html#ffmpeg
|
|
with imageio.get_writer(output_path, fps=fps, bitrate=bitrate) as writer:
|
|
for framebuffer in do():
|
|
writer.append_data(framebuffer)
|
|
|
|
def load_sphere_map(self, fname):
|
|
self._sphere_surf = pygame.image.load(fname)
|
|
self._sphere_map = pygame.surfarray.pixels3d(self._sphere_surf)
|
|
|
|
def lookup_sphere_map_dirs(self, dirs, origins):
|
|
near, far = geometry.ray_sphere_intersect(
|
|
torch.tensor(origins),
|
|
torch.tensor(dirs),
|
|
sphere_radii = torch.tensor(origins).norm(dim=-1) * 2,
|
|
)
|
|
hits = far.detach()
|
|
|
|
x = hits[..., 0]
|
|
y = hits[..., 1]
|
|
z = hits[..., 2]
|
|
theta = (z / hits.norm(dim=-1)).acos()
|
|
phi = (y/x).atan()
|
|
phi[(x<0) & (y>=0)] += 3.14
|
|
phi[(x<0) & (y< 0)] -= 3.14
|
|
|
|
w, h = self._sphere_map.shape[:2]
|
|
|
|
return self._sphere_map[
|
|
((phi / (2*torch.pi) * w).int() % w).cpu(),
|
|
((theta / (1*torch.pi) * h).int() % h).cpu(),
|
|
]
|
|
|
|
def blit_sphere_map_mask(self, pixel_view, mask=None):
|
|
dirs, origin = self.raydirs_and_cam
|
|
if mask is None: mask = (slice(None), slice(None))
|
|
pixel_view[mask] \
|
|
= self.lookup_sphere_map_dirs(dirs, origin[None, None, :])
|
|
|
|
def mk_dump_fname(self, suffix: str, uid=None) -> Path:
|
|
name = self.name.split("-")[-1] if len(self.name) > 160 else self.name
|
|
if uid is not None: name = f"{name}-{uid}"
|
|
return self.screenshot_dir / f"pygame-viewer-{datetime.now():%Y%m%d-%H%M%S}-{name}.{suffix}"
|