This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

6
.envrc Normal file

@ -0,0 +1,6 @@
#!/usr/bin/env bash
# This file is automatically loaded with `direnv` if allowed.
# It enters you into the venv.
source .localenv

9
.gitignore vendored Normal file

@ -0,0 +1,9 @@
__pycache__
/data/models/
/data/archives/
/experiments/logdir/
/.env/
/.direnv/
*.zip
*.sh
default.yaml # pandoc preview enhanced

71
.localenv Normal file

@ -0,0 +1,71 @@
#!/usr/bin/env bash
# =======================
# bootstrap a venv
# =======================
LOCAL_ENV_NAME="py310-$(basename $(pwd))"
LOCAL_ENV_DIR="$(pwd)/.env/$LOCAL_ENV_NAME"
mkdir -p "$LOCAL_ENV_DIR"
# make configs and caches a part of venv
export POETRY_CACHE_DIR="$LOCAL_ENV_DIR/xdg/cache/poetry"
export PIP_CACHE_DIR="$LOCAL_ENV_DIR/xdg/cache/pip"
mkdir -p "$POETRY_CACHE_DIR" "$PIP_CACHE_DIR"
#export POETRY_VIRTUALENVS_IN_PROJECT=true # store venv in ./.venv/
#export POETRY_VIRTUALENVS_CREATE=false # install globally
export SETUPTOOLS_USE_DISTUTILS=stdlib # https://github.com/pre-commit/pre-commit/issues/2178#issuecomment-1002163763
export IFIELD_PRETTY_TRACEBACK=1
#export SHOW_LOCALS=1 # locals in tracebacks
export PYTHON_KEYRING_BACKEND="keyring.backends.null.Keyring"
# ensure we have the correct python and poetry. Bootstrap via conda if missing
if ! command -v python310 >/dev/null || ! command -v poetry >/dev/null; then
source .localenv-bootstrap-conda
if command -v mamba >/dev/null; then
CONDA=mamba
elif command -v conda >/dev/null; then
CONDA=conda
else
>&2 echo "ERROR: 'poetry' nor 'conda'/'mamba' could be found!"
exit 1
fi
function verbose {
echo +"$(printf " %q" "$@")"
"$@"
}
if ! ($CONDA env list | grep -q "^$LOCAL_ENV_NAME "); then
verbose $CONDA create --yes --name "$LOCAL_ENV_NAME" -c conda-forge \
python==3.10.8 poetry==1.3.1 #python-lsp-server
true
fi
verbose conda activate "$LOCAL_ENV_NAME" || exit $?
#verbose $CONDA activate "$LOCAL_ENV_NAME" || exit $?
unset -f verbose
fi
# enter poetry venv
# source .envrc
poetry run true # ensure venv exists
#source "$(poetry env info -p)/bin/activate"
export VIRTUAL_ENV=$(poetry env info --path)
export POETRY_ACTIVE=1
export PATH="$VIRTUAL_ENV/bin":"$PATH"
# NOTE: poetry currently reuses and populates the conda venv.
# See: https://github.com/python-poetry/poetry/issues/1724
# ensure output dirs exist
mkdir -p experiments/logdir
# first-time-setup poetry
if ! command -v fix-my-functions >/dev/null; then
poetry install
fi

53
.localenv-bootstrap-conda Normal file

@ -0,0 +1,53 @@
#!/usr/bin/env bash
# =======================
# bootstrap a conda venv
# =======================
CONDA_ENV_DIR="${LOCAL_ENV_DIR:-$(pwd)/.conda310}"
mkdir -p "$CONDA_ENV_DIR"
#touch "$HOME/.Xauthority"
MINICONDA_PY310_URL="https://repo.anaconda.com/miniconda/Miniconda3-py310_22.11.1-1-Linux-x86_64.sh"
MINICONDA_PY310_HASH="00938c3534750a0e4069499baf8f4e6dc1c2e471c86a59caa0dd03f4a9269db6"
# Check if conda is available
if ! command -v conda >/dev/null; then
export PATH="$CONDA_ENV_DIR/conda/bin:$PATH"
fi
# Check again if conda is available, install miniconda if not
if ! command -v conda >/dev/null; then
(set -e #x
function verbose {
echo +"$(printf " %q" "$@")"
"$@"
}
if command -v curl >/dev/null; then
verbose curl -sLo "$CONDA_ENV_DIR/miniconda_py310.sh" "$MINICONDA_PY310_URL"
elif command -v wget >/dev/null; then
verbose wget -O "$CONDA_ENV_DIR/miniconda_py310.sh" "$MINICONDA_PY310_URL"
else
echo "ERROR: unable to download miniconda!"
exit 1
fi
verbose test "$(sha256sum "$CONDA_ENV_DIR/miniconda_py310.sh")" = "$MINICONDA_PY310_HASH"
verbose chmod +x "$CONDA_ENV_DIR/miniconda_py310.sh"
verbose "$CONDA_ENV_DIR/miniconda_py310.sh" -b -u -p "$CONDA_ENV_DIR/conda"
verbose rm "$CONDA_ENV_DIR/miniconda_py310.sh"
eval "$(conda shell.bash hook)" # basically `conda init`, without modifying .bashrc
verbose conda install --yes --name base mamba -c conda-forge
) || exit $?
fi
unset CONDA_ENV_DIR
unset MINICONDA_PY310_URL
unset MINICONDA_PY310_HASH
# Enter conda environment
eval "$(conda shell.bash hook)" # basically `conda init`, without modifying .bashrc

29
.remoteenv Normal file

@ -0,0 +1,29 @@
#!/usr/bin/env bash
# this file is used by remote-cli
# Assumes repo is put in a "remotes/name-hash" folder,
# the default behaviour of remote-exec
REMOTES_DIR="$(dirname $(pwd))"
LOCAL_ENV_NAME="py310-$(basename $(pwd))"
LOCAL_ENV_DIR="$REMOTES_DIR/envs/$REMOTE_ENV_NAME"
#export XDG_CACHE_HOME="$LOCAL_ENV_DIR/xdg/cache"
#export XDG_DATA_HOME="$LOCAL_ENV_DIR/xdg/share"
#export XDG_STATE_HOME="$LOCAL_ENV_DIR/xdg/state"
#mkdir -p "$XDG_CACHE_HOME" "$XDG_DATA_HOME" "$XDG_STATE_HOME"
export XDG_CONFIG_HOME="$LOCAL_ENV_DIR/xdg/config"
mkdir -p "$XDG_CONFIG_HOME"
export PYOPENGL_PLATFORM=egl # makes pyrender work headless
#export PYOPENGL_PLATFORM=osmesa # makes pyrender work headless
export SDL_VIDEODRIVER=dummy # pygame
source .localenv
# SLURM logs output dir
if command -v sbatch >/dev/null; then
mkdir -p slurm_logs
test -L experiments/logdir/slurm_logs ||
ln -s ../../slurm_logs experiments/logdir/
fi

30
.remoteignore.toml Normal file

@ -0,0 +1,30 @@
[push]
exclude = [
"*.egg-info",
"*.pyc",
".ipynb_checkpoints",
".mypy_cache",
".remote.toml",
".remoteignore.toml",
".venv",
".wandb",
"__pycache__",
"data/models",
"docs",
"experiments/logdir",
"poetry.toml",
"slurm_logs",
"tmp",
]
include = []
[pull]
exclude = [
"*",
]
include = []
[both]
exclude = [
]
include = []

128
README.md

@ -1 +1,127 @@
This is where the code for the paper _"MARF: The Medial Atom Ray Field Object Representation"_ will be published.
# MARF: The Medial Atom Ray Field Object Representation
<center>
![](figures/nn-architecture.svg)
[Publication](https://doi.org/10.1016/j.cag.2023.06.032) | [Arxiv](https://arxiv.org/abs/2307.00037) | [Training data](https://mega.nz/file/9tsz3SbA#V6SIXpCFC4hbqWaFFvKmmS8BKir7rltXuhsqpEpE9wo) | [Network weights](https://mega.nz/file/t01AyTLK#7ZNMNgbqT9x2mhq5dxLuKeKyP7G0slfQX1RaZxifayw)
</center>
**TL;DR:** We achieve _fast_ surface rendering by predicting _n_ maximally inscribed spherical intersection candidates for each camera ray.
---
## Entering the Virtual Environment
The environment is defined in `pyproject.toml` using [Poetry](https://github.com/python-poetry/poetry) and reproducibly locked in `poetry.lock`.
We propose three ways to enter the venv:
```shell
# Requires Python 3.10 and Poetry
poetry install
poetry shell
# Will bootstrap a Miniconda 3.10 environment into .env/ if needed, then run poetry
source .localenv
```
## Evaluation
### Pretrained models
You can download our pre-trained models` from <https://mega.nz/file/t01AyTLK#7ZNMNgbqT9x2mhq5dxLuKeKyP7G0slfQX1RaZxifayw>.
It should be unpacked into the root directory, such that the `experiment` folder gets merged.
### The interactive renderer
We automatically create experiment names with a schema of `{{model}}-{{experiment-name}}-{{hparams-summary}}-{{date}}-{{random-uid}}`.
You can load experiment weights using either the full path, or just the `random-uid` bit.
From the `experiments` directory:
```shell
./marf.py model {{experiment}} viewer
```
If you have downloaded our pre-trained network weights, consider trying:
```shell
./marf.py model nqzh viewer # Stanford Bunny (single-shape)
./marf.py model wznx viewer # Stanford Buddha (single-shape)
./marf.py model mxwd viewer # Stanford Armadillo (single-shape)
./marf.py model camo viewer # Stanford Dragon (single-shape)
./marf.py model ksul viewer # Stanford Lucy (single-shape)
./marf.py model oxrf viewer # COSEG four-legged (multi-shape)
```
## Training and Evaluation Data
You can download a pre-computed archive from <https://mega.nz/file/9tsz3SbA#V6SIXpCFC4hbqWaFFvKmmS8BKir7rltXuhsqpEpE9wo>.
It should be extracted into the root directory such that a `data` directory is added to the root directory.
<details>
<summary>
Optionally, you may compute the data yourself.
</summary>
Single-shape training data:
```shell
# takes takes about 23 minutes, mainly due to lucy
download-stanford bunny happy_buddha dragon armadillo lucy
preprocess-stanford bunny happy_buddha dragon armadillo lucy \
--precompute-mesh-sv-scan-uv \
--compute-miss-distances \
--fill-missing-uv-points
```
Multi-shape training data:
```shell
# takes takes about 29 minutes
download-coseg four-legged --shapes
preprocess-coseg four-legged \
--precompute-mesh-sv-scan-uv \
--compute-miss-distances \
--fill-missing-uv-points
```
Evaluation data:
```shell
# takes takes about 2 hour 20 minutes, mainly due to lucy
preprocess-stanford bunny happy_buddha dragon armadillo lucy \
--precompute-mesh-sphere-scan \
--compute-miss-distances
```
```shell
# takes takes about 4 hours
preprocess-coseg four-legged \
--precompute-mesh-sphere-scan \
--compute-miss-distances
```
</details>
## Training
Our experiments are defined using YAML config files, optionally templated using Jinja2 as a preprocessor.
These templates accept additional input from the command line in the form of `-Okey=value` options.
Our whole experiment matrix is defined in `marf.yaml.j12`. We select between different experiment groups using `-Omode={single,ablation,multi}`, and which experiment using `-Oselect={{integer}}`
From the `experiments` directory:
CPU mode:
```shell
./marf.py model marf.yaml.j2 -Oexperiment_name=cpu_test -Omode=single -Oselect=0 fit
```
GPU mode:
```shell
./marf.py model marf.yaml.j2 -Oexperiment_name=cpu_test -Omode=single -Oselect=0 fit --accelerator gpu --devices 1
```

139
ablation.md Normal file

@ -0,0 +1,139 @@
### MARF
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0010-nqzh`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0312-wznx`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-1944-mxwd`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0529-camo`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0743-ksul`
### LFN encoding
- `experiment-stanfordv12-dragon-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0539-xjte`
- `experiment-stanfordv12-lucy-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0753-ayvt`
- `experiment-stanfordv12-bunny-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0022-axft`
- `experiment-stanfordv12-happy_buddha-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0322-xfoc`
- `experiment-stanfordv12-armadillo-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2039-vbks`
### PRIF encoding
- `experiment-stanfordv12-armadillo-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2033-nkxm`
- `experiment-stanfordv12-happy_buddha-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0313-huci`
- `experiment-stanfordv12-dragon-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0537-dxsb`
- `experiment-stanfordv12-bunny-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0011-tzic`
- `experiment-stanfordv12-lucy-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0744-hzvw`
### No init scheme.
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0444-uohy`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2307-wjcf`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0707-eanc`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0225-kcfw`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0852-lkfh`
### 1 atom candidate
- `experiment-stanfordv12-lucy-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0755-qzth`
- `experiment-stanfordv12-bunny-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0027-ycnl`
- `experiment-stanfordv12-armadillo-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2121-fwvo`
- `experiment-stanfordv12-dragon-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0541-nvhs`
- `experiment-stanfordv12-happy_buddha-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0324-cuyw`
### 4 atom candidates
- `experiment-stanfordv12-armadillo-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2122-qiwg`
- `experiment-stanfordv12-dragon-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0544-ihkx`
- `experiment-stanfordv12-lucy-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0757-jwxm`
- `experiment-stanfordv12-happy_buddha-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0328-chhs`
- `experiment-stanfordv12-bunny-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0038-zymb`
### 8 atom candidates
- `experiment-stanfordv12-bunny-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0055-ogpd`
- `experiment-stanfordv12-lucy-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0757-frxb`
- `experiment-stanfordv12-happy_buddha-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0337-twys`
- `experiment-stanfordv12-dragon-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0551-bubw`
- `experiment-stanfordv12-armadillo-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2137-nnlj`
### 32 atom candidates
- `experiment-stanfordv12-bunny-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0056-ourc`
- `experiment-stanfordv12-armadillo-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2141-byaj`
- `experiment-stanfordv12-dragon-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0554-zobg`
- `experiment-stanfordv12-happy_buddha-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0337-rmyq`
- `experiment-stanfordv12-lucy-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0800-lqen`
### 64 atom candidates
- `experiment-stanfordv12-happy_buddha-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0339-whcx`
- `experiment-stanfordv12-bunny-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0058-seen`
- `experiment-stanfordv12-lucy-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0806-ycxj`
- `experiment-stanfordv12-armadillo-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2153-wnfq`
- `experiment-stanfordv12-dragon-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0555-zgcb`
### No intersection loss
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1053-ydnh`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1111-fawl`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1045-umwl`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1103-lwmb`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1041-lhcc`
### No silhouette loss
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1042-fsuw`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1046-nszw`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1111-mlal`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1055-cvkg`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1114-pdyh`
### More silhouette loss
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0157-yekm`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2243-nlrv`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0639-yros`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0842-xktg`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0423-ibxs`
### No normal loss
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0614-ttta`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0106-bnke`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2154-bxwl`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0811-qqgu`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0357-gwca`
### No inscription loss
- `experiment-stanfordv12-bunny-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0227-xrqt`
- `experiment-stanfordv12-armadillo-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2312-cgzv`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0452-rerr`
- `experiment-stanfordv12-dragon-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0709-tfgg`
- `experiment-stanfordv12-lucy-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0856-ctvc`
### More inscription loss
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0459-kyyh`
- `experiment-stanfordv12-bunny-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0243-qqqj`
- `experiment-stanfordv12-armadillo-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2336-yclo`
- `experiment-stanfordv12-lucy-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0913-mulv`
- `experiment-stanfordv12-dragon-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0714-zugg`
### No maximality reg.
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0842-cvln`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0425-vpen`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0207-qpdb`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2251-zqvi`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0641-ucdo`
### More maximality reg.
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0659-bqvf`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2256-escz`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0208-wmvs`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0442-gdah`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0845-halc`
### No specialization reg.
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0913-odyn`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0251-xzig`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0722-gxps`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2342-zybo`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0507-tvlt`
### No multi-view loss
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0310-wbqj`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-30-2357-qnct`
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0527-psnk`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0927-wxcq`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0743-pdbc`
### More multi-view loss
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0510-caah`
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0726-zkyg`
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0254-akbq`
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0924-aahb`
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-30-2352-xlrn`

624
experiments/marf.py Executable file

@ -0,0 +1,624 @@
#!/usr/bin/env python3
from abc import ABC, abstractmethod
from argparse import Namespace
from collections import defaultdict
from datetime import datetime
from ifield import logging
from ifield.cli import CliInterface
from ifield.data.common.scan import SingleViewUVScan
from ifield.data.coseg import read as coseg_read
from ifield.data.stanford import read as stanford_read
from ifield.datasets import stanford, coseg, common
from ifield.models import intersection_fields
from ifield.utils.operators import diff
from ifield.viewer.ray_field import ModelViewer
from munch import Munch
from pathlib import Path
from pytorch3d.loss.chamfer import chamfer_distance
from pytorch_lightning.utilities import rank_zero_only
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from trimesh import Trimesh
from typing import Union
import builtins
import itertools
import json
import numpy as np
import pytorch_lightning as pl
import rich
import rich.pretty
import statistics
import torch
pl.seed_everything(31337)
torch.set_float32_matmul_precision('medium')
IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity
class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
@property
@abstractmethod
def observation_ids(self) -> list[str]:
...
@abstractmethod
def mk_ad_dataset(self) -> common.AutodecoderDataset:
...
@staticmethod
@abstractmethod
def get_trimesh_from_uid(uid) -> Trimesh:
...
@staticmethod
@abstractmethod
def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
...
def setup(self, stage=None):
assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM
if not self.hparams.data_dir is None:
coseg.config.DATA_PATH = self.hparams.data_dir
step = self.hparams.step # brevity
dataset = self.mk_ad_dataset()
n_items_pre_step_mapping = len(dataset)
if step > 1:
dataset = common.TransformExtendedDataset(dataset)
for sx in range(step):
for sy in range(step):
def make_slicer(sx, sy, step) -> callable: # the closure is required
if step > 1:
return lambda t: t[sx::step, sy::step]
else:
return lambda t: t
@dataset.map(slicer=make_slicer(sx, sy, step))
def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
scan: SingleViewUVScan = sample[1]
assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"
return {
"z_uid" : sample[0],
"origins" : scan.cam_pos,
"dirs" : slicer(scan.ray_dirs),
"points" : slicer(scan.points),
"hits" : slicer(scan.hits),
"miss" : slicer(scan.miss),
"normals" : slicer(scan.normals),
"distances" : slicer(scan.distances),
}
# Split each object into train/val with SampleSplit
n_items = len(dataset)
n_val = int(n_items * self.hparams.val_fraction)
n_train = n_items - n_val
self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)
# split the dataset such that all steps are in same part
assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
indices = [
i*step*step + sx*step + sy
for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
for sx in range(step)
for sy in range(step)
]
self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
assert len(self.dataset_train) % self.hparams.batch_size == 0
assert len(self.dataset_val) % self.hparams.batch_size == 0
def train_dataloader(self):
return DataLoader(self.dataset_train,
batch_size = self.hparams.batch_size,
drop_last = self.hparams.drop_last,
num_workers = self.hparams.num_workers,
persistent_workers = self.hparams.persistent_workers,
pin_memory = self.hparams.pin_memory,
prefetch_factor = self.hparams.prefetch_factor,
shuffle = self.hparams.shuffle,
generator = self.generator,
)
def val_dataloader(self):
return DataLoader(self.dataset_val,
batch_size = self.hparams.batch_size,
drop_last = self.hparams.drop_last,
num_workers = self.hparams.num_workers,
persistent_workers = self.hparams.persistent_workers,
pin_memory = self.hparams.pin_memory,
prefetch_factor = self.hparams.prefetch_factor,
generator = self.generator,
)
class StanfordUVDataModule(RayFieldAdDataModuleBase):
skyward = "+Z"
def __init__(self,
data_dir : Union[str, Path, None] = None,
obj_names : list[str] = ["bunny"], # empty means all
prng_seed : int = 1337,
step : int = 2,
batch_size : int = 5,
drop_last : bool = False,
num_workers : int = 8,
persistent_workers : bool = True,
pin_memory : int = True,
prefetch_factor : int = 2,
shuffle : bool = True,
val_fraction : float = 0.30,
):
super().__init__()
if not obj_names:
obj_names = stanford_read.list_object_names()
self.save_hyperparameters()
@property
def observation_ids(self) -> list[str]:
return self.hparams.obj_names
def mk_ad_dataset(self) -> common.AutodecoderDataset:
return stanford.AutodecoderSingleViewUVScanDataset(
obj_names = self.hparams.obj_names,
data_path = self.hparams.data_dir,
)
@staticmethod
def get_trimesh_from_uid(obj_name) -> Trimesh:
import mesh_to_sdf
mesh = stanford_read.read_mesh(obj_name)
return mesh_to_sdf.scale_to_unit_sphere(mesh)
@staticmethod
def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
return stanford_read.read_mesh_mesh_sphere_scan(obj_name)
class CosegUVDataModule(RayFieldAdDataModuleBase):
skyward = "+Y"
def __init__(self,
data_dir : Union[str, Path, None] = None,
object_sets : tuple[str] = ["tele-aliens"], # empty means all
prng_seed : int = 1337,
step : int = 2,
batch_size : int = 5,
drop_last : bool = False,
num_workers : int = 8,
persistent_workers : bool = True,
pin_memory : int = True,
prefetch_factor : int = 2,
shuffle : bool = True,
val_fraction : float = 0.30,
):
super().__init__()
if not object_sets:
object_sets = coseg_read.list_object_sets()
object_sets = tuple(object_sets)
self.save_hyperparameters()
@property
def observation_ids(self) -> list[str]:
return coseg_read.list_model_id_strings(self.hparams.object_sets)
def mk_ad_dataset(self) -> common.AutodecoderDataset:
return coseg.AutodecoderSingleViewUVScanDataset(
object_sets = self.hparams.object_sets,
data_path = self.hparams.data_dir,
)
@staticmethod
def get_trimesh_from_uid(string_uid):
raise NotImplementedError
@staticmethod
def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
uid = coseg_read.model_id_string_to_uid(string_uid)
return coseg_read.read_mesh_mesh_sphere_scan(*uid)
def mk_cli(args=None) -> CliInterface:
cli = CliInterface(
module_cls = IField,
datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
workdir = Path(__file__).parent.resolve(),
experiment_name_prefix = "ifield",
)
cli.trainer_defaults.update(dict(
precision = 16,
min_epochs = 5,
))
@cli.register_pre_training_callback
def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
module.set_observation_ids(datamodule.observation_ids)
rank = getattr(rank_zero_only, "rank", 0)
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }")
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
rich.print(f"[rank {rank}] {module.is_conditioned = }")
@cli.register_action(help="Interactive window with direct renderings from the model", args=[
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--ground-truth", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--scale", dict(type=int, default=3, help="Rendering scale")),
("--fps", dict(type=int, default=None, help="FPS upper limit")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
("--write", dict(type=Path, default=None, help="Where to write a screenshot.")),
])
@torch.no_grad()
def viewer(args: Namespace, config: Munch, model: IField):
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
scale = args.scale,
mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.fps_cap = args.fps
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
if args.write is None:
viewer.run()
else:
assert args.write.suffix == ".png", args.write.name
viewer.render_headless(args.write,
n_frames = 1,
fps = 1,
state_callback = None,
)
@cli.register_action(help="Prerender direct renderings from the model", args=[
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
("uids", dict(type=str, nargs="*")),
("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
("--fps", dict(type=int, default=60, help="Default is 60")),
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
])
@torch.no_grad()
def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
uids = args.uids or list(model.keys())
assert len(uids) > 1
if not args.uids: uids.append(uids[0])
viewer = ModelViewer(model, uids[0],
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
)
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
def state_callback(self: ModelViewer, frame: int):
if frame % args.frames:
self.lambertian_color = (0.8, 0.8, 1.0)
else:
self.lambertian_color = (1.0, 1.0, 1.0)
self.fps = args.frames
idx = frame // args.frames + 1
if idx != len(uids):
self.current_uid = uids[idx]
print(f"Writing video to {str(args.output_path)!r}...")
viewer.render_headless(args.output_path,
n_frames = args.frames * (len(uids)-1) + 1,
fps = args.fps,
state_callback = state_callback,
bitrate = args.bitrate,
)
@cli.register_action(help="Prerender direct renderings from the model", args=[
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")),
("--fps", dict(type=int, default=60, help="Default is 60")),
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
])
@torch.no_grad()
def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
)
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
cam_rot_x_init = viewer.cam_rot_x
def state_callback(self: ModelViewer, frame: int):
self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
print(f"Writing video to {str(args.output_path)!r}...")
viewer.render_headless(args.output_path,
n_frames = args.frames,
fps = args.fps,
state_callback = state_callback,
bitrate = args.bitrate,
)
@cli.register_action(help="foo", args=[
("fname", dict(type=Path, help="where to write json")),
("-t", "--transpose", dict(action="store_true", help="transpose the output")),
("--single-shape", dict(action="store_true", help="break after first shape")),
("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
("--filter-outliers", dict(action="store_true", help="like in PRIF")),
])
@torch.enable_grad()
def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
model.eval()
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
def T(array: np.ndarray, **kw) -> torch.Tensor:
if isinstance(array, torch.Tensor): return array
return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)
MEDIAL = model.hparams.output_mode == "medial_sphere"
if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"
uids = sorted(model.keys())
if args.single_shape: uids = [uids[0]]
rich.print(f"{datamodule_cls.__name__ = }")
rich.print(f"{len(uids) = }")
# accumulators for IoU and F-Score, CD and COS
# sum reduction:
n = defaultdict(int)
n_gt_hits = defaultdict(int)
n_gt_miss = defaultdict(int)
n_gt_missing = defaultdict(int)
n_outliers = defaultdict(int)
p_mse = defaultdict(int)
s_mse = defaultdict(int)
cossim_med = defaultdict(int) # medial normals
cossim_jac = defaultdict(int) # jacovian normals
TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score
# mean reduction:
cd_dist = {} # chamfer distance
cd_cos_med = {} # chamfer medial normals
cd_cos_jac = {} # chamfer jacovian normals
all_metrics = dict(
n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
cossim_jac=cossim_jac,
TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
cd_cos_jac=cd_cos_jac,
)
if MEDIAL:
all_metrics["s_mse"] = s_mse
all_metrics["cossim_med"] = cossim_med
all_metrics["cd_cos_med"] = cd_cos_med
if args.filter_outliers:
all_metrics["n_outliers"] = n_outliers
t = datetime.now()
for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)
z = model[uid].detach()
all_intersections = []
all_medial_normals = []
all_jacobian_normals = []
step = args.batch_size
for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
# prepare batch and gt
origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True)
dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :])
gt_hits = T(sphere_scan_gt.hits [i:i+step])
gt_miss = T(sphere_scan_gt.miss [i:i+step])
gt_missing = T(sphere_scan_gt.missing [i:i+step])
gt_points = T(sphere_scan_gt.points [i:i+step, :])
gt_normals = T(sphere_scan_gt.normals [i:i+step, :])
gt_distances = T(sphere_scan_gt.distances[i:i+step])
# forward
if MEDIAL:
(
depths,
silhouettes,
intersections,
medial_normals,
is_intersecting,
sphere_centers,
sphere_radii,
) = model({
"origins" : origins,
"dirs" : dirs,
}, z, intersections_only=False, allow_nans=False)
else:
silhouettes = medial_normals = None
intersections, is_intersecting = model({
"origins" : origins,
"dirs" : dirs,
}, z, normalize_origins = True)
is_intersecting = is_intersecting > 0.5
jac = diff.jacobian(intersections, origins, detach=True)
# outlier removal (PRIF)
if args.filter_outliers:
outliers = jac.norm(dim=-2).norm(dim=-1) > 5
n_outliers[uid] += outliers[is_intersecting].sum().item()
# We count filtered points as misses
is_intersecting &= ~outliers
model.zero_grad()
jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)
all_intersections .append(intersections .detach()[is_intersecting.detach(), :])
all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None
all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])
# accumulate metrics
with torch.no_grad():
n [uid] += dirs.shape[0]
n_gt_hits [uid] += gt_hits.sum().item()
n_gt_miss [uid] += gt_miss.sum().item()
n_gt_missing [uid] += gt_missing.sum().item()
p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item()
if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
not_intersecting = ~is_intersecting
TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive
FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive
TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative
all_intersections = torch.cat(all_intersections, dim=0)
all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None
all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)
hits = sphere_scan_gt.hits # brevity
print()
assert all_intersections.shape[0] >= args.n_cd
idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd]
idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd]
print("cd... ", end="")
tt = datetime.now()
loss_cd, loss_cos_jac = chamfer_distance(
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(),
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
batch_reduction = "sum", point_reduction = "sum",
)
if MEDIAL: _, loss_cos_med = chamfer_distance(
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(),
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
batch_reduction = "sum", point_reduction = "sum",
)
print(datetime.now() - tt)
cd_dist [uid] = loss_cd.item()
cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
cd_cos_jac [uid] = loss_cos_jac.item()
print()
model.zero_grad(set_to_none=True)
print("Total time:", datetime.now() - t)
print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None
sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs)))
mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
n_cd = args.n_cd
P = sum(TP)/(sum(TP, FP))
R = sum(TP)/(sum(TP, FN))
print(f"{mean(n) = :11.1f} (rays per object)")
print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)")
print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)")
print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)")
print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None
print(f"{n_cd = :11.0f} (cd rays per object)")
print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)")
print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)")
print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)")
print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None
print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)")
print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)")
print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)")
print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)")
print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)")
print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)")
print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)")
print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)")
print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)")
print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)")
print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)")
print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None
print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)")
print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)")
print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None
print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)")
print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None
print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None
print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None
if args.transpose:
all_metrics, old_metrics = defaultdict(dict), all_metrics
for m, table in old_metrics.items():
for uid, vals in table.items():
all_metrics[uid][m] = vals
all_metrics["_hparams"] = dict(n_cd=args.n_cd)
else:
all_metrics["n_cd"] = args.n_cd
if str(args.fname) == "-":
print("{", ',\n'.join(
f" {json.dumps(k)}: {json.dumps(v)}"
for k, v in all_metrics.items()
), "}", sep="\n")
else:
args.fname.parent.mkdir(parents=True, exist_ok=True)
with args.fname.open("w") as f:
json.dump(all_metrics, f, indent=2)
return cli
if __name__ == "__main__":
mk_cli().run()

263
experiments/marf.yaml.j2 Executable file

@ -0,0 +1,263 @@
#!/usr/bin/env -S python ./marf.py module
{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}
{% do require_defined("mode", mode, "single", "ablation", "multi", strict=true, exchaustive=true) %}{# requires jinja2.ext.do #}
{% set counter = itertools.count(start=0, step=1) %}
{% set do_condition = mode == "multi" %}
{% set do_ablation = mode == "ablation" %}
{% set hp_matrix = namespace() %}{# hyper parameter matrix #}
{% set hp_matrix.input_mode = [
"both",
"perp_foot",
"plucker",
] if do_ablation else [ "both" ] %}
{% set hp_matrix.output_mode = ["medial_sphere", "orthogonal_plane"] %}{##}
{% set hp_matrix.output_mode = ["medial_sphere"] %}{##}
{% set hp_matrix.n_atoms = [16, 1, 4, 8, 32, 64] if do_ablation else [16] %}{##}
{% set hp_matrix.normal_coeff = [0.25, 0] if do_ablation else [0.25] %}{##}
{% set hp_matrix.dataset_item = [objname] if objname is defined else (["armadillo", "bunny", "happy_buddha", "dragon", "lucy"] if not do_condition else ["four-legged"]) %}{##}
{% set hp_matrix.test_val_split_frac = [0.7] %}{##}
{% set hp_matrix.lr_coeff = [5] %}{##}
{% set hp_matrix.warmup_epochs = [1] if not do_condition else [0.1] %}{##}
{% set hp_matrix.improve_miss_grads = [True] %}{##}
{% set hp_matrix.normalize_ray_dirs = [True] %}{##}
{% set hp_matrix.intersection_coeff = [2, 0] if do_ablation else [2] %}{##}
{% set hp_matrix.miss_distance_coeff = [1, 0, 5] if do_ablation else [1] %}{##}
{% set hp_matrix.relative_out = [False] %}{##}
{% set hp_matrix.hidden_features = [512] %}{# like deepsdf and prif #}
{% set hp_matrix.hidden_layers = [8] %}{# like deepsdf, nerf, prif #}
{% set hp_matrix.nonlinearity = ["leaky_relu"] %}{##}
{% set hp_matrix.omega = [30] %}{##}
{% set hp_matrix.normalization = ["layernorm"] %}{##}
{% set hp_matrix.dropout_percent = [1] %}{##}
{% set hp_matrix.sphere_grow_reg_coeff = [500, 0, 5000] if do_ablation else [500] %}{##}
{% set hp_matrix.geom_init = [True, False] if do_ablation else [True] %}{##}
{% set hp_matrix.loss_inscription = [50, 0, 250] if do_ablation else [50] %}{##}
{% set hp_matrix.atom_centroid_norm_std_reg_negexp = [0, None] if do_ablation else [0] %}{##}
{% set hp_matrix.curvature_reg_coeff = [0.2] %}{##}
{% set hp_matrix.multi_view_reg_coeff = [1, 2] if do_ablation else [1] %}{##}
{% set hp_matrix.grad_reg = [ "multi_view", "nogradreg" ] if do_ablation else [ "multi_view" ] %}
{#% for hp in cartesian_hparams(hp_matrix) %}{##}
{% for hp in ablation_hparams(hp_matrix, caartesian_keys=["output_mode", "dataset_item", "nonlinearity", "test_val_split_frac"]) %}
{% if hp.output_mode == "orthogonal_plane"%}
{% if hp.normal_coeff == 0 %}{% set hp.normal_coeff = 0.25 %}
{% elif hp.normal_coeff == 0.25 %}{% set hp.normal_coeff = 0 %}
{% endif %}
{% if hp.grad_reg == "multi_view" %}{% set hp.grad_reg = "nogradreg" %}
{% elif hp.grad_reg == "nogradreg" %}{% set hp.grad_reg = "multi_view" %}
{% endif %}
{% endif %}
{# filter bad/uninteresting hparam combos #}
{% if ( hp.nonlinearity != "sine" and hp.omega != 30 )
or ( hp.nonlinearity == "sine" and hp.normalization in ("layernorm", "layernorm_na") )
or ( hp.multi_view_reg_coeff != 1 and "multi_view" not in hp.grad_reg )
or ( "curvature" not in hp.grad_reg and hp.curvature_reg_coeff != 0.2 )
or ( hp.output_mode == "orthogonal_plane" and hp.input_mode != "both" )
or ( hp.output_mode == "orthogonal_plane" and hp.atom_centroid_norm_std_reg_negexp != 0 )
or ( hp.output_mode == "orthogonal_plane" and hp.n_atoms != 16 )
or ( hp.output_mode == "orthogonal_plane" and hp.sphere_grow_reg_coeff != 500 )
or ( hp.output_mode == "orthogonal_plane" and hp.loss_inscription != 50 )
or ( hp.output_mode == "orthogonal_plane" and hp.miss_distance_coeff != 1 )
or ( hp.output_mode == "orthogonal_plane" and hp.test_val_split_frac != 0.7 )
or ( hp.output_mode == "orthogonal_plane" and hp.lr_coeff != 5 )
or ( hp.output_mode == "orthogonal_plane" and not hp.geom_init )
or ( hp.output_mode == "orthogonal_plane" and not hp.intersection_coeff )
%}
{% continue %}{# requires jinja2.ext.loopcontrols #}
{% endif %}
{% set index = next(counter) %}
{% if select is not defined and index > 0 %}---{% endif %}
{% if select is not defined or int(select) == index %}
trainer:
gradient_clip_val : 1.0
max_epochs : 200
min_epochs : 200
log_every_n_steps : 20
{% if not do_condition %}
StanfordUVDataModule:
obj_names : ["{{ hp.dataset_item }}"]
step : 4
batch_size : 8
val_fraction : {{ 1-hp.test_val_split_frac }}
{% else %}{# if do_condition #}
CosegUVDataModule:
object_sets : ["{{ hp.dataset_item }}"]
step : 4
batch_size : 8
val_fraction : {{ 1-hp.test_val_split_frac }}
{% endif %}{# if do_condition #}
logging:
save_dir : logdir
type : tensorboard
project : ifield
{% autoescape false %}
{% do require_defined("experiment_name", experiment_name, "single-shape" if do_condition else "multi-shape", strict=true) %}
{% set input_mode_abbr = hp.input_mode
.replace("plucker", "plkr")
.replace("perp_foot", "prpft")
%}
{% set output_mode_abbr = hp.output_mode
.replace("medial_sphere", "marf")
.replace("orthogonal_plane", "prif")
%}
experiment_name: experiment-{{ "" if experiment_name is not defined else experiment_name }}
{#--#}-{{ hp.dataset_item }}
{#--#}-{{ input_mode_abbr }}2{{ output_mode_abbr }}
{#--#}
{%- if hp.output_mode == "medial_sphere" -%}
{#--#}-{{ hp.n_atoms }}atom
{#--# }-{{ "rel" if hp.relative_out else "norel" }}
{#--# }-{{ "e" if hp.improve_miss_grads else "0" }}sqrt
{#--#}-{{ int(hp.loss_inscription) if hp.loss_inscription else "no" }}xinscr
{#--#}-{{ int(hp.miss_distance_coeff * 10) }}dmiss
{#--#}-{{ "geom" if hp.geom_init else "nogeom" }}
{#--#}{% if "curvature" in hp.grad_reg %}
{#- -#}-{{ int(hp.curvature_reg_coeff*10) }}crv
{#--#}{%- endif -%}
{%- elif hp.output_mode == "orthogonal_plane" -%}
{#--#}
{%- endif -%}
{#--#}-{{ int(hp.intersection_coeff*10) }}chit
{#--#}-{{ int(hp.normal_coeff*100) or "no" }}cnrml
{#--# }-{{ "do" if hp.normalize_ray_dirs else "no" }}raynorm
{#--#}-{{ hp.hidden_layers }}x{{ hp.hidden_features }}fc
{#--#}-{{ hp.nonlinearity or "linear" }}
{#--#}
{%- if hp.nonlinearity == "sine" -%}
{#--#}-{{ hp.omega }}omega
{#--#}
{%- endif -%}
{%- if hp.output_mode == "medial_sphere" -%}
{#--#}-{{ str(hp.atom_centroid_norm_std_reg_negexp).replace(*"-n") if hp.atom_centroid_norm_std_reg_negexp is not none else 'no' }}minatomstdngxp
{#--#}-{{ hp.sphere_grow_reg_coeff }}sphgrow
{#--#}
{%- endif -%}
{#--#}-{{ int(hp.dropout_percent*10) }}mdrop
{#--#}-{{ hp.normalization or "nonorm" }}
{#--#}-{{ hp.grad_reg }}
{#--#}{% if "multi_view" in hp.grad_reg %}
{#- -#}-{{ int(hp.multi_view_reg_coeff*10) }}dmv
{#--#}{%- endif -%}
{#--#}-{{ "concat" if do_condition else "nocond" }}
{#--#}-{{ int(hp.warmup_epochs*100) }}cwu{{ int(hp.lr_coeff*100) }}clr{{ int(hp.test_val_split_frac*100) }}tvs
{#--#}-{{ gen_run_uid(4) }} # select with --Oselect={{ index }}
{#--#}
{##}
{% endautoescape %}
IntersectionFieldAutoDecoderModel:
_extra: # used for easier introspection with jq
dataset_item: {{ hp.dataset_item | to_json}}
dataset_test_val_frac: {{ hp.test_val_split_frac }}
select: {{ index }}
input_mode : {{ hp.input_mode }} # in {plucker, perp_foot, both}
output_mode : {{ hp.output_mode }} # in {medial_sphere, orthogonal_plane}
#latent_features : 256 # int
#latent_features : 128 # int
latent_features : 16 # int
hidden_features : {{ hp.hidden_features }} # int
hidden_layers : {{ hp.hidden_layers }} # int
improve_miss_grads : {{ bool(hp.improve_miss_grads) | to_json }}
normalize_ray_dirs : {{ bool(hp.normalize_ray_dirs) | to_json }}
loss_intersection : {{ hp.intersection_coeff }}
loss_intersection_l2 : 0
loss_intersection_proj : 0
loss_intersection_proj_l2 : 0
loss_normal_cossim : {{ hp.normal_coeff }} * EaseSin(85, 15)
loss_normal_euclid : 0
loss_normal_cossim_proj : 0
loss_normal_euclid_proj : 0
{% if "multi_view" in hp.grad_reg %}
loss_multi_view_reg : 0.1 * {{ hp.multi_view_reg_coeff }} * Linear(50)
{% else %}
loss_multi_view_reg : 0
{% endif %}
{% if hp.output_mode == "orthogonal_plane" %}
loss_hit_cross_entropy : 1
{% elif hp.output_mode == "medial_sphere" %}
loss_hit_nodistance_l1 : 0
loss_hit_nodistance_l2 : 100 * {{ hp.miss_distance_coeff }}
loss_miss_distance_l1 : 0
loss_miss_distance_l2 : 10 * {{ hp.miss_distance_coeff }}
loss_inscription_hits : {{ 0.4 * hp.loss_inscription }}
loss_inscription_miss : 0
loss_inscription_hits_l2 : 0
loss_inscription_miss_l2 : {{ 6 * hp.loss_inscription }}
loss_sphere_grow_reg : 1e-6 * {{ hp.sphere_grow_reg_coeff }} # constant
loss_atom_centroid_norm_std_reg: (0.09*(1-Linear(40)) + 0.01) * {{ 10**(-hp.atom_centroid_norm_std_reg_negexp) if hp.atom_centroid_norm_std_reg_negexp is not none else 0 }}
{% else %}{#endif hp.output_mode == "medial_sphere" #}
THIS IS INVALID YAML
{% endif %}
loss_embedding_norm : 0.01**2 * Linear(30, 0.1)
opt_learning_rate : {{ hp.lr_coeff }} * 10**(-4-0.5*EaseSin(170, 30)) # layernorm
opt_warmup : {{ hp.warmup_epochs }}
opt_weight_decay : 5e-6 # float
{% if hp.output_mode == "medial_sphere" %}
# MedialAtomNet:
n_atoms : {{ hp.n_atoms }} # int
{% if hp.geom_init %}
final_init_wrr: [0.05, 0.6, 0.1]
{% else %}
final_init_wrr: null
{% endif %}
{% endif %}
# FCBlock:
normalization : {{ hp.normalization or "null" }} # in {null, layernorm, layernorm_na, weightnorm}
nonlinearity : {{ hp.nonlinearity or "null" }} # in {null, relu, leaky_relu, silu, softplus, elu, selu, sine, sigmoid, tanh }
{% set middle = 1 + hp.hidden_layers // 2 + (hp.hidden_layers % 2) %}{##}
concat_skipped_layers : [{{ middle }}, -1]
{% if do_condition %}
concat_conditioned_layers : [0, {{ middle }}]
{% else %}
concat_conditioned_layers : []
{% endif %}
# FCLayer:
negative_slope : 0.01 # float
omega_0 : {{ hp.omega }} # float
residual_mode : null # in {null, identity}
{% endif %}{# -Oselect #}
{% endfor %}
{% set index = next(counter) %}
# number of possible -Oselect: {{ index }}, from 0 to {{ index-1 }}
# local: for select in {0..{{ index-1 }}}; do python ... -Omode={{ mode }} -Oselect=$select ... ; done
# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID ...
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp

849
experiments/summary.py Executable file

@ -0,0 +1,849 @@
#!/usr/bin/env python
from concurrent.futures import ThreadPoolExecutor, Future, ProcessPoolExecutor
from functools import partial
from more_itertools import first, last, tail
from munch import Munch, DefaultMunch, munchify, unmunchify
from pathlib import Path
from statistics import mean, StatisticsError
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Iterable, Optional, Literal
from math import isnan
import json
import stat
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os, os.path
import re
import shlex
import time
import itertools
import shutil
import subprocess
import sys
import traceback
import typer
import warnings
import yaml
import tempfile
EXPERIMENTS = Path(__file__).resolve()
LOGDIR = EXPERIMENTS / "logdir"
TENSORBOARD = LOGDIR / "tensorboard"
SLURM_LOGS = LOGDIR / "slurm_logs"
CACHED_SUMMARIES = LOGDIR / "cached_summaries"
COMPUTED_SCORES = LOGDIR / "computed_scores"
MISSING = object()
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
def ignore_unknown(self, node):
return None
SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
parts = (
part.lower()
for part in re.split(r'(?=[A-Z])', text)
if part
)
if join_abbreviations: # this operation is not reversible
parts = list(parts)
if len(parts) > 1:
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
if len(a) == len(b) == 1:
parts[i] = parts[i] + parts.pop(i+1)
return sep.join(parts)
def flatten_dict(data: dict, key_mapper: callable = lambda x: x) -> dict:
if not any(isinstance(val, dict) for val in data.values()):
return data
else:
return {
k: v
for k, v in data.items()
if not isinstance(v, dict)
} | {
f"{key_mapper(p)}/{k}":v
for p,d in data.items()
if isinstance(d, dict)
for k,v in d.items()
}
def parse_jsonl(data: str) -> Iterable[dict]:
yield from map(json.loads, (line for line in data.splitlines() if line.strip()))
def read_jsonl(path: Path) -> Iterable[dict]:
with path.open("r") as f:
data = f.read()
yield from parse_jsonl(data)
def get_experiment_paths(filter: str | None, assert_dumped = False) -> Iterable[Path]:
for path in TENSORBOARD.iterdir():
if filter is not None and not re.search(filter, path.name): continue
if not path.is_dir(): continue
if not (path / "hparams.yaml").is_file():
warnings.warn(f"Missing hparams: {path}")
continue
if not any(path.glob("events.out.tfevents.*")):
warnings.warn(f"Missing tfevents: {path}")
continue
if __debug__ and assert_dumped:
assert (path / "scalars/epoch.json").is_file(), path
assert (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json").is_file(), path
assert (path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json").is_file(), path
yield path
def dump_pl_tensorboard_hparams(experiment: Path):
with (experiment / "hparams.yaml").open() as f:
hparams = yaml.load(f, Loader=SafeLoaderIgnoreUnknown)
shebang = None
with (experiment / "config.yaml").open("w") as f:
raw_yaml = hparams.get('_pickled_cli_args', {}).get('_raw_yaml', "").replace("\n\r", "\n")
if raw_yaml.startswith("#!"): # preserve shebang
shebang, _, raw_yaml = raw_yaml.partition("\n")
f.write(f"{shebang}\n")
f.write(f"# {' '.join(map(shlex.quote, hparams.get('_pickled_cli_args', {}).get('sys_argv', ['None'])))}\n\n")
f.write(raw_yaml)
if shebang is not None:
os.chmod(experiment / "config.yaml", (experiment / "config.yaml").stat().st_mode | stat.S_IXUSR)
print(experiment / "config.yaml", "written!", file=sys.stderr)
with (experiment / "environ.yaml").open("w") as f:
yaml.safe_dump(hparams.get('_pickled_cli_args', {}).get('host', {}).get('environ'), f)
print(experiment / "environ.yaml", "written!", file=sys.stderr)
with (experiment / "repo.patch").open("w") as f:
f.write(hparams.get('_pickled_cli_args', {}).get('host', {}).get('vcs', "None"))
print(experiment / "repo.patch", "written!", file=sys.stderr)
def dump_simple_tf_events_to_jsonl(output_dir: Path, *tf_files: Path):
from google.protobuf.json_format import MessageToDict
import tensorboard.backend.event_processing.event_accumulator
s, l = {}, [] # reused sentinels
#resource.setrlimit(resource.RLIMIT_NOFILE, (2**16,-1))
file_handles = {}
try:
for tffile in tf_files:
loader = tensorboard.backend.event_processing.event_file_loader.LegacyEventFileLoader(str(tffile))
for event in loader.Load():
for summary in MessageToDict(event).get("summary", s).get("value", l):
if "simpleValue" in summary:
tag = summary["tag"]
if tag not in file_handles:
fname = output_dir / f"{tag}.json"
print(f"Opening {str(fname)!r}...", file=sys.stderr)
fname.parent.mkdir(parents=True, exist_ok=True)
file_handles[tag] = fname.open("w") # ("a")
val = summary["simpleValue"]
data = json.dumps({
"step" : event.step,
"value" : float(val) if isinstance(val, str) else val,
"wall_time" : event.wall_time,
})
file_handles[tag].write(f"{data}\n")
finally:
if file_handles:
print("Closing json files...", file=sys.stderr)
for k, v in file_handles.items():
v.close()
NO_FILTER = {
"__uid",
"_minutes",
"_epochs",
"_hp_nonlinearity",
"_val_uloss_intersection",
"_val_uloss_normal_cossim",
"_val_uloss_intersection",
}
def filter_jsonl_columns(data: Iterable[dict | None], no_filter=NO_FILTER) -> list[dict]:
def merge_siren_omega(data: dict) -> dict:
return {
key: (
f"{val}-{data.get('hp_omega_0', 'ERROR')}"
if (key.removeprefix("_"), val) == ("hp_nonlinearity", "sine") else
val
)
for key, val in data.items()
if key != "hp_omega_0"
}
def remove_uninteresting_cols(rows: list[dict]) -> Iterable[dict]:
unique_vals = {}
def register_val(key, val):
unique_vals.setdefault(key, set()).add(repr(val))
return val
whitelisted = {
key
for row in rows
for key, val in row.items()
if register_val(key, val) and val not in ("None", "0", "0.0")
}
for key in unique_vals:
for row in rows:
if key not in row:
unique_vals[key].add(MISSING)
for key, vals in unique_vals.items():
if key not in whitelisted: continue
if len(vals) == 1:
whitelisted.remove(key)
whitelisted.update(no_filter)
yield from (
{
key: val
for key, val in row.items()
if key in whitelisted
}
for row in rows
)
def pessemize_types(rows: list[dict]) -> Iterable[dict]:
types = {}
order = (str, float, int, bool, tuple, type(None))
for row in rows:
for key, val in row.items():
if isinstance(val, list): val = tuple(val)
assert type(val) in order, (type(val), val)
index = order.index(type(val))
types[key] = min(types.get(key, 999), index)
yield from (
{
key: order[types[key]](val) if val is not None else None
for key, val in row.items()
}
for row in rows
)
data = (row for row in data if row is not None)
data = map(partial(flatten_dict, key_mapper=camel_to_snake_case), data)
data = map(merge_siren_omega, data)
data = remove_uninteresting_cols(list(data))
data = pessemize_types(list(data))
return data
PlotMode = Literal["stackplot", "lineplot"]
def plot_losses(experiments: list[Path], mode: PlotMode, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force=True):
def get_losses(experiment: Path, training: bool = True, unscaled: bool = False) -> Iterable[Path]:
if not training and unscaled:
return experiment.glob("scalars/*.validation_step/unscaled_loss_*.json")
elif not training and not unscaled:
return experiment.glob("scalars/*.validation_step/loss_*.json")
elif training and unscaled:
return experiment.glob("scalars/*.training_step/unscaled_loss_*.json")
elif training and not unscaled:
return experiment.glob("scalars/*.training_step/loss_*.json")
print("Mapping colors...")
configurations = [
dict(unscaled=unscaled, training=training),
] if not write else [
dict(unscaled=False, training=False),
dict(unscaled=False, training=True),
dict(unscaled=True, training=False),
dict(unscaled=True, training=True),
]
legends = set(
f"""{
loss.parent.name.split(".", 1)[0]
}.{
loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
}"""
for experiment in experiments
for kw in configurations
for loss in get_losses(experiment, **kw)
)
colormap = dict(zip(
sorted(legends),
itertools.cycle(mcolors.TABLEAU_COLORS),
))
def mkplot(experiment: Path, training: bool = True, unscaled: bool = False) -> tuple[bool, str]:
label = f"{'unscaled' if unscaled else 'scaled'} {'training' if training else 'validation'}"
if write:
old_savefig_fname = experiment / f"{label.replace(' ', '-')}-{mode}.png"
savefig_fname = experiment / "plots" / f"{label.replace(' ', '-')}-{mode}.png"
savefig_fname.parent.mkdir(exist_ok=True, parents=True)
if old_savefig_fname.is_file():
old_savefig_fname.rename(savefig_fname)
if savefig_fname.is_file() and not force:
return True, "savefig_fname already exists"
# Get and sort data
losses = {}
for loss in get_losses(experiment, training=training, unscaled=unscaled):
model = loss.parent.name.split(".", 1)[0]
name = loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
losses[f"{model}.{name}"] = (loss, list(read_jsonl(loss)))
losses = dict(sorted(losses.items())) # sort keys
if not losses:
return True, "no losses"
# unwrap
steps = [i["step"] for i in first(losses.values())[1]]
values = [
[i["value"] if not isnan(i["value"]) else 0 for i in data]
for name, (scalar, data) in losses.items()
]
# normalize
if mode == "stackplot":
totals = list(map(sum, zip(*values)))
values = [
[i / t for i, t in zip(data, totals)]
for data in values
]
print(experiment.name, label)
fig, ax = plt.subplots(figsize=(16, 12))
if mode == "stackplot":
ax.stackplot(steps, values,
colors = list(map(colormap.__getitem__, losses.keys())),
labels = list(
label.split(".", 1)[1].removeprefix("loss_")
for label in losses.keys()
),
)
ax.set_xlim(0, steps[-1])
ax.set_ylim(0, 1)
ax.invert_yaxis()
elif mode == "lineplot":
for data, color, label in zip(
values,
map(colormap.__getitem__, losses.keys()),
list(losses.keys()),
):
ax.plot(steps, data,
color = color,
label = label,
)
ax.set_xlim(0, steps[-1])
else:
raise ValueError(f"{mode=}")
ax.legend()
ax.set_title(f"{label} loss\n{experiment.name}")
ax.set_xlabel("Step")
ax.set_ylabel("loss%")
if mode == "stackplot":
ax2 = make_axes_locatable(ax).append_axes("bottom", 0.8, pad=0.05, sharex=ax)
ax2.stackplot( steps, totals )
for tl in ax.get_xticklabels(): tl.set_visible(False)
fig.tight_layout()
if write:
fig.savefig(savefig_fname, dpi=300)
print(savefig_fname)
plt.close(fig)
return False, None
print("Plotting...")
if write:
matplotlib.use('agg') # fixes "WARNING: QApplication was not created in the main() thread."
any_error = False
if write:
with ThreadPoolExecutor(max_workers=None) as pool:
futures = [
(experiment, pool.submit(mkplot, experiment, **kw))
for experiment in experiments
for kw in configurations
]
else:
def mkfuture(item):
f = Future()
f.set_result(item)
return f
futures = [
(experiment, mkfuture(mkplot(experiment, **kw)))
for experiment in experiments
for kw in configurations
]
for experiment, future in futures:
try:
err, msg = future.result()
except Exception:
traceback.print_exc(file=sys.stderr)
any_error = True
continue
if err:
print(f"{msg}: {experiment.name}")
any_error = True
continue
if not any_error and not write: # show in main thread
plt.show()
elif not write:
print("There were errors, will not show figure...", file=sys.stderr)
# =========
app = typer.Typer(no_args_is_help=True, add_completion=False)
@app.command(help="Dump simple tensorboard events to json and extract some pytorch lightning hparams")
def tf_dump(tfevent_files: list[Path], j: int = typer.Option(1, "-j"), force: bool = False):
# expand to all tfevents files (there may be more than one)
tfevent_files = sorted(set([
tffile
for tffile in tfevent_files
if tffile.name.startswith("events.out.tfevents.")
] + [
tffile
for experiment_dir in tfevent_files
if experiment_dir.is_dir()
for tffile in experiment_dir.glob("events.out.tfevents.*")
] + [
tffile
for hparam_file in tfevent_files
if hparam_file.name in ("hparams.yaml", "config.yaml")
for tffile in hparam_file.parent.glob("events.out.tfevents.*")
]))
# filter already dumped
if not force:
tfevent_files = [
tffile
for tffile in tfevent_files
if not (
(tffile.parent / "scalars/epoch.json").is_file()
and
tffile.stat().st_mtime < (tffile.parent / "scalars/epoch.json").stat().st_mtime
)
]
if not tfevent_files:
raise typer.BadParameter("Nothing to be done, consider --force")
jobs = {}
for tffile in tfevent_files:
if not tffile.is_file():
print("ERROR: file not found:", tffile, file=sys.stderr)
continue
output_dir = tffile.parent / "scalars"
jobs.setdefault(output_dir, []).append(tffile)
with ProcessPoolExecutor() as p:
for experiment in set(tffile.parent for tffile in tfevent_files):
p.submit(dump_pl_tensorboard_hparams, experiment)
for output_dir, tffiles in jobs.items():
p.submit(dump_simple_tf_events_to_jsonl, output_dir, *tffiles)
@app.command(help="Propose experiment regexes")
def propose(cmd: str = typer.Argument("summary"), null: bool = False):
def get():
for i in TENSORBOARD.iterdir():
if not i.is_dir(): continue
if not (i / "hparams.yaml").is_file(): continue
prefix, name, *hparams, year, month, day, hhmm, uid = i.name.split("-")
yield f"{name}.*-{year}-{month}-{day}"
proposals = sorted(set(get()), key=lambda x: x.split(".*-", 1)[1])
print("\n".join(
f"{'>/dev/null ' if null else ''}{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
for i in proposals
))
@app.command("list", help="List used experiment regexes")
def list_cached_summaries(cmd: str = typer.Argument("summary")):
if not CACHED_SUMMARIES.is_dir():
cached = []
else:
cached = [
i.name.removesuffix(".jsonl")
for i in CACHED_SUMMARIES.iterdir()
if i.suffix == ".jsonl"
if i.is_file() and i.stat().st_size
]
def order(key: str) -> list[str]:
return re.sub(r'[^0-9\-]', '', key.split(".*")[-1]).strip("-").split("-") + [key]
print("\n".join(
f"{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
for i in sorted(cached, key=order)
))
@app.command(help="Precompute the summary of a experiment regex")
def compute_summary(filter: str, force: bool = False, dump: bool = False, no_cache: bool = False):
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
if cache.is_file() and cache.stat().st_size:
if not force:
raise FileExistsError(cache)
def mk_summary(path: Path) -> dict | None:
cache = path / "train_summary.json"
if cache.is_file() and cache.stat().st_size and cache.stat().st_mtime > (path/"scalars/epoch.json").stat().st_mtime:
with cache.open() as f:
return json.load(f)
else:
with (path / "hparams.yaml").open() as f:
hparams = munchify(yaml.load(f, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
config = hparams._pickled_cli_args._raw_yaml
config = munchify(yaml.load(config, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
try:
train_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json"))
val_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json"))
except:
traceback.print_exc(file=sys.stderr)
return None
out = Munch()
out.uid = path.name.rsplit("-", 1)[-1]
out.name = path.name
out.date = "-".join(path.name.split("-")[-5:-1])
out.epochs = int(last(read_jsonl(path / "scalars/epoch.json"))["value"])
out.steps = val_loss[-1]["step"]
out.gpu = hparams._pickled_cli_args.host.gpus[1][1]
if val_loss[-1]["wall_time"] - val_loss[0]["wall_time"] > 0:
out.batches_per_second = val_loss[-1]["step"] / (val_loss[-1]["wall_time"] - val_loss[0]["wall_time"])
else:
out.batches_per_second = 0
out.minutes = (val_loss[-1]["wall_time"] - train_loss[0]["wall_time"]) / 60
if (path / "scalars/PsutilMonitor/gpu.00.memory.used.json").is_file():
max(i["value"] for i in read_jsonl(path / "scalars/PsutilMonitor/gpu.00.memory.used.json"))
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step").glob("*.json"):
if not metric_path.is_file() or not metric_path.stat().st_size: continue
metric_name = metric_path.name.removesuffix(".json")
metric_data = read_jsonl(metric_path)
try:
out[f"val_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
except StatisticsError:
out[f"val_{metric_name}"] = float('nan')
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.training_step").glob("*.json"):
if not any(i in metric_path.name for i in ("miss_radius_grad", "sphere_center_grad", "loss_tangential_reg", "multi_view")): continue
if not metric_path.is_file() or not metric_path.stat().st_size: continue
metric_name = metric_path.name.removesuffix(".json")
metric_data = read_jsonl(metric_path)
try:
out[f"train_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
except StatisticsError:
out[f"train_{metric_name}"] = float('nan')
out.hostname = hparams._pickled_cli_args.host.hostname
for key, val in config.IntersectionFieldAutoDecoderModel.items():
if isinstance(val, dict):
out.update({f"hp_{key}_{k}": v for k, v in val.items()})
elif isinstance(val, float | int | str | bool | None):
out[f"hp_{key}"] = val
with cache.open("w") as f:
json.dump(unmunchify(out), f)
return dict(out)
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No matching experiment")
if dump:
try:
tf_dump(experiments) # force=force_dump)
except typer.BadParameter:
pass
# does literally nothing, thanks GIL
with ThreadPoolExecutor() as p:
results = list(p.map(mk_summary, experiments))
if any(result is None for result in results):
if all(result is None for result in results):
print("No summary succeeded", file=sys.stderr)
raise typer.Exit(exit_code=1)
warnings.warn("Some summaries failed:\n" + "\n".join(
str(experiment)
for result, experiment in zip(results, experiments)
if result is None
))
summaries = "\n".join( map(json.dumps, results) )
if not no_cache:
cache.parent.mkdir(parents=True, exist_ok=True)
with cache.open("w") as f:
f.write(summaries)
return summaries
@app.command(help="Show the summary of a experiment regex, precompute it if needed")
def summary(filter: Optional[str] = typer.Argument(None), force: bool = False, dump: bool = False, all: bool = False):
if filter is None:
return list_cached_summaries("summary")
def key_mangler(key: str) -> str:
for pattern, sub in (
(r'^val_unscaled_loss_', r'val_uloss_'),
(r'^train_unscaled_loss_', r'train_uloss_'),
(r'^val_loss_', r'val_sloss_'),
(r'^train_loss_', r'train_sloss_'),
):
key = re.sub(pattern, sub, key)
return key
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
if force or not (cache.is_file() and cache.stat().st_size):
compute_summary(filter, force=force, dump=dump)
assert cache.is_file() and cache.stat().st_size, (cache, cache.stat())
if os.isatty(0) and os.isatty(1) and shutil.which("vd"):
rows = read_jsonl(cache)
rows = ({key_mangler(k): v for k, v in row.items()} if row is not None else None for row in rows)
if not all:
rows = filter_jsonl_columns(rows)
rows = ({k: v for k, v in row.items() if not k.startswith(("val_sloss_", "train_sloss_"))} for row in rows)
data = "\n".join(map(json.dumps, rows))
subprocess.run(["vd",
#"--play", EXPERIMENTS / "set-key-columns.vd",
"-f", "jsonl"
], input=data, text=True, check=True)
else:
with cache.open() as f:
print(f.read())
@app.command(help="Filter uninteresting keys from jsonl stdin")
def filter_cols():
rows = map(json.loads, (line for line in sys.stdin.readlines() if line.strip()))
rows = filter_jsonl_columns(rows)
print(*map(json.dumps, rows), sep="\n")
@app.command(help="Run a command for each experiment matched by experiment regex")
def exec(filter: str, cmd: list[str], j: int = typer.Option(1, "-j"), dumped: bool = False, undumped: bool = False):
# inspired by fd / gnu parallel
def populate_cmd(experiment: Path, cmd: Iterable[str]) -> Iterable[str]:
any = False
for i in cmd:
if i == "{}":
any = True
yield str(experiment / "hparams.yaml")
elif i == "{//}":
any = True
yield str(experiment)
else:
yield i
if not any:
yield str(experiment / "hparams.yaml")
with ThreadPoolExecutor(max_workers=j or None) as p:
results = p.map(subprocess.run, (
list(populate_cmd(experiment, cmd))
for experiment in get_experiment_paths(filter)
if not dumped or (experiment / "scalars/epoch.json").is_file()
if not undumped or not (experiment / "scalars/epoch.json").is_file()
))
if any(i.returncode for i in results):
return typer.Exit(1)
@app.command(help="Show stackplot of experiment loss")
def stackplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
plot_losses(experiments,
mode = "stackplot",
write = write,
dump = dump,
training = training,
unscaled = unscaled,
force = force,
)
@app.command(help="Show stackplot of experiment loss")
def lineplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
plot_losses(experiments,
mode = "lineplot",
write = write,
dump = dump,
training = training,
unscaled = unscaled,
force = force,
)
@app.command(help="Open tensorboard for the experiments matching the regex")
def tensorboard(filter: Optional[str] = typer.Argument(None), watch: bool = False):
if filter is None:
return list_cached_summaries("tensorboard")
experiments = list(get_experiment_paths(filter, assert_dumped=False))
if not experiments:
raise typer.BadParameter("No match")
with tempfile.TemporaryDirectory(suffix=f"ifield-{filter}") as d:
treefarm = Path(d)
with ThreadPoolExecutor(max_workers=2) as p:
for experiment in experiments:
(treefarm / experiment.name).symlink_to(experiment)
cmd = ["tensorboard", "--logdir", d]
print("+", *map(shlex.quote, cmd), file=sys.stderr)
tensorboard = p.submit(subprocess.run, cmd, check=True)
if not watch:
tensorboard.result()
else:
all_experiments = set(get_experiment_paths(None, assert_dumped=False))
while not tensorboard.done():
time.sleep(10)
new_experiments = set(get_experiment_paths(None, assert_dumped=False)) - all_experiments
if new_experiments:
for experiment in new_experiments:
print(f"Adding {experiment.name!r}...", file=sys.stderr)
(treefarm / experiment.name).symlink_to(experiment)
all_experiments.update(new_experiments)
@app.command(help="Compute evaluation metrics")
def metrics(filter: Optional[str] = typer.Argument(None), dump: bool = False, dry: bool = False, prefix: Optional[str] = typer.Option(None), derive: bool = False, each: bool = False, no_total: bool = False):
if filter is None:
return list_cached_summaries("metrics --derive")
experiments = list(get_experiment_paths(filter, assert_dumped=False))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
def run(*cmd):
if prefix is not None:
cmd = [*shlex.split(prefix), *cmd]
if dry:
print(*map(shlex.quote, map(str, cmd)))
else:
print("+", *map(shlex.quote, map(str, cmd)))
subprocess.run(cmd)
for experiment in experiments:
if no_total: continue
if not (experiment / "compute-scores/metrics.json").is_file():
run(
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics.json",
"--transpose",
)
if not (experiment / "compute-scores/metrics-last.json").is_file():
run(
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-last.json",
"--transpose",
)
if "2prif-" not in experiment.name: continue
if not (experiment / "compute-scores/metrics-sans_outliers.json").is_file():
run(
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-sans_outliers.json",
"--transpose", "--filter-outliers"
)
if not (experiment / "compute-scores/metrics-last-sans_outliers.json").is_file():
run(
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-last-sans_outliers.json",
"--transpose", "--filter-outliers"
)
if dry: return
if prefix is not None:
print("prefix was used, assuming a job scheduler was used, will not print scores.", file=sys.stderr)
return
metrics = [
*(experiment / "compute-scores/metrics.json" for experiment in experiments),
*(experiment / "compute-scores/metrics-last.json" for experiment in experiments),
*(experiment / "compute-scores/metrics-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
*(experiment / "compute-scores/metrics-last-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
]
if not no_total:
assert all(metric.exists() for metric in metrics)
else:
metrics = (metric for metric in metrics if metric.exists())
out = []
for metric in metrics:
experiment = metric.parent.parent.name
is_last = metric.name in ("metrics-last.json", "metrics-last-sans_outliers.json")
with metric.open() as f:
data = json.load(f)
if derive:
derived = {}
objs = [i for i in data.keys() if i != "_hparams"]
for obj in (objs if each else []) + [None]:
if obj is None:
d = DefaultMunch(0)
for obj in objs:
for k, v in data[obj].items():
d[k] += v
obj = "_all_"
n_cd = data["_hparams"]["n_cd"] * len(objs)
n_emd = data["_hparams"]["n_emd"] * len(objs)
else:
d = munchify(data[obj])
n_cd = data["_hparams"]["n_cd"]
n_emd = data["_hparams"]["n_emd"]
precision = d.TP / (d.TP + d.FP)
recall = d.TP / (d.TP + d.FN)
derived[obj] = dict(
filtered = d.n_outliers / d.n if "n_outliers" in d else None,
iou = d.TP / (d.TP + d.FN + d.FP),
precision = precision,
recall = recall,
f_score = 2 * (precision * recall) / (precision + recall),
cd = d.cd_dist / n_cd,
emd = d.emd / n_emd,
cos_med = 1 - (d.cd_cos_med / n_cd) if "cd_cos_med" in d else None,
cos_jac = 1 - (d.cd_cos_jac / n_cd),
)
data = derived if each else derived["_all_"]
data["uid"] = experiment.rsplit("-", 1)[-1]
data["experiment_name"] = experiment
data["is_last"] = is_last
out.append(json.dumps(data))
if derive and not each and os.isatty(0) and os.isatty(1) and shutil.which("vd"):
subprocess.run(["vd", "-f", "jsonl"], input="\n".join(out), text=True, check=True)
else:
print("\n".join(out))
if __name__ == "__main__":
app()

822
figures/nn-architecture.svg Normal file

File diff suppressed because one or more lines are too long

After

(image error) Size: 198 KiB

57
ifield/__init__.py Normal file

@ -0,0 +1,57 @@
def setup_print_hooks():
import os
if not os.environ.get("IFIELD_PRETTY_TRACEBACK", None):
return
from rich.traceback import install
from rich.console import Console
import warnings, sys
if not os.isatty(2):
# https://github.com/Textualize/rich/issues/1809
os.environ.setdefault("COLUMNS", "120")
install(
show_locals = bool(os.environ.get("SHOW_LOCALS", "")),
width = None,
)
# custom warnings
# https://github.com/Textualize/rich/issues/433
from rich.traceback import install
from rich.console import Console
import warnings, sys
def showwarning(message, category, filename, lineno, file=None, line=None):
msg = warnings.WarningMessage(message, category, filename, lineno, file, line)
if file is None:
file = sys.stderr
if file is None:
# sys.stderr is None when run with pythonw.exe:
# warnings get lost
return
text = warnings._formatwarnmsg(msg)
if file.isatty():
Console(file=file, stderr=True).print(text)
else:
try:
file.write(text)
except OSError:
# the file (probably stderr) is invalid - this warning gets lost.
pass
warnings.showwarning = showwarning
def warning_no_src_line(message, category, filename, lineno, file=None, line=None):
if (file or sys.stderr) is not None:
if (file or sys.stderr).isatty():
if file is None or file is sys.stderr:
return f"[yellow]{category.__name__}[/yellow]: {message}\n ({filename}:{lineno})"
return f"{category.__name__}: {message} ({filename}:{lineno})\n"
warnings.formatwarning = warning_no_src_line
setup_print_hooks()
del setup_print_hooks

1006
ifield/cli.py Normal file

File diff suppressed because it is too large Load Diff

174
ifield/cli_utils.py Normal file

@ -0,0 +1,174 @@
#!/usr/bin/env python3
from .data.common.scan import SingleViewScan, SingleViewUVScan
from datetime import datetime
import re
import click
import gzip
import h5py as h5
import matplotlib.pyplot as plt
import numpy as np
import pyrender
import trimesh
import trimesh.transformations as T
__doc__ = """
Here are a bunch of helper scripts exposed as cli command by poetry
"""
# these entrypoints are exposed by poetry as shell commands
@click.command()
@click.argument("h5file")
@click.argument("key", default="")
def show_h5_items(h5file: str, key: str):
"Show contents of HDF5 dataset"
f = h5.File(h5file, "r")
if not key:
mlen = max(map(len, f.keys()))
for i in sorted(f.keys()):
print(i.ljust(mlen), ":",
str (f[i].dtype).ljust(10),
repr(f[i].shape).ljust(16),
"mean:", f[i][:].mean()
)
else:
if not f[key].shape:
print(f[key].value)
else:
print(f[key][:])
@click.command()
@click.argument("h5file")
@click.argument("key", default="")
def show_h5_img(h5file: str, key: str):
"Show a 2D HDF5 dataset as an image"
f = h5.File(h5file, "r")
if not key:
mlen = max(map(len, f.keys()))
for i in sorted(f.keys()):
print(i.ljust(mlen), ":", str(f[i].dtype).ljust(10), f[i].shape)
else:
plt.imshow(f[key])
plt.show()
@click.command()
@click.argument("h5file")
@click.option("--force-distances", is_flag=True, help="Always show miss distances.")
@click.option("--uv", is_flag=True, help="Load as UV scan cloud and convert it.")
@click.option("--show-unit-sphere", is_flag=True, help="Show the unit sphere.")
@click.option("--missing", is_flag=True, help="Show miss points that are not hits nor misses as purple.")
def show_h5_scan_cloud(
h5file : str,
force_distances : bool = False,
uv : bool = False,
missing : bool = False,
show_unit_sphere = False,
):
"Show a SingleViewScan HDF5 dataset"
print("Reading data...")
t = datetime.now()
if uv:
scan = SingleViewUVScan.from_h5_file(h5file)
if missing and scan.any_missing:
if not scan.has_missing:
scan.fill_missing_points()
points_missing = scan.points[scan.missing]
else:
missing = False
if not scan.is_single_view:
scan.cam_pos = None
scan = scan.to_scan()
else:
scan = SingleViewScan.from_h5_file(h5file)
if missing:
uvscan = scan.to_uv_scan()
if scan.any_missing:
uvscan.fill_missing_points()
points_missing = uvscan.points[uvscan.missing]
else:
missing = False
print("loadtime: ", datetime.now() - t)
if force_distances and not scan.has_miss_distances:
print("Computing miss distances...")
scan.compute_miss_distances()
use_miss_distances = force_distances
print("Constructing scene...")
if not scan.has_colors:
scan.colors_hit = np.zeros_like(scan.points_hit)
scan.colors_miss = np.zeros_like(scan.points_miss)
scan.colors_hit [:] = ( 31/255, 119/255, 180/255)
scan.colors_miss[:] = (243/255, 156/255, 18/255)
use_miss_distances = True
if scan.has_miss_distances and use_miss_distances:
sdm = scan.distances_miss / scan.distances_miss.max()
sdm = sdm[..., None]
scan.colors_miss \
= np.array([0.8, 0, 0])[None, :] * sdm \
+ np.array([0, 1, 0.2])[None, :] * (1-sdm)
scene = pyrender.Scene()
scene.add(pyrender.Mesh.from_points(scan.points_hit, colors=scan.colors_hit, normals=scan.normals_hit))
scene.add(pyrender.Mesh.from_points(scan.points_miss, colors=scan.colors_miss))
if missing:
scene.add(pyrender.Mesh.from_points(points_missing, colors=(np.array((0xff, 0x00, 0xff))/255)[None, :].repeat(points_missing.shape[0], axis=0)))
# camera:
if not scan.points_cam is None:
camera_mesh = trimesh.creation.uv_sphere(radius=scan.points_hit_std.max()*0.2)
camera_mesh.visual.vertex_colors = [0.0, 0.8, 0.0]
tfs = np.tile(np.eye(4), (len(scan.points_cam), 1, 1))
tfs[:,:3,3] = scan.points_cam
scene.add(pyrender.Mesh.from_trimesh(camera_mesh, poses=tfs))
# UV sphere:
if show_unit_sphere:
unit_sphere_mesh = trimesh.creation.uv_sphere(radius=1)
unit_sphere_mesh.invert()
unit_sphere_mesh.visual.vertex_colors = [0.8, 0.8, 0.0]
scene.add(pyrender.Mesh.from_trimesh(unit_sphere_mesh, poses=np.eye(4)[None, ...]))
print("Launch!")
viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)
@click.command()
@click.argument("meshfile")
@click.option('--aabb', is_flag=True)
@click.option('--z-skyward', is_flag=True)
def show_model(
meshfile : str,
aabb : bool,
z_skyward : bool,
):
"Show a 3D model with pyrender, supports .gz suffix"
if meshfile.endswith(".gz"):
with gzip.open(meshfile, "r") as f:
mesh = trimesh.load(f, file_type=meshfile.split(".", 1)[1].removesuffix(".gz"))
else:
mesh = trimesh.load(meshfile)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
if aabb:
from .data.common.mesh import rotate_to_closest_axis_aligned_bounds
mesh.apply_transform(rotate_to_closest_axis_aligned_bounds(mesh, fail_ok=True))
if z_skyward:
mesh.apply_transform(T.rotation_matrix(np.pi/2, (1, 0, 0)))
print(
*(i.strip() for i in pyrender.Viewer.__doc__.splitlines() if re.search(r"- ``[a-z0-9]``: ", i)),
sep="\n"
)
scene = pyrender.Scene()
scene.add(pyrender.Mesh.from_trimesh(mesh))
pyrender.Viewer(scene, use_raymond_lighting=True)

3
ifield/data/__init__.py Normal file

@ -0,0 +1,3 @@
__doc__ = """
Submodules to read and process datasets
"""

@ -0,0 +1,90 @@
from ...utils.helpers import make_relative
from pathlib import Path
from tqdm import tqdm
from typing import Union, Optional
import io
import os
import json
import requests
PathLike = Union[os.PathLike, str]
__doc__ = """
Here are some helper functions for processing data.
"""
def check_url(url): # HTTP HEAD
return requests.head(url).ok
def download_stream(
url : str,
file_object,
block_size : int = 1024,
silent : bool = False,
label : Optional[str] = None,
):
resp = requests.get(url, stream=True)
total_size = int(resp.headers.get("content-length", 0))
if not silent:
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
for chunk in resp.iter_content(block_size):
if not silent:
progress_bar.update(len(chunk))
file_object.write(chunk)
if not silent:
progress_bar.close()
if total_size != 0 and progress_bar.n != total_size:
print("ERROR, something went wrong")
def download_data(
url : str,
block_size : int = 1024,
silent : bool = False,
label : Optional[str] = None,
) -> bytearray:
f = io.BytesIO()
download_stream(url, f, block_size=block_size, silent=silent, label=label)
f.seek(0)
return bytearray(f.read())
def download_file(
url : str,
fname : Union[Path, str],
block_size : int = 1024,
silent = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
with fname.open("wb") as f:
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
def is_downloaded(
target_dir : PathLike,
url : str,
*,
add : bool = False,
dbfiles : Union[list[PathLike], PathLike],
):
if not isinstance(target_dir, os.PathLike):
target_dir = Path(target_dir)
if not isinstance(dbfiles, list):
dbfiles = [dbfiles]
if not dbfiles:
raise ValueError("'dbfiles' empty")
downloaded = set()
for dbfile_fname in dbfiles:
dbfile_fname = target_dir / dbfile_fname
if dbfile_fname.is_file():
with open(dbfile_fname, "r") as f:
downloaded.update(json.load(f)["downloaded"])
if add and url not in downloaded:
downloaded.add(url)
with open(dbfiles[0], "w") as f:
data = {"downloaded": sorted(downloaded)}
json.dump(data, f, indent=2, sort_keys=True)
return True
return url in downloaded

@ -0,0 +1,370 @@
#!/usr/bin/env python3
from abc import abstractmethod, ABCMeta
from collections import namedtuple
from pathlib import Path
import copy
import dataclasses
import functools
import h5py as h5
import hdf5plugin
import numpy as np
import operator
import os
import sys
import typing
__all__ = [
"DataclassMeta",
"Dataclass",
"H5Dataclass",
"H5Array",
"H5ArrayNoSlice",
]
T = typing.TypeVar("T")
NoneType = type(None)
PathLike = typing.Union[os.PathLike, str]
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
DataclassField = namedtuple("DataclassField", [
"name",
"type",
"is_optional",
"is_array",
"is_sliceable",
"is_prefix",
])
def strip_optional(val: type) -> type:
if typing.get_origin(val) is typing.Union:
union = set(typing.get_args(val))
if len(union - {NoneType}) == 1:
val, = union - {NoneType}
else:
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
return val
def is_array(val, *, _inner=False):
"""
Hacky way to check if a value or type is an array.
The hack omits having to depend on large frameworks such as pytorch or pandas
"""
val = strip_optional(val)
if val is H5Array or val is H5ArrayNoSlice:
return True
if typing._type_repr(val) in (
"numpy.ndarray",
"torch.Tensor",
):
return True
if not _inner:
return is_array(type(val), _inner=True)
return False
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
if initial is not None:
return functools.reduce(operator.mul, numbers, initial)
else:
return functools.reduce(operator.mul, numbers)
class DataclassMeta(type):
def __new__(
mcls,
name : str,
bases : tuple[type, ...],
attrs : dict[str, typing.Any],
**kwargs,
):
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
cls = dataclasses.dataclass(slots=True)(cls)
else:
cls = dataclasses.dataclass(cls)
return cls
class DataclassABCMeta(DataclassMeta, ABCMeta):
pass
class Dataclass(metaclass=DataclassMeta):
def __getitem__(self, key: str) -> typing.Any:
if key in self.keys():
return getattr(self, key)
raise KeyError(key)
def __setitem__(self, key: str, value: typing.Any):
if key in self.keys():
return setattr(self, key, value)
raise KeyError(key)
def keys(self) -> typing.KeysView:
return self.as_dict().keys()
def values(self) -> typing.ValuesView:
return self.as_dict().values()
def items(self) -> typing.ItemsView:
return self.as_dict().items()
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
out = dataclasses.asdict(self, **kw)
for name in (properties_to_include or []):
out[name] = getattr(self, name)
return out
def as_tuple(self, properties_to_include: list[str]) -> tuple:
out = dataclasses.astuple(self)
if not properties_to_include:
return out
else:
return (
*out,
*(getattr(self, name) for name in properties_to_include),
)
def copy(self: T, *, deep=True) -> T:
return (copy.deepcopy if deep else copy.copy)(self)
class H5Dataclass(Dataclass):
# settable with class params:
_prefix : str = dataclasses.field(init=False, repr=False, default="")
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
def __init_subclass__(cls,
prefix : typing.Optional[str] = None,
n_pages : typing.Optional[int] = None,
require_all : typing.Optional[bool] = None,
**kw,
):
super().__init_subclass__(**kw)
assert dataclasses.is_dataclass(cls)
if prefix is not None: cls._prefix = prefix
if n_pages is not None: cls._n_pages = n_pages
if require_all is not None: cls._require_all = require_all
@classmethod
def _get_fields(cls) -> typing.Iterable[DataclassField]:
for field in dataclasses.fields(cls):
if not field.init:
continue
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
)
if isinstance(field.type, str):
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
type_inner = strip_optional(field.type)
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
field_type = typing.Optional[field_type]
yield DataclassField(
name = field.name,
type = strip_optional(field_type),
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
is_array = is_array(field_type),
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
is_prefix = is_prefix,
)
@classmethod
def from_h5_file(cls : type[T],
fname : typing.Union[PathLike, str],
*,
page : typing.Optional[int] = None,
n_pages : typing.Optional[int] = None,
read_slice : slice = slice(None),
require_even_pages : bool = True,
) -> T:
if not isinstance(fname, Path):
fname = Path(fname)
if n_pages is None:
n_pages = cls._n_pages
if not fname.exists():
raise FileNotFoundError(str(fname))
if not h5.is_hdf5(fname):
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
# if this class has no fields, print a example class:
if not any(field.init for field in dataclasses.fields(cls)):
with h5.File(fname, "r") as f:
klen = max(map(len, f.keys()))
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
f" {k.ljust(klen)} : "
+ (
"H5Array" if prod(v.shape, 1) > 1 else (
"float" if issubclass(v.dtype.type, np.floating) else (
"int" if issubclass(v.dtype.type, np.integer) else (
"bool" if issubclass(v.dtype.type, np.bool_) else (
"typing.Any"
))))).ljust(14 + 1)
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
for k, v in f.items()
)
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
fields_consumed = set()
def make_kwarg(
file : h5.File,
keys : typing.KeysView,
field : DataclassField,
) -> tuple[str, typing.Any]:
if field.is_optional:
if field.name not in keys:
return field.name, None
if field.is_sliceable:
if page is not None:
n_items = int(f[cls._prefix + field.name].shape[0])
page_len = n_items // n_pages
modulus = n_items % n_pages
if modulus: page_len += 1 # round up
if require_even_pages and modulus:
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
this_slice = slice(
start = page_len * page,
stop = page_len * (page+1),
step = read_slice.step, # inherit step
)
else:
this_slice = read_slice
else:
this_slice = slice(None) # read all
# array or scalar?
def read_dataset(var):
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
if field.is_array:
return var[this_slice]
if var.shape == (1,):
return var[0]
else:
return var[()]
if field.is_prefix:
fields_consumed.update(
key
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
)
return field.name, {
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
}
else:
fields_consumed.add(cls._prefix + field.name)
return field.name, read_dataset(file[cls._prefix + field.name])
with h5.File(fname, "r") as f:
keys = f.keys()
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
try:
out = cls(**init_dict)
except Exception as e:
class_attrs = set(field.name for field in dataclasses.fields(cls))
file_attr = set(init_dict.keys())
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
if cls._require_all:
fields_not_consumed = set(keys) - fields_consumed
if fields_not_consumed:
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
return out
def to_h5_file(self,
fname : PathLike,
mkdir : bool = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
if not fname.parent.is_dir():
if mkdir:
fname.parent.mkdir(parents=True)
else:
raise NotADirectoryError(fname.parent)
with h5.File(fname, "w") as f:
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
value = getattr(self, field.name)
if field.is_array:
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
raise TypeError(
"When dumping a H5Dataclass, make sure the array fields are "
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
)
else:
pass
def write_value(key: str, value: typing.Any):
if field.is_array:
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
else:
f.create_dataset(key, data=value)
if field.is_prefix:
for k, v in value.items():
write_value(self._prefix + field.name + "_" + k, v)
else:
write_value(self._prefix + field.name, value)
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
if do_copy: # shallow
self = self.copy(deep=False)
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
if field.is_prefix and field.is_array:
setattr(self, field.name, {
k : func(v)
for k, v in getattr(self, field.name).items()
})
elif field.is_array:
setattr(self, field.name, func(getattr(self, field.name)))
return self
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
def copy(self: T, *, deep=True) -> T:
out = super().copy(deep=deep)
if not deep:
for field in type(self)._get_fields():
if field.is_prefix:
out[field.name] = copy.copy(field.name)
return out
@property
def shape(self) -> dict[str, tuple[int, ...]]:
return {
key: value.shape
for key, value in self.items()
if hasattr(value, "shape")
}
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
@abstractmethod
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
...
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
mtx = self.transforms[name]
out = self.transform(mtx, inplace=inplace)
out.transforms.pop(name) # consumed
inv = np.linalg.inv(mtx)
for key in list(out.transforms.keys()): # maintain the other transforms
out.transforms[key] = out.transforms[key] @ inv
if inverse_name is not None: # store inverse
out.transforms[inverse_name] = inv
return out

@ -0,0 +1,48 @@
from math import pi
from trimesh import Trimesh
import numpy as np
import os
import trimesh
import trimesh.transformations as T
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
__doc__ = """
Here are some helper functions for processing data.
"""
def rotate_to_closest_axis_aligned_bounds(
mesh : Trimesh,
order_axes : bool = True,
fail_ok : bool = True,
) -> np.ndarray:
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
if not order_axes:
return to_aabb_rot_mat4
v = pi / 4 * 1.01 # tolerance
v2 = pi / 2
faces = (
(0, 0),
(1, 0),
(2, 0),
(3, 0),
(0, 1),
(0,-1),
)
orientations = [ # 6 faces x 4 rotations per face
(f[0] * v2, f[1] * v2, i * v2)
for i in range(4)
for f in faces]
for x, y, z in orientations:
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
ai, aj, ak = T.euler_from_matrix(mat4)
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
return mat4
if fail_ok: return to_aabb_rot_mat4
raise Exception("Unable to orient mesh")

@ -0,0 +1,297 @@
from __future__ import annotations
from ...utils.helpers import compose
from functools import reduce, lru_cache
from math import ceil
from typing import Iterable
import numpy as np
import operator
__doc__ = """
Here are some helper functions for processing data.
"""
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
# based of ycb_generate_point_cloud.py provided by YCB
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
# pad the image
img_pad = np.zeros((
n_channels,
int(ceil(1.0 * rows / psize) * psize),
int(ceil(1.0 * cols / psize) * psize),
))
img_pad[:, 0:rows, 0:cols] = img
# allocate output buffer
final = np.zeros((
img_pad.shape[1],
img_pad.shape[2],
n_channels,
psize,
psize,
))
for c in range(n_channels):
for x in range(psize):
for y in range(psize):
img_shift = np.vstack((
img_pad[c, x:],
img_pad[c, :x]))
img_shift = np.column_stack((
img_shift[:, y:],
img_shift[:, :y]))
final[x::psize, y::psize, c] = np.swapaxes(
img_shift.reshape(
int(img_pad.shape[1] / psize), psize,
int(img_pad.shape[2] / psize), psize),
1,
2)
# crop output and unwrap axes with size==1
return np.squeeze(final[
0:rows - psize + 1,
0:cols - psize + 1])
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
"""
Removes data close to discontinuities, with size filt_size.
"""
# based of ycb_generate_point_cloud.py provided by YCB
# Ensure that filter sizes are okay
assert filt_size % 2, "Can only use odd filter sizes."
# Compute discontinuities
offset = int(filt_size - 1) // 2
patches = 1.0 * img2col(depth_map, filt_size)
mids = patches[:, :, offset, offset]
mins = np.min(patches, axis=(2, 3))
maxes = np.max(patches, axis=(2, 3))
discont = np.maximum(
np.abs(mins - mids),
np.abs(maxes - mids))
mark = discont > thresh
# Account for offsets
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
final_mark[offset:offset + mark.shape[0],
offset:offset + mark.shape[1]] = mark
return depth_map * (1 - final_mark)
def reorient_depth_map(
depth_map : np.ndarray,
rgb_map : np.ndarray,
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
rgb_mask_map : np.ndarray = None,
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
) -> np.ndarray:
"""
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
If _output_points is True, the points returned are in the rgb camera space.
"""
# based of ycb_generate_point_cloud.py provided by YCB
# now faster AND more easy on the GIL
height_old, width_old, *_ = depth_map.shape
height, width, *_ = rgb_map.shape
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
# make a UV grid over depth_map
u, v = np.meshgrid(
np.arange(width_old),
np.arange(height_old),
)
# compute xyz coordinates for all depths
xyz_depth = np.stack((
(u - d_cx) / d_fx,
(v - d_cy) / d_fy,
depth_map,
np.ones(depth_map.shape)
)).reshape((4, -1))
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
# undistort depth coordinates
d_x, d_y = xyz_depth[:2]
r = np.linalg.norm(xyz_depth[:2], axis=0)
xyz_depth[0, :] \
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
xyz_depth[1, :] \
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
# unproject x and y
xyz_depth[0, :] *= xyz_depth[2, :]
xyz_depth[1, :] *= xyz_depth[2, :]
# convert depths to RGB camera viewpoint
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
# project depths to RGB canvas
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
rgb_uv = np.stack((
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
)).astype(np.int)
# mask of the rgb_xyz values within view of rgb_map
mask = reduce(operator.and_, [
rgb_uv[0] >= 0,
rgb_uv[1] >= 0,
rgb_uv[0] < width,
rgb_uv[1] < height,
])
if rgb_mask_map is not None:
mask[mask] &= rgb_mask_map[
rgb_uv[1, mask],
rgb_uv[0, mask]]
if not _output_points: # output image
output = np.zeros((height, width), dtype=depth_map.dtype)
output[
rgb_uv[1, mask],
rgb_uv[0, mask],
] = xyz_rgb[2, mask]
else: # output pointcloud
rgbs = rgb_map[ # lookup rgb values using rgb_uv
rgb_uv[1, mask],
rgb_uv[0, mask]]
output = np.stack((
xyz_rgb[0, mask], # x
xyz_rgb[1, mask], # y
xyz_rgb[2, mask], # z
rgbs[:, 0], # r
rgbs[:, 1], # g
rgbs[:, 2], # b
)).T
# output for realsies
if not _output_hits_uvs: #raw
return output
else: # with hit mask
uv = np.zeros((height, width), dtype=bool)
# filter points overlapping in the depth map
uv_indices = (
rgb_uv[1, mask],
rgb_uv[0, mask],
)
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
output = output[chosen, :]
uv[uv_indices] = True
return output, uv
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
return reorient_depth_map(*a, _output_points=True, **kw)
@compose(np.array) # block lru cache mutation
@lru_cache(maxsize=1)
@compose(list)
def generate_equidistant_sphere_points(
n : int,
centroid : np.ndarray = (0, 0, 0),
radius : float = 1,
compute_sphere_coordinates : bool = False,
compute_normals : bool = False,
shift_theta : bool = False,
) -> Iterable[tuple[float, ...]]:
# Deserno M. How to generate equidistributed points on the surface of a sphere
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
if compute_sphere_coordinates and compute_normals:
raise ValueError(
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
)
n_count = 0
a = 4 * np.pi / n
d = np.sqrt(a)
n_theta = round(np.pi / d)
d_theta = np.pi / n_theta
d_phi = a / d_theta
for i in range(0, n_theta):
theta = np.pi * (i + 0.5) / n_theta
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
for j in range(0, n_phi):
phi = 2 * np.pi * j / n_phi
if compute_sphere_coordinates: # (theta, phi)
yield (
theta if shift_theta else theta - 0.5*np.pi,
phi,
)
elif compute_normals: # (x, y, z, nx, ny, nz)
yield (
centroid[0] + radius * np.sin(theta) * np.cos(phi),
centroid[1] + radius * np.sin(theta) * np.sin(phi),
centroid[2] + radius * np.cos(theta),
np.sin(theta) * np.cos(phi),
np.sin(theta) * np.sin(phi),
np.cos(theta),
)
else: # (x, y, z)
yield (
centroid[0] + radius * np.sin(theta) * np.cos(phi),
centroid[1] + radius * np.sin(theta) * np.sin(phi),
centroid[2] + radius * np.cos(theta),
)
n_count += 1
def generate_random_sphere_points(
n : int,
centroid : np.ndarray = (0, 0, 0),
radius : float = 1,
compute_sphere_coordinates : bool = False,
compute_normals : bool = False,
shift_theta : bool = False, # depends on convention
) -> np.ndarray:
if compute_sphere_coordinates and compute_normals:
raise ValueError(
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
)
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
phi = np.random.uniform(0, 2*np.pi, n)
if compute_sphere_coordinates: # (theta, phi)
return np.stack((
theta if not shift_theta else 0.5*np.pi + theta,
phi,
), axis=1)
elif compute_normals: # (x, y, z, nx, ny, nz)
return np.stack((
centroid[0] + radius * np.cos(theta) * np.cos(phi),
centroid[1] + radius * np.cos(theta) * np.sin(phi),
centroid[2] + radius * np.sin(theta),
np.cos(theta) * np.cos(phi),
np.cos(theta) * np.sin(phi),
np.sin(theta),
), axis=1)
else: # (x, y, z)
return np.stack((
centroid[0] + radius * np.cos(theta) * np.cos(phi),
centroid[1] + radius * np.cos(theta) * np.sin(phi),
centroid[2] + radius * np.sin(theta),
), axis=1)

@ -0,0 +1,85 @@
from .h5_dataclasses import H5Dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Hashable, Optional, Callable
import os
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
__doc__ = """
Here are some helper functions for processing data.
"""
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
# paralelize it in the shell instead
def precompute_data(
computer : Callable[[Hashable], Optional[H5Dataclass]],
identifiers : list[Hashable],
output_paths : list[Path],
page : tuple[int, int] = (0, 1),
*,
force : bool = False,
debug : bool = False,
):
"""
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
"""
page, n_pages = page
assert len(identifiers) == len(output_paths)
total = len(identifiers)
identifier_max_len = max(map(len, map(str, identifiers)))
t_epoch = None
def log(state: str, is_start = False):
nonlocal t_epoch
if is_start: t_epoch = datetime.now()
td = timedelta(0) if is_start else datetime.now() - t_epoch
print(" - "
f"{str(index+1).rjust(len(str(total)))}/{total}: "
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
)
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
t_begin = datetime.now()
failed = []
# pagination
page_size = total // n_pages + bool(total % n_pages)
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
if not force and output_path.exists() and output_path.stat().st_size > 0:
continue
log("compute", is_start=True)
# compute
try:
res = computer(identifier)
except Exception as e:
failed.append(identifier)
log(f"failed compute: {e.__class__.__name__}: {e}")
if DEBUG or debug: raise e
continue
if res is None:
failed.append(identifier)
log("no result")
continue
# write to file
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
res.to_h5_file(output_path)
except Exception as e:
failed.append(identifier)
log(f"failed write: {e.__class__.__name__}: {e}")
if output_path.is_file(): output_path.unlink() # cleanup
if DEBUG or debug: raise e
continue
log("done")
print("precompute_data finished in", datetime.now() - t_begin)
print("failed:", failed or None)

768
ifield/data/common/scan.py Normal file

@ -0,0 +1,768 @@
from ...utils.helpers import compose
from . import points
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
from methodtools import lru_cache
from sklearn.neighbors import BallTree
import faiss
from trimesh import Trimesh
from typing import Iterable
from typing import Optional, TypeVar
import mesh_to_sdf
import mesh_to_sdf.scan as sdf_scan
import numpy as np
import trimesh
import trimesh.transformations as T
import warnings
__doc__ = """
Here are some helper types for data.
"""
_T = TypeVar("T")
class InvalidateLRUOnWriteMixin:
def __setattr__(self, key, value):
if not key.startswith("__wire|"):
for attr in dir(self):
if attr.startswith("__wire|"):
getattr(self, attr).cache_clear()
return super().__setattr__(key, value)
def lru_property(func):
return lru_cache(maxsize=1)(property(func))
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
points_hit : H5ArrayNoSlice # (N, 3)
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
points_miss : H5ArrayNoSlice # (M, 3)
distances_miss : Optional[H5ArrayNoSlice] # (M)
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
cam_pos : H5ArrayNoSlice # (3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
out = self if inplace else self.copy(deep=False)
out.points_hit = T.transform_points(self.points_hit, mat4)
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
out.points_miss = T.transform_points(self.points_miss, mat4)
out.distances_miss = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
assert not self.has_miss_distances
if not self.is_hitting:
raise ValueError("No hits to compute the ray distance towards")
out = self.copy(deep=deep) if copy else self
out.distances_miss \
= distance_from_rays_to_point_cloud(
ray_origins = out.points_cam,
ray_dirs = out.ray_dirs_miss,
points = out.points_hit,
).astype(out.points_cam.dtype)
return out
@lru_property
def points(self) -> np.ndarray: # (N+M+1, 3)
return np.concatenate((
self.points_hit,
self.points_miss,
self.points_cam,
))
@lru_property
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
out[self.uv_hits, :] = self.points_hit
out[self.uv_miss, :] = self.points_miss
return out
@lru_property
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
out[self.uv_hits, :] = self.normals_hit
return out
@lru_property
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
if self.cam_pos is None: return None
return self.cam_pos[None, :]
@lru_property
def points_hit_centroid(self) -> np.ndarray:
return self.points_hit.mean(axis=0)
@lru_property
def points_hit_std(self) -> np.ndarray:
return self.points_hit.std(axis=0)
@lru_property
def is_hitting(self) -> bool:
return len(self.points_hit) > 0
@lru_property
def is_empty(self) -> bool:
return not (len(self.points_hit) or len(self.points_miss))
@lru_property
def has_colors(self) -> bool:
return self.colors_hit is not None or self.colors_miss is not None
@lru_property
def has_normals(self) -> bool:
return self.normals_hit is not None
@lru_property
def has_uv(self) -> bool:
return self.uv_hits is not None
@lru_property
def has_miss_distances(self) -> bool:
return self.distances_miss is not None
@lru_property
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
if self.colors_hit is None: raise ValueError
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
@lru_property
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
if self.colors_miss is None: raise ValueError
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
@lru_property
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
out = self.points_hit - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
out = self.points_miss - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances and scan.is_hitting:
scan.compute_miss_distances()
return scan
def to_uv_scan(self) -> "SingleViewUVScan":
return SingleViewUVScan.from_scan(self)
@classmethod
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
return uvscan.to_scan()
# The same, but with support for pagination (should have been this way since the start...)
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
# B may be (N) or (H, W), the latter may be flattened
hits : H5Array # (*B) dtype=bool
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
normals : Optional[H5Array] # (*B, 3) NaN if not hit
colors : Optional[H5Array] # (*B, 3)
distances : Optional[H5Array] # (*B) NaN if not miss
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
@classmethod
def from_scan(cls, scan: SingleViewScan):
if not scan.has_uv:
raise ValueError("Scan cloud has no UV data")
hits, miss = scan.uv_hits, scan.uv_miss
dtype = scan.points_hit.dtype
assert hits.ndim in (1, 2), hits.ndim
assert hits.shape == miss.shape, (hits.shape, miss.shape)
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
points[hits, :] = scan.points_hit
points[miss, :] = scan.points_miss
normals = None
if scan.has_normals:
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
normals[hits, :] = scan.normals_hit
distances = None
if scan.has_miss_distances:
distances = np.full(hits.shape, np.nan, dtype=dtype)
distances[miss] = scan.distances_miss
colors = None
if scan.has_colors:
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
if scan.colors_hit is not None:
colors[hits, :] = scan.colors_hit
if scan.colors_miss is not None:
colors[miss, :] = scan.colors_miss
return cls(
hits = hits,
miss = miss,
points = points,
normals = normals,
colors = colors,
distances = distances,
cam_pos = scan.cam_pos,
cam_mat4 = scan.cam_mat4,
proj_mat4 = scan.proj_mat4,
transforms = scan.transforms,
)
def to_scan(self) -> "SingleViewScan":
if not self.is_single_view: raise ValueError
return SingleViewScan(
points_hit = self.points [self.hits, :],
points_miss = self.points [self.miss, :],
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
uv_hits = self.hits,
uv_miss = self.miss,
cam_pos = self.cam_pos,
cam_mat4 = self.cam_mat4,
proj_mat4 = self.proj_mat4,
transforms = self.transforms,
)
def to_mesh(self) -> trimesh.Trimesh:
faces: list[(tuple[int, int],)*3] = []
for x in range(self.hits.shape[0]-1):
for y in range(self.hits.shape[1]-1):
c11 = x, y
c12 = x, y+1
c22 = x+1, y+1
c21 = x+1, y
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
if n == 3:
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
elif n == 4:
faces.append((c11, c12, c22))
faces.append((c11, c22, c21))
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
assert self.colors is not None
return trimesh.Trimesh(
vertices = [self.points[i] for i in xy2idx.keys()],
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
faces = [tuple(xy2idx[i] for i in face) for face in faces],
)
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
unflat = self.hits.shape
flat = np.product(unflat)
out = self if inplace else self.copy(deep=False)
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
out.distances = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
assert not self.has_miss_distances
shape = self.hits.shape
out = self.copy(deep=deep) if copy else self
out.distances = np.zeros(shape, dtype=self.points.dtype)
if self.is_hitting:
out.distances[self.miss] \
= distance_from_rays_to_point_cloud(
ray_origins = self.cam_pos_unsqueezed_miss,
ray_dirs = self.ray_dirs_miss,
points = surface_points if surface_points is not None else self.points[self.hits],
)
return out
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
"""
Fill in missing points as hitting the far plane.
"""
if not self.is_2d:
raise ValueError("Cannot fill missing points for non-2d scan!")
if not self.is_single_view:
raise ValueError("Cannot fill missing points for non-single-view scans!")
if self.cam_mat4 is None:
raise ValueError("cam_mat4 is None")
if self.proj_mat4 is None:
raise ValueError("proj_mat4 is None")
uv = np.argwhere(self.missing).astype(self.points.dtype)
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
uv -= 1
uv = np.stack((
uv[:, 1],
-uv[:, 0],
np.ones(uv.shape[0]), # far clipping plane
np.ones(uv.shape[0]), # homogeneous coordinate
), axis=-1)
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
out = self.copy(deep=deep) if copy else self
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
return out
@lru_property
def is_hitting(self) -> bool:
return np.any(self.hits)
@lru_property
def has_colors(self) -> bool:
return not self.colors is None
@lru_property
def has_normals(self) -> bool:
return not self.normals is None
@lru_property
def has_miss_distances(self) -> bool:
return not self.distances is None
@lru_property
def any_missing(self) -> bool:
return np.any(self.missing)
@lru_property
def has_missing(self) -> bool:
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
@lru_property
def cam_pos_unsqueezed(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos
else:
cam_pos = self.cam_pos
for _ in range(self.hits.ndim):
cam_pos = cam_pos[None, ...]
return cam_pos
@lru_property
def cam_pos_unsqueezed_hit(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.hits, :]
else:
return self.cam_pos[None, :]
@lru_property
def cam_pos_unsqueezed_miss(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.miss, :]
else:
return self.cam_pos[None, :]
@lru_property
def ray_dirs(self) -> H5Array:
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
@lru_property
def ray_dirs_hit(self) -> H5Array:
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> H5Array:
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def depths(self) -> H5Array:
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
@lru_property
def missing(self) -> H5Array:
return ~(self.hits | self.miss)
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
if compute_miss_distances:
scan.compute_miss_distances()
assert scan.is_2d
return scan
@classmethod
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances:
surface_points = None
if scan.hits.sum() > mesh.vertices.shape[0]:
surface_points = mesh.vertices.astype(scan.points.dtype)
if not kw.get("no_unit_sphere", False):
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
surface_points = (surface_points + translation) * scale
scan.compute_miss_distances(surface_points=surface_points)
assert scan.is_flat
return scan
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
n_items = np.product(self.hits.shape)
permutation = np.random.permutation(n_items)
out = self.copy(deep=False) if copy else self
out.hits = out.hits .reshape((n_items, ))[permutation]
out.miss = out.miss .reshape((n_items, ))[permutation]
out.points = out.points .reshape((n_items, 3))[permutation, :]
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
return out
@property
def is_single_view(self) -> bool:
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
@property
def is_flat(self) -> bool:
return len(self.hits.shape) == 1
@property
def is_2d(self) -> bool:
return len(self.hits.shape) == 2
# transforms can be found in pytorch3d.transforms and in open3d
# and in trimesh.transformations
def sample_single_view_scans_from_mesh(
mesh : Trimesh,
*,
n_batches : int,
scan_resolution : int = 400,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
) -> Iterable[SingleViewScan]:
normalized_mesh_cache = []
for _ in range(n_batches):
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
yield sample_single_view_scan_from_mesh(
mesh = mesh,
phi = phi,
theta = theta,
_mesh_is_normalized = False,
scan_resolution = scan_resolution,
compute_normals = compute_normals,
fov = fov,
camera_distance = camera_distance,
no_filter_backhits = no_filter_backhits,
_mesh_cache = normalized_mesh_cache,
)
def sample_single_view_scan_from_mesh(
mesh : Trimesh,
*,
phi : float,
theta : float,
scan_resolution : int = 200,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
dtype : type = np.float32,
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
) -> SingleViewScan:
# scale and center to unit sphere
is_cache = isinstance(_mesh_cache, list)
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
_, mesh, translation, scale = _mesh_cache
else:
if is_cache:
if _mesh_cache:
_mesh_cache.clear()
_mesh_cache.append(mesh)
translation, scale = compute_unit_sphere_transform(mesh)
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
if is_cache:
_mesh_cache.extend((mesh, translation, scale))
z_near = 1
z_far = 3
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
scan = sdf_scan.Scan(mesh,
camera_transform = cam_mat4,
resolution = scan_resolution,
calculate_normals = compute_normals,
fov = fov,
z_near = z_near,
z_far = z_far,
no_flip_backfaced_normals = True
)
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
misses = np.argwhere(scan.depth_buffer == 0)
points_miss = np.ones((misses.shape[0], 4))
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
points_miss[:, 1] *= -1
points_miss[:, 2] = 1 # far plane in clipping space
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
points_miss /= points_miss[:, 3][:, np.newaxis]
points_miss = points_miss[:, :3]
uv_hits = scan.depth_buffer != 0
uv_miss = ~uv_hits
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
# inner product
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
scan.points = scan.points [mask, :]
scan.normals = scan.normals[mask, :]
uv_hits[uv_hits] = mask
transforms = {}
# undo unit-sphere transform
if no_unit_sphere:
scan.points = scan.points * (1 / scale) - translation
points_miss = points_miss * (1 / scale) - translation
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
cam_mat4[:3, :] *= 1 / scale
cam_mat4[:3, 3] -= translation
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
transforms["model"] = np.eye(4)
else:
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
transforms["unit_sphere"] = np.eye(4)
return SingleViewScan(
normals_hit = scan.normals .astype(dtype),
points_hit = scan.points .astype(dtype),
points_miss = points_miss .astype(dtype),
distances_miss = None,
colors_hit = None,
colors_miss = None,
uv_hits = uv_hits .astype(bool),
uv_miss = uv_miss .astype(bool),
cam_pos = cam_pos[:3] .astype(dtype),
cam_mat4 = cam_mat4 .astype(dtype),
proj_mat4 = scan.projection_matrix .astype(dtype),
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
)
def sample_sphere_view_scan_from_mesh(
mesh : Trimesh,
*,
sphere_points : int = 4000, # resulting rays are n*(n-1)
compute_normals : bool = False,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
no_permute : bool = False,
dtype : type = np.float32,
**kw,
) -> SingleViewUVScan:
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
# get unit-sphere points, then transform to model space
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
two_sphere = two_sphere / scale - translation # we transform after cache lookup
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
(
locations,
index_ray,
index_tri,
) = mesh.ray.intersects_location(
two_sphere[:, 0, :],
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
multiple_hits=False,
)
if compute_normals:
location_normals = mesh.face_normals[index_tri]
batch = two_sphere.shape[:1]
hits = np.zeros((*batch,), dtype=np.bool)
miss = np.ones((*batch,), dtype=np.bool)
cam_pos = two_sphere[:, 0, :]
intersections = two_sphere[:, 1, :] # far-plane, effectively
normals = np.zeros((*batch, 3), dtype=dtype)
index_ray_front = index_ray
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
index_ray_front = index_ray[mask]
hits[index_ray_front] = True
miss[index_ray] = False
intersections[index_ray] = locations
normals[index_ray] = location_normals
if not no_permute:
assert len(batch) == 1, batch
permutation = np.random.permutation(*batch)
hits = hits [permutation]
miss = miss [permutation]
intersections = intersections[permutation, :]
normals = normals [permutation, :]
cam_pos = cam_pos [permutation, :]
# apply unit sphere transform
if not no_unit_sphere:
intersections = (intersections + translation) * scale
cam_pos = (cam_pos + translation) * scale
return SingleViewUVScan(
hits = hits,
miss = miss,
points = intersections,
normals = normals,
colors = None, # colors
distances = None,
cam_pos = cam_pos,
cam_mat4 = None,
proj_mat4 = None,
transforms = {},
)
def distance_from_rays_to_point_cloud(
ray_origins : np.ndarray, # (*A, 3)
ray_dirs : np.ndarray, # (*A, 3)
points : np.ndarray, # (*B, 3)
dirs_normalized : bool = False,
n_steps : int = 40,
) -> np.ndarray: # (A)
# anything outside of this volume will never constribute to the result
max_norm = max(
np.linalg.norm(ray_origins, axis=-1).max(),
np.linalg.norm(points, axis=-1).max(),
) * 1.02
if not dirs_normalized:
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
# deal with single-view clouds
if ray_origins.shape != ray_dirs.shape:
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
n_points = np.product(points.shape[:-1])
use_faiss = n_points > 160000*4
if not use_faiss:
index = BallTree(points)
else:
# http://ann-benchmarks.com/index.html
assert np.issubdtype(points.dtype, np.float32)
assert np.issubdtype(ray_origins.dtype, np.float32)
assert np.issubdtype(ray_dirs.dtype, np.float32)
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
index.nprobe = 5 # 10 # default is 1
index.train(points)
index.add(points)
if not use_faiss:
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
else:
min_d, min_n = index.search(ray_origins, k=1)
min_d = np.sqrt(min_d)
acc_d = min_d.copy()
for step in range(1, n_steps+1):
query_points = ray_origins + acc_d * ray_dirs
if max_norm is not None:
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
if not qmask.any(): break
query_points = query_points[qmask]
else:
qmask = slice(None)
if not use_faiss:
current_d, current_n = index.query(query_points, k=1, return_distance=True)
else:
current_d, current_n = index.search(query_points, k=1)
current_d = np.sqrt(current_d)
if max_norm is not None:
min_d[qmask] = np.minimum(current_d, min_d[qmask])
new_min_mask = min_d[qmask] == current_d
qmask2 = qmask.copy()
qmask2[qmask2] = new_min_mask[..., 0]
min_n[qmask2] = current_n[new_min_mask[..., 0]]
acc_d[qmask] += current_d * 0.25
else:
np.minimum(current_d, min_d, out=min_d)
new_min_mask = min_d == current_d
min_n[new_min_mask] = current_n[new_min_mask]
acc_d += current_d * 0.25
closest_points = points[min_n[:, 0], :] # k=1
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
return distances
# helpers
@compose(np.array) # make copy to avoid lru cache mutation
@lru_cache(maxsize=1)
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
indices = np.indices((len(sphere_points),))[0] # (N)
# cartesian product
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
# filter repeated combinations
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
# lookup sphere points
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
return two_sphere
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
"""
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
"""
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
translation = -mesh.bounding_box.centroid
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
if dtype is not None:
translation = translation.astype(dtype)
scale = scale .astype(dtype)
return translation, scale

@ -0,0 +1,6 @@
__doc__ = """
Some helper types.
"""
class MalformedMesh(Exception):
pass

28
ifield/data/config.py Normal file

@ -0,0 +1,28 @@
from ..utils.helpers import make_relative
from pathlib import Path
from typing import Optional
import os
import warnings
def data_path_get(dataset_name: str, no_warn: bool = False) -> Path:
dataset_envvar = f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"
if dataset_envvar in os.environ:
data_path = Path(os.environ[dataset_envvar])
elif "IFIELD_DATA_MODELS" in os.environ:
data_path = Path(os.environ["IFIELD_DATA_MODELS"]) / dataset_name
else:
data_path = Path(__file__).resolve().parent.parent.parent / "data" / "models" / dataset_name
if not data_path.is_dir() and not no_warn:
warnings.warn(f"{make_relative(data_path, Path.cwd()).__str__()!r} is not a directory!")
return data_path
def data_path_persist(dataset_name: Optional[str], path: os.PathLike) -> os.PathLike:
"Persist the datapath, ensuring subprocesses also will use it. The path passes through."
if dataset_name is None:
os.environ["IFIELD_DATA_MODELS"] = str(path)
else:
os.environ[f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"] = str(path)
return path

@ -0,0 +1,56 @@
from ..config import data_path_get, data_path_persist
from collections import namedtuple
import os
# Data source:
# http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
__ALL__ = ["config", "Model", "MODELS"]
Archive = namedtuple("Archive", "url fname download_size_str")
@(lambda x: x()) # singleton
class config:
DATA_PATH = property(
doc = """
Path to the dataset. The following envvars override it:
${IFIELD_DATA_MODELS}/coseg
${IFIELD_DATA_MODELS_COSEG}
""",
fget = lambda self: data_path_get ("coseg"),
fset = lambda self, path: data_path_persist("coseg", path),
)
@property
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
return [
self.DATA_PATH / "downloaded.json",
]
SHAPES: dict[str, Archive] = {
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/shapes.zip", "candelabra-shapes.zip", "3,3M"),
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/shapes.zip", "chair-shapes.zip", "3,2M"),
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/shapes.zip", "four-legged-shapes.zip", "2,9M"),
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/shapes.zip", "goblets-shapes.zip", "500K"),
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip", "guitars-shapes.zip", "1,9M"),
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/shapes.zip", "lampes-shapes.zip", "2,4M"),
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/shapes.zip", "vases-shapes.zip", "5,5M"),
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/shapes.zip", "irons-shapes.zip", "1,2M"),
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/shapes.zip", "tele-aliens-shapes.zip", "15M"),
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/shapes.zip", "large-vases-shapes.zip", "6,2M"),
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/shapes.zip", "large-chairs-shapes.zip", "14M"),
}
GROUND_TRUTHS: dict[str, Archive] = {
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/gt.zip", "candelabra-gt.zip", "68K"),
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/gt.zip", "chair-gt.zip", "20K"),
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/gt.zip", "four-legged-gt.zip", "24K"),
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/gt.zip", "goblets-gt.zip", "4,0K"),
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip", "guitars-gt.zip", "12K"),
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/gt.zip", "lampes-gt.zip", "60K"),
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/gt.zip", "vases-gt.zip", "40K"),
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/gt.zip", "irons-gt.zip", "8,0K"),
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/gt.zip", "tele-aliens-gt.zip", "72K"),
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/gt.zip", "large-vases-gt.zip", "68K"),
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/gt.zip", "large-chairs-gt.zip", "116K"),
}

@ -0,0 +1,135 @@
#!/usr/bin/env python3
from . import config
from ...utils.helpers import make_relative
from ..common import download
from pathlib import Path
from textwrap import dedent
import argparse
import io
import zipfile
def is_downloaded(*a, **kw):
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
def download_and_extract(target_dir: Path, url_dict: dict[str, str], *, force=False, silent=False) -> bool:
target_dir.mkdir(parents=True, exist_ok=True)
ret = False
for url, fname in url_dict.items():
if not force:
if is_downloaded(target_dir, url): continue
if not download.check_url(url):
print("ERROR:", url)
continue
ret = True
if force or not (target_dir / "archives" / fname).is_file():
data = download.download_data(url, silent=silent, label=fname)
assert url.endswith(".zip")
print("writing...")
(target_dir / "archives").mkdir(parents=True, exist_ok=True)
with (target_dir / "archives" / fname).open("wb") as f:
f.write(data)
del data
print(f"extracting {fname}...")
with zipfile.ZipFile(target_dir / "archives" / fname, 'r') as f:
f.extractall(target_dir / Path(fname).stem.removesuffix("-shapes").removesuffix("-gt"))
is_downloaded(target_dir, url, add=True)
return ret
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Download The COSEG Shape Dataset.
More info: http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
Example:
download-coseg --shapes chairs
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument
arg("sets", nargs="*", default=[],
help="Which set to download, defaults to none.")
arg("--all", action="store_true",
help="Download all sets")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--shapes", action="store_true",
help="Download the 3d shapes for each chosen set")
arg("--gts", action="store_true",
help="Download the ground-truth segmentation data for each chosen set")
arg("--list", action="store_true",
help="Lists all the sets")
arg("--list-urls", action="store_true",
help="Lists the urls to download")
arg("--list-sizes", action="store_true",
help="Lists the download size of each set")
arg("--silent", action="store_true",
help="")
arg("--force", action="store_true",
help="Download again even if already downloaded")
return parser
# entrypoint
def cli(parser=make_parser()):
args = parser.parse_args()
assert set(config.SHAPES.keys()) == set(config.GROUND_TRUTHS.keys())
set_names = sorted(set(args.sets))
if args.all:
assert not set_names, "--all is mutually exclusive from manually selected sets"
set_names = sorted(config.SHAPES.keys())
if args.list:
print(*config.SHAPES.keys(), sep="\n")
exit()
if args.list_sizes:
print(*(f"{set_name:<15}{config.SHAPES[set_name].download_size_str}" for set_name in (set_names or config.SHAPES.keys())), sep="\n")
exit()
try:
url_dict \
= {config.SHAPES[set_name].url : config.SHAPES[set_name].fname for set_name in set_names if args.shapes} \
| {config.GROUND_TRUTHS[set_name].url : config.GROUND_TRUTHS[set_name].fname for set_name in set_names if args.gts}
except KeyError:
print("Error: unrecognized object name:", *set(set_names).difference(config.SHAPES.keys()), sep="\n")
exit(1)
if not url_dict:
if set_names and not (args.shapes or args.gts):
print("Error: Provide at least one of --shapes of --gts")
else:
print("Error: No object set was selected for download!")
exit(1)
if args.list_urls:
print(*url_dict.keys(), sep="\n")
exit()
print("Download start")
any_downloaded = download_and_extract(
target_dir = Path(args.dir),
url_dict = url_dict,
force = args.force,
silent = args.silent,
)
if not any_downloaded:
print("Everything has already been downloaded, skipping.")
if __name__ == "__main__":
cli()

@ -0,0 +1,137 @@
#!/usr/bin/env python3
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
from . import config, read
from ...utils.helpers import make_relative
from pathlib import Path
from textwrap import dedent
import argparse
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Preprocess the COSEG dataset. Depends on `download-coseg --shapes ...` having been run.
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument # brevity
arg("items", nargs="*", default=[],
help="Which object-set[/model-id] to process, defaults to all downloaded. Format: OBJECT-SET[/MODEL-ID]")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--force", action="store_true",
help="Overwrite existing files")
arg("--list-models", action="store_true",
help="List the downloaded models available for preprocessing")
arg("--list-object-sets", action="store_true",
help="List the downloaded object-sets available for preprocessing")
arg("--list-pages", type=int, default=None,
help="List the downloaded models available for preprocessing, paginated into N pages.")
arg("--page", nargs=2, type=int, default=[0, 1],
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sphere-scan", action="store_true",
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
arg3 = parser.add_argument_group("modifiers").add_argument # brevity
arg3("--n-sphere-points", type=int, default=4000,
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
arg3("--compute-miss-distances", action="store_true",
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
arg3("--fill-missing-uv-points", action="store_true",
help="TODO")
arg3("--no-filter-backhits", action="store_true",
help="Do not filter scan hits on backside of mesh faces.")
arg3("--no-unit-sphere", action="store_true",
help="Do not center the objects to the unit sphere.")
arg3("--convert-ok", action="store_true",
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
arg3("--debug", action="store_true",
help="Abort on failiure.")
return parser
# entrypoint
def cli(parser=make_parser()):
args = parser.parse_args()
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list_models or args.list_object_sets or args.list_pages):
parser.error("no preprocessing target selected") # exits
config.DATA_PATH = Path(args.dir)
object_sets = [i for i in args.items if "/" not in i]
models = [i.split("/") for i in args.items if "/" in i]
# convert/expand synsets to models
# they are mutually exclusive
if object_sets: assert not models
if models: assert not object_sets
if not models:
models = read.list_model_ids(tuple(object_sets) or None)
if args.list_models:
try:
print(*(f"{object_set_id}/{model_id}" for object_set_id, model_id in models), sep="\n")
except BrokenPipeError:
pass
parser.exit()
if args.list_object_sets:
try:
print(*sorted(set(object_set_id for object_set_id, model_id in models)), sep="\n")
except BrokenPipeError:
pass
parser.exit()
if args.list_pages is not None:
try:
print(*(
f"--page {i} {args.list_pages} {object_set_id}/{model_id}"
for object_set_id, model_id in models
for i in range(args.list_pages)
), sep="\n")
except BrokenPipeError:
pass
parser.exit()
if args.precompute_mesh_sv_scan_clouds:
read.precompute_mesh_scan_point_clouds(
models,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sv_scan_uvs:
read.precompute_mesh_scan_uvs(
models,
compute_miss_distances = args.compute_miss_distances,
fill_missing_points = args.fill_missing_uv_points,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sphere_scan:
read.precompute_mesh_sphere_scan(
models,
sphere_points = args.n_sphere_points,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
page = args.page,
force = args.force,
debug = args.debug,
)
if __name__ == "__main__":
cli()

290
ifield/data/coseg/read.py Normal file

@ -0,0 +1,290 @@
from . import config
from ..common import points
from ..common import processing
from ..common.scan import SingleViewScan, SingleViewUVScan
from ..common.types import MalformedMesh
from functools import lru_cache
from typing import Optional, Iterable
import numpy as np
import trimesh
import trimesh.transformations as T
__doc__ = """
Here are functions for reading and preprocessing coseg benchmark data
There are essentially a few sets per object:
"img" - meaning the RGBD images (none found in coseg)
"mesh_scans" - meaning synthetic scans of a mesh
"""
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0)) # rotate to be upright in pyrender
MESH_POSE_CORRECTIONS = { # to gain a shared canonical orientation
("four-legged", 381): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 382): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
("four-legged", 383): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 384): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 385): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 386): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
("four-legged", 387): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
("four-legged", 388): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 389): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 390): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 391): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 392): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 393): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 394): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 395): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
("four-legged", 396): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
("four-legged", 397): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 398): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 399): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 400): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
}
ModelUid = tuple[str, int]
@lru_cache(maxsize=1)
def list_object_sets() -> list[str]:
return sorted(
object_set.name
for object_set in config.DATA_PATH.iterdir()
if (object_set / "shapes").is_dir() and object_set.name != "archive"
)
@lru_cache(maxsize=1)
def list_model_ids(object_sets: Optional[tuple[str]] = None) -> list[ModelUid]:
return sorted(
(object_set.name, int(model.stem))
for object_set in config.DATA_PATH.iterdir()
if (object_set / "shapes").is_dir() and object_set.name != "archive" and (object_sets is None or object_set.name in object_sets)
for model in (object_set / "shapes").iterdir()
if model.is_file() and model.suffix == ".off"
)
def list_model_id_strings(object_sets: Optional[tuple[str]] = None) -> list[str]:
return [model_uid_to_string(object_set_id, model_id) for object_set_id, model_id in list_model_ids(object_sets)]
def model_uid_to_string(object_set_id: str, model_id: int) -> str:
return f"{object_set_id}-{model_id}"
def model_id_string_to_uid(model_string_uid: str) -> ModelUid:
object_set, split, model = model_string_uid.rpartition("-")
assert split == "-"
return (object_set, int(model))
@lru_cache(maxsize=1)
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
return (
f"{'np'[theta>=0]}{abs(theta):.2f}"
f"{'np'[phi >=0]}{abs(phi) :.2f}"
).replace(".", "d")
@lru_cache(maxsize=1)
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
out = [
mesh_scan_identifier(phi=phi, theta=theta)
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
]
assert len(out) == len(set(out))
return out
# ===
def read_mesh(object_set_id: str, model_id: int) -> trimesh.Trimesh:
path = config.DATA_PATH / object_set_id / "shapes" / f"{model_id}.off"
if not path.is_file():
raise FileNotFoundError(f"{path = }")
try:
mesh = trimesh.load(path, force="mesh")
except Exception as e:
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
pose = MESH_POSE_CORRECTIONS.get((object_set_id, int(model_id)))
mesh.apply_transform(pose @ MESH_TRANSFORM_SKYWARD if pose is not None else MESH_TRANSFORM_SKYWARD)
return mesh
# === single-view scan clouds
def compute_mesh_scan_point_cloud(
object_set_id : str,
model_id : int,
phi : float,
theta : float,
*,
compute_miss_distances : bool = False,
fill_missing_points : bool = False,
compute_normals : bool = True,
convert_ok : bool = False,
**kw,
) -> SingleViewScan:
if convert_ok:
try:
return read_mesh_scan_uv(object_set_id, model_id, phi=phi, theta=theta).to_scan()
except FileNotFoundError:
pass
mesh = read_mesh(object_set_id, model_id)
scan = SingleViewScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
**kw,
)
if compute_miss_distances:
scan.compute_miss_distances()
if fill_missing_points:
scan.fill_missing_points()
return scan
def precompute_mesh_scan_point_clouds(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_point_cloud_h5_fnames(models, pose_identifiers, n_poses=n_poses)
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
pretty_identifiers = [
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for object_set_id, model_id in models
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
def computer(pretty_identifier: str) -> SingleViewScan:
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_point_cloud(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_point_cloud(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
return SingleViewScan.from_h5_file(file)
def list_mesh_scan_point_cloud_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
for object_set_id, model_id in models
for identifier in identifiers
]
# === single-view UV scan clouds
def compute_mesh_scan_uv(
object_set_id : str,
model_id : int,
phi : float,
theta : float,
*,
compute_miss_distances : bool = False,
fill_missing_points : bool = False,
compute_normals : bool = True,
convert_ok : bool = False,
**kw,
) -> SingleViewUVScan:
if convert_ok:
try:
return read_mesh_scan_point_cloud(object_set_id, model_id, phi=phi, theta=theta).to_uv_scan()
except FileNotFoundError:
pass
mesh = read_mesh(object_set_id, model_id)
scan = SingleViewUVScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
**kw,
)
if compute_miss_distances:
scan.compute_miss_distances()
if fill_missing_points:
scan.fill_missing_points()
return scan
def precompute_mesh_scan_uvs(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_uv_h5_fnames(models, pose_identifiers, n_poses=n_poses)
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
pretty_identifiers = [
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for object_set_id, model_id in models
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
def computer(pretty_identifier: str) -> SingleViewUVScan:
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_uv(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_uv(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
return SingleViewUVScan.from_h5_file(file)
def list_mesh_scan_uv_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
for object_set_id, model_id in models
for identifier in identifiers
]
# === sphere-view (UV) scan clouds
def compute_mesh_sphere_scan(
object_set_id : str,
model_id : int,
*,
compute_normals : bool = True,
**kw,
) -> SingleViewUVScan:
mesh = read_mesh(object_set_id, model_id)
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
compute_normals = compute_normals,
**kw,
)
return scan
def precompute_mesh_sphere_scan(models: Iterable[ModelUid], *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
"precomputes all sphere scan clouds and stores them as HDF5 datasets"
paths = list_mesh_sphere_scan_h5_fnames(models)
identifiers = [model_uid_to_string(*i) for i in models]
def computer(identifier: str) -> SingleViewScan:
object_set_id, model_id = model_id_string_to_uid(identifier)
return compute_mesh_sphere_scan(object_set_id, model_id, **kw)
return processing.precompute_data(computer, identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_mesh_sphere_scan(object_set_id: str, model_id: int) -> SingleViewUVScan:
file = config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
return SingleViewUVScan.from_h5_file(file)
def list_mesh_sphere_scan_h5_fnames(models: Iterable[ModelUid]) -> list[str]:
return [
config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
for object_set_id, model_id in models
]

@ -0,0 +1,76 @@
from ..config import data_path_get, data_path_persist
from collections import namedtuple
import os
# Data source:
# http://graphics.stanford.edu/data/3Dscanrep/
__ALL__ = ["config", "Model", "MODELS"]
@(lambda x: x()) # singleton
class config:
DATA_PATH = property(
doc = """
Path to the dataset. The following envvars override it:
${IFIELD_DATA_MODELS}/stanford
${IFIELD_DATA_MODELS_STANFORD}
""",
fget = lambda self: data_path_get ("stanford"),
fset = lambda self, path: data_path_persist("stanford", path),
)
@property
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
return [
self.DATA_PATH / "downloaded.json",
]
Model = namedtuple("Model", "url mesh_fname download_size_str")
MODELS: dict[str, Model] = {
"bunny": Model(
"http://graphics.stanford.edu/pub/3Dscanrep/bunny.tar.gz",
"bunny/reconstruction/bun_zipper.ply",
"4.89M",
),
"drill_bit": Model(
"http://graphics.stanford.edu/pub/3Dscanrep/drill.tar.gz",
"drill/reconstruction/drill_shaft_vrip.ply",
"555k",
),
"happy_buddha": Model(
# religious symbol
"http://graphics.stanford.edu/pub/3Dscanrep/happy/happy_recon.tar.gz",
"happy_recon/happy_vrip.ply",
"14.5M",
),
"dragon": Model(
# symbol of Chinese culture
"http://graphics.stanford.edu/pub/3Dscanrep/dragon/dragon_recon.tar.gz",
"dragon_recon/dragon_vrip.ply",
"11.2M",
),
"armadillo": Model(
"http://graphics.stanford.edu/pub/3Dscanrep/armadillo/Armadillo.ply.gz",
"armadillo.ply.gz",
"3.87M",
),
"lucy": Model(
# Christian angel
"http://graphics.stanford.edu/data/3Dscanrep/lucy.tar.gz",
"lucy.ply",
"322M",
),
"asian_dragon": Model(
# symbol of Chinese culture
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_dragon.ply.gz",
"xyzrgb_dragon.ply.gz",
"70.5M",
),
"thai_statue": Model(
# Hindu religious significance
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_statuette.ply.gz",
"xyzrgb_statuette.ply.gz",
"106M",
),
}

@ -0,0 +1,129 @@
#!/usr/bin/env python3
from . import config
from ...utils.helpers import make_relative
from ..common import download
from pathlib import Path
from textwrap import dedent
from typing import Iterable
import argparse
import io
import tarfile
def is_downloaded(*a, **kw):
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
def download_and_extract(target_dir: Path, url_list: Iterable[str], *, force=False, silent=False) -> bool:
target_dir.mkdir(parents=True, exist_ok=True)
ret = False
for url in url_list:
if not force:
if is_downloaded(target_dir, url): continue
if not download.check_url(url):
print("ERROR:", url)
continue
ret = True
data = download.download_data(url, silent=silent, label=str(Path(url).name))
print("extracting...")
if url.endswith(".ply.gz"):
fname = target_dir / "meshes" / url.split("/")[-1].lower()
fname.parent.mkdir(parents=True, exist_ok=True)
with fname.open("wb") as f:
f.write(data)
elif url.endswith(".tar.gz"):
with tarfile.open(fileobj=io.BytesIO(data)) as tar:
for member in tar.getmembers():
if not member.isfile(): continue
if member.name.startswith("/"): continue
if member.name.startswith("."): continue
if Path(member.name).name.startswith("."): continue
tar.extract(member, target_dir / "meshes")
del tar
else:
raise NotImplementedError(f"Extraction for {str(Path(url).name)} unknown")
is_downloaded(target_dir, url, add=True)
del data
return ret
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Download The Stanford 3D Scanning Repository models.
More info: http://graphics.stanford.edu/data/3Dscanrep/
Example:
download-stanford bunny
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument
arg("objects", nargs="*", default=[],
help="Which objects to download, defaults to none.")
arg("--all", action="store_true",
help="Download all objects")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--list", action="store_true",
help="Lists all the objects")
arg("--list-urls", action="store_true",
help="Lists the urls to download")
arg("--list-sizes", action="store_true",
help="Lists the download size of each model")
arg("--silent", action="store_true",
help="")
arg("--force", action="store_true",
help="Download again even if already downloaded")
return parser
# entrypoint
def cli(parser=make_parser()):
args = parser.parse_args()
obj_names = sorted(set(args.objects))
if args.all:
assert not obj_names
obj_names = sorted(config.MODELS.keys())
if not obj_names and args.list_urls: config.MODELS.keys()
if args.list:
print(*config.MODELS.keys(), sep="\n")
exit()
if args.list_sizes:
print(*(f"{obj_name:<15}{config.MODELS[obj_name].download_size_str}" for obj_name in (obj_names or config.MODELS.keys())), sep="\n")
exit()
try:
url_list = [config.MODELS[obj_name].url for obj_name in obj_names]
except KeyError:
print("Error: unrecognized object name:", *set(obj_names).difference(config.MODELS.keys()), sep="\n")
exit(1)
if not url_list:
print("Error: No object set was selected for download!")
exit(1)
if args.list_urls:
print(*url_list, sep="\n")
exit()
print("Download start")
any_downloaded = download_and_extract(
target_dir = Path(args.dir),
url_list = url_list,
force = args.force,
silent = args.silent,
)
if not any_downloaded:
print("Everything has already been downloaded, skipping.")
if __name__ == "__main__":
cli()

@ -0,0 +1,118 @@
#!/usr/bin/env python3
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
from . import config, read
from ...utils.helpers import make_relative
from pathlib import Path
from textwrap import dedent
import argparse
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Preprocess the Stanford models. Depends on `download-stanford` having been run.
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument # brevity
arg("objects", nargs="*", default=[],
help="Which objects to process, defaults to all downloaded")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--force", action="store_true",
help="Overwrite existing files")
arg("--list", action="store_true",
help="List the downloaded models available for preprocessing")
arg("--list-pages", type=int, default=None,
help="List the downloaded models available for preprocessing, paginated into N pages.")
arg("--page", nargs=2, type=int, default=[0, 1],
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sphere-scan", action="store_true",
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
arg3 = parser.add_argument_group("ray-scan modifiers").add_argument # brevity
arg3("--n-sphere-points", type=int, default=4000,
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
arg3("--compute-miss-distances", action="store_true",
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
arg3("--fill-missing-uv-points", action="store_true",
help="TODO")
arg3("--no-filter-backhits", action="store_true",
help="Do not filter scan hits on backside of mesh faces.")
arg3("--no-unit-sphere", action="store_true",
help="Do not center the objects to the unit sphere.")
arg3("--convert-ok", action="store_true",
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
arg3("--debug", action="store_true",
help="Abort on failiure.")
arg5 = parser.add_argument_group("Shared modifiers").add_argument # brevity
arg5("--scan-resolution", type=int, default=400,
help="The resolution of the depth map rendered to sample points. Becomes x*x")
return parser
# entrypoint
def cli(parser: argparse.ArgumentParser = make_parser()):
args = parser.parse_args()
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list or args.list_pages):
parser.error("no preprocessing target selected") # exits
config.DATA_PATH = Path(args.dir)
obj_names = args.objects or read.list_object_names()
if args.list:
print(*obj_names, sep="\n")
parser.exit()
if args.list_pages is not None:
print(*(
f"--page {i} {args.list_pages} {obj_name}"
for obj_name in obj_names
for i in range(args.list_pages)
), sep="\n")
parser.exit()
if args.precompute_mesh_sv_scan_clouds:
read.precompute_mesh_scan_point_clouds(
obj_names,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sv_scan_uvs:
read.precompute_mesh_scan_uvs(
obj_names,
compute_miss_distances = args.compute_miss_distances,
fill_missing_points = args.fill_missing_uv_points,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sphere_scan:
read.precompute_mesh_sphere_scan(
obj_names,
sphere_points = args.n_sphere_points,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
page = args.page,
force = args.force,
debug = args.debug,
)
if __name__ == "__main__":
cli()

@ -0,0 +1,251 @@
from . import config
from ..common import points
from ..common import processing
from ..common.scan import SingleViewScan, SingleViewUVScan
from ..common.types import MalformedMesh
from functools import lru_cache, wraps
from typing import Optional, Iterable
from pathlib import Path
import gzip
import numpy as np
import trimesh
import trimesh.transformations as T
__doc__ = """
Here are functions for reading and preprocessing shapenet benchmark data
There are essentially a few sets per object:
"img" - meaning the RGBD images (none found in stanford)
"mesh_scans" - meaning synthetic scans of a mesh
"""
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0))
MESH_TRANSFORM_CANONICAL = { # to gain a shared canonical orientation
"armadillo" : T.rotation_matrix(np.pi, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
"asian_dragon" : T.rotation_matrix(-np.pi/2, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
"bunny" : MESH_TRANSFORM_SKYWARD,
"dragon" : MESH_TRANSFORM_SKYWARD,
"drill_bit" : MESH_TRANSFORM_SKYWARD,
"happy_buddha" : MESH_TRANSFORM_SKYWARD,
"lucy" : T.rotation_matrix(np.pi, (0, 0, 1)),
"thai_statue" : MESH_TRANSFORM_SKYWARD,
}
def list_object_names() -> list[str]:
# downloaded only:
return [
i for i, v in config.MODELS.items()
if (config.DATA_PATH / "meshes" / v.mesh_fname).is_file()
]
@lru_cache(maxsize=1)
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)#, shift_theta=True
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
return (
f"{'np'[theta>=0]}{abs(theta):.2f}"
f"{'np'[phi >=0]}{abs(phi) :.2f}"
).replace(".", "d")
@lru_cache(maxsize=1)
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
out = [
mesh_scan_identifier(phi=phi, theta=theta)
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
]
assert len(out) == len(set(out))
return out
# ===
@lru_cache(maxsize=1)
def read_mesh(obj_name: str) -> trimesh.Trimesh:
path = config.DATA_PATH / "meshes" / config.MODELS[obj_name].mesh_fname
if not path.exists():
raise FileNotFoundError(f"{obj_name = } -> {str(path) = }")
try:
if path.suffixes[-1] == ".gz":
with gzip.open(path, "r") as f:
mesh = trimesh.load(f, file_type="".join(path.suffixes[:-1])[1:])
else:
mesh = trimesh.load(path)
except Exception as e:
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
# rotate to be upright in pyrender
mesh.apply_transform(MESH_TRANSFORM_CANONICAL.get(obj_name, MESH_TRANSFORM_SKYWARD))
return mesh
# === single-view scan clouds
def compute_mesh_scan_point_cloud(
obj_name : str,
*,
phi : float,
theta : float,
compute_miss_distances : bool = False,
compute_normals : bool = True,
convert_ok : bool = False, # this does not respect the other hparams
**kw,
) -> SingleViewScan:
if convert_ok:
try:
return read_mesh_scan_uv(obj_name, phi=phi, theta=theta).to_scan()
except FileNotFoundError:
pass
mesh = read_mesh(obj_name)
return SingleViewScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
compute_miss_distances = compute_miss_distances,
**kw,
)
def precompute_mesh_scan_point_clouds(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_point_cloud_h5_fnames(obj_names, pose_identifiers)
mlen = max(map(len, config.MODELS.keys()))
pretty_identifiers = [
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for obj_name in obj_names
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
@wraps(compute_mesh_scan_point_cloud)
def computer(pretty_identifier: str) -> SingleViewScan:
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_point_cloud(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
if not file.exists(): raise FileNotFoundError(str(file))
return SingleViewScan.from_h5_file(file)
def list_mesh_scan_point_cloud_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
for obj_name in obj_names
for identifier in identifiers
]
# === single-view UV scan clouds
def compute_mesh_scan_uv(
obj_name : str,
*,
phi : float,
theta : float,
compute_miss_distances : bool = False,
fill_missing_points : bool = False,
compute_normals : bool = True,
convert_ok : bool = False,
**kw,
) -> SingleViewUVScan:
if convert_ok:
try:
return read_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta).to_uv_scan()
except FileNotFoundError:
pass
mesh = read_mesh(obj_name)
scan = SingleViewUVScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
**kw,
)
if compute_miss_distances:
scan.compute_miss_distances()
if fill_missing_points:
scan.fill_missing_points()
return scan
def precompute_mesh_scan_uvs(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_uv_h5_fnames(obj_names, pose_identifiers)
mlen = max(map(len, config.MODELS.keys()))
pretty_identifiers = [
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for obj_name in obj_names
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
@wraps(compute_mesh_scan_uv)
def computer(pretty_identifier: str) -> SingleViewScan:
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_uv(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_uv(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
if not file.exists(): raise FileNotFoundError(str(file))
return SingleViewUVScan.from_h5_file(file)
def list_mesh_scan_uv_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
for obj_name in obj_names
for identifier in identifiers
]
# === sphere-view (UV) scan clouds
def compute_mesh_sphere_scan(
obj_name : str,
*,
compute_normals : bool = True,
**kw,
) -> SingleViewUVScan:
mesh = read_mesh(obj_name)
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
compute_normals = compute_normals,
**kw,
)
return scan
def precompute_mesh_sphere_scan(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
paths = list_mesh_sphere_scan_h5_fnames(obj_names)
@wraps(compute_mesh_sphere_scan)
def computer(obj_name: str) -> SingleViewScan:
return compute_mesh_sphere_scan(obj_name, **kw)
return processing.precompute_data(computer, obj_names, paths, page=page, force=force, debug=debug)
def read_mesh_mesh_sphere_scan(obj_name) -> SingleViewUVScan:
file = config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
if not file.exists(): raise FileNotFoundError(str(file))
return SingleViewUVScan.from_h5_file(file)
def list_mesh_sphere_scan_h5_fnames(obj_names: Iterable[str]) -> list[Path]:
return [
config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
for obj_name in obj_names
]

@ -0,0 +1,3 @@
__doc__ = """
Submodules defining various `torch.utils.data.Dataset`
"""

196
ifield/datasets/common.py Normal file

@ -0,0 +1,196 @@
from ..data.common.h5_dataclasses import H5Dataclass, PathLike
from torch.utils.data import Dataset, IterableDataset
from typing import Any, Iterable, Hashable, TypeVar, Iterator, Callable
from functools import partial, lru_cache
import inspect
T = TypeVar("T")
T_H5 = TypeVar("T_H5", bound=H5Dataclass)
class TransformableDatasetMixin:
def __init_subclass__(cls):
if getattr(cls, "_transformable_mixin_no_override_getitem", False):
pass
elif issubclass(cls, Dataset):
if cls.__getitem__ is not cls._transformable_mixin_getitem_wrapper:
cls._transformable_mixin_inner_getitem = cls.__getitem__
cls.__getitem__ = cls._transformable_mixin_getitem_wrapper
elif issubclass(cls, IterableDataset):
if cls.__iter__ is not cls._transformable_mixin_iter_wrapper:
cls._transformable_mixin_inner_iter = cls.__iter__
cls.__iter__ = cls._transformable_mixin_iter_wrapper
else:
raise TypeError(f"{cls.__name__!r} is neither a Dataset nor a IterableDataset!")
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self._transforms = []
# works as a decorator
def map(self: T, func: callable = None, /, args=[], **kw) -> T:
def wrapper(func) -> T:
if args or kw:
func = partial(func, *args, **kw)
self._transforms.append(func)
return self
if func is None:
return wrapper
else:
return wrapper(func)
def _transformable_mixin_getitem_wrapper(self, index: int):
if not self._transforms:
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, no transforms)
else:
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, has transforms)
for f in self._transforms:
out = f(out) # (TransformableDatasetMixin)
return out
def _transformable_mixin_iter_wrapper(self):
if not self._transforms:
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, no transforms)
else:
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, has transforms)
for f in self._transforms:
out = map(f, out) # (TransformableDatasetMixin)
return out
class TransformedDataset(Dataset, TransformableDatasetMixin):
# used to wrap an another dataset
def __init__(self, dataset: Dataset, transforms: Iterable[callable]):
super().__init__()
self.dataset = dataset
for i in transforms:
self.map(i)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index: int):
return self.dataset[index] # (TransformedDataset)
class TransformExtendedDataset(Dataset, TransformableDatasetMixin):
_transformable_mixin_no_override_getitem = True
def __init__(self, dataset: Dataset):
super().__init__()
self.dataset = dataset
def __len__(self):
return len(self.dataset) * len(self._transforms)
def __getitem__(self, index: int):
n = len(self._transforms)
assert n > 0, f"{len(self._transforms) = }"
item = index // n
transform = self._transforms[index % n]
return transform(self.dataset[item])
class CachedDataset(Dataset):
# used to wrap an another dataset
def __init__(self, dataset: Dataset, cache_size: int | None):
super().__init__()
self.dataset = dataset
if cache_size is not None and cache_size > 0:
self.cached_getter = lru_cache(cache_size, self.dataset.__getitem__)
else:
self.cached_getter = self.dataset.__getitem__
def __len__(self):
return len(self.dataset)
def __getitem__(self, index: int):
return self.cached_getter(index)
class AutodecoderDataset(Dataset, TransformableDatasetMixin):
def __init__(self,
keys : Iterable[Hashable],
dataset : Dataset,
):
super().__init__()
self.ad_mapping = list(keys)
self.dataset = dataset
if len(self.ad_mapping) != len(dataset):
raise ValueError(f"__len__ mismatch between keys and dataset: {len(self.ad_mapping)} != {len(dataset)}")
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, index: int) -> tuple[Hashable, Any]:
return self.ad_mapping[index], self.dataset[index] # (AutodecoderDataset)
def keys(self) -> list[Hashable]:
return self.ad_mapping
def values(self) -> Iterator:
return iter(self.dataset)
def items(self) -> Iterable[tuple[Hashable, Any]]:
return zip(self.ad_mapping, self.dataset)
class FunctionDataset(Dataset, TransformableDatasetMixin):
def __init__(self,
getter : Callable[[Hashable], T],
keys : list[Hashable],
cache_size : int | None = None,
):
super().__init__()
if cache_size is not None and cache_size > 0:
getter = lru_cache(cache_size)(getter)
self.getter = getter
self.keys = keys
def __len__(self) -> int:
return len(self.keys)
def __getitem__(self, index: int) -> T:
return self.getter(self.keys[index])
class H5Dataset(FunctionDataset):
def __init__(self,
h5_dataclass_cls : type[T_H5],
fnames : list[PathLike],
**kw,
):
super().__init__(
getter = h5_dataclass_cls.from_h5_file,
keys = fnames,
**kw,
)
class PaginatedH5Dataset(Dataset, TransformableDatasetMixin):
def __init__(self,
h5_dataclass_cls : type[T_H5],
fnames : list[PathLike],
n_pages : int = 10,
require_even_pages : bool = True,
):
super().__init__()
self.h5_dataclass_cls = h5_dataclass_cls
self.fnames = fnames
self.n_pages = n_pages
self.require_even_pages = require_even_pages
def __len__(self) -> int:
return len(self.fnames) * self.n_pages
def __getitem__(self, index: int) -> T_H5:
item = index // self.n_pages
page = index % self.n_pages
return self.h5_dataclass_cls.from_h5_file( # (PaginatedH5Dataset)
fname = self.fname[item],
page = page,
n_pages = self.n_pages,
require_even_pages = self.require_even_pages,
)

40
ifield/datasets/coseg.py Normal file

@ -0,0 +1,40 @@
from . import common
from ..data.coseg import config
from ..data.coseg import read
from ..data.common import scan
from typing import Iterable, Optional, Union
import os
class SingleViewUVScanDataset(common.H5Dataset):
def __init__(self,
object_sets : tuple[str],
identifiers : Optional[Iterable[str]] = None,
data_path : Union[str, os.PathLike, None] = None,
):
if not object_sets:
raise ValueError("'object_sets' cannot be empty!")
if identifiers is None:
identifiers = read.list_mesh_scan_identifiers()
if data_path is not None:
config.DATA_PATH = data_path
models = read.list_model_ids(object_sets)
fnames = read.list_mesh_scan_uv_h5_fnames(models, identifiers)
super().__init__(
h5_dataclass_cls = scan.SingleViewUVScan,
fnames = fnames,
)
class AutodecoderSingleViewUVScanDataset(common.AutodecoderDataset):
def __init__(self,
object_sets : tuple[str],
identifiers : Optional[Iterable[str]] = None,
data_path : Union[str, os.PathLike, None] = None,
):
if identifiers is None:
identifiers = read.list_mesh_scan_identifiers()
# here do this step first, such that all the duplicate strings reference the same object
super().__init__(
keys = [key for key in read.list_model_id_strings(object_sets) for _ in range(len(identifiers))],
dataset = SingleViewUVScanDataset(object_sets, identifiers, data_path=data_path),
)

@ -0,0 +1,64 @@
from . import common
from ..data.stanford import config
from ..data.stanford import read
from ..data.common import scan
from typing import Iterable, Optional, Union
import os
class SingleViewUVScanDataset(common.H5Dataset):
def __init__(self,
obj_names : Iterable[str],
identifiers : Optional[Iterable[str]] = None,
data_path : Union[str, os.PathLike, None] = None,
):
if not obj_names:
raise ValueError("'obj_names' cannot be empty!")
if identifiers is None:
identifiers = read.list_mesh_scan_identifiers()
if data_path is not None:
config.DATA_PATH = data_path
fnames = read.list_mesh_scan_uv_h5_fnames(obj_names, identifiers)
super().__init__(
h5_dataclass_cls = scan.SingleViewUVScan,
fnames = fnames,
)
class AutodecoderSingleViewUVScanDataset(common.AutodecoderDataset):
def __init__(self,
obj_names : Iterable[str],
identifiers : Optional[Iterable[str]] = None,
data_path : Union[str, os.PathLike, None] = None,
):
if identifiers is None:
identifiers = read.list_mesh_scan_identifiers()
super().__init__(
keys = [obj_name for obj_name in obj_names for _ in range(len(identifiers))],
dataset = SingleViewUVScanDataset(obj_names, identifiers, data_path=data_path),
)
class SphereScanDataset(common.H5Dataset):
def __init__(self,
obj_names : Iterable[str],
data_path : Union[str, os.PathLike, None] = None,
):
if not obj_names:
raise ValueError("'obj_names' cannot be empty!")
if data_path is not None:
config.DATA_PATH = data_path
fnames = read.list_mesh_sphere_scan_h5_fnames(obj_names)
super().__init__(
h5_dataclass_cls = scan.SingleViewUVScan,
fnames = fnames,
)
class AutodecoderSphereScanDataset(common.AutodecoderDataset):
def __init__(self,
obj_names : Iterable[str],
data_path : Union[str, os.PathLike, None] = None,
):
super().__init__(
keys = obj_names,
dataset = SphereScanDataset(obj_names, data_path=data_path),
)

258
ifield/logging.py Normal file

@ -0,0 +1,258 @@
from . import param
from dataclasses import dataclass
from pathlib import Path
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Union, Literal, Optional, TypeVar
import concurrent.futures
import psutil
import pytorch_lightning as pl
import statistics
import threading
import time
import torch
import yaml
# from https://github.com/yaml/pyyaml/issues/240#issuecomment-1018712495
def str_presenter(dumper, data):
"""configures yaml for dumping multiline strings
Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data"""
if len(data.splitlines()) > 1: # check for multiline string
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
LoggerStr = Literal[
#"csv",
"tensorboard",
#"mlflow",
#"comet",
#"neptune",
#"wandb",
None]
try:
Logger = TypeVar("L", bound=pl.loggers.Logger)
except AttributeError:
Logger = TypeVar("L", bound=pl.loggers.base.LightningLoggerBase)
def make_logger(
experiment_name : str,
default_root_dir : Union[str, Path], # from pl.Trainer
save_dir : Union[str, Path],
type : LoggerStr = "tensorboard",
project : str = "ifield",
) -> Optional[Logger]:
if type is None:
return None
elif type == "tensorboard":
return pl.loggers.TensorBoardLogger(
name = "tensorboard",
save_dir = Path(default_root_dir) / save_dir,
version = experiment_name,
log_graph = True,
)
raise ValueError(f"make_logger({type=})")
def make_jinja_template(*, save_dir: Union[None, str, Path], **kw) -> str:
return param.make_jinja_template(make_logger,
defaults = dict(
save_dir = save_dir,
),
exclude_list = {
"experiment_name",
"default_root_dir",
},
**({"name": "logging"} | kw),
)
def get_checkpoints(experiment_name, default_root_dir, save_dir, type, project) -> list[Path]:
if type is None:
return None
if type == "tensorboard":
folder = Path(default_root_dir) / save_dir / "tensorboard" / experiment_name
return folder.glob("*.ckpt")
if type == "mlflow":
raise NotImplementedError(f"{type=}")
if type == "wandb":
raise NotImplementedError(f"{type=}")
raise ValueError(f"get_checkpoint({type=})")
def log_config(_logger: Logger, **kwargs: Union[str, dict, list, int, float]):
assert isinstance(_logger, pl.loggers.Logger) \
or isinstance(_logger, pl.loggers.base.LightningLoggerBase), _logger
_logger: pl.loggers.TensorBoardLogger
_logger.log_hyperparams(params=kwargs)
@dataclass
class ModelOutputMonitor(pl.callbacks.Callback):
log_training : bool = True
log_validation : bool = True
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
if not trainer.loggers:
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
@staticmethod
def _log_outputs(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, fname: str):
if outputs is None:
return
elif isinstance(outputs, list) or isinstance(outputs, tuple):
outputs = {
f"loss[{i}]": v
for i, v in enumerate(outputs)
}
elif isinstance(outputs, torch.Tensor):
outputs = {
"loss": outputs,
}
elif isinstance(outputs, dict):
pass
else:
raise ValueError
sep = trainer.logger.group_separator
pl_module.log_dict({
f"{pl_module.__class__.__qualname__}.{fname}{sep}{k}":
float(v.item()) if isinstance(v, torch.Tensor) else float(v)
for k, v in outputs.items()
}, sync_dist=True)
def on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, unused=0):
if self.log_training:
self._log_outputs(trainer, pl_module, outputs, "training_step")
def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0):
if self.log_validation:
self._log_outputs(trainer, pl_module, outputs, "validation_step")
class EpochTimeMonitor(pl.callbacks.Callback):
__slots__ = [
"epoch_start",
"epoch_start_train",
"epoch_start_validation",
"epoch_start_test",
"epoch_start_predict",
]
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
if not trainer.loggers:
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
@rank_zero_only
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.epoch_start_train = time.time()
@rank_zero_only
def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.epoch_start_validation = time.time()
@rank_zero_only
def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.epoch_start_test = time.time()
@rank_zero_only
def on_predict_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.epoch_start_predict = time.time()
@rank_zero_only
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
t = time.time() - self.epoch_start_train
del self.epoch_start_train
sep = trainer.logger.group_separator
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_train_time" : t}, step=trainer.global_step)
@rank_zero_only
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
t = time.time() - self.epoch_start_validation
del self.epoch_start_validation
sep = trainer.logger.group_separator
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_validation_time" : t}, step=trainer.global_step)
@rank_zero_only
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
t = time.time() - self.epoch_start_test
del self.epoch_start_validation
sep = trainer.logger.group_separator
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_test_time" : t}, step=trainer.global_step)
@rank_zero_only
def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
t = time.time() - self.epoch_start_predict
del self.epoch_start_validation
sep = trainer.logger.group_separator
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_predict_time" : t}, step=trainer.global_step)
@dataclass
class PsutilMonitor(pl.callbacks.Callback):
sample_rate : float = 0.2 # times per second
_should_stop = False
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
if not trainer.loggers:
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
assert not hasattr(self, "_thread")
self._should_stop = False
self._thread = threading.Thread(
target = self.thread_target,
name = self.thread_target.__qualname__,
args = [trainer],
daemon=True,
)
self._thread.start()
@rank_zero_only
def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
assert getattr(self, "_thread", None) is not None
self._should_stop = True
del self._thread
def thread_target(self, trainer: pl.Trainer):
uses_gpu = isinstance(trainer.accelerator, (pl.accelerators.GPUAccelerator, pl.accelerators.CUDAAccelerator))
gpu_ids = trainer.device_ids
prefix = f"{self.__class__.__qualname__}{trainer.logger.group_separator}"
while not self._should_stop:
step = trainer.global_step
p = psutil.Process()
meminfo = p.memory_info()
rss_ram = meminfo.rss / 1024**2 # MB
vms_ram = meminfo.vms / 1024**2 # MB
util_per_cpu = psutil.cpu_percent(percpu=True)
util_per_cpu = [util_per_cpu[i] for i in p.cpu_affinity()]
util_total = statistics.mean(util_per_cpu)
if uses_gpu:
with concurrent.futures.ThreadPoolExecutor() as e:
if hasattr(pl.accelerators, "cuda"):
gpu_stats = e.map(pl.accelerators.cuda.get_nvidia_gpu_stats, gpu_ids)
else:
gpu_stats = e.map(pl.accelerators.gpu.get_nvidia_gpu_stats, gpu_ids)
trainer.logger.log_metrics({
f"{prefix}ram.rss" : rss_ram,
f"{prefix}ram.vms" : vms_ram,
f"{prefix}cpu.total" : util_total,
**{ f"{prefix}cpu.{i:03}.utilization" : stat for i, stat in enumerate(util_per_cpu) },
**{
f"{prefix}gpu.{gpu_idx:02}.{key.split(' ',1)[0]}" : stat
for gpu_idx, stats in zip(gpu_ids, gpu_stats)
for key, stat in stats.items()
},
}, step = step)
else:
trainer.logger.log_metrics({
f"{prefix}cpu.total" : util_total,
**{ f"{prefix}cpu.{i:03}.utilization" : stat for i, stat in enumerate(util_per_cpu) },
}, step = step)
time.sleep(1 / self.sample_rate)
print("DAEMON END")

@ -0,0 +1,3 @@
__doc__ = """
Contains Pytorch Models
"""

@ -0,0 +1,159 @@
from abc import ABC, abstractmethod
from torch import nn, Tensor
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
from typing import Hashable, Union, Optional, KeysView, ValuesView, ItemsView, Any, Sequence
import torch
class RequiresConditioner(nn.Module, ABC): # mixin
@property
@abstractmethod
def n_latent_features(self) -> int:
"This should provide the width of the conditioning feature vector"
...
@property
@abstractmethod
def latent_embeddings_init_std(self) -> float:
"This should provide the standard deviation to initialize the latent features with. DeepSDF uses 0.01."
...
@property
@abstractmethod
def latent_embeddings() -> Optional[Tensor]:
"""This property should return a tensor cotnaining all stored embeddings, for use in computing auto-decoder losses"""
...
@abstractmethod
def encode(self, batch: Any, batch_idx: int, optimizer_idx: int) -> Tensor:
"This should, given a training batch, return the encoded conditioning vector"
...
class AutoDecoderModuleMixin(RequiresConditioner, ABC):
"""
Populates dunder methods making it behave as a mapping.
The mapping indexes into a stored set of learnable embedding vectors.
Based on the auto-decoder architecture of
J.J. Park, P. Florence, J. Straub, R. Newcombe, S. Lovegrove, DeepSDF:
Learning Continuous Signed Distance Functions for Shape Representation, in:
2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
IEEE, Long Beach, CA, USA, 2019: pp. 165174.
https://doi.org/10.1109/CVPR.2019.00025.
"""
_autodecoder_mapping: dict[Hashable, int]
autodecoder_embeddings: nn.Parameter
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
@self._register_load_state_dict_pre_hook
def hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if f"{prefix}_autodecoder_mapping" in state_dict:
state_dict[f"{prefix}{_EXTRA_STATE_KEY_SUFFIX}"] = state_dict.pop(f"{prefix}_autodecoder_mapping")
class ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing(nn.UninitializedParameter):
def copy_(self, other):
self.materialize(other.shape, other.device, other.dtype)
return self.copy_(other)
self.autodecoder_embeddings = ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing()
# nn.Module interface
def get_extra_state(self):
return {
"ad_uids": getattr(self, "_autodecoder_mapping", {}),
}
def set_extra_state(self, obj):
if "ad_uids" not in obj: # backward compat
self._autodecoder_mapping = obj
else:
self._autodecoder_mapping = obj["ad_uids"]
# RequiresConditioner interface
@property
def latent_embeddings(self) -> Tensor:
return self.autodecoder_embeddings
# my interface
def set_observation_ids(self, z_uids: set[Hashable]):
assert self.latent_embeddings_init_std is not None, f"{self.__module__}.{self.__class__.__qualname__}.latent_embeddings_init_std"
assert self.n_latent_features is not None, f"{self.__module__}.{self.__class__.__qualname__}.n_latent_features"
assert self.latent_embeddings_init_std > 0, self.latent_embeddings_init_std
assert self.n_latent_features > 0, self.n_latent_features
self._autodecoder_mapping = {
k: i
for i, k in enumerate(sorted(set(z_uids)))
}
if not len(z_uids) == len(self._autodecoder_mapping):
raise ValueError(f"Observation identifiers are not unique! {z_uids = }")
self.autodecoder_embeddings = nn.Parameter(
torch.Tensor(len(self._autodecoder_mapping), self.n_latent_features)
.normal_(mean=0, std=self.latent_embeddings_init_std)
.to(self.device, self.dtype)
)
def add_key(self, z_uid: Hashable, z: Optional[Tensor] = None):
if z_uid in self._autodecoder_mapping:
raise ValueError(f"Observation identifier {z_uid!r} not unique!")
self._autodecoder_mapping[z_uid] = len(self._autodecoder_mapping)
self.autodecoder_embeddings
raise NotImplementedError
def __delitem__(self, z_uid: Hashable):
i = self._autodecoder_mapping.pop(z_uid)
for k, v in list(self._autodecoder_mapping.items()):
if v > i:
self._autodecoder_mapping[k] -= 1
with torch.no_grad():
self.autodecoder_embeddings = nn.Parameter(torch.cat((
self.autodecoder_embeddings.detach()[:i, :],
self.autodecoder_embeddings.detach()[i+1:, :],
), dim=0))
def __contains__(self, z_uid: Hashable) -> bool:
return z_uid in self._autodecoder_mapping
def __getitem__(self, z_uids: Union[Hashable, Sequence[Hashable]]) -> Tensor:
if isinstance(z_uids, tuple) or isinstance(z_uids, list):
key = tuple(map(self._autodecoder_mapping.__getitem__, z_uids))
else:
key = self._autodecoder_mapping[z_uids]
return self.autodecoder_embeddings[key, :]
def __iter__(self):
return self._autodecoder_mapping.keys()
def keys(self) -> KeysView[Hashable]:
"""
lists the identifiers of each code
"""
return self._autodecoder_mapping.keys()
def values(self) -> ValuesView[Tensor]:
return list(self.autodecoder_embeddings)
def items(self) -> ItemsView[Hashable, Tensor]:
"""
lists all the learned codes / latent vectors with their identifiers as keys
"""
return {
k : self.autodecoder_embeddings[i]
for k, i in self._autodecoder_mapping.items()
}.items()
class EncoderModuleMixin(RequiresConditioner, ABC):
@property
def latent_embeddings(self) -> None:
return None

@ -0,0 +1,589 @@
from .. import param
from ..modules.dtype import DtypeMixin
from ..utils import geometry
from ..utils.helpers import compose
from ..utils.loss import Schedulable, ensure_schedulables, HParamSchedule, HParamScheduleBase, Linear
from ..utils.operators import diff
from .conditioning import RequiresConditioner, AutoDecoderModuleMixin
from .medial_atoms import MedialAtomNet
from .orthogonal_plane import OrthogonalPlaneNet
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.nn import functional as F
from typing import TypedDict, Literal, Union, Hashable, Optional
import pytorch_lightning as pl
import torch
import os
LOG_ALL_METRICS = bool(int(os.environ.get("IFIELD_LOG_ALL_METRICS", "1")))
if __debug__:
def broadcast_tensors(*tensors: torch.Tensor) -> list[torch.Tensor]:
try:
return torch.broadcast_tensors(*tensors)
except RuntimeError as e:
shapes = ", ".join(f"{chr(c)}.size={tuple(t.shape)}" for c, t in enumerate(tensors, ord("a")))
raise ValueError(f"Could not broadcast tensors {shapes}.\n{str(e)}")
else:
broadcast_tensors = torch.broadcast_tensors
class ForwardDepthMapsBatch(TypedDict):
cam2world : Tensor # (B, 4, 4)
uv : Tensor # (B, H, W)
intrinsics : Tensor # (B, 3, 3)
class ForwardScanRaysBatch(TypedDict):
origins : Tensor # (B, H, W, 3) or (B, 3)
dirs : Tensor # (B, H, W, 3)
class LossBatch(TypedDict):
hits : Tensor # (B, H, W) dtype=bool
miss : Tensor # (B, H, W) dtype=bool
depths : Tensor # (B, H, W)
normals : Tensor # (B, H, W, 3) NaN if not hit
distances : Tensor # (B, H, W, 1) NaN if not miss
class LabeledBatch(TypedDict):
z_uid : list[Hashable]
ForwardBatch = Union[ForwardDepthMapsBatch, ForwardScanRaysBatch]
TrainingBatch = Union[ForwardBatch, LossBatch, LabeledBatch]
IntersectionMode = Literal[
"medial_sphere",
"orthogonal_plane",
]
class IntersectionFieldModel(pl.LightningModule, RequiresConditioner, DtypeMixin):
net: Union[MedialAtomNet, OrthogonalPlaneNet]
@ensure_schedulables
def __init__(self,
# mode
input_mode : geometry.RayEmbedding = "plucker",
output_mode : IntersectionMode = "medial_sphere",
# network
latent_features : int = 256,
hidden_features : int = 512,
hidden_layers : int = 8,
improve_miss_grads: bool = True,
normalize_ray_dirs: bool = False, # the dataset is usually already normalized, but this could still be important for backprop
# orthogonal plane
loss_hit_cross_entropy : Schedulable = 1.0,
# medial atoms
loss_intersection : Schedulable = 1,
loss_intersection_l2 : Schedulable = 0,
loss_intersection_proj : Schedulable = 0,
loss_intersection_proj_l2 : Schedulable = 0,
loss_normal_cossim : Schedulable = 0.25, # supervise target normal cosine similarity
loss_normal_euclid : Schedulable = 0, # supervise target normal l2 distance
loss_normal_cossim_proj : Schedulable = 0, # supervise target normal cosine similarity
loss_normal_euclid_proj : Schedulable = 0, # supervise target normal l2 distance
loss_hit_nodistance_l1 : Schedulable = 0, # constrain no miss distance for hits
loss_hit_nodistance_l2 : Schedulable = 32, # constrain no miss distance for hits
loss_miss_distance_l1 : Schedulable = 0, # supervise target miss distance for misses
loss_miss_distance_l2 : Schedulable = 0, # supervise target miss distance for misses
loss_inscription_hits : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
loss_inscription_hits_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
loss_inscription_miss : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
loss_inscription_miss_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
loss_sphere_grow_reg : Schedulable = 0, # maximialize sphere size
loss_sphere_grow_reg_hit: Schedulable = 0, # maximialize sphere size
loss_embedding_norm : Schedulable = "0.01**2 * Linear(15)", # DeepSDF schedules over 150 epochs. DeepSDF use 0.01**2, irobot uses 0.04**2
loss_multi_view_reg : Schedulable = 0, # minimize gradient w.r.t. delta ray dir, when ray origin = intersection
loss_atom_centroid_norm_std_reg : Schedulable = 0, # minimize per-atom centroid std
# optimization
opt_learning_rate : Schedulable = 1e-5,
opt_weight_decay : float = 0,
opt_warmup : float = 0,
**kw,
):
super().__init__()
opt_warmup = Linear(opt_warmup)
opt_warmup._param_name = "opt_warmup"
self.save_hyperparameters()
if "half" in input_mode:
assert output_mode == "medial_sphere" and kw.get("n_atoms", 1) > 1
assert output_mode in ["medial_sphere", "orthogonal_plane"]
assert opt_weight_decay >= 0, opt_weight_decay
if output_mode == "orthogonal_plane":
self.net = OrthogonalPlaneNet(
in_features = self.n_input_embedding_features,
hidden_layers = hidden_layers,
hidden_features = hidden_features,
latent_features = latent_features,
**kw,
)
elif output_mode == "medial_sphere":
self.net = MedialAtomNet(
in_features = self.n_input_embedding_features,
hidden_layers = hidden_layers,
hidden_features = hidden_features,
latent_features = latent_features,
**kw,
)
def on_fit_start(self):
if __debug__:
for k, v in self.hparams.items():
if isinstance(v, HParamScheduleBase):
v.assert_positive(self.trainer.max_epochs)
@property
def n_input_embedding_features(self) -> int:
return geometry.ray_input_embedding_length(self.hparams.input_mode)
@property
def n_latent_features(self) -> int:
return self.hparams.latent_features
@property
def latent_embeddings_init_std(self) -> float:
return 0.01
@property
def is_conditioned(self):
return self.net.is_conditioned
@property
def is_double_backprop(self) -> bool:
return self.is_double_backprop_origins or self.is_double_backprop_dirs
@property
def is_double_backprop_origins(self) -> bool:
prif = self.hparams.output_mode == "orthogonal_plane"
return prif and self.hparams.loss_normal_cossim
@property
def is_double_backprop_dirs(self) -> bool:
return self.hparams.loss_multi_view_reg
@classmethod
@compose("\n".join)
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
yield param.make_jinja_template(cls, top_level=top_level, **kw)
yield MedialAtomNet.make_jinja_template(top_level=False, exclude_list={
"in_features",
"hidden_layers",
"hidden_features",
"latent_features",
})
def batch2rays(self, batch: ForwardBatch) -> tuple[Tensor, Tensor]:
if "uv" in batch:
raise NotImplementedError
assert not (self.hparams.loss_multi_view_reg and self.training)
ray_origins, \
ray_dirs, \
= geometry.camera_uv_to_rays(
cam2world = batch["cam2world"],
uv = batch["uv"],
intrinsics = batch["intrinsics"],
)
else:
ray_origins = batch["points" if self.hparams.loss_multi_view_reg and self.training else "origins"]
ray_dirs = batch["dirs"]
return ray_origins, ray_dirs
def forward(self,
batch : ForwardBatch,
z : Optional[Tensor] = None, # latent code
*,
return_input : bool = False,
allow_nans : bool = False, # in output
**kw,
) -> tuple[torch.Tensor, ...]:
(
ray_origins, # (B, 3)
ray_dirs, # (B, H, W, 3)
) = self.batch2rays(batch)
# Ensure rays are normalized
# NOTICE: this is slow, make sure to train with optimizations!
assert ray_dirs.detach().norm(dim=-1).allclose(torch.ones(ray_dirs.shape[:-1], **self.device_and_dtype)),\
ray_dirs.detach().norm(dim=-1)
if ray_origins.ndim + 2 == ray_dirs.ndim:
ray_origins = ray_origins[..., None, None, :]
ray_origins, ray_dirs = broadcast_tensors(ray_origins, ray_dirs)
if self.is_double_backprop and self.training:
if self.is_double_backprop_dirs:
ray_dirs.requires_grad = True
if self.is_double_backprop_origins:
ray_origins.requires_grad = True
assert ray_origins.requires_grad or ray_dirs.requires_grad
input = geometry.ray_input_embedding(
ray_origins, ray_dirs,
mode = self.hparams.input_mode,
normalize_dirs = self.hparams.normalize_ray_dirs,
is_training = self.training,
)
assert not input.detach().isnan().any()
predictions = self.net(input, z)
intersections = self.net.compute_intersections(
ray_origins, ray_dirs, predictions,
allow_nans = allow_nans and not self.training, **kw
)
if return_input:
return ray_origins, ray_dirs, input, intersections
else:
return intersections
def training_step(self, batch: TrainingBatch, batch_idx: int, *, is_validation=False) -> Tensor:
z = self.encode(batch) if self.is_conditioned else None
assert self.is_conditioned or len(set(batch["z_uid"])) <= 1, \
f"Network is unconditioned, but the batch has multiple uids: {set(batch['z_uid'])!r}"
# unpack
target_hits = batch["hits"] # (B, H, W) dtype=bool
target_miss = batch["miss"] # (B, H, W) dtype=bool
target_points = batch["points"] # (B, H, W, 3)
target_normals = batch["normals"] # (B, H, W, 3) NaN if not hit
target_distances = batch["distances"] # (B, H, W) NaN if not miss
assert not target_normals [target_hits].isnan().any()
assert not target_distances[target_miss].isnan().any()
target_normals[target_normals.isnan()] = 0
assert not target_normals .isnan().any()
# make z fit batch scheme
if z is not None:
z = z[..., None, None, :]
losses = {}
metrics = {}
zeros = torch.zeros_like(target_distances)
if self.hparams.output_mode == "medial_sphere":
assert isinstance(self.net, MedialAtomNet)
ray_origins, ray_dirs, plucker, (
depths, # (...) float, projection if not hit
silhouettes, # (...) float
intersections, # (..., 3) float, projection or NaN if not hit
intersection_normals, # (..., 3) float, rejection or NaN if not hit
is_intersecting, # (...) bool, true if hit
sphere_centers, # (..., 3) network output
sphere_radii, # (...) network output
atom_indices,
all_intersections, # (..., N_ATOMS) float, projection or NaN if not hit
all_intersection_normals, # (..., N_ATOMS, 3) float, rejection or NaN if not hit
all_depths, # (..., N_ATOMS) float, projection if not hit
all_silhouettes, # (..., N_ATOMS, 3) float, projection or NaN if not hit
all_is_intersecting, # (..., N_ATOMS) bool, true if hit
all_sphere_centers, # (..., N_ATOMS, 3) network output
all_sphere_radii, # (..., N_ATOMS) network output
) = self(batch, z,
intersections_only = False,
return_all_atoms = True,
allow_nans = False,
return_input = True,
improve_miss_grads = True,
)
# target hit supervision
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection: # scores true hits
losses["loss_intersection"] = (
(target_points - intersections).norm(dim=-1)
).where(target_hits & is_intersecting, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2: # scores true hits
losses["loss_intersection_l2"] = (
(target_points - intersections).pow(2).sum(dim=-1)
).where(target_hits & is_intersecting, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj: # scores misses as if they were hits, using the projection
losses["loss_intersection_proj"] = (
(target_points - intersections).norm(dim=-1)
).where(target_hits, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj_l2: # scores misses as if they were hits, using the projection
losses["loss_intersection_proj_l2"] = (
(target_points - intersections).pow(2).sum(dim=-1)
).where(target_hits, zeros).mean()
# target hit normal supervision
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim: # scores true hits
losses["loss_normal_cossim"] = (
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
).where(target_hits & is_intersecting, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid: # scores true hits
losses["loss_normal_euclid"] = (
(target_normals - intersection_normals).norm(dim=-1)
).where(target_hits & is_intersecting, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim_proj: # scores misses as if they were hits
losses["loss_normal_cossim_proj"] = (
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
).where(target_hits, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid_proj: # scores misses as if they were hits
losses["loss_normal_euclid_proj"] = (
(target_normals - intersection_normals).norm(dim=-1)
).where(target_hits, zeros).mean()
# target sufficient hit radius
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l1: # ensures hits become hits, instead of relying on the projection being right
losses["loss_hit_nodistance_l1"] = (
silhouettes
).where(target_hits & (silhouettes > 0), zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l2: # ensures hits become hits, instead of relying on the projection being right
losses["loss_hit_nodistance_l2"] = (
silhouettes
).where(target_hits & (silhouettes > 0), zeros).pow(2).mean()
# target miss supervision
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l1: # only positive misses reinforcement
losses["loss_miss_distance_l1"] = (
target_distances - silhouettes
).where(target_miss, zeros).abs().mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l2: # only positive misses reinforcement
losses["loss_miss_distance_l2"] = (
target_distances - silhouettes
).where(target_miss, zeros).pow(2).mean()
# incentivise maximal spheres
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg: # all atoms
losses["loss_sphere_grow_reg"] = ((all_sphere_radii.detach() + 1) - all_sphere_radii).abs().mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg_hit: # true hits only
losses["loss_sphere_grow_reg_hit"] = ((sphere_radii.detach() + 1) - sphere_radii).where(target_hits & is_intersecting, zeros).abs().mean()
# spherical latent prior
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_embedding_norm:
losses["loss_embedding_norm"] = self.latent_embeddings.norm(dim=-1).mean()
is_grad_enabled = torch.is_grad_enabled()
# multi-view regularization: atom should not change when view changes
if self.hparams.loss_multi_view_reg and is_grad_enabled:
assert ray_dirs.requires_grad, ray_dirs
assert plucker.requires_grad, plucker
assert intersections.grad_fn is not None
assert intersection_normals.grad_fn is not None
*center_grads, radii_grads = diff.gradients(
sphere_centers[..., 0],
sphere_centers[..., 1],
sphere_centers[..., 2],
sphere_radii,
wrt=ray_dirs,
)
losses["loss_multi_view_reg"] = (
sum(
i.pow(2).sum(dim=-1)
for i in center_grads
).where(target_hits & is_intersecting, zeros).mean()
+
radii_grads.pow(2).sum(dim=-1)
.where(target_hits & is_intersecting, zeros).mean()
)
# minimize the volume spanned by each atom
if self.hparams.loss_atom_centroid_norm_std_reg and self.net.n_atoms > 1:
assert len(all_sphere_centers.shape) == 5, all_sphere_centers.shape
losses["loss_atom_centroid_norm_std_reg"] \
= ((
all_sphere_centers
- all_sphere_centers
.mean(dim=(1, 2), keepdim=True)
).pow(2).sum(dim=-1) - 0.05**2).clamp(0, None).mean()
# prif is l1, LSMAT is l2
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits or self.hparams.loss_inscription_miss or self.hparams.loss_inscription_hits_l2 or self.hparams.loss_inscription_miss_l2:
b = target_hits.shape[0] # number of objects
n = target_hits.shape[1:].numel() # rays per object
perm = torch.randperm(n, device=self.device) # ray2ray permutation
flatten = dict(start_dim=1, end_dim=len(target_hits.shape) - 1)
(
inscr_sphere_center_projs, # (b, n, n_atoms, 3)
inscr_intersections_near, # (b, n, n_atoms, 3)
inscr_intersections_far, # (b, n, n_atoms, 3)
inscr_is_intersecting, # (b, n, n_atoms) dtype=bool
) = geometry.ray_sphere_intersect(
ray_origins.flatten(**flatten)[:, perm, None, :],
ray_dirs .flatten(**flatten)[:, perm, None, :],
all_sphere_centers.flatten(**flatten),
all_sphere_radii .flatten(**flatten),
return_parts = True,
allow_nans = False,
improve_miss_grads = self.hparams.improve_miss_grads,
)
assert inscr_sphere_center_projs.shape == (b, n, self.net.n_atoms, 3), \
(inscr_sphere_center_projs.shape, (b, n, self.net.n_atoms, 3))
inscr_silhouettes = (
inscr_sphere_center_projs - all_sphere_centers.flatten(**flatten)
).norm(dim=-1) - all_sphere_radii.flatten(**flatten)
loss_inscription_hits = (
(
(inscr_intersections_near - target_points.flatten(**flatten)[:, perm, None, :])
* ray_dirs.flatten(**flatten)[:, perm, None, :]
).sum(dim=-1)
).where(target_hits.flatten(**flatten)[:, perm, None] & inscr_is_intersecting,
torch.zeros(inscr_intersections_near.shape[:-1], **self.device_and_dtype),
).clamp(None, 0)
loss_inscription_miss = (
inscr_silhouettes - target_distances.flatten(**flatten)[:, perm, None]
).where(target_miss.flatten(**flatten)[:, perm, None],
torch.zeros_like(inscr_silhouettes)
).clamp(None, 0)
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits:
losses["loss_inscription_hits"] = loss_inscription_hits.neg().mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss:
losses["loss_inscription_miss"] = loss_inscription_miss.neg().mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits_l2:
losses["loss_inscription_hits_l2"] = loss_inscription_hits.pow(2).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss_l2:
losses["loss_inscription_miss_l2"] = loss_inscription_miss.pow(2).mean()
# metrics
metrics["iou"] = (
((~target_miss) & is_intersecting.detach()).sum() /
((~target_miss) | is_intersecting.detach()).sum()
)
metrics["radii"] = sphere_radii.detach().mean() # with the constant applied pressure, we need to measure it this way instead
elif self.hparams.output_mode == "orthogonal_plane":
assert isinstance(self.net, OrthogonalPlaneNet)
ray_origins, ray_dirs, input_embedding, (
intersections, # (..., 3) dtype=float
is_intersecting, # (...) dtype=float
) = self(batch, z, return_input=True, normalize_origins=True)
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection:
losses["loss_intersection"] = (
(intersections - target_points).norm(dim=-1)
).where(target_hits, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2:
losses["loss_intersection_l2"] = (
(intersections - target_points).pow(2).sum(dim=-1)
).where(target_hits, zeros).mean()
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_cross_entropy:
losses["loss_hit_cross_entropy"] = (
F.binary_cross_entropy_with_logits(is_intersecting, (~target_miss).to(self.dtype))
).mean()
if self.hparams.loss_normal_cossim and torch.is_grad_enabled():
jac = diff.jacobian(intersections, ray_origins)
intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs)
losses["loss_normal_cossim"] = (
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
).where(target_hits, zeros).mean()
if self.hparams.loss_normal_euclid and torch.is_grad_enabled():
jac = diff.jacobian(intersections, ray_origins)
intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs)
losses["loss_normal_euclid"] = (
(target_normals - intersection_normals).norm(dim=-1)
).where(target_hits, zeros).mean()
if self.hparams.loss_multi_view_reg and torch.is_grad_enabled():
assert ray_dirs .requires_grad, ray_dirs
assert intersections.grad_fn is not None
grads = diff.gradients(
intersections[..., 0],
intersections[..., 1],
intersections[..., 2],
wrt=ray_dirs,
)
losses["loss_multi_view_reg"] = sum(
i.pow(2).sum(dim=-1)
for i in grads
).where(target_hits, zeros).mean()
metrics["iou"] = (
((~target_miss) & (is_intersecting>0.5).detach()).sum() /
((~target_miss) | (is_intersecting>0.5).detach()).sum()
)
else:
raise NotImplementedError(self.hparams.output_mode)
# output losses and metrics
# apply scaling:
losses_unscaled = losses.copy() # shallow copy
for k in list(losses.keys()):
assert losses[k].numel() == 1, f"losses[{k!r}] shape: {losses[k].shape}"
val_schedule: HParamSchedule = self.hparams[k]
val = val_schedule.get(self)
if val == 0:
if (__debug__ or LOG_ALL_METRICS) and val_schedule.is_const:
del losses[k] # it was only added for unscaled logging, do not backprop
else:
losses[k] = 0
elif val != 1:
losses[k] = losses[k] * val
if not losses:
raise MisconfigurationException("no loss was computed")
losses["loss"] = sum(losses.values()) * self.hparams.opt_warmup.get(self)
losses.update({f"unscaled_{k}": v.detach() for k, v in losses_unscaled.items()})
losses.update({f"metric_{k}": v.detach() for k, v in metrics.items()})
return losses
# used by pl.callbacks.EarlyStopping, via cli.py
@property
def metric_early_stop(self): return (
"unscaled_loss_intersection_proj"
if self.hparams.output_mode == "medial_sphere" else
"unscaled_loss_intersection"
)
def validation_step(self, batch: TrainingBatch, batch_idx: int) -> dict[str, Tensor]:
losses = self.training_step(batch, batch_idx, is_validation=True)
return losses
def configure_optimizers(self):
adam = torch.optim.Adam(self.parameters(),
lr=1 if not self.hparams.opt_learning_rate.is_const else self.hparams.opt_learning_rate.get_train_value(0),
weight_decay=self.hparams.opt_weight_decay)
schedules = []
if not self.hparams.opt_learning_rate.is_const:
schedules = [
torch.optim.lr_scheduler.LambdaLR(adam,
lambda epoch: self.hparams.opt_learning_rate.get_train_value(epoch),
),
]
return [adam], schedules
@property
def example_input_array(self) -> tuple[dict[str, Tensor], Tensor]:
return (
{ # see self.batch2rays
"origins" : torch.zeros(1, 3), # most commonly used
"points" : torch.zeros(1, 3), # used if self.training and self.hparams.loss_multi_view_reg
"dirs" : torch.ones(1, 3) * torch.rsqrt(torch.tensor(3)),
},
torch.ones(1, self.hparams.latent_features),
)
@staticmethod
def compute_normals_from_intersection_origin_jacobian(origin_jac: Tensor, ray_dirs: Tensor) -> Tensor:
normals = sum((
torch.cross(origin_jac[..., 0], origin_jac[..., 1], dim=-1) * -ray_dirs[..., [2]],
torch.cross(origin_jac[..., 1], origin_jac[..., 2], dim=-1) * -ray_dirs[..., [0]],
torch.cross(origin_jac[..., 2], origin_jac[..., 0], dim=-1) * -ray_dirs[..., [1]],
))
return normals / normals.norm(dim=-1, keepdim=True)
class IntersectionFieldAutoDecoderModel(IntersectionFieldModel, AutoDecoderModuleMixin):
def encode(self, batch: LabeledBatch) -> Tensor:
assert not isinstance(self.trainer.strategy, pl.strategies.DataParallelStrategy)
return self[batch["z_uid"]] # [N, Z_n]

@ -0,0 +1,186 @@
from .. import param
from ..modules import fc
from ..data.common import points
from ..utils import geometry
from ..utils.helpers import compose
from textwrap import indent, dedent
from torch import nn, Tensor
from typing import Optional
import torch
import warnings
# generalize this into a HypoHyperConcat net? ConditionedNet?
class MedialAtomNet(nn.Module):
def __init__(self,
in_features : int,
latent_features : int,
hidden_features : int,
hidden_layers : int,
n_atoms : int = 1,
final_init_wrr : tuple[float, float] | None = (0.05, 0.6, 0.1),
**kw,
):
super().__init__()
assert n_atoms >= 1, n_atoms
self.n_atoms = n_atoms
self.fc = fc.FCBlock(
in_features = in_features,
hidden_layers = hidden_layers,
hidden_features = hidden_features,
out_features = n_atoms * 4, # n_atoms * (x, y, z, r)
outermost_linear = True,
latent_features = latent_features,
**kw,
)
if final_init_wrr is not None:
with torch.no_grad():
w, r1, r2 = final_init_wrr
if w != 1: self.fc[-1].linear.weight *= w
dtype = self.fc[-1].linear.bias.dtype
self.fc[-1].linear.bias[..., [4*n+i for n in range(n_atoms) for i in range(3)]] = torch.tensor(points.generate_random_sphere_points(n_atoms, radius=r1), dtype=dtype).flatten()
self.fc[-1].linear.bias[..., 3::4] = r2
@property
def is_conditioned(self):
return self.fc.is_conditioned
@classmethod
@compose("\n".join)
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
yield fc.FCBlock.make_jinja_template(top_level=False, exclude_list={
"in_features",
"hidden_layers",
"hidden_features",
"out_features",
"outermost_linear",
"latent_features",
})
def forward(self, x: Tensor, z: Optional[Tensor] = None):
if __debug__ and self.is_conditioned and z is None:
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
return self.fc(x, z)
def compute_intersections(self,
ray_origins : Tensor, # (..., 3)
ray_dirs : Tensor, # (..., 3)
medial_atoms : Tensor, # (..., 4*self.n_atoms)
*,
intersections_only : bool = True,
return_all_atoms : bool = False, # only applies if intersections_only=False
allow_nans : bool = True,
improve_miss_grads : bool = False,
) -> tuple[(Tensor,)*5]:
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == medial_atoms.shape[:-1], \
(ray_origins.shape, ray_dirs.shape, medial_atoms.shape)
assert medial_atoms.shape[-1] % 4 == 0, \
medial_atoms.shape
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
(ray_origins.shape, ray_dirs.shape)
#n_atoms = medial_atoms.shape[-1] // 4
n_atoms = medial_atoms.shape[-1] >> 2
# reshape (..., n_atoms * d) to (..., n_atoms, d)
medial_atoms = medial_atoms.view(*medial_atoms.shape[:-1], n_atoms, 4)
ray_origins = ray_origins.unsqueeze(-2).broadcast_to([*ray_origins.shape[:-1], n_atoms, 3])
ray_dirs = ray_dirs .unsqueeze(-2).broadcast_to([*ray_dirs .shape[:-1], n_atoms, 3])
# unpack atoms
sphere_centers = medial_atoms[..., :3]
sphere_radii = medial_atoms[..., 3].abs()
assert not ray_origins .detach().isnan().any()
assert not ray_dirs .detach().isnan().any()
assert not sphere_centers.detach().isnan().any()
assert not sphere_radii .detach().isnan().any()
# compute intersections
(
sphere_center_projs, # (..., 3)
intersections_near, # (..., 3)
intersections_far, # (..., 3)
is_intersecting, # (...) bool
) = geometry.ray_sphere_intersect(
ray_origins,
ray_dirs,
sphere_centers,
sphere_radii,
return_parts = True,
allow_nans = allow_nans,
improve_miss_grads = improve_miss_grads,
)
# early return
if intersections_only and n_atoms == 1:
return intersections_near.squeeze(-2), is_intersecting.squeeze(-1)
# compute how close each hit and miss are
depths = ((intersections_near - ray_origins) * ray_dirs).sum(-1)
silhouettes = torch.linalg.norm(sphere_center_projs - sphere_centers, dim=-1) - sphere_radii
if return_all_atoms:
intersections_near_all = intersections_near
depths_all = depths
silhouettes_all = silhouettes
is_intersecting_all = is_intersecting
sphere_centers_all = sphere_centers
sphere_radii_all = sphere_radii
# collapse n_atoms
if n_atoms > 1:
atom_indices = torch.where(is_intersecting.any(dim=-1, keepdim=True),
torch.where(is_intersecting, depths.detach(), depths.detach()+100).argmin(dim=-1, keepdim=True),
silhouettes.detach().argmin(dim=-1, keepdim=True),
)
intersections_near = intersections_near.take_along_dim(atom_indices[..., None], -2).squeeze(-2)
depths = depths .take_along_dim(atom_indices, -1).squeeze(-1)
silhouettes = silhouettes .take_along_dim(atom_indices, -1).squeeze(-1)
is_intersecting = is_intersecting .take_along_dim(atom_indices, -1).squeeze(-1)
sphere_centers = sphere_centers .take_along_dim(atom_indices[..., None], -2).squeeze(-2)
sphere_radii = sphere_radii .take_along_dim(atom_indices, -1).squeeze(-1)
else:
atom_indices = None
intersections_near = intersections_near.squeeze(-2)
depths = depths .squeeze(-1)
silhouettes = silhouettes .squeeze(-1)
is_intersecting = is_intersecting .squeeze(-1)
sphere_centers = sphere_centers .squeeze(-2)
sphere_radii = sphere_radii .squeeze(-1)
# early return
if intersections_only:
return intersections_near, is_intersecting
# compute sphere normals
intersection_normals = intersections_near - sphere_centers
intersection_normals = intersection_normals / (intersection_normals.norm(dim=-1)[..., None] + 1e-9)
if return_all_atoms:
intersection_normals_all = intersections_near_all - sphere_centers_all
intersection_normals_all = intersection_normals_all / (intersection_normals_all.norm(dim=-1)[..., None] + 1e-9)
return (
depths, # (...) valid if hit, based on 'intersections'
silhouettes, # (...) always valid
intersections_near, # (..., 3) valid if hit, projection if not
intersection_normals, # (..., 3) valid if hit, rejection if not
is_intersecting, # (...) dtype=bool
sphere_centers, # (..., 3) network output
sphere_radii, # (...) network output
*(() if not return_all_atoms else (
atom_indices,
intersections_near_all, # (..., N_ATOMS) valid if hit, based on 'intersections'
intersection_normals_all, # (..., N_ATOMS, 3) valid if hit, rejection if not
depths_all, # (..., N_ATOMS) always valid
silhouettes_all, # (..., N_ATOMS, 3) valid if hit, projection if not
is_intersecting_all, # (..., N_ATOMS) dtype=bool
sphere_centers_all, # (..., N_ATOMS, 3) network output
sphere_radii_all, # (..., N_ATOMS) network output
)))

@ -0,0 +1,101 @@
from .. import param
from ..modules import fc
from ..utils import geometry
from ..utils.helpers import compose
from textwrap import indent, dedent
from torch import nn, Tensor
from typing import Optional
import warnings
class OrthogonalPlaneNet(nn.Module):
"""
"""
def __init__(self,
in_features : int,
latent_features : int,
hidden_features : int,
hidden_layers : int,
**kw,
):
super().__init__()
self.fc = fc.FCBlock(
in_features = in_features,
hidden_layers = hidden_layers,
hidden_features = hidden_features,
out_features = 2, # (plane_offset, is_intersecting)
outermost_linear = True,
latent_features = latent_features,
**kw,
)
@property
def is_conditioned(self):
return self.fc.is_conditioned
@classmethod
@compose("\n".join)
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
yield param.make_jinja_template(fc.FCBlock, top_level=False, exclude_list={
"in_features",
"hidden_layers",
"hidden_features",
"out_features",
"outermost_linear",
})
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
if __debug__ and self.is_conditioned and z is None:
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
return self.fc(x, z)
@staticmethod
def compute_intersections(
ray_origins : Tensor, # (..., 3)
ray_dirs : Tensor, # (..., 3)
predictions : Tensor, # (..., 2)
*,
normalize_origins = True,
return_signed_displacements = False,
allow_nans = False, # MARF compat
atom_random_prob = None, # MARF compat
atom_dropout_prob = None, # MARF compat
) -> tuple[(Tensor,)*5]:
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == predictions.shape[:-1], \
(ray_origins.shape, ray_dirs.shape, predictions.shape)
assert predictions.shape[-1] == 2, \
predictions.shape
assert not allow_nans
if normalize_origins:
ray_origins = geometry.project_point_on_ray(0, ray_origins, ray_dirs)
# unpack predictions
signed_displacements = predictions[..., 0]
is_intersecting = predictions[..., 1]
# compute intersections
intersections = ray_origins - signed_displacements[..., None] * ray_dirs
return (
intersections,
is_intersecting,
*((signed_displacements,) if return_signed_displacements else ()),
)
OrthogonalPlaneNet.__doc__ = __doc__ = f"""
{dedent(OrthogonalPlaneNet.__doc__).strip()}
# Config template:
```yaml
{OrthogonalPlaneNet.make_jinja_template()}
```
"""

@ -0,0 +1,3 @@
__doc__ = """
Contains Pytorch Modules
"""

22
ifield/modules/dtype.py Normal file

@ -0,0 +1,22 @@
import pytorch_lightning as pl
class DtypeMixin:
def __init_subclass__(cls):
assert issubclass(cls, pl.LightningModule), \
f"{cls.__name__!r} is not a subclass of 'pytorch_lightning.LightningModule'!"
@property
def device_and_dtype(self) -> dict:
"""
Examples:
```
torch.tensor(1337, **self.device_and_dtype)
some_tensor.to(**self.device_and_dtype)
```
"""
return {
"dtype": self.dtype,
"device": self.device,
}

424
ifield/modules/fc.py Normal file

@ -0,0 +1,424 @@
from . import siren
from .. import param
from ..utils.helpers import compose, run_length_encode, MetaModuleProxy
from collections import OrderedDict
from pytorch_lightning.core.mixins import HyperparametersMixin
from torch import nn, Tensor
from torch.nn.utils.weight_norm import WeightNorm
from torchmeta.modules import MetaModule, MetaSequential
from typing import Iterable, Literal, Optional, Union, Callable
import itertools
import math
import torch
__doc__ = """
`fc` is short for "Fully Connected"
"""
def broadcast_tensors_except(*tensors: Tensor, dim: int) -> list[Tensor]:
if dim == -1:
shapes = [ i.shape[:dim] for i in tensors ]
else:
shapes = [ (*i.shape[:dim], i.shape[dim+1:]) for i in tensors ]
target_shape = list(torch.broadcast_shapes(*shapes))
if dim == -1:
target_shape.append(-1)
elif dim < 0:
target_shape.insert(dim+1, -1)
else:
target_shape.insert(dim, -1)
return [ i.broadcast_to(target_shape) for i in tensors ]
EPS = 1e-8
Nonlinearity = Literal[
None,
"relu",
"leaky_relu",
"silu",
"softplus",
"elu",
"selu",
"sine",
"sigmoid",
"tanh",
]
Normalization = Literal[
None,
"batchnorm",
"batchnorm_na",
"layernorm",
"layernorm_na",
"weightnorm",
]
class ReprHyperparametersMixin(HyperparametersMixin):
def extra_repr(self):
this = ", ".join(f"{k}={v!r}" for k, v in self.hparams.items())
rest = super().extra_repr()
if rest:
return f"{this}, {rest}"
else:
return this
class MultilineReprHyperparametersMixin(HyperparametersMixin):
def extra_repr(self):
items = [f"{k}={v!r}" for k, v in self.hparams.items()]
this = "\n".join(
", ".join(filter(bool, i)) + ","
for i in itertools.zip_longest(items[0::3], items[1::3], items[2::3])
)
rest = super().extra_repr()
if rest:
return f"{this}, {rest}"
else:
return this
class BatchLinear(nn.Linear):
"""
A linear (meta-)layer that can deal with batched weight matrices and biases,
as for instance output by a hypernetwork.
"""
__doc__ = nn.Linear.__doc__
_meta_forward_pre_hooks = None
def register_forward_pre_hook(self, hook: Callable) -> torch.utils.hooks.RemovableHandle:
if not isinstance(hook, WeightNorm):
return super().register_forward_pre_hook(hook)
if self._meta_forward_pre_hooks is None:
self._meta_forward_pre_hooks = OrderedDict()
handle = torch.utils.hooks.RemovableHandle(self._meta_forward_pre_hooks)
self._meta_forward_pre_hooks[handle.id] = hook
return handle
def forward(self, input: Tensor, params: Optional[dict[str, Tensor]]=None):
if params is None or not isinstance(self, MetaModule):
params = OrderedDict(self.named_parameters())
if self._meta_forward_pre_hooks is not None:
proxy = MetaModuleProxy(self, params)
for hook in self._meta_forward_pre_hooks.values():
hook(proxy, [input])
weight = params["weight"]
bias = params.get("bias", None)
# transpose weights
weight = weight.permute(*range(len(weight.shape) - 2), -1, -2) # does not jit
output = input.unsqueeze(-2).matmul(weight).squeeze(-2)
if bias is not None:
output = output + bias
return output
class MetaBatchLinear(BatchLinear, MetaModule):
pass
class CallbackConcatLayer(nn.Module):
"A tricky way to enable skip connections in sequentials models"
def __init__(self, tensor_getter: Callable[[], tuple[Tensor, ...]]):
super().__init__()
self.tensor_getter = tensor_getter
def forward(self, x):
ys = self.tensor_getter()
return torch.cat(broadcast_tensors_except(x, *ys, dim=-1), dim=-1)
class ResidualSkipConnectionEndLayer(nn.Module):
"""
Residual skip connections that can be added to a nn.Sequential
"""
class ResidualSkipConnectionStartLayer(nn.Module):
def __init__(self):
super().__init__()
self._stored_tensor = None
def forward(self, x):
assert self._stored_tensor is None
self._stored_tensor = x
return x
def get(self):
assert self._stored_tensor is not None
x = self._stored_tensor
self._stored_tensor = None
return x
def __init__(self):
super().__init__()
self._stored_tensor = None
self._start = self.ResidualSkipConnectionStartLayer()
def forward(self, x):
skip = self._start.get()
return x + skip
@property
def start(self) -> ResidualSkipConnectionStartLayer:
return self._start
@property
def end(self) -> "ResidualSkipConnectionEndLayer":
return self
ResidualMode = Literal[
None,
"identity",
]
class FCLayer(MultilineReprHyperparametersMixin, MetaSequential):
"""
A single fully connected (FC) layer
"""
def __init__(self,
in_features : int,
out_features : int,
*,
nonlinearity : Nonlinearity = "relu",
normalization : Normalization = None,
is_first : bool = False, # used for SIREN initialization
is_final : bool = False, # used for fan_out init
dropout_prob : float = 0.0,
negative_slope : float = 0.01, # only for nonlinearity="leaky_relu", default is normally 0.01
omega_0 : float = 30, # only for nonlinearity="sine"
residual_mode : ResidualMode = None,
_no_meta : bool = False, # set to true in hypernetworks
**_
):
super().__init__()
self.save_hyperparameters()
# improve repr
if nonlinearity != "leaky_relu":
self.hparams.pop("negative_slope")
if nonlinearity != "sine":
self.hparams.pop("omega_0")
Linear = nn.Linear if _no_meta else MetaBatchLinear
def make_layer() -> Iterable[nn.Module]:
# residual start
if residual_mode is not None:
residual_layer = ResidualSkipConnectionEndLayer()
yield "res_a", residual_layer.start
linear = Linear(in_features, out_features)
# initialize
if nonlinearity in {"relu", "leaky_relu", "silu", "softplus"}:
nn.init.kaiming_uniform_(linear.weight, a=negative_slope, nonlinearity=nonlinearity, mode="fan_in" if not is_final else "fan_out")
elif nonlinearity == "elu":
nn.init.normal_(linear.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(linear.weight.size(-1)))
elif nonlinearity == "selu":
nn.init.normal_(linear.weight, std=1 / math.sqrt(linear.weight.size(-1)))
elif nonlinearity == "sine":
siren.init_weights_(linear, omega_0, is_first)
elif nonlinearity in {"sigmoid", "tanh"}:
nn.init.xavier_normal_(linear.weight)
elif nonlinearity is None:
pass # this is effectively uniform(-1/sqrt(in_features), 1/sqrt(in_features))
else:
raise NotImplementedError(nonlinearity)
# linear + normalize
if normalization is None:
yield "linear", linear
elif normalization == "batchnorm":
yield "linear", linear
yield "norm", nn.BatchNorm1d(out_features, affine=True)
elif normalization == "batchnorm_na":
yield "linear", linear
yield "norm", nn.BatchNorm1d(out_features, affine=False)
elif normalization == "layernorm":
yield "linear", linear
yield "norm", nn.LayerNorm([out_features], elementwise_affine=True)
elif normalization == "layernorm_na":
yield "linear", linear
yield "norm", nn.LayerNorm([out_features], elementwise_affine=False)
elif normalization == "weightnorm":
yield "linear", nn.utils.weight_norm(linear)
else:
raise NotImplementedError(normalization)
# activation
inplace = False
if nonlinearity is None : pass
elif nonlinearity == "relu" : yield nonlinearity, nn.ReLU(inplace=inplace)
elif nonlinearity == "leaky_relu" : yield nonlinearity, nn.LeakyReLU(negative_slope=negative_slope, inplace=inplace)
elif nonlinearity == "silu" : yield nonlinearity, nn.SiLU(inplace=inplace)
elif nonlinearity == "softplus" : yield nonlinearity, nn.Softplus()
elif nonlinearity == "elu" : yield nonlinearity, nn.ELU(inplace=inplace)
elif nonlinearity == "selu" : yield nonlinearity, nn.SELU(inplace=inplace)
elif nonlinearity == "sine" : yield nonlinearity, siren.Sine(omega_0)
elif nonlinearity == "sigmoid" : yield nonlinearity, nn.Sigmoid()
elif nonlinearity == "tanh" : yield nonlinearity, nn.Tanh()
else : raise NotImplementedError(f"{nonlinearity=}")
# dropout
if dropout_prob > 0:
if nonlinearity == "selu":
yield "adropout", nn.AlphaDropout(p=dropout_prob)
else:
yield "dropout", nn.Dropout(p=dropout_prob)
# residual end
if residual_mode is not None:
yield "res_b", residual_layer.end
for name, module in make_layer():
self.add_module(name.replace("-", "_"), module)
@property
def nonlinearity(self) -> Optional[nn.Module]:
"alias to the activation function submodule"
if self.hparams.nonlinearity is None:
return None
return getattr(self, self.hparams.nonlinearity.replace("-", "_"))
def initialize_weights():
raise NotImplementedError
class FCBlock(MultilineReprHyperparametersMixin, MetaSequential):
"""
A block of FC layers
"""
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
out_features : int,
normalization : Normalization = None,
nonlinearity : Nonlinearity = "relu",
dropout_prob : float = 0.0,
outermost_linear : bool = True, # whether last linear is nonlinear
latent_features : Optional[int] = None,
concat_skipped_layers : Union[list[int], bool] = [],
concat_conditioned_layers : Union[list[int], bool] = [],
**kw,
):
super().__init__()
self.save_hyperparameters()
if isinstance(concat_skipped_layers, bool):
concat_skipped_layers = list(range(hidden_layers+2)) if concat_skipped_layers else []
if isinstance(concat_conditioned_layers, bool):
concat_conditioned_layers = list(range(hidden_layers+2)) if concat_conditioned_layers else []
if len(concat_conditioned_layers) != 0 and latent_features is None:
raise ValueError("Layers marked to be conditioned without known number of latent features")
concat_skipped_layers = [i if i >= 0 else hidden_layers+2-abs(i) for i in concat_skipped_layers]
concat_conditioned_layers = [i if i >= 0 else hidden_layers+2-abs(i) for i in concat_conditioned_layers]
self._concat_x_layers: frozenset[int] = frozenset(concat_skipped_layers)
self._concat_z_layers: frozenset[int] = frozenset(concat_conditioned_layers)
if len(self._concat_x_layers) != len(concat_skipped_layers):
raise ValueError(f"Duplicates found in {concat_skipped_layers = }")
if len(self._concat_z_layers) != len(concat_conditioned_layers):
raise ValueError(f"Duplicates found in {concat_conditioned_layers = }")
if not all(isinstance(i, int) for i in self._concat_x_layers):
raise TypeError(f"Expected only integers in {concat_skipped_layers = }")
if not all(isinstance(i, int) for i in self._concat_z_layers):
raise TypeError(f"Expected only integers in {concat_conditioned_layers = }")
def make_layers() -> Iterable[nn.Module]:
def make_concat_layer(*idxs: int) -> int:
x_condition_this_layer = any(idx in self._concat_x_layers for idx in idxs)
z_condition_this_layer = any(idx in self._concat_z_layers for idx in idxs)
if x_condition_this_layer and z_condition_this_layer:
yield CallbackConcatLayer(lambda: (self._current_x, self._current_z))
elif x_condition_this_layer:
yield CallbackConcatLayer(lambda: (self._current_x,))
elif z_condition_this_layer:
yield CallbackConcatLayer(lambda: (self._current_z,))
return in_features*x_condition_this_layer + (latent_features or 0)*z_condition_this_layer
added = yield from make_concat_layer(0)
yield FCLayer(
in_features = in_features + added,
out_features = hidden_features,
nonlinearity = nonlinearity,
normalization = normalization,
dropout_prob = dropout_prob,
is_first = True,
is_final = False,
**kw,
)
for i in range(hidden_layers):
added = yield from make_concat_layer(i+1)
yield FCLayer(
in_features = hidden_features + added,
out_features = hidden_features,
nonlinearity = nonlinearity,
normalization = normalization,
dropout_prob = dropout_prob,
is_first = False,
is_final = False,
**kw,
)
added = yield from make_concat_layer(hidden_layers+1)
nl = nonlinearity
yield FCLayer(
in_features = hidden_features + added,
out_features = out_features,
nonlinearity = None if outermost_linear else nl,
normalization = None if outermost_linear else normalization,
dropout_prob = 0.0 if outermost_linear else dropout_prob,
is_first = False,
is_final = True,
**kw,
)
for i, module in enumerate(make_layers()):
self.add_module(str(i), module)
@property
def is_conditioned(self) -> bool:
"Whether z is used or not"
return bool(self._concat_z_layers)
@classmethod
@compose("\n".join)
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
@compose(" ".join)
def as_jexpr(values: Union[list[int]]):
yield "{{"
for val, count in run_length_encode(values):
yield f"[{val!r}]*{count!r}"
yield "}}"
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list)
yield param.make_jinja_template(FCLayer, top_level=False, exclude_list=exclude_list | {
"in_features",
"out_features",
"nonlinearity",
"normalization",
"dropout_prob",
"is_first",
"is_final",
})
def forward(self, input: Tensor, z: Optional[Tensor] = None, *, params: Optional[dict[str, Tensor]]=None):
assert not self.is_conditioned or z is not None
if z is not None and z.ndim < input.ndim:
z = z[(*(None,)*(input.ndim - z.ndim), ...)]
self._current_x = input
self._current_z = z
return super().forward(input, params=params)

25
ifield/modules/siren.py Normal file

@ -0,0 +1,25 @@
from math import sqrt
from torch import nn
import torch
class Sine(nn.Module):
def __init__(self, omega_0: float):
super().__init__()
self.omega_0 = omega_0
def forward(self, input):
if self.omega_0 == 1:
return torch.sin(input)
else:
return torch.sin(input * self.omega_0)
def init_weights_(module: nn.Linear, omega_0: float, is_first: bool = True):
assert isinstance(module, nn.Linear), module
with torch.no_grad():
mag = (
1 / module.in_features
if is_first else
sqrt(6 / module.in_features) / omega_0
)
module.weight.uniform_(-mag, mag)

231
ifield/param.py Normal file

@ -0,0 +1,231 @@
from .utils.helpers import compose, elementwise_max
from datetime import datetime
from torch import nn
from typing import Any, Literal, Iterable, Union, Callable, Optional
import inspect
import jinja2
import json
import os
import random
import re
import shlex
import string
import sys
import time
import typing
import warnings
import yaml
_UNDEFINED = " I AM UNDEFINED "
def _yaml_encode_value(val) -> str:
if isinstance(val, tuple):
val = list(val)
elif isinstance(val, set):
val = list(val)
if isinstance(val, list):
return json.dumps(val)
elif isinstance(val, dict):
return json.dumps(val)
else:
return yaml.dump(val).removesuffix("\n...\n").rstrip("\n")
def _raise(val: Union[Exception, str]):
if isinstance(val, str):
val = jinja2.TemplateError(val)
raise val
def make_jinja_globals(*, enable_require_defined: bool) -> dict:
import builtins
import functools
import itertools
import operator
import json
def require_defined(name, value, *defaults, failed: bool = False, strict: bool=False, exchaustive=False):
if not defaults:
raise ValueError("`require_defined` requires at least one valid value provided")
if jinja2.is_undefined(value):
assert value._undefined_name == name, \
f"Name mismatch: {value._undefined_name=}, {name=}"
if failed or jinja2.is_undefined(value):
if enable_require_defined or strict:
raise ValueError(
f"Required variable {name!r} "
f"is {'incorrect' if failed else 'undefined'}! "
f"Try providing:\n" + "\n".join(
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
for default in defaults
)
)
else:
warnings.warn(
f"Required variable {name!r} "
f"is {'incorrect' if failed else 'undefined'}! "
f"Try providing:\n" + "\n".join(
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
for default in defaults
)
)
if exchaustive and not jinja2.is_undefined(value) and value not in defaults:
raise ValueError(
f"Variable {name!r} not in list of allowed values: {defaults!r}"
)
def gen_run_uid(n: int, _choice = random.Random(time.time_ns()).choice):
"""
generates a UID for the experiment run, nice for regexes, grepping and timekeeping.
"""
# we have _choice, since most likely, pl.seed_everything has been run by this point
# we store it as a default parameter to reuse it, on the off-chance of two calls to this function being run withion the same ns
code = ''.join(_choice(string.ascii_lowercase) for _ in range(n))
return f"{datetime.now():%Y-%m-%d-%H%M}-{code}"
return f"{datetime.now():%Y%m%d-%H%M}-{code}"
def cartesian_hparams(_map=None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
if isinstance(_map, jinja2.utils.Namespace):
kw = _map._Namespace__attrs | kw
elif isinstance(_map, dict):
kw = _map._Namespace__attrs | kw
keys, vals = zip(*kw.items())
for i in itertools.product(*vals):
yield jinja2.utils.Namespace(zip(keys, i))
def ablation_hparams(_map=None, *, caartesian_keys: list[str] = None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
if isinstance(_map, jinja2.utils.Namespace):
kw = _map._Namespace__attrs | kw
elif isinstance(_map, dict):
kw = _map._Namespace__attrs | kw
keys = list(kw.keys())
caartesian_keys = [k for k in keys if k in caartesian_keys] if caartesian_keys else []
ablation_keys = [k for k in keys if k not in caartesian_keys]
caartesian_vals = list(map(kw.__getitem__, caartesian_keys))
ablation_vals = list(map(kw.__getitem__, ablation_keys))
for base_vals in itertools.product(*caartesian_vals):
base = list(itertools.chain(zip(caartesian_keys, base_vals), zip(ablation_keys, [i[0] for i in ablation_vals])))
yield jinja2.utils.Namespace(base)
for ablation_key, ablation_val in zip(ablation_keys, ablation_vals):
for val in ablation_val[1:]:
yield jinja2.utils.Namespace(base, **{ablation_key: val}) # ablation variation
return {
**locals(),
**vars(builtins),
"argv": sys.argv,
"raise": _raise,
}
def make_jinja_env(globals = make_jinja_globals(enable_require_defined=True), allow_undef=False) -> jinja2.Environment:
env = jinja2.Environment(
loader = jinja2.FileSystemLoader([os.getcwd(), "/"], followlinks=True),
autoescape = False,
trim_blocks = True,
lstrip_blocks = True,
undefined = jinja2.Undefined if allow_undef else jinja2.StrictUndefined,
extensions = [
"jinja2.ext.do", # statements with side-effects
"jinja2.ext.loopcontrols", # break and continue
],
)
env.globals.update(globals)
env.filters.update({
"defined": lambda x: _raise(f"{x._undefined_name!r} is not defined!") if jinja2.is_undefined(x) else x,
"repr": repr,
"to_json": json.dumps,
"bool": lambda x: json.dumps(bool(x)),
"int": lambda x: json.dumps(int(x)),
"float": lambda x: json.dumps(float(x)),
"str": lambda x: json.dumps(str(x)),
})
return env
def list_func_params(func: callable, exclude_list: set[str], defaults: dict={}) -> Iterable[tuple[str, Any, str]]:
signature = inspect.signature(func)
for i, (k, v) in enumerate(signature.parameters.items()):
if not i and k in {"self", "cls"}:
continue
if k in exclude_list:
continue
if k.startswith("_"):
continue
if v.kind is v.VAR_POSITIONAL or v.kind is v.VAR_KEYWORD:
continue
has_default = not defaults.get(k, v.default) is v.empty
has_annotation = not v.annotation is v.empty
allowed_literals = f"{{{', '.join(map(_yaml_encode_value, typing.get_args(v.annotation)))}}}" \
if typing.get_origin(v.annotation) is Literal else None
assert has_annotation, f"param {k!r} has no type annotation"
yield (
k,
defaults.get(k, v.default) if has_default else _UNDEFINED,
f"in {allowed_literals}" if allowed_literals else typing._type_repr(v.annotation),
)
@compose("\n".join)
def make_jinja_template(
network_cls: nn.Module,
*,
exclude_list: set[str] = set(),
defaults: dict[str, Any]={},
top_level: bool = True,
commented: bool = False,
name=None,
comment: Optional[str] = None,
special_encoders: dict[str, Callable[[Any], str]]={},
) -> str:
c = "#" if commented else ""
if name is None:
name = network_cls.__name__
if comment is not None:
if "\n" in comment:
raise ValueError("newline in jinja template comment is not allowed")
hparams = [*list_func_params(network_cls, exclude_list, defaults=defaults)]
if not hparams:
if top_level:
yield f"{name}:"
else:
yield f" # {name}:"
return
encoded_hparams = [
(key, _yaml_encode_value(value) if value is not _UNDEFINED else "", comment)
if key not in special_encoders else
(key, special_encoders[key](value) if value is not _UNDEFINED else "", comment)
for key, value, comment in hparams
]
ml_key, ml_value = elementwise_max(
(
len(key),
len(value),
)
for key, value, comment in encoded_hparams
)
if top_level:
yield f"{name}:" if not comment else f"{name}: # {comment}"
else:
yield f" # {name}:" if not comment else f" # {name}: # {comment}"
for key, value, comment in encoded_hparams:
if key in exclude_list:
continue
pad_key = ml_key - len(key)
pad_value = ml_value - len(value)
yield f" {c}{key}{' '*pad_key} : {value}{' '*pad_value} # {comment}"
yield ""
# helpers:
def squash_newlines(data: str) -> str:
return re.sub(r'\n\n\n+', '\n\n', data)

0
ifield/utils/__init__.py Normal file

197
ifield/utils/geometry.py Normal file

@ -0,0 +1,197 @@
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, Literal
import torch
from .helpers import compose
def get_ray_origins(cam2world: Tensor):
return cam2world[..., :3, 3]
def camera_uv_to_rays(
cam2world : Tensor,
uv : Tensor,
intrinsics : Tensor,
) -> tuple[Tensor, Tensor]:
"""
Computes rays and origins from batched cam2world & intrinsics matrices, as well as pixel coordinates
cam2world: (..., 4, 4)
intrinsics: (..., 3, 3)
uv: (..., n, 2)
"""
ray_dirs = get_ray_directions(uv, cam2world=cam2world, intrinsics=intrinsics)
ray_origins = get_ray_origins(cam2world)
ray_origins = ray_origins[..., None, :].expand([*uv.shape[:-1], 3])
return ray_origins, ray_dirs
RayEmbedding = Literal[
"plucker", # LFN
"perp_foot", # PRIF
"both",
]
@compose(torch.cat, dim=-1)
@compose(tuple)
def ray_input_embedding(ray_origins: Tensor, ray_dirs: Tensor, mode: RayEmbedding = "plucker", normalize_dirs=False, is_training=False):
"""
Computes the plucker coordinates / perpendicular foot from ray origins and directions, appending it to direction
"""
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
f"{ray_dirs.shape = }, {ray_origins.shape = }"
if normalize_dirs:
ray_dirs = ray_dirs / ray_dirs.norm(dim=-1, keepdim=True)
yield ray_dirs
do_moment = mode in ("plucker", "both")
do_perp_feet = mode in ("perp_foot", "both")
assert do_moment or do_perp_feet
moment = torch.cross(ray_origins, ray_dirs, dim=-1)
if do_moment:
yield moment
if do_perp_feet:
perp_feet = torch.cross(ray_dirs, moment, dim=-1)
yield perp_feet
def ray_input_embedding_length(mode: RayEmbedding = "plucker") -> int:
do_moment = mode in ("plucker", "both")
do_perp_feet = mode in ("perp_foot", "both")
assert do_moment or do_perp_feet
out = 3 # ray_dirs
if do_moment:
out += 3 # moment
if do_perp_feet:
out += 3 # perp foot
return out
def parse_intrinsics(intrinsics, return_dict=False):
fx = intrinsics[..., 0, 0:1]
fy = intrinsics[..., 1, 1:2]
cx = intrinsics[..., 0, 2:3]
cy = intrinsics[..., 1, 2:3]
if return_dict:
return {"fx": fx, "fy": fy, "cx": cx, "cy": cy}
else:
return fx, fy, cx, cy
def expand_as(x, y):
if len(x.shape) == len(y.shape):
return x
for i in range(len(y.shape) - len(x.shape)):
x = x.unsqueeze(-1)
return x
def lift(x, y, z, intrinsics, homogeneous=False):
"""
:param self:
:param x: Shape (batch_size, num_points)
:param y:
:param z:
:param intrinsics:
:return:
"""
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
if homogeneous:
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
else:
return torch.stack((x_lift, y_lift, z), dim=-1)
def project(x, y, z, intrinsics):
"""
:param self:
:param x: Shape (batch_size, num_points)
:param y:
:param z:
:param intrinsics:
:return:
"""
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_proj = expand_as(fx, x) * x / z + expand_as(cx, x)
y_proj = expand_as(fy, y) * y / z + expand_as(cy, y)
return torch.stack((x_proj, y_proj, z), dim=-1)
def world_from_xy_depth(xy, depth, cam2world, intrinsics):
batch_size, *_ = cam2world.shape
x_cam = xy[..., 0]
y_cam = xy[..., 1]
z_cam = depth
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True)
world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[..., :3]
return world_coords
def project_point_on_ray(projection_point, ray_origin, ray_dir):
dot = torch.einsum("...j,...j", projection_point-ray_origin, ray_dir)
return ray_origin + dot[..., None] * ray_dir
def get_ray_directions(
xy : Tensor, # (..., N, 2)
cam2world : Tensor, # (..., 4, 4)
intrinsics : Tensor, # (..., 3, 3)
):
z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3)
cam_pos = cam2world[..., :3, 3]
ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
ray_dirs = F.normalize(ray_dirs, dim=-1)
return ray_dirs
def ray_sphere_intersect(
ray_origins : Tensor, # (..., 3)
ray_dirs : Tensor, # (..., 3)
sphere_centers : Optional[Tensor] = None, # (..., 3)
sphere_radii : Optional[Tensor] = 1, # (...)
*,
return_parts : bool = False,
allow_nans : bool = True,
improve_miss_grads : bool = False,
) -> tuple[Tensor, ...]:
if improve_miss_grads: assert not allow_nans, "improve_miss_grads does not work with allow_nans"
if sphere_centers is None:
ray_origins_centered = ray_origins #- torch.zeros_like(ray_origins)
else:
ray_origins_centered = ray_origins - sphere_centers
ray_dir_dot_origins = (ray_dirs * ray_origins_centered).sum(dim=-1, keepdim=True)
discriminants2 = ray_dir_dot_origins**2 - ((ray_origins_centered * ray_origins_centered).sum(dim=-1) - sphere_radii**2)[..., None]
if not allow_nans or return_parts:
is_intersecting = discriminants2 > 0
if allow_nans:
discriminants = torch.sqrt(discriminants2)
else:
discriminants = torch.sqrt(torch.where(is_intersecting, discriminants2,
discriminants2 - discriminants2.detach() + 0.001
if improve_miss_grads else
torch.zeros_like(discriminants2)
))
assert not discriminants.detach().isnan().any() # slow, use optimizations!
if not return_parts:
return (
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
)
else:
return (
ray_origins + ray_dirs * (- ray_dir_dot_origins),
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
is_intersecting.squeeze(-1),
)

205
ifield/utils/helpers.py Normal file

@ -0,0 +1,205 @@
from functools import wraps, reduce, partial
from itertools import zip_longest, groupby
from pathlib import Path
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
import collections
import operator
import re
Numeric = Union[int, float, complex]
T = TypeVar("T")
S = TypeVar("S")
# decorator
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
def wrapper(inner_func: Callable[..., S]):
@wraps(inner_func)
def wrapped(*a, **kw):
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
return wrapped
return wrapper
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
def wrapper(inner_func: Callable[..., S]):
@wraps(inner_func)
def wrapped(*a, **kw):
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
return wrapped
return wrapper
# itertools
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
if initial is not None:
return reduce(operator.mul, numbers, initial)
else:
return reduce(operator.mul, numbers)
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
return (
(x, len(y))
for x, y in groupby(data)
)
# text conversion
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
parts = (
part.lower()
for part in re.split(r'(?=[A-Z])', text)
if part
)
if join_abbreviations:
parts = list(parts)
if len(parts) > 1:
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
if len(a) == len(b) == 1:
parts[i] = parts[i] + parts.pop(i+1)
return sep.join(parts)
def snake_to_camel_case(text: str) -> str:
return "".join(
part.captialize()
for part in text.split("_")
if part
)
# textwrap
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
sub = (len(data) + 1) // n_columns
return reduce(partial(columnize, sep=sep),
(
columnize(
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
)
for n in range(n_columns)
)
)
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
left = left .split("\n")
right = right.split("\n")
width = max(map(len, left)) if left else 0
return "\n".join(
f"{prefix}{a.ljust(width)}{sep}{b}"
if b else
f"{prefix}{a}"
for a, b in zip_longest(left, right, fillvalue="")
)
# pathlib
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
if isinstance(path, str):
path = Path(path)
if parent is None:
parent = Path.cwd()
try:
return path.relative_to(parent)
except ValueError:
pass
try:
return ".." / path.relative_to(parent.parent)
except ValueError:
pass
return path
# dictionaries
def update_recursive(target: dict, source: dict):
""" Update two config dictionaries recursively. """
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = type(target)()
update_recursive(target[k], v)
else:
target[k] = v
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
if isinstance(val, collections.abc.Mapping):
return {
k: map_tree(func, subval)
for k, subval in val.items()
}
elif isinstance(val, tuple):
return tuple(
map_tree(func, subval)
for subval in val
)
elif isinstance(val, list):
return [
map_tree(func, subval)
for subval in val
]
else:
return func(val)
def flatten_tree(val, *, sep=".", prefix=None):
if isinstance(val, collections.abc.Mapping):
return {
k: v
for subkey, subval in val.items()
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
}
elif isinstance(val, tuple) or isinstance(val, list):
return {
k: v
for index, subval in enumerate(val)
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
}
elif prefix:
return {prefix: val}
else:
return val
# conversions
def hex2tuple(data: str) -> tuple[int]:
data = data.removeprefix("#")
return (*(
int(data[i:i+2], 16)
for i in range(0, len(data), 2)
),)
# repr shims
class CustomRepr:
def __init__(self, repr_str: str):
self.repr_str = repr_str
def __str__(self):
return self.repr_str
def __repr__(self):
return self.repr_str
# Meta Params Module proxy
class MetaModuleProxy:
def __init__(self, module, params):
self._module = module
self._params = params
def __getattr__(self, name):
params = super().__getattribute__("_params")
if name in params:
return params[name]
else:
return getattr(super().__getattribute__("_module"), name)
def __setattr__(self, name, value):
if name not in ("_params", "_module"):
super().__getattribute__("_params")[name] = value
else:
super().__setattr__(name, value)

590
ifield/utils/loss.py Normal file

@ -0,0 +1,590 @@
from abc import abstractmethod, ABC
from dataclasses import dataclass, field, fields, MISSING
from functools import wraps
from matplotlib import pyplot as plt
from matplotlib.artist import Artist
from tabulate import tabulate
from torch import nn
from typing import Optional, TypeVar, Union
import inspect
import math
import pytorch_lightning as pl
import typing
import warnings
HParamSchedule = TypeVar("HParamSchedule", bound="HParamScheduleBase")
Schedulable = Union[HParamSchedule, int, float, str]
class HParamScheduleBase(ABC):
_subclasses = {} # shared reference intended
def __init_subclass__(cls):
if not cls.__name__.startswith("_"):
cls._subclasses[cls.__name__] = cls
_infix : Optional[str] = field(init=False, repr=False, default=None)
_param_name : Optional[str] = field(init=False, repr=False, default=None)
_expr : Optional[str] = field(init=False, repr=False, default=None)
def get(self, module: nn.Module, *, trainer: Optional[pl.Trainer] = None) -> float:
if module.training:
if trainer is None:
trainer = module.trainer # this assumes `module` is a pl.LightningModule
value = self.get_train_value(
epoch = trainer.current_epoch + (trainer.fit_loop.epoch_loop.batch_progress.current.processed / trainer.num_training_batches),
)
if trainer.logger is not None and self._param_name is not None and self.__class__ is not Const and trainer.global_step % 15 == 0:
trainer.logger.log_metrics({
f"HParamSchedule/{self._param_name}": value,
}, step=trainer.global_step)
return value
else:
return self.get_eval_value()
def _gen_data(self, n_epochs, steps_per_epoch=1000):
global_steps = 0
for epoch in range(n_epochs):
for step in range(steps_per_epoch):
yield (
epoch + step/steps_per_epoch,
self.get_train_value(epoch + step/steps_per_epoch),
)
global_steps += steps_per_epoch
def plot(self, *a, ax: Optional[plt.Axes] = None, **kw) -> Artist:
if ax is None: ax = plt.gca()
out = ax.plot(*zip(*self._gen_data(*a, **kw)), label=self._expr)
ax.set_title(self._param_name)
ax.set_xlabel("Epoch")
ax.set_ylabel("Value")
ax.legend()
return out
def assert_positive(self, *a, **kw):
for epoch, val in self._gen_data(*a, **kw):
assert val >= 0, f"{epoch=}, {val=}"
@abstractmethod
def get_eval_value(self) -> float:
...
@abstractmethod
def get_train_value(self, epoch: float) -> float:
...
def __add__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "+":
return cls(self, rhs)
return NotImplemented
def __radd__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "+":
return cls(lhs, self)
return NotImplemented
def __sub__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(self, rhs)
return NotImplemented
def __rsub__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(lhs, self)
return NotImplemented
def __mul__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "*":
return cls(self, rhs)
return NotImplemented
def __rmul__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "*":
return cls(lhs, self)
return NotImplemented
def __matmul__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "@":
return cls(self, rhs)
return NotImplemented
def __rmatmul__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "@":
return cls(lhs, self)
return NotImplemented
def __truediv__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "/":
return cls(self, rhs)
return NotImplemented
def __rtruediv__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "/":
return cls(lhs, self)
return NotImplemented
def __floordiv__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "//":
return cls(self, rhs)
return NotImplemented
def __rfloordiv__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "//":
return cls(lhs, self)
return NotImplemented
def __mod__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "%":
return cls(self, rhs)
return NotImplemented
def __rmod__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "%":
return cls(lhs, self)
return NotImplemented
def __pow__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "**":
return cls(self, rhs)
return NotImplemented
def __rpow__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "**":
return cls(lhs, self)
return NotImplemented
def __lshift__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<<":
return cls(self, rhs)
return NotImplemented
def __rlshift__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "<<":
return cls(lhs, self)
return NotImplemented
def __rshift__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">>":
return cls(self, rhs)
return NotImplemented
def __rrshift__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == ">>":
return cls(lhs, self)
return NotImplemented
def __and__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "&":
return cls(self, rhs)
return NotImplemented
def __rand__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "&":
return cls(lhs, self)
return NotImplemented
def __xor__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "^":
return cls(self, rhs)
return NotImplemented
def __rxor__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "^":
return cls(lhs, self)
return NotImplemented
def __or__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "|":
return cls(self, rhs)
return NotImplemented
def __ror__(self, lhs):
for cls in self._subclasses.values():
if cls._infix == "|":
return cls(lhs, self)
return NotImplemented
def __ge__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">=":
return cls(self, rhs)
return NotImplemented
def __gt__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == ">":
return cls(self, rhs)
return NotImplemented
def __le__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<=":
return cls(self, rhs)
return NotImplemented
def __lt__(self, rhs):
for cls in self._subclasses.values():
if cls._infix == "<":
return cls(self, rhs)
return NotImplemented
def __bool__(self):
return True
def __neg__(self):
for cls in self._subclasses.values():
if cls._infix == "-":
return cls(0, self)
return NotImplemented
@property
def is_const(self) -> bool:
return False
def parse_dsl(config: Schedulable, name=None) -> HParamSchedule:
if isinstance(config, HParamScheduleBase):
return config
elif isinstance(config, str):
out = eval(config, {"__builtins__": {}, "lg": math.log10}, HParamScheduleBase._subclasses)
if not isinstance(out, HParamScheduleBase):
out = Const(out)
else:
out = Const(config)
out._expr = config
out._param_name = name
return out
# decorator
def ensure_schedulables(func):
signature = inspect.signature(func)
module_name = func.__qualname__.removesuffix(".__init__")
@wraps(func)
def wrapper(*a, **kw):
bound_args = signature.bind(*a, **kw)
for param_name, param in signature.parameters.items():
type_origin = typing.get_origin(param.annotation)
type_args = typing.get_args (param.annotation)
if type_origin is HParamSchedule or (type_origin is Union and (HParamSchedule in type_args or HParamScheduleBase in type_args)):
if param_name in bound_args.arguments:
bound_args.arguments[param_name] = parse_dsl(bound_args.arguments[param_name], name=f"{module_name}.{param_name}")
elif param.default is not param.empty:
bound_args.arguments[param_name] = parse_dsl(param.default, name=f"{module_name}.{param_name}")
return func(
*bound_args.args,
**bound_args.kwargs,
)
return wrapper
# https://easings.net/
@dataclass
class _InfixBase(HParamScheduleBase):
l : Union[HParamSchedule, int, float]
r : Union[HParamSchedule, int, float]
def _operation(self, l: float, r: float) -> float:
raise NotImplementedError
def get_eval_value(self) -> float:
return self._operation(
self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) else self.l,
self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) else self.r,
)
def get_train_value(self, epoch: float) -> float:
return self._operation(
self.l.get_train_value(epoch) if isinstance(self.l, HParamScheduleBase) else self.l,
self.r.get_train_value(epoch) if isinstance(self.r, HParamScheduleBase) else self.r,
)
def __bool__(self):
if self.is_const:
return bool(self.get_eval_value())
else:
return True
@property
def is_const(self) -> bool:
return (self.l.is_const if isinstance(self.l, HParamScheduleBase) else True) \
and (self.r.is_const if isinstance(self.r, HParamScheduleBase) else True)
@dataclass
class Add(_InfixBase):
""" adds the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="+")
def _operation(self, l: float, r: float) -> float:
return l + r
@dataclass
class Sub(_InfixBase):
""" subtracts the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="-")
def _operation(self, l: float, r: float) -> float:
return l - r
@dataclass
class Prod(_InfixBase):
""" multiplies the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="*")
def _operation(self, l: float, r: float) -> float:
return l * r
@property
def is_const(self) -> bool: # propagate identity
l = self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) and self.l.is_const else self.l
r = self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) and self.r.is_const else self.r
return l == 0 or r == 0 or super().is_const
@dataclass
class Div(_InfixBase):
""" divides the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="/")
def _operation(self, l: float, r: float) -> float:
return l / r
@dataclass
class Pow(_InfixBase):
""" raises the results of one schedule to the other """
_infix : Optional[str] = field(init=False, repr=False, default="**")
def _operation(self, l: float, r: float) -> float:
return l ** r
@dataclass
class Gt(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default=">")
def _operation(self, l: float, r: float) -> float:
return l > r
@dataclass
class Lt(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="<")
def _operation(self, l: float, r: float) -> float:
return l < r
@dataclass
class Ge(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default=">=")
def _operation(self, l: float, r: float) -> float:
return l >= r
@dataclass
class Le(_InfixBase):
""" compares the results of two schedules """
_infix : Optional[str] = field(init=False, repr=False, default="<=")
def _operation(self, l: float, r: float) -> float:
return l <= r
@dataclass
class Const(HParamScheduleBase):
""" A way to ensure .get(...) exists """
c : Union[int, float]
def get_eval_value(self) -> float:
return self.c
def get_train_value(self, epoch: float) -> float:
return self.c
def __bool__(self):
return bool(self.get_eval_value())
@property
def is_const(self) -> bool:
return True
@dataclass
class Step(HParamScheduleBase):
""" steps from 0 to 1 at `epoch` """
epoch : float
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
return 1 if epoch >= self.epoch else 0
@dataclass
class Linear(HParamScheduleBase):
""" linear from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
if self.n_epochs <= 0: return 1
return min(max(epoch - self.offset, 0) / self.n_epochs, 1)
@dataclass
class EaseSin(HParamScheduleBase): # effectively 1-CosineAnnealing
""" sinusoidal ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
return -(math.cos(math.pi * x) - 1) / 2
@dataclass
class EaseExp(HParamScheduleBase):
""" exponential ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
n_epochs : float
offset : float = 0
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
if (epoch-self.offset) < 0:
return 0
if (epoch-self.offset) > self.n_epochs:
return 1
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
return (
2**(20*x-10) / 2
if x < 0.5 else
(2 - 2**(-20*x+10)) / 2
)
@dataclass
class Steps(HParamScheduleBase):
""" Starts at 1, multiply by gamma every n epochs. Models StepLR in pytorch """
step_size: float
gamma: float = 0.1
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
return self.gamma**int(epoch / self.step_size)
@dataclass
class MultiStep(HParamScheduleBase):
""" Starts at 1, multiply by gamma every milstone epoch. Models MultiStepLR in pytorch """
milestones: list[float]
gamma: float = 0.1
def get_eval_value(self) -> float:
return 1
def get_train_value(self, epoch: float) -> float:
for i, m in list(enumerate(self.milestones))[::-1]:
if epoch >= m:
return self.gamma**(i+1)
return 1
@dataclass
class Epoch(HParamScheduleBase):
""" The current epoch, starting at 0 """
def get_eval_value(self) -> float:
return 0
def get_train_value(self, epoch: float) -> float:
return epoch
@dataclass
class Offset(HParamScheduleBase):
""" Offsets the epoch for the subexpression, clamped above 0. Positive offsets makes it happen later """
expr : Union[HParamSchedule, int, float]
offset : float
def get_eval_value(self) -> float:
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
def get_train_value(self, epoch: float) -> float:
return self.expr.get_train_value(max(epoch - self.offset, 0)) if isinstance(self.expr, HParamScheduleBase) else self.expr
@dataclass
class Mod(HParamScheduleBase):
""" The epoch in the subexptression is subject to a modulus. Use for warm restarts """
modulus : float
expr : Union[HParamSchedule, int, float]
def get_eval_value(self) -> float:
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
def get_train_value(self, epoch: float) -> float:
return self.expr.get_train_value(epoch % self.modulus) if isinstance(self.expr, HParamScheduleBase) else self.expr
def main():
import sys, rich.pretty
if not sys.argv[2:]:
print(f"Usage: {sys.argv[0]} n_epochs 'expression'")
print("Available operations:")
def mk_ops():
for name, cls in HParamScheduleBase._subclasses.items():
if isinstance(cls._infix, str):
yield (cls._infix, f"(infix) {cls.__doc__.strip()}")
else:
yield (
f"""{name}({', '.join(
i.name
if i.default is MISSING else
f"{i.name}={i.default!r}"
for i in fields(cls)
)})""",
cls.__doc__.strip(),
)
rich.print(tabulate(sorted(mk_ops()), tablefmt="plain"))
else:
n_epochs = int(sys.argv[1])
schedules = [parse_dsl(arg, name="cli arg") for arg in sys.argv[2:]]
ax = plt.gca()
print("[")
for schedule in schedules:
rich.print(f" {schedule}, # {schedule.is_const = }")
schedule.plot(n_epochs, ax=ax)
print("]")
plt.show()
if __name__ == "__main__":
main()

@ -0,0 +1,96 @@
import torch
from torch.autograd import grad
def hessian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
"""
hessian of y wrt x
y: shape (..., Y)
x: shape (..., X)
return: shape (..., Y, X, X)
"""
assert x.requires_grad
assert y.grad_fn
grad_y = torch.ones_like(y[..., 0]).to(y.device) # reuse -> less memory
hess = torch.stack([
# calculate hessian on y for each x value
torch.stack(
gradients(
*(dydx[..., j] for j in range(x.shape[-1])),
wrt=x,
grad_outputs=[grad_y]*x.shape[-1],
detach=detach,
),
dim = -2,
)
# calculate dydx over batches for each feature value of y
for dydx in gradients(*(y[..., i] for i in range(y.shape[-1])), wrt=x)
], dim=-3)
if check:
assert hess.isnan().any()
return hess
def laplace(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return divergence(*gradients(y, wrt=x), x)
def divergence(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
assert x.requires_grad
assert y.grad_fn
return sum(
grad(
y[..., i],
x,
torch.ones_like(y[..., i]),
create_graph=True
)[0][..., i:i+1]
for i in range(y.shape[-1])
)
def gradients(*ys, wrt, grad_outputs=None, detach=False) -> list[torch.Tensor]:
assert wrt.requires_grad
assert all(y.grad_fn for y in ys)
if grad_outputs is None:
grad_outputs = [torch.ones_like(y) for y in ys]
grads = (
grad(
[y],
[wrt],
grad_outputs=y_grad,
create_graph=True,
)[0]
for y, y_grad in zip(ys, grad_outputs)
)
if detach:
grads = map(torch.detach, grads)
return [*grads]
def jacobian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
"""
jacobian of `y` w.r.t. `x`
y: shape (..., Y)
x: shape (..., X)
return: shape (..., Y, X)
"""
assert x.requires_grad
assert y.grad_fn
y_grad = torch.ones_like(y[..., 0])
jac = torch.stack(
gradients(
*(y[..., i] for i in range(y.shape[-1])),
wrt=x,
grad_outputs=[y_grad]*x.shape[-1],
detach=detach,
),
dim=-2,
)
if check:
assert jac.isnan().any()
return jac

Binary file not shown.

After

(image error) Size: 944 KiB

430
ifield/viewer/common.py Normal file

@ -0,0 +1,430 @@
from ..utils import geometry
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from pytorch3d.transforms import euler_angles_to_matrix
from tqdm import tqdm
from typing import Sequence, Callable, TypedDict
import imageio
import shlex
import json
import numpy as np
import os
import time
import torch
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
import pygame
IVec2 = tuple[int, int]
IVec3 = tuple[int, int, int]
Vec2 = tuple[float|int, float|int]
Vec3 = tuple[float|int, float|int, float|int]
class CamState(TypedDict, total=False):
distance : float
pos_x : float
pos_y : float
pos_z : float
rot_x : float
rot_y : float
fov_y : float
class InteractiveViewer(ABC):
constants = pygame.constants # saves an import
# realtime
t : float # time since start
td : float # time delta since last frame
# offline
is_headless : bool
fps : int
frame_idx : int
fill_color = (255, 255, 255)
def __init__(self, name: str, res: IVec2 = (640, 480), scale: int= 1, screenshot_dir: Path = "."):
self.name = name
self.res = res
self.scale = scale
self.screenshot_dir = Path(screenshot_dir)
self.is_headless = False
self.cam_distance = 2.0
self.cam_pos_x = 0.0 # look-at and rotation pivot
self.cam_pos_y = 0.0 # look-at and rotation pivot
self.cam_pos_z = 0.0 # look-at and rotation pivot
self.cam_rot_x = 0.5 * torch.pi # radians
self.cam_rot_y = -0.5 * torch.pi # radians
self.cam_fov_y = 60.0 / 180.0 * 3.1415 # radians
self.keep_rotating = False
self.initial_camera_state = self.cam_state
self.fps_cap = None
@property
def cam_state(self) -> CamState:
return dict(
distance = self.cam_distance,
pos_x = self.cam_pos_x,
pos_y = self.cam_pos_y,
pos_z = self.cam_pos_z,
rot_x = self.cam_rot_x,
rot_y = self.cam_rot_y,
fov_y = self.cam_fov_y,
)
@cam_state.setter
def cam_state(self, new_state: CamState):
self.cam_distance = new_state.get("distance", self.cam_distance)
self.cam_pos_x = new_state.get("pos_x", self.cam_pos_x)
self.cam_pos_y = new_state.get("pos_y", self.cam_pos_y)
self.cam_pos_z = new_state.get("pos_z", self.cam_pos_z)
self.cam_rot_x = new_state.get("rot_x", self.cam_rot_x)
self.cam_rot_y = new_state.get("rot_y", self.cam_rot_y)
self.cam_fov_y = new_state.get("fov_y", self.cam_fov_y)
@property
def scaled_res(self) -> IVec2:
return (
self.res[0] * self.scale,
self.res[1] * self.scale,
)
def setup(self):
pass
def teardown(self):
pass
@abstractmethod
def render_frame(self, pixel_view: np.ndarray): # (W, H, 3) dtype=uint8
...
def handle_key_up(self, key: int, keys_pressed: Sequence[bool]):
pass
def handle_key_down(self, key: int, keys_pressed: Sequence[bool]):
mod = keys_pressed[pygame.K_LSHIFT] or keys_pressed[pygame.K_RSHIFT]
mod2 = keys_pressed[pygame.K_LCTRL] or keys_pressed[pygame.K_RCTRL]
if key == pygame.K_r:
self.keep_rotating = True
self.cam_rot_x += self.td
if key == pygame.K_MINUS:
self.scale += 1
if __debug__: print()
print(f"== Scale = {self.scale} ==")
if key == pygame.K_PLUS and self.scale > 1:
self.scale -= 1
if __debug__: print()
print(f"== Scale = {self.scale} ==")
if key == pygame.K_RETURN:
self.cam_state = self.initial_camera_state
if key == pygame.K_h:
if mod2:
print(shlex.quote(json.dumps(self.cam_state)))
elif mod:
with (self.screenshot_dir / "camera.json").open("w") as f:
json.dump(self.cam_state, f)
print("Wrote", self.screenshot_dir / "camera.json")
else:
with (self.screenshot_dir / "camera.json").open("r") as f:
self.cam_state = json.load(f)
print("Read", self.screenshot_dir / "camera.json")
def handle_keys_pressed(self, pressed: Sequence[bool]) -> float:
mod1 = pressed[pygame.K_LCTRL] or pressed[pygame.K_RCTRL]
mod2 = pressed[pygame.K_LSHIFT] or pressed[pygame.K_RSHIFT]
mod3 = pressed[pygame.K_LALT] or pressed[pygame.K_RALT]
td = self.td * (0.5 if mod2 else (6 if mod1 else 2))
if pressed[pygame.K_UP]: self.cam_rot_y += td
if pressed[pygame.K_DOWN]: self.cam_rot_y -= td
if pressed[pygame.K_LEFT]: self.cam_rot_x += td
if pressed[pygame.K_RIGHT]: self.cam_rot_x -= td
if pressed[pygame.K_PAGEUP] and mod3: self.cam_distance -= td
if pressed[pygame.K_PAGEDOWN] and mod3: self.cam_distance += td
if any(pressed[i] for i in [pygame.K_UP, pygame.K_DOWN, pygame.K_LEFT, pygame.K_RIGHT]):
self.keep_rotating = False
if self.keep_rotating: self.cam_rot_x += self.td * 0.25
if pressed[pygame.K_w]: self.cam_pos_x -= td * np.cos(-self.cam_rot_x)
if pressed[pygame.K_w]: self.cam_pos_y += td * np.sin(-self.cam_rot_x)
if pressed[pygame.K_s]: self.cam_pos_x += td * np.cos(-self.cam_rot_x)
if pressed[pygame.K_s]: self.cam_pos_y -= td * np.sin(-self.cam_rot_x)
if pressed[pygame.K_a]: self.cam_pos_x += td * np.sin(self.cam_rot_x)
if pressed[pygame.K_a]: self.cam_pos_y -= td * np.cos(self.cam_rot_x)
if pressed[pygame.K_d]: self.cam_pos_x -= td * np.sin(self.cam_rot_x)
if pressed[pygame.K_d]: self.cam_pos_y += td * np.cos(self.cam_rot_x)
if pressed[pygame.K_PAGEUP] and not mod3: self.cam_pos_z -= td
if pressed[pygame.K_PAGEDOWN] and not mod3: self.cam_pos_z += td
return td
def handle_mouse_button_up(self, pos: IVec2, button: int, keys_pressed: Sequence[bool]):
pass
def handle_mouse_button_down(self, pos: IVec2, button: int, keys_pressed: Sequence[bool]):
pass
def handle_mouse_motion(self, pos: IVec2, rel: IVec2, buttons: Sequence[bool], keys_pressed: Sequence[bool]):
pass
def handle_mousewheel(self, flipped: bool, x: int, y: int, keys_pressed: Sequence[bool]):
if keys_pressed[pygame.K_LALT] or keys_pressed[pygame.K_RALT]:
self.cam_fov_y -= y * 0.015
else:
self.cam_distance -= y * 0.2
_current_caption = None
def set_caption(self, title: str, *a, **kw):
if self._current_caption != title and not self.is_headless:
print(f"set_caption: {title!r}")
self._current_caption = title
return pygame.display.set_caption(title, *a, **kw)
@property
def mouse_position(self) -> IVec2:
mx, my = pygame.mouse.get_pos() if not self.is_headless else (0, 0)
return (
mx // self.scale,
my // self.scale,
)
@property
def uvs(self) -> torch.Tensor: # (w, h, 2) dtype=float32
res = tuple(self.res)
if not getattr(self, "_uvs_res", None) == res:
U, V = torch.meshgrid(
torch.arange(self.res[1]).to(torch.float32),
torch.arange(self.res[0]).to(torch.float32),
indexing="xy",
)
self._uvs_res, self._uvs = res, torch.stack((U, V), dim=-1)
return self._uvs
@property
def cam2world(self) -> torch.Tensor: # (4, 4) dtype=float32
if getattr(self, "_cam2world_cam_rot_y", None) is not self.cam_rot_y \
or getattr(self, "_cam2world_cam_rot_x", None) is not self.cam_rot_x \
or getattr(self, "_cam2world_cam_pos_x", None) is not self.cam_pos_x \
or getattr(self, "_cam2world_cam_pos_y", None) is not self.cam_pos_y \
or getattr(self, "_cam2world_cam_pos_z", None) is not self.cam_pos_z \
or getattr(self, "_cam2world_cam_distance", None) is not self.cam_distance:
self._cam2world_cam_rot_y = self.cam_rot_y
self._cam2world_cam_rot_x = self.cam_rot_x
self._cam2world_cam_pos_x = self.cam_pos_x
self._cam2world_cam_pos_y = self.cam_pos_y
self._cam2world_cam_pos_z = self.cam_pos_z
self._cam2world_cam_distance = self.cam_distance
a = torch.eye(4)
a[2, 3] = self.cam_distance
b = torch.eye(4)
b[:3, :3] = euler_angles_to_matrix(torch.tensor((self.cam_rot_x, self.cam_rot_y, 0)), "ZYX")
b[0:3, 3] -= torch.tensor(( self.cam_pos_x, self.cam_pos_y, self.cam_pos_z, ))
self._cam2world = b @ a
self._cam2world_inv = None
return self._cam2world
@property
def cam2world_inv(self) -> torch.Tensor: # (4, 4) dtype=float32
if getattr(self, "_cam2world_inv", None) is None:
self._cam2world_inv = torch.linalg.inv(self._cam2world)
return self._cam2world_inv
@property
def intrinsics(self) -> torch.Tensor: # (3, 3) dtype=float32
if getattr(self, "_intrinsics_res", None) is not self.res \
or getattr(self, "_intrinsics_cam_fov_y", None) is not self.cam_fov_y:
self._intrinsics_res = res = self.res
self._intrinsics_cam_fov_y = cam_fov_y = self.cam_fov_y
self._intrinsics = torch.eye(3)
p = torch.sin(torch.tensor(cam_fov_y / 2))
s = (res[1] / 2)
self._intrinsics[0, 0] = s/p # fx - focal length x
self._intrinsics[1, 1] = s/p # fy - focal length y
self._intrinsics[0, 2] = (res[1] - 1) / 2 # cx - optical center x
self._intrinsics[1, 2] = (res[0] - 1) / 2 # cy - optical center y
return self._intrinsics
@property
def raydirs_and_cam(self) -> tuple[torch.Tensor, torch.Tensor]: # (w, h, 3) and (3) dtype=float32
if getattr(self, "_raydirs_and_cam_cam2world", None) is not self.cam2world \
or getattr(self, "_raydirs_and_cam_intrinsics", None) is not self.intrinsics \
or getattr(self, "_raydirs_and_cam_uvs", None) is not self.uvs:
self._raydirs_and_cam_cam2world = cam2world = self.cam2world
self._raydirs_and_cam_intrinsics = intrinsics = self.intrinsics
self._raydirs_and_cam_uvs = uvs = self.uvs
#cam_pos = (cam2world @ torch.tensor([0, 0, 0, 1], dtype=torch.float32))[:3]
cam_pos = cam2world[:3, -1]
dirs = -geometry.get_ray_directions(uvs, cam2world[None, ...], intrinsics[None, ...]).squeeze(-1)
self._raydirs_and_cam = (dirs, cam_pos)
return (
self._raydirs_and_cam[0],
self._raydirs_and_cam[1],
)
def run(self):
self.is_headless = False
pygame.display.init() # we do not use the mixer, which often hangs on quit
try:
window = pygame.display.set_mode(self.scaled_res, flags=pygame.RESIZABLE)
buffer = pygame.surface.Surface(self.res)
window.fill(self.fill_color)
buffer.fill(self.fill_color)
pygame.display.flip()
pixel_view = pygame.surfarray.pixels3d(buffer) # (W, H, 3)
current_scale = self.scale
def remake_window_buffer(window_size: IVec2):
nonlocal buffer, pixel_view, current_scale
self.res = (
window_size[0] // self.scale,
window_size[1] // self.scale,
)
buffer = pygame.surface.Surface(self.res)
pixel_view = pygame.surfarray.pixels3d(buffer)
current_scale = self.scale
print()
self.setup()
is_running = True
clock = pygame.time.Clock()
epoch = t_prev = time.time()
self.frame_idx = -1
while is_running:
self.frame_idx += 1
if not self.fps_cap is None: clock.tick(self.fps_cap)
t = time.time()
self.td = t - t_prev
t_prev = t
self.t = t - epoch
print("\rFPS:", 1/self.td, " "*10, end="")
self.render_frame(pixel_view)
pygame.transform.scale(buffer, window.get_size(), window)
pygame.display.flip()
keys_pressed = pygame.key.get_pressed()
self.handle_keys_pressed(keys_pressed)
for event in pygame.event.get():
if event.type == pygame.VIDEORESIZE:
print()
print("== resize window ==")
remake_window_buffer(event.size)
elif event.type == pygame.QUIT:
is_running = False
elif event.type == pygame.KEYUP:
self.handle_key_up(event.key, keys_pressed)
elif event.type == pygame.KEYDOWN:
self.handle_key_down(event.key, keys_pressed)
if event.key == pygame.K_q:
is_running = False
elif event.key == pygame.K_y:
fname = self.mk_dump_fname("png")
fname.parent.mkdir(parents=True, exist_ok=True)
pygame.image.save(buffer.copy(), fname)
print()
print("Saved", fname)
elif event.type == pygame.MOUSEBUTTONUP:
self.handle_mouse_button_up(event.pos, event.button, keys_pressed)
elif event.type == pygame.MOUSEBUTTONDOWN:
self.handle_mouse_button_down(event.pos, event.button, keys_pressed)
elif event.type == pygame.MOUSEMOTION:
self.handle_mouse_motion(event.pos, event.rel, event.buttons, keys_pressed)
elif event.type == pygame.MOUSEWHEEL:
self.handle_mousewheel(event.flipped, event.x, event.y, keys_pressed)
if current_scale != self.scale:
remake_window_buffer(window.get_size())
finally:
self.teardown()
print()
pygame.quit()
def render_headless(self, output_path: str, *, n_frames: int, fps: int, state_callback: Callable[["InteractiveViewer", int], None] | None, resolution=None, bitrate=None, **kw):
self.is_headless = True
self.fps = fps
buffer = pygame.surface.Surface(self.res if resolution is None else resolution)
pixel_view = pygame.surfarray.pixels3d(buffer) # (W, H, 3)
def do():
try:
self.setup()
for frame in tqdm(range(n_frames), **kw, disable=n_frames==1):
self.frame_idx = frame
if state_callback is not None:
state_callback(self, frame)
self.render_frame(pixel_view)
yield pixel_view.copy().swapaxes(0,1)
finally:
self.teardown()
output_path = Path(output_path)
if output_path.suffix == ".png":
if n_frames > 1 and "%" not in output_path.name: raise ValueError
output_path.parent.mkdir(parents=True, exist_ok=True)
for i, framebuffer in enumerate(do()):
with imageio.get_writer(output_path.parent / output_path.name.replace("%", f"{i:04}")) as writer:
writer.append_data(framebuffer)
else: # ffmpeg - https://imageio.readthedocs.io/en/v2.9.0/format_ffmpeg.html#ffmpeg
with imageio.get_writer(output_path, fps=fps, bitrate=bitrate) as writer:
for framebuffer in do():
writer.append_data(framebuffer)
def load_sphere_map(self, fname):
self._sphere_surf = pygame.image.load(fname)
self._sphere_map = pygame.surfarray.pixels3d(self._sphere_surf)
def lookup_sphere_map_dirs(self, dirs, origins):
near, far = geometry.ray_sphere_intersect(
torch.tensor(origins),
torch.tensor(dirs),
sphere_radii = torch.tensor(origins).norm(dim=-1) * 2,
)
hits = far.detach()
x = hits[..., 0]
y = hits[..., 1]
z = hits[..., 2]
theta = (z / hits.norm(dim=-1)).acos()
phi = (y/x).atan()
phi[(x<0) & (y>=0)] += 3.14
phi[(x<0) & (y< 0)] -= 3.14
w, h = self._sphere_map.shape[:2]
return self._sphere_map[
((phi / (2*torch.pi) * w).int() % w).cpu(),
((theta / (1*torch.pi) * h).int() % h).cpu(),
]
def blit_sphere_map_mask(self, pixel_view, mask=None):
dirs, origin = self.raydirs_and_cam
if mask is None: mask = (slice(None), slice(None))
pixel_view[mask] \
= self.lookup_sphere_map_dirs(dirs, origin[None, None, :])
def mk_dump_fname(self, suffix: str, uid=None) -> Path:
name = self.name.split("-")[-1] if len(self.name) > 160 else self.name
if uid is not None: name = f"{name}-{uid}"
return self.screenshot_dir / f"pygame-viewer-{datetime.now():%Y%m%d-%H%M%S}-{name}.{suffix}"

792
ifield/viewer/ray_field.py Normal file

@ -0,0 +1,792 @@
from ..data.common.scan import SingleViewUVScan
import mesh_to_sdf.scan as sdf_scan
from ..models import intersection_fields
from ..utils import geometry, helpers
from ..utils.operators import diff
from .common import InteractiveViewer
from matplotlib import cm
import matplotlib.colors as mcolors
from concurrent.futures import ThreadPoolExecutor
from textwrap import dedent
from typing import Hashable, Optional, Callable
from munch import Munch
import functools
import itertools
import numpy as np
import random
from pathlib import Path
import shutil
import subprocess
import torch
from trimesh import Trimesh
import trimesh.transformations as T
class ModelViewer(InteractiveViewer):
lambertian_color = (1.0, 1.0, 1.0)
max_cols = 200
max_cols = 32
def __init__(self,
model : intersection_fields.IntersectionFieldAutoDecoderModel,
start_uid : Hashable,
skyward : str = "+Z",
mesh_gt_getter: Callable[[Hashable], Trimesh] | None = None,
*a, **kw):
self.model = model
self.model.eval()
self.current_uid = self._prev_uid = start_uid
self.all_uids = list(model.keys())
self.mesh_gt_getter = mesh_gt_getter
self.current_gt_mesh: tuple[Hashable, Trimesh] = (None, None)
self.display_mode_normals = self.vizmodes_normals .index("medial" if self.model.hparams.output_mode == "medial_sphere" else "analytical")
self.display_mode_shading = self.vizmodes_shading .index("lambertian")
self.display_mode_centroid = self.vizmodes_centroids.index("best-centroids-colored")
self.display_mode_spheres = self.vizmodes_spheres .index(None)
self.display_mode_variation = 0
self.display_sphere_map_bg = True
self.atom_radius_offset = 0
self.atom_index_solo = None
self.export_medial_surface_mesh = False
self.light_angle1 = 0
self.light_angle2 = 0
self.obj_rot = {
"-X": torch.tensor(T.rotation_matrix(angle= np.pi/2, direction=(0, 1, 0))[:3, :3], **model.device_and_dtype).T,
"+X": torch.tensor(T.rotation_matrix(angle=-np.pi/2, direction=(0, 1, 0))[:3, :3], **model.device_and_dtype).T,
"-Y": torch.tensor(T.rotation_matrix(angle= np.pi/2, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
"+Y": torch.tensor(T.rotation_matrix(angle=-np.pi/2, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
"-Z": torch.tensor(T.rotation_matrix(angle= np.pi, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
"+Z": torch.eye(3, **model.device_and_dtype),
}[str(skyward).upper()]
self.obj_rot_inv = torch.linalg.inv(self.obj_rot)
super().__init__(*a, **kw)
vizmodes_normals = (
"medial",
"analytical",
"ground_truth",
)
vizmodes_shading = (
None, # just atoms or medial axis
"colored-lambertian",
"lambertian",
"shade-best-radii",
"shade-all-radii",
"translucent",
"normal",
"centroid-grad-norm", # backprop
"anisotropic", # backprop
"curvature", # backprop
"glass",
"double-glass",
)
vizmodes_centroids = (
None,
"best-centroids",
"all-centroids",
"best-centroids-colored",
"all-centroids-colored",
"miss-centroids-colored",
"all-miss-centroids-colored",
)
vizmodes_spheres = (
None,
"intersecting-sphere",
"intersecting-sphere-colored",
"best-sphere",
"best-sphere-colored",
"all-spheres-colored",
)
def get_display_mode(self) -> tuple[str, str, Optional[str], Optional[str]]:
MARF = self.model.hparams.output_mode == "medial_sphere"
if isinstance(self.display_mode_normals, str): self.display_mode_normals = self.vizmodes_shading .index(self.display_mode_normals)
if isinstance(self.display_mode_shading, str): self.display_mode_shading = self.vizmodes_shading .index(self.display_mode_shading)
if isinstance(self.display_mode_centroid, str): self.display_mode_centroid = self.vizmodes_centroids.index(self.display_mode_centroid)
if isinstance(self.display_mode_spheres, str): self.display_mode_spheres = self.vizmodes_spheres .index(self.display_mode_spheres)
out = (
self.vizmodes_normals [self.display_mode_normals % len(self.vizmodes_normals)],
self.vizmodes_shading [self.display_mode_shading % len(self.vizmodes_shading)],
self.vizmodes_centroids[self.display_mode_centroid % len(self.vizmodes_centroids)] if MARF else None,
self.vizmodes_spheres [self.display_mode_spheres % len(self.vizmodes_spheres)] if MARF else None,
)
self.set_caption(" & ".join(i for i in out if i is not None))
return out
@property
def cam_state(self):
return super().cam_state | {
"light_angle1" : self.light_angle1,
"light_angle2" : self.light_angle2,
}
@cam_state.setter
def cam_state(self, new_state):
InteractiveViewer.cam_state.fset(self, new_state)
self.light_angle1 = new_state.get("light_angle1", self.light_angle1)
self.light_angle2 = new_state.get("light_angle2", self.light_angle2)
def get_current_conditioning(self) -> Optional[torch.Tensor]:
if not self.model.is_conditioned:
return None
prev_uid = self._prev_uid # to determine if target has changed
next_z = self.model[prev_uid].detach() # interpolation target
prev_z = getattr(self, "_prev_z", next_z) # interpolation source
epoch = getattr(self, "_prev_epoch", 0) # interpolation factor
if not self.is_headless:
now = self.t
t = (now - epoch) / 1 # 1 second
else:
now = self.frame_idx
t = (now - epoch) / self.fps # 1 second
assert t >= 0
if t < 1:
next_z = next_z*t + prev_z*(1-t)
if prev_uid != self.current_uid:
self._prev_uid = self.current_uid
self._prev_z = next_z
self._prev_epoch = now
return next_z
def get_current_ground_truth(self) -> Trimesh | None:
if self.mesh_gt_getter is None:
return None
uid, mesh = self.current_gt_mesh
try:
if uid != self.current_uid:
print("Loading ground truth mesh...")
mesh = self.mesh_gt_getter(self.current_uid)
self.current_gt_mesh = self.current_uid, mesh
except NotImplementedError:
self.current_gt_mesh = self.current_uid, None
return None
return mesh
def handle_keys_pressed(self, pressed):
td = super().handle_keys_pressed(pressed)
mod = pressed[self.constants.K_LALT] or pressed[self.constants.K_RALT]
if not mod and pressed[self.constants.K_f]: self.light_angle1 -= td * 0.5
if not mod and pressed[self.constants.K_g]: self.light_angle1 += td * 0.5
if mod and pressed[self.constants.K_f]: self.light_angle2 += td * 0.5
if mod and pressed[self.constants.K_g]: self.light_angle2 -= td * 0.5
return td
def handle_key_down(self, key, keys_pressed):
super().handle_key_down(key, keys_pressed)
shift = keys_pressed[self.constants.K_LSHIFT] or keys_pressed[self.constants.K_RSHIFT]
if key == self.constants.K_o:
i = self.all_uids.index(self.current_uid)
i = (i - 1) % len(self.all_uids)
self.current_uid = self.all_uids[i]
print(self.current_uid)
if key == self.constants.K_p:
i = self.all_uids.index(self.current_uid)
i = (i + 1) % len(self.all_uids)
self.current_uid = self.all_uids[i]
print(self.current_uid)
if key == self.constants.K_SPACE:
self.display_sphere_map_bg = {
True : 255,
255 : 0,
0 : True,
}.get(self.display_sphere_map_bg, True)
if key == self.constants.K_u: self.export_medial_surface_mesh = True
if key == self.constants.K_x: self.display_mode_normals += -1 if shift else 1
if key == self.constants.K_c: self.display_mode_shading += -1 if shift else 1
if key == self.constants.K_v: self.display_mode_centroid += -1 if shift else 1
if key == self.constants.K_b: self.display_mode_spheres += -1 if shift else 1
if key == self.constants.K_e: self.display_mode_variation+= -1 if shift else 1
if key == self.constants.K_c: self.display_mode_variation = 0
if key == self.constants.K_0: self.atom_index_solo = None
if key == self.constants.K_1: self.atom_index_solo = 0 if self.atom_index_solo != 0 else None
if key == self.constants.K_2: self.atom_index_solo = 1 if self.atom_index_solo != 1 else None
if key == self.constants.K_3: self.atom_index_solo = 2 if self.atom_index_solo != 2 else None
if key == self.constants.K_4: self.atom_index_solo = 3 if self.atom_index_solo != 3 else None
if key == self.constants.K_5: self.atom_index_solo = 4 if self.atom_index_solo != 4 else None
if key == self.constants.K_6: self.atom_index_solo = 5 if self.atom_index_solo != 5 else None
if key == self.constants.K_7: self.atom_index_solo = 6 if self.atom_index_solo != 6 else None
if key == self.constants.K_8: self.atom_index_solo = 7 if self.atom_index_solo != 7 else None
if key == self.constants.K_9: self.atom_index_solo = self.atom_index_solo + (-1 if shift else 1) if self.atom_index_solo is not None else 0
def handle_mouse_button_down(self, pos, button, keys_pressed):
super().handle_mouse_button_down(pos, button, keys_pressed)
if button in (1, 3):
self.display_mode_spheres += 1 if button == 1 else -1
def handle_mousewheel(self, flipped, x, y, keys_pressed):
shift = keys_pressed[self.constants.K_LSHIFT] or keys_pressed[self.constants.K_RSHIFT]
if not shift:
super().handle_mousewheel(flipped, x, y, keys_pressed)
else:
self.atom_radius_offset += 0.005 * y
print()
print("atom_radius_offset:", self.atom_radius_offset)
def setup(self):
if not self.is_headless:
print(dedent("""
WASD + PG Up/Down - translate
ARROWS - rotate
(SHIFT+) C - Next/(Prev) shading mode
(SHIFT+) V - Next/(Prev) centroids mode
(SHIFT+) B - Next/(Prev) sphere mode
Mouse L/ R - Next/ Prev sphere mode
(SHIFT+) E - Next/(Prev) variation (for quick experimentation within a shading mode)
SHIFT + Scroll - Offset atom radius
ALT + Scroll - Modify FoV (_true_ zoom)
Mouse Scroll - Translate in/out ("zoom", moves camera to/from to point of focus)
Alt+PG Up/Down - Translate in/out ("zoom", moves camera to/from to point of focus)
F / G - rotate light left / right
ALT+ F / G - rotate light up / down
CTRL / SHIFT - faster/slower rotation
O / P - prev/next object
1-9 - solo atom
0 - show all atoms
+ / - - decrease/increase pixel scale
R - rotate continuously
H / SHIFT+H / CTRL+H - load/save/print camera state
Enter - reset camera state
Y - save screenshot
U - save mesh of centroids
Space - cycle sphere map background
Q - quit
""").strip())
fname = Path(__file__).parent.resolve() / "assets/texturify_pano-1-4.jpg"
self.load_sphere_map(fname)
if self.model.hparams.output_mode == "medial_sphere":
@self.model.net.register_forward_hook
def atom_offset_radius_and_solo(model, input, output):
slice = (..., [i+3 for i in range(0, output.shape[-1], 4)])
output[slice] += self.atom_radius_offset * output[slice].sign()
if self.atom_index_solo is not None:
x = self.atom_index_solo * 4
x = x % output.shape[-1]
output = output[..., list(range(x, x+4))]
return output
self._atom_offset_radius_and_solo_hook = atom_offset_radius_and_solo
def teardown(self):
if hasattr(self, "_atom_offset_radius_and_solo_hook"):
self._atom_offset_radius_and_solo_hook.remove()
del self._atom_offset_radius_and_solo_hook
@torch.no_grad()
def render_frame(self, pixel_view: np.ndarray): # (W, H, 3) dtype=uint8
MARF = self.model.hparams.output_mode == "medial_sphere"
PRIF = self.model.hparams.output_mode == "orthogonal_plane"
assert (MARF or PRIF) and MARF != PRIF
device_and_dtype = self.model.device_and_dtype
device = self.model.device
dtype = self.model.dtype
(
vizmode_normals,
vizmode_shading,
vizmode_centroids,
vizmode_spheres,
) = self.get_display_mode()
dirs, origins = self.raydirs_and_cam
origins = origins.detach().clone().to(**device_and_dtype)
dirs = dirs .detach().clone().to(**device_and_dtype)
if vizmode_normals != "ground_truth" or self.get_current_ground_truth() is None:
# enable grad or not
do_jac = PRIF or vizmode_normals == "analytical"
do_jac_medial = MARF and "centroid-grad-norm" in (vizmode_shading or "")
do_shape_operator = "anisotropic" in (vizmode_shading or "") or "curvature" in (vizmode_shading or "")
do_grad = do_jac or do_jac_medial or do_shape_operator
if do_grad:
origins = origins.broadcast_to(dirs.shape)
self.model.eval()
latent = self.get_current_conditioning()
if self.max_cols is None or self.max_cols > dirs.shape[0]:
chunks = [slice(None)]
else:
chunks = [slice(col, col+self.max_cols) for col in range(0, dirs.shape[0], self.max_cols)]
forward_chunks = []
for chunk in chunks:
self.model.zero_grad()
origins_chunk = origins[chunk if origins.ndim != 1 else slice(None)] @ self.obj_rot
dirs_chunk = dirs [chunk] @ self.obj_rot
if do_grad:
origins_chunk.requires_grad = dirs_chunk.requires_grad = True
@forward_chunks.append
@(lambda f: f(origins_chunk, dirs_chunk))
@torch.set_grad_enabled(do_grad)
def forward_chunk(origins, dirs) -> Munch:
if PRIF:
intersections, is_intersecting = self.model(dict(origins=origins, dirs=dirs), z=latent, normalize_origins=True)
is_intersecting = is_intersecting > 0.5
elif MARF:
(
depths, silhouettes, intersections,
intersection_normals, is_intersecting,
sphere_centers, sphere_radii,
atom_indices,
all_intersections, all_intersection_normals, all_depths, all_silhouettes, all_is_intersecting,
all_sphere_centers, all_sphere_radii,
) = self.model.forward(dict(origins=origins, dirs=dirs), z=latent,
intersections_only = False,
return_all_atoms = True,
)
if do_jac:
jac = diff.jacobian(intersections, origins, detach=not do_shape_operator)
intersection_normals = self.model.compute_normals_from_intersection_origin_jacobian(jac, dirs.detach())
if do_jac_medial:
sphere_centers_jac = diff.jacobian(sphere_centers, origins, detach=True)
if do_shape_operator:
hess = diff.jacobian(intersection_normals, origins, detach=True)[is_intersecting, :, :]
N = intersection_normals.detach()[is_intersecting, :]
TM = (torch.eye(3, device=device) - N[..., None, :]*N[..., :, None]) # projection onto tangent plane
# shape operator, i.e. total derivative of the surface normal w.r.t. the tangent space
shape_operator = hess @ TM
return Munch((k, v.detach()) for k, v in locals().items() if isinstance(v, torch.Tensor))
intersections = torch.cat([chunk.intersections for chunk in forward_chunks], dim=0)
is_intersecting = torch.cat([chunk.is_intersecting for chunk in forward_chunks], dim=0)
intersection_normals = torch.cat([chunk.intersection_normals for chunk in forward_chunks], dim=0)
if MARF:
all_sphere_centers = torch.cat([chunk.all_sphere_centers for chunk in forward_chunks], dim=0)
all_sphere_radii = torch.cat([chunk.all_sphere_radii for chunk in forward_chunks], dim=0)
atom_indices = torch.cat([chunk.atom_indices for chunk in forward_chunks], dim=0)
silhouettes = torch.cat([chunk.silhouettes for chunk in forward_chunks], dim=0)
sphere_centers = torch.cat([chunk.sphere_centers for chunk in forward_chunks], dim=0)
sphere_radii = torch.cat([chunk.sphere_radii for chunk in forward_chunks], dim=0)
if do_jac_medial:
sphere_centers_jac = torch.cat([chunk.sphere_centers_jac for chunk in forward_chunks], dim=0)
if do_shape_operator:
shape_operator = torch.cat([chunk.shape_operator for chunk in forward_chunks], dim=0)
n_atoms = all_sphere_centers.shape[-2] if MARF else 1
intersections = intersections @ self.obj_rot_inv
intersection_normals = intersection_normals @ self.obj_rot_inv
sphere_centers = sphere_centers @ self.obj_rot_inv if sphere_centers is not None else None
all_sphere_centers = all_sphere_centers @ self.obj_rot_inv if all_sphere_centers is not None else None
else: # render ground truth mesh
# HACK: we use a thread to not break the pygame opengl context
with ThreadPoolExecutor(max_workers=1) as p:
scan = p.submit(sdf_scan.Scan, self.get_current_ground_truth(),
camera_transform = self.cam2world.numpy(),
resolution = self.res[1],
calculate_normals = True,
fov = self.cam_fov_y,
z_near = 0.001,
z_far = 50,
no_flip_backfaced_normals = True
).result()
n_atoms, MARF, PRIF = 1, False, True
is_intersecting = torch.zeros(self.res, dtype=bool)
is_intersecting[ (self.res[0]-self.res[1]) // 2 : (self.res[0]-self.res[1]) // 2 + self.res[1], : ] = torch.tensor(scan.depth_buffer != 0, dtype=bool)
intersections = torch.zeros((*is_intersecting.shape, 3), dtype=dtype)
intersection_normals = torch.zeros((*is_intersecting.shape, 3), dtype=dtype)
intersections [is_intersecting] = torch.tensor(scan.points, dtype=dtype)
intersection_normals[is_intersecting] = torch.tensor(scan.normals, dtype=dtype)
is_intersecting = is_intersecting .flip(1).to(device)
intersections = intersections .flip(1).to(device)
intersection_normals = intersection_normals.flip(1).to(device)
mask = is_intersecting.cpu()
mx, my = self.mouse_position
w, h = dirs.shape[:2]
# fill white
if self.display_sphere_map_bg == True:
self.blit_sphere_map_mask(pixel_view)
else:
pixel_view[:] = self.display_sphere_map_bg
# draw to buffer
to_cam = -dirs.detach()
# light direction
extra = np.pi if vizmode_shading == "translucent" else 0
LM = torch.tensor(T.rotation_matrix(angle=self.light_angle2, direction=(0, 1, 0))[:3, :3], dtype=dtype)
LM = torch.tensor(T.rotation_matrix(angle=self.light_angle1 + extra, direction=(1, 0, 0))[:3, :3], dtype=dtype) @ LM
to_light = (self.cam2world[:3, :3] @ LM @ torch.tensor((1, 1, 3), dtype=dtype)).to(device)[None, :]
to_light = to_light / to_light.norm(dim=-1, keepdim=True)
# used to color different atom candidates
color_set = tuple(map(helpers.hex2tuple,
itertools.chain(
mcolors.TABLEAU_COLORS.values(),
#list(mcolors.TABLEAU_COLORS.values())[::-1],
#['#f8481c', '#c20078', '#35530a', '#010844', '#a8ff04'],
mcolors.XKCD_COLORS.values(),
)
))
color_per_atom = (*zip(*zip(range(n_atoms), itertools.cycle(color_set))),)[1]
# shade hits
if vizmode_shading is None:
pass
elif vizmode_shading == "colored-lambertian":
if n_atoms > 1:
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[is_intersecting].T,)]
else:
color = torch.tensor(color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)], device=device)
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None]
pixel_view[mask, :] = (color *
torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_cam[is_intersecting, :],
)[..., None]).int().cpu()
pixel_view[mask, :] = (
255 * lambertian.clamp(0, 1).pow(32) +
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.clamp(0, 1).pow(32))
).cpu()
elif vizmode_shading == "lambertian":
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None].clamp(0, 1)
if self.lambertian_color == (1.0, 1.0, 1.0):
pixel_view[mask, :] = (255 * lambertian).cpu()
else:
color = 255*torch.tensor(self.lambertian_color, device=device)
pixel_view[mask, :] = (color *
torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_cam[is_intersecting, :],
)[..., None]).int().cpu()
pixel_view[mask, :] = (
255 * lambertian.clamp(0, 1).pow(32) +
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.clamp(0, 1).pow(32))
).cpu()
elif vizmode_shading == "translucent" and MARF:
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None].abs().clamp(0, 1)
distortion = 0.08
power = 16
ambient = 0
thickness = sphere_radii[is_intersecting].detach()
if self.display_mode_variation % 2:
thickness = thickness.mean()
color1 = torch.tensor((1, 0.5, 0.5), **device_and_dtype) # subsurface
color2 = torch.tensor((0, 1, 1), **device_and_dtype) # diffuse
l = to_light + intersection_normals[is_intersecting, :] * distortion
d = (to_cam[is_intersecting, :] * -l).sum(dim=-1).clamp(0, None).pow(power)
f = (d + ambient) * (1/(0.05 + thickness))
pixel_view[((dirs * to_light).sum(dim=-1) > 0.99).cpu(), :] = 255 # draw light source
pixel_view[mask, :] = (255 * (
color2 * (0.05 + lambertian*0.15) +
color1 * 0.3 * f[..., None]
).clamp(0, 1)).cpu()
elif vizmode_shading == "anisotropic" and vizmode_normals != "ground_truth":
eigvals, eigvecs = torch.linalg.eig(shape_operator.mT) # slow, complex output, not sorted
eigvals, indices = eigvals.abs().sort(dim=-1)
eigvecs = (eigvecs.abs() * eigvecs.real.sign()).take_along_dim(indices[..., None, :], dim=-1)
eigvecs = eigvecs.mT
s = self.display_mode_variation % 5
if s in (0, 1):
# try to keep these below 0.2:
if s == 0: a1, a2 = 0.05, 0.3
if s == 1: a1, a2 = 0.3, 0.05
# == Ward anisotropic specular reflectance ==
# G.J. Ward, Measuring and modeling anisotropic reflection, in:
# Proceedings of the 19th Annual Conference on Computer Graphics and
# Interactive Techniques, 1992: pp. 265272.
eigvecs /= eigvecs.norm(dim=-1, keepdim=True)
N = intersection_normals[is_intersecting, :]
H = to_cam[is_intersecting, :] + to_light
H = H / H.norm(dim=-1, keepdim=True)
specular = (1/(4*torch.pi * a1*a2 * torch.sqrt((
(N * to_cam[is_intersecting, :]).sum(dim=-1) *
(N * to_light ).sum(dim=-1)
)))) * torch.exp(
-2 * (
((H * eigvecs[..., 2, :]).sum(dim=-1) / a1).pow(2)
+
((H * eigvecs[..., 1, :]).sum(dim=-1) / a2).pow(2)
) / (
1 + (N * H).sum(dim=-1)
)
)
specular = specular.clamp(0, None).nan_to_num(0, 0, 0)
lambertian = torch.einsum("id,id->i", N, to_light ).clamp(0, None)
color1 = 0.4 * torch.tensor((1, 1, 1), **device_and_dtype) # specular
color2 = 0.4 * torch.tensor((0, 1, 1), **device_and_dtype) # diffuse
pixel_view[mask, :] = (255 * (
color1 * specular [..., None] +
color2 * lambertian[..., None]
).clamp(0, 1)).int().cpu()
if s == 2:
pixel_view[mask, :] = (255 * (
eigvecs[..., 2, :].abs().clamp(0, 1) # orientation only
)).int().cpu()
elif s == 3:
pixel_view[mask, :] = (255 * (
eigvecs[..., 1, :].abs().clamp(0, 1) # orientation only
)).int().cpu()
elif s == 4:
pixel_view[mask, :] = (255 * (
eigvecs[..., 0, :].abs().clamp(0, 1) # orientation only
)).int().cpu()
elif vizmode_shading == "shade-best-radii" and MARF:
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None]
radii = sphere_radii[is_intersecting]
radii = radii - 0.04
radii = radii / 0.4
colors = cm.plasma(radii.clamp(0, 1).cpu())[..., :3]
pixel_view[mask, :] = 255 * (
lambertian.pow(32).clamp(0, 1).cpu().numpy() +
colors * (lambertian + 0.25).clamp(0, 1).cpu().numpy() * (1-lambertian.pow(32).clamp(0, 1)).cpu().numpy()
)
elif vizmode_shading == "shade-all-radii" and MARF:
radii = sphere_radii[is_intersecting][..., None]
radii /= radii.max()
if n_atoms > 1:
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[is_intersecting].T,)]
else:
color = torch.tensor(color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)], device=device)
pixel_view[mask, :] = (color * radii).int().cpu()
elif vizmode_shading == "normal":
normal = intersection_normals[is_intersecting, :]
pixel_view[mask, :] = (255 * (normal * 0.5 + 0.5) ).int().cpu()
elif vizmode_shading == "curvature" and vizmode_normals != "ground_truth":
eigvals = torch.linalg.eigvals(shape_operator.mT) # complex output, not sorted
# we sort them by absolute magnitude, not the real component
_, indices = (eigvals.abs() * eigvals.real.sign()).sort(dim=-1)
eigvals = eigvals.real.take_along_dim(indices, dim=-1)
s = self.display_mode_variation % (6 if MARF else 5)
if s==0: out = (eigvals[..., [0, 2]].mean(dim=-1, keepdim=True) / 25).tanh() # mean curvature
if s==1: out = (eigvals[..., [0, 2]].prod(dim=-1, keepdim=True) / 25).tanh() # gaussian curvature
if s==2: out = (eigvals[..., [2]] / 25).tanh() # maximum principal curvature - k1
if s==3: out = (eigvals[..., [1]] / 25).tanh() # some curvature
if s==4: out = (eigvals[..., [0]] / 25).tanh() # minimum principal curvature - k2
if s==5: out = ((sphere_radii[is_intersecting][..., None].detach() - 1 / eigvals[..., [2]].clamp(1e-8, None)) * 5).tanh().clamp(0, None)
lambertian = torch.einsum("id,id->i",
intersection_normals[is_intersecting, :],
to_light,
)[..., None]
pixel_view[mask, :] = (255 * (lambertian+0.5).clamp(0, 1) * torch.cat((
1+out.clamp(-1, 0),
1-out.abs(),
1-out.clamp(0, 1),
), dim=-1)).int().cpu()
elif vizmode_shading == "centroid-grad-norm" and MARF:
asd = sphere_centers_jac[is_intersecting, :, :].norm(dim=-2).mean(dim=-1, keepdim=True)
asd -= asd.min()
asd /= asd.max()
pixel_view[mask, :] = (255 * asd).cpu()
elif "glass" in vizmode_shading:
normals = intersection_normals[is_intersecting, :]
to_cam_ = to_cam [is_intersecting, :]
# "Empiricial Approximation" of fresnel
# https://developer.download.nvidia.com/CgTutorial/cg_tutorial_chapter07.html via
# http://kylehalladay.com/blog/tutorial/2014/02/18/Fresnel-Shaders-From-The-Ground-Up.html
cos = torch.einsum("id,id->i", normals, to_cam_ )[..., None]
bias, scale, power = 0, 4, 3
fresnel = (bias + scale*(1-cos)**power).clamp(0, 1)
#reflection
reflection = -to_cam_ - 2*(-cos)*normals
#refraction
r = 1 / 1.5 # refractive index, air -> glass
refraction = -r*to_cam_ + (r*cos - (1-r**2*(1-cos**2)).sqrt()) * normals
exit_point = intersections[is_intersecting, :]
# reflect the refraction over the plane defined by the refraction direction and the sphere center, resulting in the second refraction
if vizmode_shading == "double-glass" and MARF:
cos2 = torch.einsum("id,id->i", refraction, -to_cam_ )[..., None]
pn = -to_cam_ - cos2*refraction
pn /= pn.norm(dim=-1, keepdim=True)
refraction = -to_cam_ - 2*torch.einsum("id,id->i", pn, -to_cam_ )[..., None]*pn
exit_point -= sphere_centers[is_intersecting, :]
exit_point = exit_point - 2*torch.einsum("id,id->i", pn, exit_point )[..., None]*pn
exit_point += sphere_centers[is_intersecting, :]
fresnel = np.asanyarray(fresnel.cpu())
pixel_view[mask, :] \
= self.lookup_sphere_map_dirs(reflection, intersections[is_intersecting, :]) * fresnel \
+ self.lookup_sphere_map_dirs(refraction, exit_point) * (1-fresnel)
else: # flat
pixel_view[mask, :] = 80
if not MARF: return
# overlay medial atoms
if vizmode_spheres is not None:
# show miss distance in red
s = silhouettes.detach()[~is_intersecting].clamp(0, 1)
s /= s.max()
pixel_view[~mask, 1] = (s * 255).cpu()
pixel_view[:, 2] = pixel_view[:, 1]
mouse_hits = 0 <= mx < w and 0 <= my < h and mask[mx, my]
draw_intersecting = "intersecting-sphere" in vizmode_spheres
draw_best = "best-sphere" in vizmode_spheres
draw_color = "-sphere-colored" in vizmode_spheres
draw_all = "all-spheres-colored" in vizmode_spheres
def get_nears():
if draw_all:
projected, near, far, is_intersecting = geometry.ray_sphere_intersect(
torch.tensor(origins),
torch.tensor(dirs[..., None, :]),
sphere_centers = all_sphere_centers[mx, my][None, None, ...],
sphere_radii = all_sphere_radii [mx, my][None, None, ...],
allow_nans = False,
return_parts = True,
)
depths = (near - origins).norm(dim=-1)
atom_indices_ = torch.where(is_intersecting, depths.detach(), depths.detach()+100).argmin(dim=-1, keepdim=True)
is_intersecting = is_intersecting.any(dim=-1)
projected = None
near = near.take_along_dim(atom_indices_[..., None], -2).squeeze(-2)
far = None
sphere_centers_ = all_sphere_centers[mx, my][None, None, ...].take_along_dim(atom_indices_[..., None], -2).squeeze(-2)
normals = near[is_intersecting, :] - sphere_centers_[is_intersecting, :]
normals /= torch.linalg.norm(normals, dim=-1)[..., None]
color = torch.tensor(color_per_atom, device=device)[(*atom_indices_[is_intersecting].T,)]
yield color, projected, near, far, is_intersecting, normals
if (mouse_hits and draw_intersecting) or draw_best:
projected, near, far, is_intersecting = geometry.ray_sphere_intersect(
torch.tensor(origins),
torch.tensor(dirs),
# unit-sphere by default
sphere_centers = sphere_centers[mx, my][None, None, ...],
sphere_radii = sphere_radii [mx, my][None, None, ...],
return_parts = True,
)
normals = near[is_intersecting, :] - sphere_centers[mx, my][None, ...]
normals /= torch.linalg.norm(normals, dim=-1)[..., None]
color = (255, 255, 255) if not draw_color else color_per_atom[atom_indices[mx, my]]
yield torch.tensor(color, device=device), projected, near, far, is_intersecting, normals
# draw sphere with lambertian shading
for color, projected, near, far, is_intersecting_2, normals in get_nears():
lambertian = torch.einsum("...id,...id->...i", normals, to_light )[..., None]
pixel_view[is_intersecting_2.cpu(), :] = (
255*lambertian.pow(32).clamp(0, 1) +
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.pow(32).clamp(0, 1))
).cpu()
# overlay points / sphere centers
if vizmode_centroids is not None:
cam2world_inv = torch.tensor(self.cam2world_inv, **device_and_dtype)
intrinsics = torch.tensor(self.intrinsics, **device_and_dtype)
def get_coords():
miss_centroid = "miss-centroids" in vizmode_centroids
mask = is_intersecting if not miss_centroid else ~is_intersecting
if vizmode_centroids in ("all-centroids-colored", "all-miss-centroids-colored"):
# we use temporal dithering to the show all overlapping centers
for color, atom_index in sorted(zip(itertools.chain(color_set), range(n_atoms)), key=lambda x: random.random()):
yield color, all_sphere_centers[..., atom_index, :][mask], mask
elif "all-centroids" in vizmode_centroids:
yield (80, 150, 80), all_sphere_centers[mask].reshape(-1, 3), mask # [:, 3]
if "centroids-colored" in vizmode_centroids:
if n_atoms == 1:
color = color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)]
else:
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[mask].T,)].cpu()
else:
color = (0, 0, 0)
yield color, sphere_centers[mask], mask
for i, (color, coords, coord_mask) in enumerate(get_coords()):
if self.export_medial_surface_mesh:
fname = self.mk_dump_fname("ply", uid=i)
p = torch.zeros_like(sphere_centers)
c = torch.zeros_like(sphere_centers)
p[coord_mask, :] = coords
c[coord_mask, :] = torch.tensor(color, device=p.device) / 255
SingleViewUVScan(
hits = ( mask).numpy(),
miss = (~mask).numpy(),
points = p.cpu().numpy(),
colors = c.cpu().numpy(),
normals=None, distances=None, cam_pos=None,
cam_mat4=None, proj_mat4=None, transforms=None,
).to_mesh().export(str(fname), file_type="ply")
print("dumped", fname)
if shutil.which("f3d"):
subprocess.Popen(["f3d", "-gsy", "--up=+z", "--bg-color=1,1,1", fname], close_fds=True)
coords = torch.cat((coords, torch.ones((*coords.shape[:-1], 1), **device_and_dtype)), dim=-1)
coords = torch.einsum("...ij,...kj->...ki", cam2world_inv, coords)[..., :3]
coords = geometry.project(coords[..., 0], coords[..., 1], coords[..., 2], intrinsics)
in_view = functools.reduce(torch.mul, (
coords[:, 0] < pixel_view.shape[1],
coords[:, 0] >= 0,
coords[:, 1] < pixel_view.shape[0],
coords[:, 1] >= 0,
)).cpu()
coords = coords[in_view, :]
if not isinstance(color, tuple):
color = color[in_view, :]
pixel_view[(*coords[..., [1, 0]].int().T.cpu(),)] = color
self.export_medial_surface_mesh = False

6369
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

80
pyproject.toml Normal file

@ -0,0 +1,80 @@
[tool.poetry]
name = "ifield"
version = "0.2.0"
description = ""
authors = ["Peder Bergebakken Sundt <pbsds@hotmail.com>"]
[build-system]
requires = ["poetry-core>=1.0.0", "setuptools>=60"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.dependencies]
python = ">=3.10,<3.11"
faiss-cpu = "^1.7.3"
geomloss = "0.2.4" # 0.2.5 has no bdist on pypi
h5py = "^3.7.0"
hdf5plugin = "^4.0.1"
imageio = "^2.23.0"
jinja2 = "^3.1.2"
matplotlib = "^3.6.2"
mesh-to-sdf = {git = "https://github.com/pbsds/mesh_to_sdf", rev = "no_flip_normals"}
methodtools = "^0.4.5"
more-itertools = "^9.1.0"
munch = "^2.5.0"
numpy = "^1.23.0"
pyembree = {url = "https://folk.ntnu.no/pederbs/pypy/pep503/pyembree/pyembree-0.2.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"}
pygame = "^2.1.2"
pykeops = "^2.1.1"
pytorch3d = [
{url = "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu116_pyt1130/pytorch3d-0.7.2-cp310-cp310-linux_x86_64.whl"},
]
pyqt5 = "^5.15.7"
pyrender = "^0.1.45"
pytorch-lightning = "^1.8.6"
pyyaml = "^6.0"
rich = "^13.3.2"
rtree = "^1.0.1"
scikit-image = "^0.19.3"
scikit-learn = "^1.2.0"
seaborn = "^0.12.1"
serve-me-once = "^0.1.2"
torch = "^1.13.0"
torchmeta = {git = "https://github.com/pbsds/pytorch-meta", rev = "upgrade"}
torchviz = "^0.0.2"
tqdm = "^4.64.1"
trimesh = "^3.17.1"
typer = "^0.7.0"
[tool.poetry.dev-dependencies]
python-lsp-server = {extras = ["all"], version = "^1.6.0"}
fix-my-functions = "^0.1.3"
imageio-ffmpeg = "^0.4.7"
jupyter = "^1.0.0"
jupyter-contrib-nbextensions = "^0.7.0"
jupyterlab = "^3.5.2"
jupyterthemes = "^0.20.0"
llvmlite = "^0.39.1" # only to make poetry install the python3.10 wheels instead of building them
nbconvert = "<=6.5.0" # https://github.com/jupyter/nbconvert/issues/1894
numba = "^0.56.4" # only to make poetry install the python3.10 wheels instead of building them
papermill = "^2.4.0"
pdoc = "^12.3.0"
pdoc3 = "^0.10.0"
ptpython = "^3.0.22"
pudb = "^2022.1.3"
remote-exec = {git = "https://github.com/pbsds/remote", rev = "whitespace-push"} # https://github.com/remote-cli/remote/pull/52
shapely = "^2.0.0"
sympy = "^1.11.1"
tensorboard = "^2.11.0"
visidata = "^2.11"
[tool.poetry.scripts]
show-schedule = 'ifield.utils.loss:main'
show-h5-items = 'ifield.cli_utils:show_h5_items'
show-h5-img = 'ifield.cli_utils:show_h5_img'
show-h5-scan-cloud = 'ifield.cli_utils:show_h5_scan_cloud'
show-model = 'ifield.cli_utils:show_model'
download-stanford = 'ifield.data.stanford.download:cli'
download-coseg = 'ifield.data.coseg.download:cli'
preprocess-stanford = 'ifield.data.stanford.preprocess:cli'
preprocess-coseg = 'ifield.data.coseg.preprocess:cli'