Add code
This commit is contained in:
.envrc.gitignore.localenv.localenv-bootstrap-conda.remoteenv.remoteignore.tomlREADME.mdablation.md
experiments
figures
ifield
__init__.pycli.pycli_utils.py
poetry.lockpyproject.tomldata
datasets
logging.pymodels
modules
param.pyutils
viewer
6
.envrc
Normal file
6
.envrc
Normal file
@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This file is automatically loaded with `direnv` if allowed.
|
||||
# It enters you into the venv.
|
||||
|
||||
source .localenv
|
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
__pycache__
|
||||
/data/models/
|
||||
/data/archives/
|
||||
/experiments/logdir/
|
||||
/.env/
|
||||
/.direnv/
|
||||
*.zip
|
||||
*.sh
|
||||
default.yaml # pandoc preview enhanced
|
71
.localenv
Normal file
71
.localenv
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# =======================
|
||||
# bootstrap a venv
|
||||
# =======================
|
||||
|
||||
LOCAL_ENV_NAME="py310-$(basename $(pwd))"
|
||||
LOCAL_ENV_DIR="$(pwd)/.env/$LOCAL_ENV_NAME"
|
||||
mkdir -p "$LOCAL_ENV_DIR"
|
||||
|
||||
# make configs and caches a part of venv
|
||||
export POETRY_CACHE_DIR="$LOCAL_ENV_DIR/xdg/cache/poetry"
|
||||
export PIP_CACHE_DIR="$LOCAL_ENV_DIR/xdg/cache/pip"
|
||||
mkdir -p "$POETRY_CACHE_DIR" "$PIP_CACHE_DIR"
|
||||
|
||||
#export POETRY_VIRTUALENVS_IN_PROJECT=true # store venv in ./.venv/
|
||||
#export POETRY_VIRTUALENVS_CREATE=false # install globally
|
||||
export SETUPTOOLS_USE_DISTUTILS=stdlib # https://github.com/pre-commit/pre-commit/issues/2178#issuecomment-1002163763
|
||||
export IFIELD_PRETTY_TRACEBACK=1
|
||||
#export SHOW_LOCALS=1 # locals in tracebacks
|
||||
export PYTHON_KEYRING_BACKEND="keyring.backends.null.Keyring"
|
||||
|
||||
# ensure we have the correct python and poetry. Bootstrap via conda if missing
|
||||
if ! command -v python310 >/dev/null || ! command -v poetry >/dev/null; then
|
||||
source .localenv-bootstrap-conda
|
||||
|
||||
if command -v mamba >/dev/null; then
|
||||
CONDA=mamba
|
||||
elif command -v conda >/dev/null; then
|
||||
CONDA=conda
|
||||
else
|
||||
>&2 echo "ERROR: 'poetry' nor 'conda'/'mamba' could be found!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
function verbose {
|
||||
echo +"$(printf " %q" "$@")"
|
||||
"$@"
|
||||
}
|
||||
|
||||
if ! ($CONDA env list | grep -q "^$LOCAL_ENV_NAME "); then
|
||||
verbose $CONDA create --yes --name "$LOCAL_ENV_NAME" -c conda-forge \
|
||||
python==3.10.8 poetry==1.3.1 #python-lsp-server
|
||||
true
|
||||
fi
|
||||
|
||||
verbose conda activate "$LOCAL_ENV_NAME" || exit $?
|
||||
#verbose $CONDA activate "$LOCAL_ENV_NAME" || exit $?
|
||||
|
||||
unset -f verbose
|
||||
fi
|
||||
|
||||
|
||||
# enter poetry venv
|
||||
# source .envrc
|
||||
poetry run true # ensure venv exists
|
||||
#source "$(poetry env info -p)/bin/activate"
|
||||
export VIRTUAL_ENV=$(poetry env info --path)
|
||||
export POETRY_ACTIVE=1
|
||||
export PATH="$VIRTUAL_ENV/bin":"$PATH"
|
||||
# NOTE: poetry currently reuses and populates the conda venv.
|
||||
# See: https://github.com/python-poetry/poetry/issues/1724
|
||||
|
||||
|
||||
# ensure output dirs exist
|
||||
mkdir -p experiments/logdir
|
||||
|
||||
# first-time-setup poetry
|
||||
if ! command -v fix-my-functions >/dev/null; then
|
||||
poetry install
|
||||
fi
|
53
.localenv-bootstrap-conda
Normal file
53
.localenv-bootstrap-conda
Normal file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# =======================
|
||||
# bootstrap a conda venv
|
||||
# =======================
|
||||
|
||||
CONDA_ENV_DIR="${LOCAL_ENV_DIR:-$(pwd)/.conda310}"
|
||||
mkdir -p "$CONDA_ENV_DIR"
|
||||
#touch "$HOME/.Xauthority"
|
||||
|
||||
MINICONDA_PY310_URL="https://repo.anaconda.com/miniconda/Miniconda3-py310_22.11.1-1-Linux-x86_64.sh"
|
||||
MINICONDA_PY310_HASH="00938c3534750a0e4069499baf8f4e6dc1c2e471c86a59caa0dd03f4a9269db6"
|
||||
|
||||
# Check if conda is available
|
||||
if ! command -v conda >/dev/null; then
|
||||
export PATH="$CONDA_ENV_DIR/conda/bin:$PATH"
|
||||
fi
|
||||
|
||||
# Check again if conda is available, install miniconda if not
|
||||
if ! command -v conda >/dev/null; then
|
||||
(set -e #x
|
||||
function verbose {
|
||||
echo +"$(printf " %q" "$@")"
|
||||
"$@"
|
||||
}
|
||||
|
||||
if command -v curl >/dev/null; then
|
||||
verbose curl -sLo "$CONDA_ENV_DIR/miniconda_py310.sh" "$MINICONDA_PY310_URL"
|
||||
elif command -v wget >/dev/null; then
|
||||
verbose wget -O "$CONDA_ENV_DIR/miniconda_py310.sh" "$MINICONDA_PY310_URL"
|
||||
else
|
||||
echo "ERROR: unable to download miniconda!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
verbose test "$(sha256sum "$CONDA_ENV_DIR/miniconda_py310.sh")" = "$MINICONDA_PY310_HASH"
|
||||
verbose chmod +x "$CONDA_ENV_DIR/miniconda_py310.sh"
|
||||
|
||||
verbose "$CONDA_ENV_DIR/miniconda_py310.sh" -b -u -p "$CONDA_ENV_DIR/conda"
|
||||
verbose rm "$CONDA_ENV_DIR/miniconda_py310.sh"
|
||||
|
||||
eval "$(conda shell.bash hook)" # basically `conda init`, without modifying .bashrc
|
||||
verbose conda install --yes --name base mamba -c conda-forge
|
||||
|
||||
) || exit $?
|
||||
fi
|
||||
|
||||
unset CONDA_ENV_DIR
|
||||
unset MINICONDA_PY310_URL
|
||||
unset MINICONDA_PY310_HASH
|
||||
|
||||
# Enter conda environment
|
||||
eval "$(conda shell.bash hook)" # basically `conda init`, without modifying .bashrc
|
29
.remoteenv
Normal file
29
.remoteenv
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
# this file is used by remote-cli
|
||||
|
||||
# Assumes repo is put in a "remotes/name-hash" folder,
|
||||
# the default behaviour of remote-exec
|
||||
REMOTES_DIR="$(dirname $(pwd))"
|
||||
LOCAL_ENV_NAME="py310-$(basename $(pwd))"
|
||||
LOCAL_ENV_DIR="$REMOTES_DIR/envs/$REMOTE_ENV_NAME"
|
||||
|
||||
#export XDG_CACHE_HOME="$LOCAL_ENV_DIR/xdg/cache"
|
||||
#export XDG_DATA_HOME="$LOCAL_ENV_DIR/xdg/share"
|
||||
#export XDG_STATE_HOME="$LOCAL_ENV_DIR/xdg/state"
|
||||
#mkdir -p "$XDG_CACHE_HOME" "$XDG_DATA_HOME" "$XDG_STATE_HOME"
|
||||
export XDG_CONFIG_HOME="$LOCAL_ENV_DIR/xdg/config"
|
||||
mkdir -p "$XDG_CONFIG_HOME"
|
||||
|
||||
|
||||
export PYOPENGL_PLATFORM=egl # makes pyrender work headless
|
||||
#export PYOPENGL_PLATFORM=osmesa # makes pyrender work headless
|
||||
export SDL_VIDEODRIVER=dummy # pygame
|
||||
|
||||
source .localenv
|
||||
|
||||
# SLURM logs output dir
|
||||
if command -v sbatch >/dev/null; then
|
||||
mkdir -p slurm_logs
|
||||
test -L experiments/logdir/slurm_logs ||
|
||||
ln -s ../../slurm_logs experiments/logdir/
|
||||
fi
|
30
.remoteignore.toml
Normal file
30
.remoteignore.toml
Normal file
@ -0,0 +1,30 @@
|
||||
[push]
|
||||
exclude = [
|
||||
"*.egg-info",
|
||||
"*.pyc",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".remote.toml",
|
||||
".remoteignore.toml",
|
||||
".venv",
|
||||
".wandb",
|
||||
"__pycache__",
|
||||
"data/models",
|
||||
"docs",
|
||||
"experiments/logdir",
|
||||
"poetry.toml",
|
||||
"slurm_logs",
|
||||
"tmp",
|
||||
]
|
||||
include = []
|
||||
|
||||
[pull]
|
||||
exclude = [
|
||||
"*",
|
||||
]
|
||||
include = []
|
||||
|
||||
[both]
|
||||
exclude = [
|
||||
]
|
||||
include = []
|
128
README.md
128
README.md
@ -1 +1,127 @@
|
||||
This is where the code for the paper _"MARF: The Medial Atom Ray Field Object Representation"_ will be published.
|
||||
# MARF: The Medial Atom Ray Field Object Representation
|
||||
|
||||
<center>
|
||||
|
||||

|
||||
|
||||
[Publication](https://doi.org/10.1016/j.cag.2023.06.032) | [Arxiv](https://arxiv.org/abs/2307.00037) | [Training data](https://mega.nz/file/9tsz3SbA#V6SIXpCFC4hbqWaFFvKmmS8BKir7rltXuhsqpEpE9wo) | [Network weights](https://mega.nz/file/t01AyTLK#7ZNMNgbqT9x2mhq5dxLuKeKyP7G0slfQX1RaZxifayw)
|
||||
|
||||
</center>
|
||||
|
||||
**TL;DR:** We achieve _fast_ surface rendering by predicting _n_ maximally inscribed spherical intersection candidates for each camera ray.
|
||||
|
||||
---
|
||||
|
||||
## Entering the Virtual Environment
|
||||
|
||||
The environment is defined in `pyproject.toml` using [Poetry](https://github.com/python-poetry/poetry) and reproducibly locked in `poetry.lock`.
|
||||
We propose three ways to enter the venv:
|
||||
|
||||
```shell
|
||||
# Requires Python 3.10 and Poetry
|
||||
poetry install
|
||||
poetry shell
|
||||
|
||||
# Will bootstrap a Miniconda 3.10 environment into .env/ if needed, then run poetry
|
||||
source .localenv
|
||||
```
|
||||
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Pretrained models
|
||||
|
||||
You can download our pre-trained models` from <https://mega.nz/file/t01AyTLK#7ZNMNgbqT9x2mhq5dxLuKeKyP7G0slfQX1RaZxifayw>.
|
||||
It should be unpacked into the root directory, such that the `experiment` folder gets merged.
|
||||
|
||||
### The interactive renderer
|
||||
|
||||
We automatically create experiment names with a schema of `{{model}}-{{experiment-name}}-{{hparams-summary}}-{{date}}-{{random-uid}}`.
|
||||
You can load experiment weights using either the full path, or just the `random-uid` bit.
|
||||
|
||||
From the `experiments` directory:
|
||||
|
||||
```shell
|
||||
./marf.py model {{experiment}} viewer
|
||||
```
|
||||
|
||||
If you have downloaded our pre-trained network weights, consider trying:
|
||||
|
||||
```shell
|
||||
./marf.py model nqzh viewer # Stanford Bunny (single-shape)
|
||||
./marf.py model wznx viewer # Stanford Buddha (single-shape)
|
||||
./marf.py model mxwd viewer # Stanford Armadillo (single-shape)
|
||||
./marf.py model camo viewer # Stanford Dragon (single-shape)
|
||||
./marf.py model ksul viewer # Stanford Lucy (single-shape)
|
||||
./marf.py model oxrf viewer # COSEG four-legged (multi-shape)
|
||||
```
|
||||
|
||||
## Training and Evaluation Data
|
||||
|
||||
You can download a pre-computed archive from <https://mega.nz/file/9tsz3SbA#V6SIXpCFC4hbqWaFFvKmmS8BKir7rltXuhsqpEpE9wo>.
|
||||
It should be extracted into the root directory such that a `data` directory is added to the root directory.
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
Optionally, you may compute the data yourself.
|
||||
</summary>
|
||||
|
||||
Single-shape training data:
|
||||
|
||||
```shell
|
||||
# takes takes about 23 minutes, mainly due to lucy
|
||||
download-stanford bunny happy_buddha dragon armadillo lucy
|
||||
preprocess-stanford bunny happy_buddha dragon armadillo lucy \
|
||||
--precompute-mesh-sv-scan-uv \
|
||||
--compute-miss-distances \
|
||||
--fill-missing-uv-points
|
||||
```
|
||||
|
||||
Multi-shape training data:
|
||||
|
||||
```shell
|
||||
# takes takes about 29 minutes
|
||||
download-coseg four-legged --shapes
|
||||
preprocess-coseg four-legged \
|
||||
--precompute-mesh-sv-scan-uv \
|
||||
--compute-miss-distances \
|
||||
--fill-missing-uv-points
|
||||
```
|
||||
|
||||
Evaluation data:
|
||||
|
||||
```shell
|
||||
# takes takes about 2 hour 20 minutes, mainly due to lucy
|
||||
preprocess-stanford bunny happy_buddha dragon armadillo lucy \
|
||||
--precompute-mesh-sphere-scan \
|
||||
--compute-miss-distances
|
||||
```
|
||||
|
||||
```shell
|
||||
# takes takes about 4 hours
|
||||
preprocess-coseg four-legged \
|
||||
--precompute-mesh-sphere-scan \
|
||||
--compute-miss-distances
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
Our experiments are defined using YAML config files, optionally templated using Jinja2 as a preprocessor.
|
||||
These templates accept additional input from the command line in the form of `-Okey=value` options.
|
||||
Our whole experiment matrix is defined in `marf.yaml.j12`. We select between different experiment groups using `-Omode={single,ablation,multi}`, and which experiment using `-Oselect={{integer}}`
|
||||
|
||||
From the `experiments` directory:
|
||||
|
||||
CPU mode:
|
||||
|
||||
```shell
|
||||
./marf.py model marf.yaml.j2 -Oexperiment_name=cpu_test -Omode=single -Oselect=0 fit
|
||||
```
|
||||
|
||||
GPU mode:
|
||||
|
||||
```shell
|
||||
./marf.py model marf.yaml.j2 -Oexperiment_name=cpu_test -Omode=single -Oselect=0 fit --accelerator gpu --devices 1
|
||||
```
|
||||
|
139
ablation.md
Normal file
139
ablation.md
Normal file
@ -0,0 +1,139 @@
|
||||
### MARF
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0010-nqzh`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0312-wznx`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-1944-mxwd`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0529-camo`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0743-ksul`
|
||||
|
||||
### LFN encoding
|
||||
- `experiment-stanfordv12-dragon-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0539-xjte`
|
||||
- `experiment-stanfordv12-lucy-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0753-ayvt`
|
||||
- `experiment-stanfordv12-bunny-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0022-axft`
|
||||
- `experiment-stanfordv12-happy_buddha-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0322-xfoc`
|
||||
- `experiment-stanfordv12-armadillo-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2039-vbks`
|
||||
|
||||
### PRIF encoding
|
||||
- `experiment-stanfordv12-armadillo-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2033-nkxm`
|
||||
- `experiment-stanfordv12-happy_buddha-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0313-huci`
|
||||
- `experiment-stanfordv12-dragon-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0537-dxsb`
|
||||
- `experiment-stanfordv12-bunny-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0011-tzic`
|
||||
- `experiment-stanfordv12-lucy-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0744-hzvw`
|
||||
|
||||
### No init scheme.
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0444-uohy`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2307-wjcf`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0707-eanc`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0225-kcfw`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0852-lkfh`
|
||||
|
||||
### 1 atom candidate
|
||||
- `experiment-stanfordv12-lucy-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0755-qzth`
|
||||
- `experiment-stanfordv12-bunny-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0027-ycnl`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2121-fwvo`
|
||||
- `experiment-stanfordv12-dragon-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0541-nvhs`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0324-cuyw`
|
||||
|
||||
### 4 atom candidates
|
||||
- `experiment-stanfordv12-armadillo-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2122-qiwg`
|
||||
- `experiment-stanfordv12-dragon-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0544-ihkx`
|
||||
- `experiment-stanfordv12-lucy-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0757-jwxm`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0328-chhs`
|
||||
- `experiment-stanfordv12-bunny-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0038-zymb`
|
||||
|
||||
### 8 atom candidates
|
||||
- `experiment-stanfordv12-bunny-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0055-ogpd`
|
||||
- `experiment-stanfordv12-lucy-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0757-frxb`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0337-twys`
|
||||
- `experiment-stanfordv12-dragon-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0551-bubw`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2137-nnlj`
|
||||
|
||||
### 32 atom candidates
|
||||
- `experiment-stanfordv12-bunny-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0056-ourc`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2141-byaj`
|
||||
- `experiment-stanfordv12-dragon-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0554-zobg`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0337-rmyq`
|
||||
- `experiment-stanfordv12-lucy-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0800-lqen`
|
||||
|
||||
### 64 atom candidates
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0339-whcx`
|
||||
- `experiment-stanfordv12-bunny-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0058-seen`
|
||||
- `experiment-stanfordv12-lucy-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0806-ycxj`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2153-wnfq`
|
||||
- `experiment-stanfordv12-dragon-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0555-zgcb`
|
||||
|
||||
### No intersection loss
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1053-ydnh`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1111-fawl`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1045-umwl`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1103-lwmb`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1041-lhcc`
|
||||
|
||||
### No silhouette loss
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1042-fsuw`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1046-nszw`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1111-mlal`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1055-cvkg`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1114-pdyh`
|
||||
|
||||
### More silhouette loss
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0157-yekm`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2243-nlrv`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0639-yros`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0842-xktg`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0423-ibxs`
|
||||
|
||||
### No normal loss
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0614-ttta`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0106-bnke`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2154-bxwl`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0811-qqgu`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0357-gwca`
|
||||
|
||||
### No inscription loss
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0227-xrqt`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2312-cgzv`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0452-rerr`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0709-tfgg`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0856-ctvc`
|
||||
|
||||
### More inscription loss
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0459-kyyh`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0243-qqqj`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2336-yclo`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0913-mulv`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0714-zugg`
|
||||
|
||||
### No maximality reg.
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0842-cvln`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0425-vpen`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0207-qpdb`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2251-zqvi`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0641-ucdo`
|
||||
|
||||
### More maximality reg.
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0659-bqvf`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2256-escz`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0208-wmvs`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0442-gdah`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0845-halc`
|
||||
|
||||
### No specialization reg.
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0913-odyn`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0251-xzig`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0722-gxps`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2342-zybo`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0507-tvlt`
|
||||
|
||||
### No multi-view loss
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0310-wbqj`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-30-2357-qnct`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0527-psnk`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0927-wxcq`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0743-pdbc`
|
||||
|
||||
### More multi-view loss
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0510-caah`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0726-zkyg`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0254-akbq`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0924-aahb`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-30-2352-xlrn`
|
624
experiments/marf.py
Executable file
624
experiments/marf.py
Executable file
@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from ifield import logging
|
||||
from ifield.cli import CliInterface
|
||||
from ifield.data.common.scan import SingleViewUVScan
|
||||
from ifield.data.coseg import read as coseg_read
|
||||
from ifield.data.stanford import read as stanford_read
|
||||
from ifield.datasets import stanford, coseg, common
|
||||
from ifield.models import intersection_fields
|
||||
from ifield.utils.operators import diff
|
||||
from ifield.viewer.ray_field import ModelViewer
|
||||
from munch import Munch
|
||||
from pathlib import Path
|
||||
from pytorch3d.loss.chamfer import chamfer_distance
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from tqdm import tqdm
|
||||
from trimesh import Trimesh
|
||||
from typing import Union
|
||||
import builtins
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rich
|
||||
import rich.pretty
|
||||
import statistics
|
||||
import torch
|
||||
pl.seed_everything(31337)
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
|
||||
|
||||
IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity
|
||||
|
||||
|
||||
class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def observation_ids(self) -> list[str]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_trimesh_from_uid(uid) -> Trimesh:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
|
||||
...
|
||||
|
||||
def setup(self, stage=None):
|
||||
assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM
|
||||
|
||||
if not self.hparams.data_dir is None:
|
||||
coseg.config.DATA_PATH = self.hparams.data_dir
|
||||
step = self.hparams.step # brevity
|
||||
|
||||
dataset = self.mk_ad_dataset()
|
||||
n_items_pre_step_mapping = len(dataset)
|
||||
|
||||
if step > 1:
|
||||
dataset = common.TransformExtendedDataset(dataset)
|
||||
|
||||
for sx in range(step):
|
||||
for sy in range(step):
|
||||
def make_slicer(sx, sy, step) -> callable: # the closure is required
|
||||
if step > 1:
|
||||
return lambda t: t[sx::step, sy::step]
|
||||
else:
|
||||
return lambda t: t
|
||||
@dataset.map(slicer=make_slicer(sx, sy, step))
|
||||
def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
|
||||
scan: SingleViewUVScan = sample[1]
|
||||
assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
|
||||
assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"
|
||||
|
||||
return {
|
||||
"z_uid" : sample[0],
|
||||
"origins" : scan.cam_pos,
|
||||
"dirs" : slicer(scan.ray_dirs),
|
||||
"points" : slicer(scan.points),
|
||||
"hits" : slicer(scan.hits),
|
||||
"miss" : slicer(scan.miss),
|
||||
"normals" : slicer(scan.normals),
|
||||
"distances" : slicer(scan.distances),
|
||||
}
|
||||
|
||||
# Split each object into train/val with SampleSplit
|
||||
n_items = len(dataset)
|
||||
n_val = int(n_items * self.hparams.val_fraction)
|
||||
n_train = n_items - n_val
|
||||
self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)
|
||||
|
||||
# split the dataset such that all steps are in same part
|
||||
assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
|
||||
indices = [
|
||||
i*step*step + sx*step + sy
|
||||
for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
|
||||
for sx in range(step)
|
||||
for sy in range(step)
|
||||
]
|
||||
self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
|
||||
self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
|
||||
|
||||
assert len(self.dataset_train) % self.hparams.batch_size == 0
|
||||
assert len(self.dataset_val) % self.hparams.batch_size == 0
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.dataset_train,
|
||||
batch_size = self.hparams.batch_size,
|
||||
drop_last = self.hparams.drop_last,
|
||||
num_workers = self.hparams.num_workers,
|
||||
persistent_workers = self.hparams.persistent_workers,
|
||||
pin_memory = self.hparams.pin_memory,
|
||||
prefetch_factor = self.hparams.prefetch_factor,
|
||||
shuffle = self.hparams.shuffle,
|
||||
generator = self.generator,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.dataset_val,
|
||||
batch_size = self.hparams.batch_size,
|
||||
drop_last = self.hparams.drop_last,
|
||||
num_workers = self.hparams.num_workers,
|
||||
persistent_workers = self.hparams.persistent_workers,
|
||||
pin_memory = self.hparams.pin_memory,
|
||||
prefetch_factor = self.hparams.prefetch_factor,
|
||||
generator = self.generator,
|
||||
)
|
||||
|
||||
|
||||
class StanfordUVDataModule(RayFieldAdDataModuleBase):
|
||||
skyward = "+Z"
|
||||
def __init__(self,
|
||||
data_dir : Union[str, Path, None] = None,
|
||||
obj_names : list[str] = ["bunny"], # empty means all
|
||||
|
||||
prng_seed : int = 1337,
|
||||
step : int = 2,
|
||||
batch_size : int = 5,
|
||||
drop_last : bool = False,
|
||||
num_workers : int = 8,
|
||||
persistent_workers : bool = True,
|
||||
pin_memory : int = True,
|
||||
prefetch_factor : int = 2,
|
||||
shuffle : bool = True,
|
||||
val_fraction : float = 0.30,
|
||||
):
|
||||
super().__init__()
|
||||
if not obj_names:
|
||||
obj_names = stanford_read.list_object_names()
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def observation_ids(self) -> list[str]:
|
||||
return self.hparams.obj_names
|
||||
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
return stanford.AutodecoderSingleViewUVScanDataset(
|
||||
obj_names = self.hparams.obj_names,
|
||||
data_path = self.hparams.data_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_trimesh_from_uid(obj_name) -> Trimesh:
|
||||
import mesh_to_sdf
|
||||
mesh = stanford_read.read_mesh(obj_name)
|
||||
return mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
|
||||
@staticmethod
|
||||
def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
|
||||
return stanford_read.read_mesh_mesh_sphere_scan(obj_name)
|
||||
|
||||
|
||||
class CosegUVDataModule(RayFieldAdDataModuleBase):
|
||||
skyward = "+Y"
|
||||
def __init__(self,
|
||||
data_dir : Union[str, Path, None] = None,
|
||||
object_sets : tuple[str] = ["tele-aliens"], # empty means all
|
||||
|
||||
prng_seed : int = 1337,
|
||||
step : int = 2,
|
||||
batch_size : int = 5,
|
||||
drop_last : bool = False,
|
||||
num_workers : int = 8,
|
||||
persistent_workers : bool = True,
|
||||
pin_memory : int = True,
|
||||
prefetch_factor : int = 2,
|
||||
shuffle : bool = True,
|
||||
val_fraction : float = 0.30,
|
||||
):
|
||||
super().__init__()
|
||||
if not object_sets:
|
||||
object_sets = coseg_read.list_object_sets()
|
||||
object_sets = tuple(object_sets)
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def observation_ids(self) -> list[str]:
|
||||
return coseg_read.list_model_id_strings(self.hparams.object_sets)
|
||||
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
return coseg.AutodecoderSingleViewUVScanDataset(
|
||||
object_sets = self.hparams.object_sets,
|
||||
data_path = self.hparams.data_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_trimesh_from_uid(string_uid):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
|
||||
uid = coseg_read.model_id_string_to_uid(string_uid)
|
||||
return coseg_read.read_mesh_mesh_sphere_scan(*uid)
|
||||
|
||||
|
||||
def mk_cli(args=None) -> CliInterface:
|
||||
cli = CliInterface(
|
||||
module_cls = IField,
|
||||
datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
|
||||
workdir = Path(__file__).parent.resolve(),
|
||||
experiment_name_prefix = "ifield",
|
||||
)
|
||||
cli.trainer_defaults.update(dict(
|
||||
precision = 16,
|
||||
min_epochs = 5,
|
||||
))
|
||||
|
||||
@cli.register_pre_training_callback
|
||||
def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
|
||||
module.set_observation_ids(datamodule.observation_ids)
|
||||
rank = getattr(rank_zero_only, "rank", 0)
|
||||
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }")
|
||||
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
|
||||
rich.print(f"[rank {rank}] {module.is_conditioned = }")
|
||||
|
||||
@cli.register_action(help="Interactive window with direct renderings from the model", args=[
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--ground-truth", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--scale", dict(type=int, default=3, help="Rendering scale")),
|
||||
("--fps", dict(type=int, default=None, help="FPS upper limit")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
("--write", dict(type=Path, default=None, help="Where to write a screenshot.")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def viewer(args: Namespace, config: Munch, model: IField):
|
||||
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
scale = args.scale,
|
||||
mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
|
||||
)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.fps_cap = args.fps
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
if args.write is None:
|
||||
viewer.run()
|
||||
else:
|
||||
assert args.write.suffix == ".png", args.write.name
|
||||
viewer.render_headless(args.write,
|
||||
n_frames = 1,
|
||||
fps = 1,
|
||||
state_callback = None,
|
||||
)
|
||||
|
||||
@cli.register_action(help="Prerender direct renderings from the model", args=[
|
||||
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
|
||||
("uids", dict(type=str, nargs="*")),
|
||||
("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
|
||||
("--fps", dict(type=int, default=60, help="Default is 60")),
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
uids = args.uids or list(model.keys())
|
||||
assert len(uids) > 1
|
||||
if not args.uids: uids.append(uids[0])
|
||||
viewer = ModelViewer(model, uids[0],
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
)
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
def state_callback(self: ModelViewer, frame: int):
|
||||
if frame % args.frames:
|
||||
self.lambertian_color = (0.8, 0.8, 1.0)
|
||||
else:
|
||||
self.lambertian_color = (1.0, 1.0, 1.0)
|
||||
self.fps = args.frames
|
||||
idx = frame // args.frames + 1
|
||||
if idx != len(uids):
|
||||
self.current_uid = uids[idx]
|
||||
print(f"Writing video to {str(args.output_path)!r}...")
|
||||
viewer.render_headless(args.output_path,
|
||||
n_frames = args.frames * (len(uids)-1) + 1,
|
||||
fps = args.fps,
|
||||
state_callback = state_callback,
|
||||
bitrate = args.bitrate,
|
||||
)
|
||||
|
||||
@cli.register_action(help="Prerender direct renderings from the model", args=[
|
||||
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
|
||||
("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")),
|
||||
("--fps", dict(type=int, default=60, help="Default is 60")),
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
)
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
cam_rot_x_init = viewer.cam_rot_x
|
||||
def state_callback(self: ModelViewer, frame: int):
|
||||
self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
|
||||
print(f"Writing video to {str(args.output_path)!r}...")
|
||||
viewer.render_headless(args.output_path,
|
||||
n_frames = args.frames,
|
||||
fps = args.fps,
|
||||
state_callback = state_callback,
|
||||
bitrate = args.bitrate,
|
||||
)
|
||||
|
||||
@cli.register_action(help="foo", args=[
|
||||
("fname", dict(type=Path, help="where to write json")),
|
||||
("-t", "--transpose", dict(action="store_true", help="transpose the output")),
|
||||
("--single-shape", dict(action="store_true", help="break after first shape")),
|
||||
("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
|
||||
("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
|
||||
("--filter-outliers", dict(action="store_true", help="like in PRIF")),
|
||||
])
|
||||
@torch.enable_grad()
|
||||
def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
|
||||
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
|
||||
model.eval()
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
|
||||
def T(array: np.ndarray, **kw) -> torch.Tensor:
|
||||
if isinstance(array, torch.Tensor): return array
|
||||
return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)
|
||||
|
||||
MEDIAL = model.hparams.output_mode == "medial_sphere"
|
||||
if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"
|
||||
|
||||
|
||||
uids = sorted(model.keys())
|
||||
if args.single_shape: uids = [uids[0]]
|
||||
rich.print(f"{datamodule_cls.__name__ = }")
|
||||
rich.print(f"{len(uids) = }")
|
||||
|
||||
# accumulators for IoU and F-Score, CD and COS
|
||||
|
||||
# sum reduction:
|
||||
n = defaultdict(int)
|
||||
n_gt_hits = defaultdict(int)
|
||||
n_gt_miss = defaultdict(int)
|
||||
n_gt_missing = defaultdict(int)
|
||||
n_outliers = defaultdict(int)
|
||||
p_mse = defaultdict(int)
|
||||
s_mse = defaultdict(int)
|
||||
cossim_med = defaultdict(int) # medial normals
|
||||
cossim_jac = defaultdict(int) # jacovian normals
|
||||
TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score
|
||||
# mean reduction:
|
||||
cd_dist = {} # chamfer distance
|
||||
cd_cos_med = {} # chamfer medial normals
|
||||
cd_cos_jac = {} # chamfer jacovian normals
|
||||
all_metrics = dict(
|
||||
n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
|
||||
cossim_jac=cossim_jac,
|
||||
TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
|
||||
cd_cos_jac=cd_cos_jac,
|
||||
)
|
||||
if MEDIAL:
|
||||
all_metrics["s_mse"] = s_mse
|
||||
all_metrics["cossim_med"] = cossim_med
|
||||
all_metrics["cd_cos_med"] = cd_cos_med
|
||||
if args.filter_outliers:
|
||||
all_metrics["n_outliers"] = n_outliers
|
||||
|
||||
t = datetime.now()
|
||||
for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
|
||||
sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)
|
||||
|
||||
z = model[uid].detach()
|
||||
|
||||
all_intersections = []
|
||||
all_medial_normals = []
|
||||
all_jacobian_normals = []
|
||||
|
||||
step = args.batch_size
|
||||
for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
|
||||
# prepare batch and gt
|
||||
origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True)
|
||||
dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :])
|
||||
gt_hits = T(sphere_scan_gt.hits [i:i+step])
|
||||
gt_miss = T(sphere_scan_gt.miss [i:i+step])
|
||||
gt_missing = T(sphere_scan_gt.missing [i:i+step])
|
||||
gt_points = T(sphere_scan_gt.points [i:i+step, :])
|
||||
gt_normals = T(sphere_scan_gt.normals [i:i+step, :])
|
||||
gt_distances = T(sphere_scan_gt.distances[i:i+step])
|
||||
|
||||
# forward
|
||||
if MEDIAL:
|
||||
(
|
||||
depths,
|
||||
silhouettes,
|
||||
intersections,
|
||||
medial_normals,
|
||||
is_intersecting,
|
||||
sphere_centers,
|
||||
sphere_radii,
|
||||
) = model({
|
||||
"origins" : origins,
|
||||
"dirs" : dirs,
|
||||
}, z, intersections_only=False, allow_nans=False)
|
||||
else:
|
||||
silhouettes = medial_normals = None
|
||||
intersections, is_intersecting = model({
|
||||
"origins" : origins,
|
||||
"dirs" : dirs,
|
||||
}, z, normalize_origins = True)
|
||||
is_intersecting = is_intersecting > 0.5
|
||||
jac = diff.jacobian(intersections, origins, detach=True)
|
||||
|
||||
# outlier removal (PRIF)
|
||||
if args.filter_outliers:
|
||||
outliers = jac.norm(dim=-2).norm(dim=-1) > 5
|
||||
n_outliers[uid] += outliers[is_intersecting].sum().item()
|
||||
# We count filtered points as misses
|
||||
is_intersecting &= ~outliers
|
||||
|
||||
model.zero_grad()
|
||||
jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)
|
||||
|
||||
all_intersections .append(intersections .detach()[is_intersecting.detach(), :])
|
||||
all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None
|
||||
all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])
|
||||
|
||||
# accumulate metrics
|
||||
with torch.no_grad():
|
||||
n [uid] += dirs.shape[0]
|
||||
n_gt_hits [uid] += gt_hits.sum().item()
|
||||
n_gt_miss [uid] += gt_miss.sum().item()
|
||||
n_gt_missing [uid] += gt_missing.sum().item()
|
||||
p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
|
||||
if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item()
|
||||
if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
|
||||
cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
|
||||
not_intersecting = ~is_intersecting
|
||||
TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive
|
||||
FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
|
||||
FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive
|
||||
TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative
|
||||
|
||||
all_intersections = torch.cat(all_intersections, dim=0)
|
||||
all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None
|
||||
all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)
|
||||
|
||||
hits = sphere_scan_gt.hits # brevity
|
||||
print()
|
||||
|
||||
assert all_intersections.shape[0] >= args.n_cd
|
||||
idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd]
|
||||
idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd]
|
||||
|
||||
print("cd... ", end="")
|
||||
tt = datetime.now()
|
||||
loss_cd, loss_cos_jac = chamfer_distance(
|
||||
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
|
||||
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
|
||||
batch_reduction = "sum", point_reduction = "sum",
|
||||
)
|
||||
if MEDIAL: _, loss_cos_med = chamfer_distance(
|
||||
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
|
||||
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
|
||||
batch_reduction = "sum", point_reduction = "sum",
|
||||
)
|
||||
print(datetime.now() - tt)
|
||||
|
||||
cd_dist [uid] = loss_cd.item()
|
||||
cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
|
||||
cd_cos_jac [uid] = loss_cos_jac.item()
|
||||
|
||||
print()
|
||||
model.zero_grad(set_to_none=True)
|
||||
print("Total time:", datetime.now() - t)
|
||||
print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None
|
||||
|
||||
sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs)))
|
||||
mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
|
||||
stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
|
||||
n_cd = args.n_cd
|
||||
P = sum(TP)/(sum(TP, FP))
|
||||
R = sum(TP)/(sum(TP, FN))
|
||||
print(f"{mean(n) = :11.1f} (rays per object)")
|
||||
print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)")
|
||||
print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)")
|
||||
print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)")
|
||||
print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None
|
||||
print(f"{n_cd = :11.0f} (cd rays per object)")
|
||||
print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)")
|
||||
print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)")
|
||||
print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)")
|
||||
print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None
|
||||
print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)")
|
||||
print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)")
|
||||
print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)")
|
||||
print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)")
|
||||
print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)")
|
||||
print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)")
|
||||
print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)")
|
||||
print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)")
|
||||
print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)")
|
||||
print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)")
|
||||
print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)")
|
||||
print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None
|
||||
print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)")
|
||||
print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)")
|
||||
print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None
|
||||
print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)")
|
||||
print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None
|
||||
print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None
|
||||
print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None
|
||||
|
||||
if args.transpose:
|
||||
all_metrics, old_metrics = defaultdict(dict), all_metrics
|
||||
for m, table in old_metrics.items():
|
||||
for uid, vals in table.items():
|
||||
all_metrics[uid][m] = vals
|
||||
all_metrics["_hparams"] = dict(n_cd=args.n_cd)
|
||||
else:
|
||||
all_metrics["n_cd"] = args.n_cd
|
||||
|
||||
if str(args.fname) == "-":
|
||||
print("{", ',\n'.join(
|
||||
f" {json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in all_metrics.items()
|
||||
), "}", sep="\n")
|
||||
else:
|
||||
args.fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.fname.open("w") as f:
|
||||
json.dump(all_metrics, f, indent=2)
|
||||
|
||||
return cli
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mk_cli().run()
|
263
experiments/marf.yaml.j2
Executable file
263
experiments/marf.yaml.j2
Executable file
@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env -S python ./marf.py module
|
||||
{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}
|
||||
{% do require_defined("mode", mode, "single", "ablation", "multi", strict=true, exchaustive=true) %}{# requires jinja2.ext.do #}
|
||||
{% set counter = itertools.count(start=0, step=1) %}
|
||||
{% set do_condition = mode == "multi" %}
|
||||
{% set do_ablation = mode == "ablation" %}
|
||||
|
||||
{% set hp_matrix = namespace() %}{# hyper parameter matrix #}
|
||||
|
||||
{% set hp_matrix.input_mode = [
|
||||
"both",
|
||||
"perp_foot",
|
||||
"plucker",
|
||||
] if do_ablation else [ "both" ] %}
|
||||
{% set hp_matrix.output_mode = ["medial_sphere", "orthogonal_plane"] %}{##}
|
||||
{% set hp_matrix.output_mode = ["medial_sphere"] %}{##}
|
||||
{% set hp_matrix.n_atoms = [16, 1, 4, 8, 32, 64] if do_ablation else [16] %}{##}
|
||||
{% set hp_matrix.normal_coeff = [0.25, 0] if do_ablation else [0.25] %}{##}
|
||||
{% set hp_matrix.dataset_item = [objname] if objname is defined else (["armadillo", "bunny", "happy_buddha", "dragon", "lucy"] if not do_condition else ["four-legged"]) %}{##}
|
||||
{% set hp_matrix.test_val_split_frac = [0.7] %}{##}
|
||||
{% set hp_matrix.lr_coeff = [5] %}{##}
|
||||
{% set hp_matrix.warmup_epochs = [1] if not do_condition else [0.1] %}{##}
|
||||
{% set hp_matrix.improve_miss_grads = [True] %}{##}
|
||||
{% set hp_matrix.normalize_ray_dirs = [True] %}{##}
|
||||
{% set hp_matrix.intersection_coeff = [2, 0] if do_ablation else [2] %}{##}
|
||||
{% set hp_matrix.miss_distance_coeff = [1, 0, 5] if do_ablation else [1] %}{##}
|
||||
{% set hp_matrix.relative_out = [False] %}{##}
|
||||
{% set hp_matrix.hidden_features = [512] %}{# like deepsdf and prif #}
|
||||
{% set hp_matrix.hidden_layers = [8] %}{# like deepsdf, nerf, prif #}
|
||||
{% set hp_matrix.nonlinearity = ["leaky_relu"] %}{##}
|
||||
{% set hp_matrix.omega = [30] %}{##}
|
||||
{% set hp_matrix.normalization = ["layernorm"] %}{##}
|
||||
{% set hp_matrix.dropout_percent = [1] %}{##}
|
||||
{% set hp_matrix.sphere_grow_reg_coeff = [500, 0, 5000] if do_ablation else [500] %}{##}
|
||||
{% set hp_matrix.geom_init = [True, False] if do_ablation else [True] %}{##}
|
||||
{% set hp_matrix.loss_inscription = [50, 0, 250] if do_ablation else [50] %}{##}
|
||||
{% set hp_matrix.atom_centroid_norm_std_reg_negexp = [0, None] if do_ablation else [0] %}{##}
|
||||
{% set hp_matrix.curvature_reg_coeff = [0.2] %}{##}
|
||||
{% set hp_matrix.multi_view_reg_coeff = [1, 2] if do_ablation else [1] %}{##}
|
||||
{% set hp_matrix.grad_reg = [ "multi_view", "nogradreg" ] if do_ablation else [ "multi_view" ] %}
|
||||
|
||||
{#% for hp in cartesian_hparams(hp_matrix) %}{##}
|
||||
{% for hp in ablation_hparams(hp_matrix, caartesian_keys=["output_mode", "dataset_item", "nonlinearity", "test_val_split_frac"]) %}
|
||||
|
||||
{% if hp.output_mode == "orthogonal_plane"%}
|
||||
{% if hp.normal_coeff == 0 %}{% set hp.normal_coeff = 0.25 %}
|
||||
{% elif hp.normal_coeff == 0.25 %}{% set hp.normal_coeff = 0 %}
|
||||
{% endif %}
|
||||
{% if hp.grad_reg == "multi_view" %}{% set hp.grad_reg = "nogradreg" %}
|
||||
{% elif hp.grad_reg == "nogradreg" %}{% set hp.grad_reg = "multi_view" %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{# filter bad/uninteresting hparam combos #}
|
||||
{% if ( hp.nonlinearity != "sine" and hp.omega != 30 )
|
||||
or ( hp.nonlinearity == "sine" and hp.normalization in ("layernorm", "layernorm_na") )
|
||||
or ( hp.multi_view_reg_coeff != 1 and "multi_view" not in hp.grad_reg )
|
||||
or ( "curvature" not in hp.grad_reg and hp.curvature_reg_coeff != 0.2 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.input_mode != "both" )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.atom_centroid_norm_std_reg_negexp != 0 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.n_atoms != 16 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.sphere_grow_reg_coeff != 500 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.loss_inscription != 50 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.miss_distance_coeff != 1 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.test_val_split_frac != 0.7 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.lr_coeff != 5 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and not hp.geom_init )
|
||||
or ( hp.output_mode == "orthogonal_plane" and not hp.intersection_coeff )
|
||||
%}
|
||||
{% continue %}{# requires jinja2.ext.loopcontrols #}
|
||||
{% endif %}
|
||||
|
||||
{% set index = next(counter) %}
|
||||
{% if select is not defined and index > 0 %}---{% endif %}
|
||||
{% if select is not defined or int(select) == index %}
|
||||
|
||||
trainer:
|
||||
gradient_clip_val : 1.0
|
||||
max_epochs : 200
|
||||
min_epochs : 200
|
||||
log_every_n_steps : 20
|
||||
|
||||
{% if not do_condition %}
|
||||
|
||||
StanfordUVDataModule:
|
||||
obj_names : ["{{ hp.dataset_item }}"]
|
||||
step : 4
|
||||
batch_size : 8
|
||||
val_fraction : {{ 1-hp.test_val_split_frac }}
|
||||
|
||||
{% else %}{# if do_condition #}
|
||||
|
||||
CosegUVDataModule:
|
||||
object_sets : ["{{ hp.dataset_item }}"]
|
||||
step : 4
|
||||
batch_size : 8
|
||||
val_fraction : {{ 1-hp.test_val_split_frac }}
|
||||
|
||||
{% endif %}{# if do_condition #}
|
||||
|
||||
logging:
|
||||
save_dir : logdir
|
||||
type : tensorboard
|
||||
project : ifield
|
||||
|
||||
{% autoescape false %}
|
||||
{% do require_defined("experiment_name", experiment_name, "single-shape" if do_condition else "multi-shape", strict=true) %}
|
||||
{% set input_mode_abbr = hp.input_mode
|
||||
.replace("plucker", "plkr")
|
||||
.replace("perp_foot", "prpft")
|
||||
%}
|
||||
{% set output_mode_abbr = hp.output_mode
|
||||
.replace("medial_sphere", "marf")
|
||||
.replace("orthogonal_plane", "prif")
|
||||
%}
|
||||
experiment_name: experiment-{{ "" if experiment_name is not defined else experiment_name }}
|
||||
{#--#}-{{ hp.dataset_item }}
|
||||
{#--#}-{{ input_mode_abbr }}2{{ output_mode_abbr }}
|
||||
{#--#}
|
||||
{%- if hp.output_mode == "medial_sphere" -%}
|
||||
{#--#}-{{ hp.n_atoms }}atom
|
||||
{#--# }-{{ "rel" if hp.relative_out else "norel" }}
|
||||
{#--# }-{{ "e" if hp.improve_miss_grads else "0" }}sqrt
|
||||
{#--#}-{{ int(hp.loss_inscription) if hp.loss_inscription else "no" }}xinscr
|
||||
{#--#}-{{ int(hp.miss_distance_coeff * 10) }}dmiss
|
||||
{#--#}-{{ "geom" if hp.geom_init else "nogeom" }}
|
||||
{#--#}{% if "curvature" in hp.grad_reg %}
|
||||
{#- -#}-{{ int(hp.curvature_reg_coeff*10) }}crv
|
||||
{#--#}{%- endif -%}
|
||||
{%- elif hp.output_mode == "orthogonal_plane" -%}
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{#--#}-{{ int(hp.intersection_coeff*10) }}chit
|
||||
{#--#}-{{ int(hp.normal_coeff*100) or "no" }}cnrml
|
||||
{#--# }-{{ "do" if hp.normalize_ray_dirs else "no" }}raynorm
|
||||
{#--#}-{{ hp.hidden_layers }}x{{ hp.hidden_features }}fc
|
||||
{#--#}-{{ hp.nonlinearity or "linear" }}
|
||||
{#--#}
|
||||
{%- if hp.nonlinearity == "sine" -%}
|
||||
{#--#}-{{ hp.omega }}omega
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{%- if hp.output_mode == "medial_sphere" -%}
|
||||
{#--#}-{{ str(hp.atom_centroid_norm_std_reg_negexp).replace(*"-n") if hp.atom_centroid_norm_std_reg_negexp is not none else 'no' }}minatomstdngxp
|
||||
{#--#}-{{ hp.sphere_grow_reg_coeff }}sphgrow
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{#--#}-{{ int(hp.dropout_percent*10) }}mdrop
|
||||
{#--#}-{{ hp.normalization or "nonorm" }}
|
||||
{#--#}-{{ hp.grad_reg }}
|
||||
{#--#}{% if "multi_view" in hp.grad_reg %}
|
||||
{#- -#}-{{ int(hp.multi_view_reg_coeff*10) }}dmv
|
||||
{#--#}{%- endif -%}
|
||||
{#--#}-{{ "concat" if do_condition else "nocond" }}
|
||||
{#--#}-{{ int(hp.warmup_epochs*100) }}cwu{{ int(hp.lr_coeff*100) }}clr{{ int(hp.test_val_split_frac*100) }}tvs
|
||||
{#--#}-{{ gen_run_uid(4) }} # select with --Oselect={{ index }}
|
||||
{#--#}
|
||||
{##}
|
||||
|
||||
{% endautoescape %}
|
||||
IntersectionFieldAutoDecoderModel:
|
||||
_extra: # used for easier introspection with jq
|
||||
dataset_item: {{ hp.dataset_item | to_json}}
|
||||
dataset_test_val_frac: {{ hp.test_val_split_frac }}
|
||||
select: {{ index }}
|
||||
|
||||
input_mode : {{ hp.input_mode }} # in {plucker, perp_foot, both}
|
||||
output_mode : {{ hp.output_mode }} # in {medial_sphere, orthogonal_plane}
|
||||
#latent_features : 256 # int
|
||||
#latent_features : 128 # int
|
||||
latent_features : 16 # int
|
||||
hidden_features : {{ hp.hidden_features }} # int
|
||||
hidden_layers : {{ hp.hidden_layers }} # int
|
||||
|
||||
improve_miss_grads : {{ bool(hp.improve_miss_grads) | to_json }}
|
||||
normalize_ray_dirs : {{ bool(hp.normalize_ray_dirs) | to_json }}
|
||||
|
||||
loss_intersection : {{ hp.intersection_coeff }}
|
||||
loss_intersection_l2 : 0
|
||||
loss_intersection_proj : 0
|
||||
loss_intersection_proj_l2 : 0
|
||||
|
||||
loss_normal_cossim : {{ hp.normal_coeff }} * EaseSin(85, 15)
|
||||
loss_normal_euclid : 0
|
||||
loss_normal_cossim_proj : 0
|
||||
loss_normal_euclid_proj : 0
|
||||
|
||||
{% if "multi_view" in hp.grad_reg %}
|
||||
loss_multi_view_reg : 0.1 * {{ hp.multi_view_reg_coeff }} * Linear(50)
|
||||
{% else %}
|
||||
loss_multi_view_reg : 0
|
||||
{% endif %}
|
||||
|
||||
{% if hp.output_mode == "orthogonal_plane" %}
|
||||
|
||||
loss_hit_cross_entropy : 1
|
||||
|
||||
{% elif hp.output_mode == "medial_sphere" %}
|
||||
|
||||
loss_hit_nodistance_l1 : 0
|
||||
loss_hit_nodistance_l2 : 100 * {{ hp.miss_distance_coeff }}
|
||||
loss_miss_distance_l1 : 0
|
||||
loss_miss_distance_l2 : 10 * {{ hp.miss_distance_coeff }}
|
||||
|
||||
loss_inscription_hits : {{ 0.4 * hp.loss_inscription }}
|
||||
loss_inscription_miss : 0
|
||||
loss_inscription_hits_l2 : 0
|
||||
loss_inscription_miss_l2 : {{ 6 * hp.loss_inscription }}
|
||||
|
||||
loss_sphere_grow_reg : 1e-6 * {{ hp.sphere_grow_reg_coeff }} # constant
|
||||
loss_atom_centroid_norm_std_reg: (0.09*(1-Linear(40)) + 0.01) * {{ 10**(-hp.atom_centroid_norm_std_reg_negexp) if hp.atom_centroid_norm_std_reg_negexp is not none else 0 }}
|
||||
|
||||
{% else %}{#endif hp.output_mode == "medial_sphere" #}
|
||||
THIS IS INVALID YAML
|
||||
{% endif %}
|
||||
|
||||
loss_embedding_norm : 0.01**2 * Linear(30, 0.1)
|
||||
|
||||
opt_learning_rate : {{ hp.lr_coeff }} * 10**(-4-0.5*EaseSin(170, 30)) # layernorm
|
||||
opt_warmup : {{ hp.warmup_epochs }}
|
||||
opt_weight_decay : 5e-6 # float
|
||||
|
||||
{% if hp.output_mode == "medial_sphere" %}
|
||||
|
||||
# MedialAtomNet:
|
||||
n_atoms : {{ hp.n_atoms }} # int
|
||||
{% if hp.geom_init %}
|
||||
final_init_wrr: [0.05, 0.6, 0.1]
|
||||
{% else %}
|
||||
final_init_wrr: null
|
||||
{% endif %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
|
||||
# FCBlock:
|
||||
normalization : {{ hp.normalization or "null" }} # in {null, layernorm, layernorm_na, weightnorm}
|
||||
nonlinearity : {{ hp.nonlinearity or "null" }} # in {null, relu, leaky_relu, silu, softplus, elu, selu, sine, sigmoid, tanh }
|
||||
{% set middle = 1 + hp.hidden_layers // 2 + (hp.hidden_layers % 2) %}{##}
|
||||
concat_skipped_layers : [{{ middle }}, -1]
|
||||
{% if do_condition %}
|
||||
concat_conditioned_layers : [0, {{ middle }}]
|
||||
{% else %}
|
||||
concat_conditioned_layers : []
|
||||
{% endif %}
|
||||
|
||||
# FCLayer:
|
||||
negative_slope : 0.01 # float
|
||||
omega_0 : {{ hp.omega }} # float
|
||||
residual_mode : null # in {null, identity}
|
||||
|
||||
{% endif %}{# -Oselect #}
|
||||
|
||||
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% set index = next(counter) %}
|
||||
# number of possible -Oselect: {{ index }}, from 0 to {{ index-1 }}
|
||||
# local: for select in {0..{{ index-1 }}}; do python ... -Omode={{ mode }} -Oselect=$select ... ; done
|
||||
# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done
|
||||
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID ...
|
||||
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp
|
849
experiments/summary.py
Executable file
849
experiments/summary.py
Executable file
@ -0,0 +1,849 @@
|
||||
#!/usr/bin/env python
|
||||
from concurrent.futures import ThreadPoolExecutor, Future, ProcessPoolExecutor
|
||||
from functools import partial
|
||||
from more_itertools import first, last, tail
|
||||
from munch import Munch, DefaultMunch, munchify, unmunchify
|
||||
from pathlib import Path
|
||||
from statistics import mean, StatisticsError
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from typing import Iterable, Optional, Literal
|
||||
from math import isnan
|
||||
import json
|
||||
import stat
|
||||
import matplotlib
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import os, os.path
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
import itertools
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import typer
|
||||
import warnings
|
||||
import yaml
|
||||
import tempfile
|
||||
|
||||
EXPERIMENTS = Path(__file__).resolve()
|
||||
LOGDIR = EXPERIMENTS / "logdir"
|
||||
TENSORBOARD = LOGDIR / "tensorboard"
|
||||
SLURM_LOGS = LOGDIR / "slurm_logs"
|
||||
CACHED_SUMMARIES = LOGDIR / "cached_summaries"
|
||||
COMPUTED_SCORES = LOGDIR / "computed_scores"
|
||||
|
||||
MISSING = object()
|
||||
|
||||
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
|
||||
def ignore_unknown(self, node):
|
||||
return None
|
||||
SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)
|
||||
|
||||
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
||||
parts = (
|
||||
part.lower()
|
||||
for part in re.split(r'(?=[A-Z])', text)
|
||||
if part
|
||||
)
|
||||
if join_abbreviations: # this operation is not reversible
|
||||
parts = list(parts)
|
||||
if len(parts) > 1:
|
||||
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
||||
if len(a) == len(b) == 1:
|
||||
parts[i] = parts[i] + parts.pop(i+1)
|
||||
return sep.join(parts)
|
||||
|
||||
def flatten_dict(data: dict, key_mapper: callable = lambda x: x) -> dict:
|
||||
if not any(isinstance(val, dict) for val in data.values()):
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if not isinstance(v, dict)
|
||||
} | {
|
||||
f"{key_mapper(p)}/{k}":v
|
||||
for p,d in data.items()
|
||||
if isinstance(d, dict)
|
||||
for k,v in d.items()
|
||||
}
|
||||
|
||||
def parse_jsonl(data: str) -> Iterable[dict]:
|
||||
yield from map(json.loads, (line for line in data.splitlines() if line.strip()))
|
||||
|
||||
def read_jsonl(path: Path) -> Iterable[dict]:
|
||||
with path.open("r") as f:
|
||||
data = f.read()
|
||||
yield from parse_jsonl(data)
|
||||
|
||||
def get_experiment_paths(filter: str | None, assert_dumped = False) -> Iterable[Path]:
|
||||
for path in TENSORBOARD.iterdir():
|
||||
if filter is not None and not re.search(filter, path.name): continue
|
||||
if not path.is_dir(): continue
|
||||
|
||||
if not (path / "hparams.yaml").is_file():
|
||||
warnings.warn(f"Missing hparams: {path}")
|
||||
continue
|
||||
if not any(path.glob("events.out.tfevents.*")):
|
||||
warnings.warn(f"Missing tfevents: {path}")
|
||||
continue
|
||||
|
||||
if __debug__ and assert_dumped:
|
||||
assert (path / "scalars/epoch.json").is_file(), path
|
||||
assert (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json").is_file(), path
|
||||
assert (path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json").is_file(), path
|
||||
|
||||
yield path
|
||||
|
||||
def dump_pl_tensorboard_hparams(experiment: Path):
|
||||
with (experiment / "hparams.yaml").open() as f:
|
||||
hparams = yaml.load(f, Loader=SafeLoaderIgnoreUnknown)
|
||||
|
||||
shebang = None
|
||||
with (experiment / "config.yaml").open("w") as f:
|
||||
raw_yaml = hparams.get('_pickled_cli_args', {}).get('_raw_yaml', "").replace("\n\r", "\n")
|
||||
if raw_yaml.startswith("#!"): # preserve shebang
|
||||
shebang, _, raw_yaml = raw_yaml.partition("\n")
|
||||
f.write(f"{shebang}\n")
|
||||
f.write(f"# {' '.join(map(shlex.quote, hparams.get('_pickled_cli_args', {}).get('sys_argv', ['None'])))}\n\n")
|
||||
f.write(raw_yaml)
|
||||
if shebang is not None:
|
||||
os.chmod(experiment / "config.yaml", (experiment / "config.yaml").stat().st_mode | stat.S_IXUSR)
|
||||
print(experiment / "config.yaml", "written!", file=sys.stderr)
|
||||
|
||||
with (experiment / "environ.yaml").open("w") as f:
|
||||
yaml.safe_dump(hparams.get('_pickled_cli_args', {}).get('host', {}).get('environ'), f)
|
||||
print(experiment / "environ.yaml", "written!", file=sys.stderr)
|
||||
|
||||
with (experiment / "repo.patch").open("w") as f:
|
||||
f.write(hparams.get('_pickled_cli_args', {}).get('host', {}).get('vcs', "None"))
|
||||
print(experiment / "repo.patch", "written!", file=sys.stderr)
|
||||
|
||||
def dump_simple_tf_events_to_jsonl(output_dir: Path, *tf_files: Path):
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import tensorboard.backend.event_processing.event_accumulator
|
||||
s, l = {}, [] # reused sentinels
|
||||
|
||||
#resource.setrlimit(resource.RLIMIT_NOFILE, (2**16,-1))
|
||||
file_handles = {}
|
||||
try:
|
||||
for tffile in tf_files:
|
||||
loader = tensorboard.backend.event_processing.event_file_loader.LegacyEventFileLoader(str(tffile))
|
||||
for event in loader.Load():
|
||||
for summary in MessageToDict(event).get("summary", s).get("value", l):
|
||||
if "simpleValue" in summary:
|
||||
tag = summary["tag"]
|
||||
if tag not in file_handles:
|
||||
fname = output_dir / f"{tag}.json"
|
||||
print(f"Opening {str(fname)!r}...", file=sys.stderr)
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_handles[tag] = fname.open("w") # ("a")
|
||||
val = summary["simpleValue"]
|
||||
data = json.dumps({
|
||||
"step" : event.step,
|
||||
"value" : float(val) if isinstance(val, str) else val,
|
||||
"wall_time" : event.wall_time,
|
||||
})
|
||||
file_handles[tag].write(f"{data}\n")
|
||||
finally:
|
||||
if file_handles:
|
||||
print("Closing json files...", file=sys.stderr)
|
||||
for k, v in file_handles.items():
|
||||
v.close()
|
||||
|
||||
|
||||
NO_FILTER = {
|
||||
"__uid",
|
||||
"_minutes",
|
||||
"_epochs",
|
||||
"_hp_nonlinearity",
|
||||
"_val_uloss_intersection",
|
||||
"_val_uloss_normal_cossim",
|
||||
"_val_uloss_intersection",
|
||||
}
|
||||
def filter_jsonl_columns(data: Iterable[dict | None], no_filter=NO_FILTER) -> list[dict]:
|
||||
def merge_siren_omega(data: dict) -> dict:
|
||||
return {
|
||||
key: (
|
||||
f"{val}-{data.get('hp_omega_0', 'ERROR')}"
|
||||
if (key.removeprefix("_"), val) == ("hp_nonlinearity", "sine") else
|
||||
val
|
||||
)
|
||||
for key, val in data.items()
|
||||
if key != "hp_omega_0"
|
||||
}
|
||||
|
||||
def remove_uninteresting_cols(rows: list[dict]) -> Iterable[dict]:
|
||||
unique_vals = {}
|
||||
def register_val(key, val):
|
||||
unique_vals.setdefault(key, set()).add(repr(val))
|
||||
return val
|
||||
|
||||
whitelisted = {
|
||||
key
|
||||
for row in rows
|
||||
for key, val in row.items()
|
||||
if register_val(key, val) and val not in ("None", "0", "0.0")
|
||||
}
|
||||
for key in unique_vals:
|
||||
for row in rows:
|
||||
if key not in row:
|
||||
unique_vals[key].add(MISSING)
|
||||
for key, vals in unique_vals.items():
|
||||
if key not in whitelisted: continue
|
||||
if len(vals) == 1:
|
||||
whitelisted.remove(key)
|
||||
|
||||
whitelisted.update(no_filter)
|
||||
|
||||
yield from (
|
||||
{
|
||||
key: val
|
||||
for key, val in row.items()
|
||||
if key in whitelisted
|
||||
}
|
||||
for row in rows
|
||||
)
|
||||
|
||||
def pessemize_types(rows: list[dict]) -> Iterable[dict]:
|
||||
types = {}
|
||||
order = (str, float, int, bool, tuple, type(None))
|
||||
for row in rows:
|
||||
for key, val in row.items():
|
||||
if isinstance(val, list): val = tuple(val)
|
||||
assert type(val) in order, (type(val), val)
|
||||
index = order.index(type(val))
|
||||
types[key] = min(types.get(key, 999), index)
|
||||
|
||||
yield from (
|
||||
{
|
||||
key: order[types[key]](val) if val is not None else None
|
||||
for key, val in row.items()
|
||||
}
|
||||
for row in rows
|
||||
)
|
||||
|
||||
data = (row for row in data if row is not None)
|
||||
data = map(partial(flatten_dict, key_mapper=camel_to_snake_case), data)
|
||||
data = map(merge_siren_omega, data)
|
||||
data = remove_uninteresting_cols(list(data))
|
||||
data = pessemize_types(list(data))
|
||||
|
||||
return data
|
||||
|
||||
PlotMode = Literal["stackplot", "lineplot"]
|
||||
|
||||
def plot_losses(experiments: list[Path], mode: PlotMode, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force=True):
|
||||
def get_losses(experiment: Path, training: bool = True, unscaled: bool = False) -> Iterable[Path]:
|
||||
if not training and unscaled:
|
||||
return experiment.glob("scalars/*.validation_step/unscaled_loss_*.json")
|
||||
elif not training and not unscaled:
|
||||
return experiment.glob("scalars/*.validation_step/loss_*.json")
|
||||
elif training and unscaled:
|
||||
return experiment.glob("scalars/*.training_step/unscaled_loss_*.json")
|
||||
elif training and not unscaled:
|
||||
return experiment.glob("scalars/*.training_step/loss_*.json")
|
||||
|
||||
print("Mapping colors...")
|
||||
configurations = [
|
||||
dict(unscaled=unscaled, training=training),
|
||||
] if not write else [
|
||||
dict(unscaled=False, training=False),
|
||||
dict(unscaled=False, training=True),
|
||||
dict(unscaled=True, training=False),
|
||||
dict(unscaled=True, training=True),
|
||||
]
|
||||
legends = set(
|
||||
f"""{
|
||||
loss.parent.name.split(".", 1)[0]
|
||||
}.{
|
||||
loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
|
||||
}"""
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
for loss in get_losses(experiment, **kw)
|
||||
)
|
||||
colormap = dict(zip(
|
||||
sorted(legends),
|
||||
itertools.cycle(mcolors.TABLEAU_COLORS),
|
||||
))
|
||||
|
||||
def mkplot(experiment: Path, training: bool = True, unscaled: bool = False) -> tuple[bool, str]:
|
||||
label = f"{'unscaled' if unscaled else 'scaled'} {'training' if training else 'validation'}"
|
||||
if write:
|
||||
old_savefig_fname = experiment / f"{label.replace(' ', '-')}-{mode}.png"
|
||||
savefig_fname = experiment / "plots" / f"{label.replace(' ', '-')}-{mode}.png"
|
||||
savefig_fname.parent.mkdir(exist_ok=True, parents=True)
|
||||
if old_savefig_fname.is_file():
|
||||
old_savefig_fname.rename(savefig_fname)
|
||||
if savefig_fname.is_file() and not force:
|
||||
return True, "savefig_fname already exists"
|
||||
|
||||
# Get and sort data
|
||||
losses = {}
|
||||
for loss in get_losses(experiment, training=training, unscaled=unscaled):
|
||||
model = loss.parent.name.split(".", 1)[0]
|
||||
name = loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
|
||||
losses[f"{model}.{name}"] = (loss, list(read_jsonl(loss)))
|
||||
losses = dict(sorted(losses.items())) # sort keys
|
||||
if not losses:
|
||||
return True, "no losses"
|
||||
|
||||
# unwrap
|
||||
steps = [i["step"] for i in first(losses.values())[1]]
|
||||
values = [
|
||||
[i["value"] if not isnan(i["value"]) else 0 for i in data]
|
||||
for name, (scalar, data) in losses.items()
|
||||
]
|
||||
|
||||
# normalize
|
||||
if mode == "stackplot":
|
||||
totals = list(map(sum, zip(*values)))
|
||||
values = [
|
||||
[i / t for i, t in zip(data, totals)]
|
||||
for data in values
|
||||
]
|
||||
|
||||
print(experiment.name, label)
|
||||
fig, ax = plt.subplots(figsize=(16, 12))
|
||||
|
||||
if mode == "stackplot":
|
||||
ax.stackplot(steps, values,
|
||||
colors = list(map(colormap.__getitem__, losses.keys())),
|
||||
labels = list(
|
||||
label.split(".", 1)[1].removeprefix("loss_")
|
||||
for label in losses.keys()
|
||||
),
|
||||
)
|
||||
ax.set_xlim(0, steps[-1])
|
||||
ax.set_ylim(0, 1)
|
||||
ax.invert_yaxis()
|
||||
|
||||
elif mode == "lineplot":
|
||||
for data, color, label in zip(
|
||||
values,
|
||||
map(colormap.__getitem__, losses.keys()),
|
||||
list(losses.keys()),
|
||||
):
|
||||
ax.plot(steps, data,
|
||||
color = color,
|
||||
label = label,
|
||||
)
|
||||
ax.set_xlim(0, steps[-1])
|
||||
|
||||
else:
|
||||
raise ValueError(f"{mode=}")
|
||||
|
||||
ax.legend()
|
||||
ax.set_title(f"{label} loss\n{experiment.name}")
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("loss%")
|
||||
|
||||
if mode == "stackplot":
|
||||
ax2 = make_axes_locatable(ax).append_axes("bottom", 0.8, pad=0.05, sharex=ax)
|
||||
ax2.stackplot( steps, totals )
|
||||
|
||||
for tl in ax.get_xticklabels(): tl.set_visible(False)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
if write:
|
||||
fig.savefig(savefig_fname, dpi=300)
|
||||
print(savefig_fname)
|
||||
plt.close(fig)
|
||||
|
||||
return False, None
|
||||
|
||||
print("Plotting...")
|
||||
if write:
|
||||
matplotlib.use('agg') # fixes "WARNING: QApplication was not created in the main() thread."
|
||||
any_error = False
|
||||
if write:
|
||||
with ThreadPoolExecutor(max_workers=None) as pool:
|
||||
futures = [
|
||||
(experiment, pool.submit(mkplot, experiment, **kw))
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
]
|
||||
else:
|
||||
def mkfuture(item):
|
||||
f = Future()
|
||||
f.set_result(item)
|
||||
return f
|
||||
futures = [
|
||||
(experiment, mkfuture(mkplot(experiment, **kw)))
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
]
|
||||
|
||||
for experiment, future in futures:
|
||||
try:
|
||||
err, msg = future.result()
|
||||
except Exception:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
any_error = True
|
||||
continue
|
||||
if err:
|
||||
print(f"{msg}: {experiment.name}")
|
||||
any_error = True
|
||||
continue
|
||||
|
||||
if not any_error and not write: # show in main thread
|
||||
plt.show()
|
||||
elif not write:
|
||||
print("There were errors, will not show figure...", file=sys.stderr)
|
||||
|
||||
|
||||
|
||||
# =========
|
||||
|
||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||
|
||||
@app.command(help="Dump simple tensorboard events to json and extract some pytorch lightning hparams")
|
||||
def tf_dump(tfevent_files: list[Path], j: int = typer.Option(1, "-j"), force: bool = False):
|
||||
# expand to all tfevents files (there may be more than one)
|
||||
tfevent_files = sorted(set([
|
||||
tffile
|
||||
for tffile in tfevent_files
|
||||
if tffile.name.startswith("events.out.tfevents.")
|
||||
] + [
|
||||
tffile
|
||||
for experiment_dir in tfevent_files
|
||||
if experiment_dir.is_dir()
|
||||
for tffile in experiment_dir.glob("events.out.tfevents.*")
|
||||
] + [
|
||||
tffile
|
||||
for hparam_file in tfevent_files
|
||||
if hparam_file.name in ("hparams.yaml", "config.yaml")
|
||||
for tffile in hparam_file.parent.glob("events.out.tfevents.*")
|
||||
]))
|
||||
|
||||
# filter already dumped
|
||||
if not force:
|
||||
tfevent_files = [
|
||||
tffile
|
||||
for tffile in tfevent_files
|
||||
if not (
|
||||
(tffile.parent / "scalars/epoch.json").is_file()
|
||||
and
|
||||
tffile.stat().st_mtime < (tffile.parent / "scalars/epoch.json").stat().st_mtime
|
||||
)
|
||||
]
|
||||
|
||||
if not tfevent_files:
|
||||
raise typer.BadParameter("Nothing to be done, consider --force")
|
||||
|
||||
jobs = {}
|
||||
for tffile in tfevent_files:
|
||||
if not tffile.is_file():
|
||||
print("ERROR: file not found:", tffile, file=sys.stderr)
|
||||
continue
|
||||
output_dir = tffile.parent / "scalars"
|
||||
jobs.setdefault(output_dir, []).append(tffile)
|
||||
with ProcessPoolExecutor() as p:
|
||||
for experiment in set(tffile.parent for tffile in tfevent_files):
|
||||
p.submit(dump_pl_tensorboard_hparams, experiment)
|
||||
for output_dir, tffiles in jobs.items():
|
||||
p.submit(dump_simple_tf_events_to_jsonl, output_dir, *tffiles)
|
||||
|
||||
@app.command(help="Propose experiment regexes")
|
||||
def propose(cmd: str = typer.Argument("summary"), null: bool = False):
|
||||
def get():
|
||||
for i in TENSORBOARD.iterdir():
|
||||
if not i.is_dir(): continue
|
||||
if not (i / "hparams.yaml").is_file(): continue
|
||||
prefix, name, *hparams, year, month, day, hhmm, uid = i.name.split("-")
|
||||
yield f"{name}.*-{year}-{month}-{day}"
|
||||
proposals = sorted(set(get()), key=lambda x: x.split(".*-", 1)[1])
|
||||
print("\n".join(
|
||||
f"{'>/dev/null ' if null else ''}{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
|
||||
for i in proposals
|
||||
))
|
||||
|
||||
@app.command("list", help="List used experiment regexes")
|
||||
def list_cached_summaries(cmd: str = typer.Argument("summary")):
|
||||
if not CACHED_SUMMARIES.is_dir():
|
||||
cached = []
|
||||
else:
|
||||
cached = [
|
||||
i.name.removesuffix(".jsonl")
|
||||
for i in CACHED_SUMMARIES.iterdir()
|
||||
if i.suffix == ".jsonl"
|
||||
if i.is_file() and i.stat().st_size
|
||||
]
|
||||
def order(key: str) -> list[str]:
|
||||
return re.sub(r'[^0-9\-]', '', key.split(".*")[-1]).strip("-").split("-") + [key]
|
||||
|
||||
print("\n".join(
|
||||
f"{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
|
||||
for i in sorted(cached, key=order)
|
||||
))
|
||||
|
||||
@app.command(help="Precompute the summary of a experiment regex")
|
||||
def compute_summary(filter: str, force: bool = False, dump: bool = False, no_cache: bool = False):
|
||||
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
|
||||
if cache.is_file() and cache.stat().st_size:
|
||||
if not force:
|
||||
raise FileExistsError(cache)
|
||||
|
||||
def mk_summary(path: Path) -> dict | None:
|
||||
cache = path / "train_summary.json"
|
||||
if cache.is_file() and cache.stat().st_size and cache.stat().st_mtime > (path/"scalars/epoch.json").stat().st_mtime:
|
||||
with cache.open() as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
with (path / "hparams.yaml").open() as f:
|
||||
hparams = munchify(yaml.load(f, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
|
||||
config = hparams._pickled_cli_args._raw_yaml
|
||||
config = munchify(yaml.load(config, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
|
||||
|
||||
try:
|
||||
train_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json"))
|
||||
val_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json"))
|
||||
except:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
return None
|
||||
|
||||
out = Munch()
|
||||
out.uid = path.name.rsplit("-", 1)[-1]
|
||||
out.name = path.name
|
||||
out.date = "-".join(path.name.split("-")[-5:-1])
|
||||
out.epochs = int(last(read_jsonl(path / "scalars/epoch.json"))["value"])
|
||||
out.steps = val_loss[-1]["step"]
|
||||
out.gpu = hparams._pickled_cli_args.host.gpus[1][1]
|
||||
|
||||
if val_loss[-1]["wall_time"] - val_loss[0]["wall_time"] > 0:
|
||||
out.batches_per_second = val_loss[-1]["step"] / (val_loss[-1]["wall_time"] - val_loss[0]["wall_time"])
|
||||
else:
|
||||
out.batches_per_second = 0
|
||||
|
||||
out.minutes = (val_loss[-1]["wall_time"] - train_loss[0]["wall_time"]) / 60
|
||||
|
||||
if (path / "scalars/PsutilMonitor/gpu.00.memory.used.json").is_file():
|
||||
max(i["value"] for i in read_jsonl(path / "scalars/PsutilMonitor/gpu.00.memory.used.json"))
|
||||
|
||||
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step").glob("*.json"):
|
||||
if not metric_path.is_file() or not metric_path.stat().st_size: continue
|
||||
|
||||
metric_name = metric_path.name.removesuffix(".json")
|
||||
metric_data = read_jsonl(metric_path)
|
||||
try:
|
||||
out[f"val_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
|
||||
except StatisticsError:
|
||||
out[f"val_{metric_name}"] = float('nan')
|
||||
|
||||
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.training_step").glob("*.json"):
|
||||
if not any(i in metric_path.name for i in ("miss_radius_grad", "sphere_center_grad", "loss_tangential_reg", "multi_view")): continue
|
||||
if not metric_path.is_file() or not metric_path.stat().st_size: continue
|
||||
|
||||
metric_name = metric_path.name.removesuffix(".json")
|
||||
metric_data = read_jsonl(metric_path)
|
||||
try:
|
||||
out[f"train_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
|
||||
except StatisticsError:
|
||||
out[f"train_{metric_name}"] = float('nan')
|
||||
|
||||
out.hostname = hparams._pickled_cli_args.host.hostname
|
||||
|
||||
for key, val in config.IntersectionFieldAutoDecoderModel.items():
|
||||
if isinstance(val, dict):
|
||||
out.update({f"hp_{key}_{k}": v for k, v in val.items()})
|
||||
elif isinstance(val, float | int | str | bool | None):
|
||||
out[f"hp_{key}"] = val
|
||||
|
||||
with cache.open("w") as f:
|
||||
json.dump(unmunchify(out), f)
|
||||
|
||||
return dict(out)
|
||||
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No matching experiment")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments) # force=force_dump)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
# does literally nothing, thanks GIL
|
||||
with ThreadPoolExecutor() as p:
|
||||
results = list(p.map(mk_summary, experiments))
|
||||
|
||||
if any(result is None for result in results):
|
||||
if all(result is None for result in results):
|
||||
print("No summary succeeded", file=sys.stderr)
|
||||
raise typer.Exit(exit_code=1)
|
||||
warnings.warn("Some summaries failed:\n" + "\n".join(
|
||||
str(experiment)
|
||||
for result, experiment in zip(results, experiments)
|
||||
if result is None
|
||||
))
|
||||
|
||||
summaries = "\n".join( map(json.dumps, results) )
|
||||
if not no_cache:
|
||||
cache.parent.mkdir(parents=True, exist_ok=True)
|
||||
with cache.open("w") as f:
|
||||
f.write(summaries)
|
||||
return summaries
|
||||
|
||||
@app.command(help="Show the summary of a experiment regex, precompute it if needed")
|
||||
def summary(filter: Optional[str] = typer.Argument(None), force: bool = False, dump: bool = False, all: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("summary")
|
||||
|
||||
def key_mangler(key: str) -> str:
|
||||
for pattern, sub in (
|
||||
(r'^val_unscaled_loss_', r'val_uloss_'),
|
||||
(r'^train_unscaled_loss_', r'train_uloss_'),
|
||||
(r'^val_loss_', r'val_sloss_'),
|
||||
(r'^train_loss_', r'train_sloss_'),
|
||||
):
|
||||
key = re.sub(pattern, sub, key)
|
||||
|
||||
return key
|
||||
|
||||
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
|
||||
if force or not (cache.is_file() and cache.stat().st_size):
|
||||
compute_summary(filter, force=force, dump=dump)
|
||||
assert cache.is_file() and cache.stat().st_size, (cache, cache.stat())
|
||||
|
||||
if os.isatty(0) and os.isatty(1) and shutil.which("vd"):
|
||||
rows = read_jsonl(cache)
|
||||
rows = ({key_mangler(k): v for k, v in row.items()} if row is not None else None for row in rows)
|
||||
if not all:
|
||||
rows = filter_jsonl_columns(rows)
|
||||
rows = ({k: v for k, v in row.items() if not k.startswith(("val_sloss_", "train_sloss_"))} for row in rows)
|
||||
data = "\n".join(map(json.dumps, rows))
|
||||
subprocess.run(["vd",
|
||||
#"--play", EXPERIMENTS / "set-key-columns.vd",
|
||||
"-f", "jsonl"
|
||||
], input=data, text=True, check=True)
|
||||
else:
|
||||
with cache.open() as f:
|
||||
print(f.read())
|
||||
|
||||
@app.command(help="Filter uninteresting keys from jsonl stdin")
|
||||
def filter_cols():
|
||||
rows = map(json.loads, (line for line in sys.stdin.readlines() if line.strip()))
|
||||
rows = filter_jsonl_columns(rows)
|
||||
print(*map(json.dumps, rows), sep="\n")
|
||||
|
||||
@app.command(help="Run a command for each experiment matched by experiment regex")
|
||||
def exec(filter: str, cmd: list[str], j: int = typer.Option(1, "-j"), dumped: bool = False, undumped: bool = False):
|
||||
# inspired by fd / gnu parallel
|
||||
def populate_cmd(experiment: Path, cmd: Iterable[str]) -> Iterable[str]:
|
||||
any = False
|
||||
for i in cmd:
|
||||
if i == "{}":
|
||||
any = True
|
||||
yield str(experiment / "hparams.yaml")
|
||||
elif i == "{//}":
|
||||
any = True
|
||||
yield str(experiment)
|
||||
else:
|
||||
yield i
|
||||
if not any:
|
||||
yield str(experiment / "hparams.yaml")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=j or None) as p:
|
||||
results = p.map(subprocess.run, (
|
||||
list(populate_cmd(experiment, cmd))
|
||||
for experiment in get_experiment_paths(filter)
|
||||
if not dumped or (experiment / "scalars/epoch.json").is_file()
|
||||
if not undumped or not (experiment / "scalars/epoch.json").is_file()
|
||||
))
|
||||
|
||||
if any(i.returncode for i in results):
|
||||
return typer.Exit(1)
|
||||
|
||||
@app.command(help="Show stackplot of experiment loss")
|
||||
def stackplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
plot_losses(experiments,
|
||||
mode = "stackplot",
|
||||
write = write,
|
||||
dump = dump,
|
||||
training = training,
|
||||
unscaled = unscaled,
|
||||
force = force,
|
||||
)
|
||||
|
||||
@app.command(help="Show stackplot of experiment loss")
|
||||
def lineplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
plot_losses(experiments,
|
||||
mode = "lineplot",
|
||||
write = write,
|
||||
dump = dump,
|
||||
training = training,
|
||||
unscaled = unscaled,
|
||||
force = force,
|
||||
)
|
||||
|
||||
@app.command(help="Open tensorboard for the experiments matching the regex")
|
||||
def tensorboard(filter: Optional[str] = typer.Argument(None), watch: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("tensorboard")
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=False))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
|
||||
with tempfile.TemporaryDirectory(suffix=f"ifield-{filter}") as d:
|
||||
treefarm = Path(d)
|
||||
with ThreadPoolExecutor(max_workers=2) as p:
|
||||
for experiment in experiments:
|
||||
(treefarm / experiment.name).symlink_to(experiment)
|
||||
|
||||
cmd = ["tensorboard", "--logdir", d]
|
||||
print("+", *map(shlex.quote, cmd), file=sys.stderr)
|
||||
tensorboard = p.submit(subprocess.run, cmd, check=True)
|
||||
if not watch:
|
||||
tensorboard.result()
|
||||
|
||||
else:
|
||||
all_experiments = set(get_experiment_paths(None, assert_dumped=False))
|
||||
while not tensorboard.done():
|
||||
time.sleep(10)
|
||||
new_experiments = set(get_experiment_paths(None, assert_dumped=False)) - all_experiments
|
||||
if new_experiments:
|
||||
for experiment in new_experiments:
|
||||
print(f"Adding {experiment.name!r}...", file=sys.stderr)
|
||||
(treefarm / experiment.name).symlink_to(experiment)
|
||||
all_experiments.update(new_experiments)
|
||||
|
||||
@app.command(help="Compute evaluation metrics")
|
||||
def metrics(filter: Optional[str] = typer.Argument(None), dump: bool = False, dry: bool = False, prefix: Optional[str] = typer.Option(None), derive: bool = False, each: bool = False, no_total: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("metrics --derive")
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=False))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
def run(*cmd):
|
||||
if prefix is not None:
|
||||
cmd = [*shlex.split(prefix), *cmd]
|
||||
if dry:
|
||||
print(*map(shlex.quote, map(str, cmd)))
|
||||
else:
|
||||
print("+", *map(shlex.quote, map(str, cmd)))
|
||||
subprocess.run(cmd)
|
||||
|
||||
for experiment in experiments:
|
||||
if no_total: continue
|
||||
if not (experiment / "compute-scores/metrics.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics.json",
|
||||
"--transpose",
|
||||
)
|
||||
if not (experiment / "compute-scores/metrics-last.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-last.json",
|
||||
"--transpose",
|
||||
)
|
||||
if "2prif-" not in experiment.name: continue
|
||||
if not (experiment / "compute-scores/metrics-sans_outliers.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-sans_outliers.json",
|
||||
"--transpose", "--filter-outliers"
|
||||
)
|
||||
if not (experiment / "compute-scores/metrics-last-sans_outliers.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-last-sans_outliers.json",
|
||||
"--transpose", "--filter-outliers"
|
||||
)
|
||||
|
||||
if dry: return
|
||||
if prefix is not None:
|
||||
print("prefix was used, assuming a job scheduler was used, will not print scores.", file=sys.stderr)
|
||||
return
|
||||
|
||||
metrics = [
|
||||
*(experiment / "compute-scores/metrics.json" for experiment in experiments),
|
||||
*(experiment / "compute-scores/metrics-last.json" for experiment in experiments),
|
||||
*(experiment / "compute-scores/metrics-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
|
||||
*(experiment / "compute-scores/metrics-last-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
|
||||
]
|
||||
if not no_total:
|
||||
assert all(metric.exists() for metric in metrics)
|
||||
else:
|
||||
metrics = (metric for metric in metrics if metric.exists())
|
||||
|
||||
out = []
|
||||
for metric in metrics:
|
||||
experiment = metric.parent.parent.name
|
||||
is_last = metric.name in ("metrics-last.json", "metrics-last-sans_outliers.json")
|
||||
with metric.open() as f:
|
||||
data = json.load(f)
|
||||
|
||||
if derive:
|
||||
derived = {}
|
||||
objs = [i for i in data.keys() if i != "_hparams"]
|
||||
for obj in (objs if each else []) + [None]:
|
||||
if obj is None:
|
||||
d = DefaultMunch(0)
|
||||
for obj in objs:
|
||||
for k, v in data[obj].items():
|
||||
d[k] += v
|
||||
obj = "_all_"
|
||||
n_cd = data["_hparams"]["n_cd"] * len(objs)
|
||||
n_emd = data["_hparams"]["n_emd"] * len(objs)
|
||||
else:
|
||||
d = munchify(data[obj])
|
||||
n_cd = data["_hparams"]["n_cd"]
|
||||
n_emd = data["_hparams"]["n_emd"]
|
||||
|
||||
precision = d.TP / (d.TP + d.FP)
|
||||
recall = d.TP / (d.TP + d.FN)
|
||||
derived[obj] = dict(
|
||||
filtered = d.n_outliers / d.n if "n_outliers" in d else None,
|
||||
iou = d.TP / (d.TP + d.FN + d.FP),
|
||||
precision = precision,
|
||||
recall = recall,
|
||||
f_score = 2 * (precision * recall) / (precision + recall),
|
||||
cd = d.cd_dist / n_cd,
|
||||
emd = d.emd / n_emd,
|
||||
cos_med = 1 - (d.cd_cos_med / n_cd) if "cd_cos_med" in d else None,
|
||||
cos_jac = 1 - (d.cd_cos_jac / n_cd),
|
||||
)
|
||||
data = derived if each else derived["_all_"]
|
||||
|
||||
data["uid"] = experiment.rsplit("-", 1)[-1]
|
||||
data["experiment_name"] = experiment
|
||||
data["is_last"] = is_last
|
||||
|
||||
out.append(json.dumps(data))
|
||||
|
||||
if derive and not each and os.isatty(0) and os.isatty(1) and shutil.which("vd"):
|
||||
subprocess.run(["vd", "-f", "jsonl"], input="\n".join(out), text=True, check=True)
|
||||
else:
|
||||
print("\n".join(out))
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
822
figures/nn-architecture.svg
Normal file
822
figures/nn-architecture.svg
Normal file
File diff suppressed because one or more lines are too long
After (image error) Size: 198 KiB |
57
ifield/__init__.py
Normal file
57
ifield/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
def setup_print_hooks():
|
||||
import os
|
||||
if not os.environ.get("IFIELD_PRETTY_TRACEBACK", None):
|
||||
return
|
||||
|
||||
from rich.traceback import install
|
||||
from rich.console import Console
|
||||
import warnings, sys
|
||||
|
||||
if not os.isatty(2):
|
||||
# https://github.com/Textualize/rich/issues/1809
|
||||
os.environ.setdefault("COLUMNS", "120")
|
||||
|
||||
install(
|
||||
show_locals = bool(os.environ.get("SHOW_LOCALS", "")),
|
||||
width = None,
|
||||
)
|
||||
|
||||
# custom warnings
|
||||
# https://github.com/Textualize/rich/issues/433
|
||||
|
||||
from rich.traceback import install
|
||||
from rich.console import Console
|
||||
import warnings, sys
|
||||
|
||||
|
||||
def showwarning(message, category, filename, lineno, file=None, line=None):
|
||||
msg = warnings.WarningMessage(message, category, filename, lineno, file, line)
|
||||
|
||||
if file is None:
|
||||
file = sys.stderr
|
||||
if file is None:
|
||||
# sys.stderr is None when run with pythonw.exe:
|
||||
# warnings get lost
|
||||
return
|
||||
text = warnings._formatwarnmsg(msg)
|
||||
if file.isatty():
|
||||
Console(file=file, stderr=True).print(text)
|
||||
else:
|
||||
try:
|
||||
file.write(text)
|
||||
except OSError:
|
||||
# the file (probably stderr) is invalid - this warning gets lost.
|
||||
pass
|
||||
warnings.showwarning = showwarning
|
||||
|
||||
def warning_no_src_line(message, category, filename, lineno, file=None, line=None):
|
||||
if (file or sys.stderr) is not None:
|
||||
if (file or sys.stderr).isatty():
|
||||
if file is None or file is sys.stderr:
|
||||
return f"[yellow]{category.__name__}[/yellow]: {message}\n ({filename}:{lineno})"
|
||||
return f"{category.__name__}: {message} ({filename}:{lineno})\n"
|
||||
warnings.formatwarning = warning_no_src_line
|
||||
|
||||
|
||||
setup_print_hooks()
|
||||
del setup_print_hooks
|
1006
ifield/cli.py
Normal file
1006
ifield/cli.py
Normal file
File diff suppressed because it is too large
Load Diff
174
ifield/cli_utils.py
Normal file
174
ifield/cli_utils.py
Normal file
@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
from .data.common.scan import SingleViewScan, SingleViewUVScan
|
||||
from datetime import datetime
|
||||
import re
|
||||
import click
|
||||
import gzip
|
||||
import h5py as h5
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyrender
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are a bunch of helper scripts exposed as cli command by poetry
|
||||
"""
|
||||
|
||||
|
||||
# these entrypoints are exposed by poetry as shell commands
|
||||
|
||||
@click.command()
|
||||
@click.argument("h5file")
|
||||
@click.argument("key", default="")
|
||||
def show_h5_items(h5file: str, key: str):
|
||||
"Show contents of HDF5 dataset"
|
||||
f = h5.File(h5file, "r")
|
||||
if not key:
|
||||
mlen = max(map(len, f.keys()))
|
||||
for i in sorted(f.keys()):
|
||||
print(i.ljust(mlen), ":",
|
||||
str (f[i].dtype).ljust(10),
|
||||
repr(f[i].shape).ljust(16),
|
||||
"mean:", f[i][:].mean()
|
||||
)
|
||||
else:
|
||||
if not f[key].shape:
|
||||
print(f[key].value)
|
||||
else:
|
||||
print(f[key][:])
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("h5file")
|
||||
@click.argument("key", default="")
|
||||
def show_h5_img(h5file: str, key: str):
|
||||
"Show a 2D HDF5 dataset as an image"
|
||||
f = h5.File(h5file, "r")
|
||||
if not key:
|
||||
mlen = max(map(len, f.keys()))
|
||||
for i in sorted(f.keys()):
|
||||
print(i.ljust(mlen), ":", str(f[i].dtype).ljust(10), f[i].shape)
|
||||
else:
|
||||
plt.imshow(f[key])
|
||||
plt.show()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("h5file")
|
||||
@click.option("--force-distances", is_flag=True, help="Always show miss distances.")
|
||||
@click.option("--uv", is_flag=True, help="Load as UV scan cloud and convert it.")
|
||||
@click.option("--show-unit-sphere", is_flag=True, help="Show the unit sphere.")
|
||||
@click.option("--missing", is_flag=True, help="Show miss points that are not hits nor misses as purple.")
|
||||
def show_h5_scan_cloud(
|
||||
h5file : str,
|
||||
force_distances : bool = False,
|
||||
uv : bool = False,
|
||||
missing : bool = False,
|
||||
show_unit_sphere = False,
|
||||
):
|
||||
"Show a SingleViewScan HDF5 dataset"
|
||||
print("Reading data...")
|
||||
t = datetime.now()
|
||||
if uv:
|
||||
scan = SingleViewUVScan.from_h5_file(h5file)
|
||||
if missing and scan.any_missing:
|
||||
if not scan.has_missing:
|
||||
scan.fill_missing_points()
|
||||
points_missing = scan.points[scan.missing]
|
||||
else:
|
||||
missing = False
|
||||
if not scan.is_single_view:
|
||||
scan.cam_pos = None
|
||||
scan = scan.to_scan()
|
||||
else:
|
||||
scan = SingleViewScan.from_h5_file(h5file)
|
||||
if missing:
|
||||
uvscan = scan.to_uv_scan()
|
||||
if scan.any_missing:
|
||||
uvscan.fill_missing_points()
|
||||
points_missing = uvscan.points[uvscan.missing]
|
||||
else:
|
||||
missing = False
|
||||
print("loadtime: ", datetime.now() - t)
|
||||
|
||||
if force_distances and not scan.has_miss_distances:
|
||||
print("Computing miss distances...")
|
||||
scan.compute_miss_distances()
|
||||
use_miss_distances = force_distances
|
||||
print("Constructing scene...")
|
||||
if not scan.has_colors:
|
||||
scan.colors_hit = np.zeros_like(scan.points_hit)
|
||||
scan.colors_miss = np.zeros_like(scan.points_miss)
|
||||
scan.colors_hit [:] = ( 31/255, 119/255, 180/255)
|
||||
scan.colors_miss[:] = (243/255, 156/255, 18/255)
|
||||
use_miss_distances = True
|
||||
if scan.has_miss_distances and use_miss_distances:
|
||||
sdm = scan.distances_miss / scan.distances_miss.max()
|
||||
sdm = sdm[..., None]
|
||||
scan.colors_miss \
|
||||
= np.array([0.8, 0, 0])[None, :] * sdm \
|
||||
+ np.array([0, 1, 0.2])[None, :] * (1-sdm)
|
||||
|
||||
|
||||
scene = pyrender.Scene()
|
||||
|
||||
scene.add(pyrender.Mesh.from_points(scan.points_hit, colors=scan.colors_hit, normals=scan.normals_hit))
|
||||
scene.add(pyrender.Mesh.from_points(scan.points_miss, colors=scan.colors_miss))
|
||||
|
||||
if missing:
|
||||
scene.add(pyrender.Mesh.from_points(points_missing, colors=(np.array((0xff, 0x00, 0xff))/255)[None, :].repeat(points_missing.shape[0], axis=0)))
|
||||
|
||||
# camera:
|
||||
if not scan.points_cam is None:
|
||||
camera_mesh = trimesh.creation.uv_sphere(radius=scan.points_hit_std.max()*0.2)
|
||||
camera_mesh.visual.vertex_colors = [0.0, 0.8, 0.0]
|
||||
tfs = np.tile(np.eye(4), (len(scan.points_cam), 1, 1))
|
||||
tfs[:,:3,3] = scan.points_cam
|
||||
scene.add(pyrender.Mesh.from_trimesh(camera_mesh, poses=tfs))
|
||||
|
||||
# UV sphere:
|
||||
if show_unit_sphere:
|
||||
unit_sphere_mesh = trimesh.creation.uv_sphere(radius=1)
|
||||
unit_sphere_mesh.invert()
|
||||
unit_sphere_mesh.visual.vertex_colors = [0.8, 0.8, 0.0]
|
||||
scene.add(pyrender.Mesh.from_trimesh(unit_sphere_mesh, poses=np.eye(4)[None, ...]))
|
||||
|
||||
print("Launch!")
|
||||
viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("meshfile")
|
||||
@click.option('--aabb', is_flag=True)
|
||||
@click.option('--z-skyward', is_flag=True)
|
||||
def show_model(
|
||||
meshfile : str,
|
||||
aabb : bool,
|
||||
z_skyward : bool,
|
||||
):
|
||||
"Show a 3D model with pyrender, supports .gz suffix"
|
||||
if meshfile.endswith(".gz"):
|
||||
with gzip.open(meshfile, "r") as f:
|
||||
mesh = trimesh.load(f, file_type=meshfile.split(".", 1)[1].removesuffix(".gz"))
|
||||
else:
|
||||
mesh = trimesh.load(meshfile)
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
mesh = mesh.dump(concatenate=True)
|
||||
|
||||
if aabb:
|
||||
from .data.common.mesh import rotate_to_closest_axis_aligned_bounds
|
||||
mesh.apply_transform(rotate_to_closest_axis_aligned_bounds(mesh, fail_ok=True))
|
||||
|
||||
if z_skyward:
|
||||
mesh.apply_transform(T.rotation_matrix(np.pi/2, (1, 0, 0)))
|
||||
|
||||
print(
|
||||
*(i.strip() for i in pyrender.Viewer.__doc__.splitlines() if re.search(r"- ``[a-z0-9]``: ", i)),
|
||||
sep="\n"
|
||||
)
|
||||
|
||||
scene = pyrender.Scene()
|
||||
scene.add(pyrender.Mesh.from_trimesh(mesh))
|
||||
pyrender.Viewer(scene, use_raymond_lighting=True)
|
3
ifield/data/__init__.py
Normal file
3
ifield/data/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Submodules to read and process datasets
|
||||
"""
|
0
ifield/data/common/__init__.py
Normal file
0
ifield/data/common/__init__.py
Normal file
90
ifield/data/common/download.py
Normal file
90
ifield/data/common/download.py
Normal file
@ -0,0 +1,90 @@
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Optional
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
PathLike = Union[os.PathLike, str]
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def check_url(url): # HTTP HEAD
|
||||
return requests.head(url).ok
|
||||
|
||||
def download_stream(
|
||||
url : str,
|
||||
file_object,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
):
|
||||
resp = requests.get(url, stream=True)
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
if not silent:
|
||||
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
|
||||
|
||||
for chunk in resp.iter_content(block_size):
|
||||
if not silent:
|
||||
progress_bar.update(len(chunk))
|
||||
file_object.write(chunk)
|
||||
|
||||
if not silent:
|
||||
progress_bar.close()
|
||||
if total_size != 0 and progress_bar.n != total_size:
|
||||
print("ERROR, something went wrong")
|
||||
|
||||
def download_data(
|
||||
url : str,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
) -> bytearray:
|
||||
f = io.BytesIO()
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=label)
|
||||
f.seek(0)
|
||||
return bytearray(f.read())
|
||||
|
||||
def download_file(
|
||||
url : str,
|
||||
fname : Union[Path, str],
|
||||
block_size : int = 1024,
|
||||
silent = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
with fname.open("wb") as f:
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
|
||||
|
||||
def is_downloaded(
|
||||
target_dir : PathLike,
|
||||
url : str,
|
||||
*,
|
||||
add : bool = False,
|
||||
dbfiles : Union[list[PathLike], PathLike],
|
||||
):
|
||||
if not isinstance(target_dir, os.PathLike):
|
||||
target_dir = Path(target_dir)
|
||||
if not isinstance(dbfiles, list):
|
||||
dbfiles = [dbfiles]
|
||||
if not dbfiles:
|
||||
raise ValueError("'dbfiles' empty")
|
||||
downloaded = set()
|
||||
for dbfile_fname in dbfiles:
|
||||
dbfile_fname = target_dir / dbfile_fname
|
||||
if dbfile_fname.is_file():
|
||||
with open(dbfile_fname, "r") as f:
|
||||
downloaded.update(json.load(f)["downloaded"])
|
||||
|
||||
if add and url not in downloaded:
|
||||
downloaded.add(url)
|
||||
with open(dbfiles[0], "w") as f:
|
||||
data = {"downloaded": sorted(downloaded)}
|
||||
json.dump(data, f, indent=2, sort_keys=True)
|
||||
return True
|
||||
|
||||
return url in downloaded
|
370
ifield/data/common/h5_dataclasses.py
Normal file
370
ifield/data/common/h5_dataclasses.py
Normal file
@ -0,0 +1,370 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import h5py as h5
|
||||
import hdf5plugin
|
||||
import numpy as np
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
|
||||
__all__ = [
|
||||
"DataclassMeta",
|
||||
"Dataclass",
|
||||
"H5Dataclass",
|
||||
"H5Array",
|
||||
"H5ArrayNoSlice",
|
||||
]
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
NoneType = type(None)
|
||||
PathLike = typing.Union[os.PathLike, str]
|
||||
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
|
||||
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
|
||||
|
||||
DataclassField = namedtuple("DataclassField", [
|
||||
"name",
|
||||
"type",
|
||||
"is_optional",
|
||||
"is_array",
|
||||
"is_sliceable",
|
||||
"is_prefix",
|
||||
])
|
||||
|
||||
def strip_optional(val: type) -> type:
|
||||
if typing.get_origin(val) is typing.Union:
|
||||
union = set(typing.get_args(val))
|
||||
if len(union - {NoneType}) == 1:
|
||||
val, = union - {NoneType}
|
||||
else:
|
||||
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
|
||||
return val
|
||||
|
||||
def is_array(val, *, _inner=False):
|
||||
"""
|
||||
Hacky way to check if a value or type is an array.
|
||||
The hack omits having to depend on large frameworks such as pytorch or pandas
|
||||
"""
|
||||
val = strip_optional(val)
|
||||
if val is H5Array or val is H5ArrayNoSlice:
|
||||
return True
|
||||
|
||||
if typing._type_repr(val) in (
|
||||
"numpy.ndarray",
|
||||
"torch.Tensor",
|
||||
):
|
||||
return True
|
||||
if not _inner:
|
||||
return is_array(type(val), _inner=True)
|
||||
return False
|
||||
|
||||
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return functools.reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return functools.reduce(operator.mul, numbers)
|
||||
|
||||
class DataclassMeta(type):
|
||||
def __new__(
|
||||
mcls,
|
||||
name : str,
|
||||
bases : tuple[type, ...],
|
||||
attrs : dict[str, typing.Any],
|
||||
**kwargs,
|
||||
):
|
||||
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
|
||||
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
|
||||
cls = dataclasses.dataclass(slots=True)(cls)
|
||||
else:
|
||||
cls = dataclasses.dataclass(cls)
|
||||
return cls
|
||||
|
||||
class DataclassABCMeta(DataclassMeta, ABCMeta):
|
||||
pass
|
||||
|
||||
class Dataclass(metaclass=DataclassMeta):
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
if key in self.keys():
|
||||
return getattr(self, key)
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: typing.Any):
|
||||
if key in self.keys():
|
||||
return setattr(self, key, value)
|
||||
raise KeyError(key)
|
||||
|
||||
def keys(self) -> typing.KeysView:
|
||||
return self.as_dict().keys()
|
||||
|
||||
def values(self) -> typing.ValuesView:
|
||||
return self.as_dict().values()
|
||||
|
||||
def items(self) -> typing.ItemsView:
|
||||
return self.as_dict().items()
|
||||
|
||||
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
|
||||
out = dataclasses.asdict(self, **kw)
|
||||
for name in (properties_to_include or []):
|
||||
out[name] = getattr(self, name)
|
||||
return out
|
||||
|
||||
def as_tuple(self, properties_to_include: list[str]) -> tuple:
|
||||
out = dataclasses.astuple(self)
|
||||
if not properties_to_include:
|
||||
return out
|
||||
else:
|
||||
return (
|
||||
*out,
|
||||
*(getattr(self, name) for name in properties_to_include),
|
||||
)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
return (copy.deepcopy if deep else copy.copy)(self)
|
||||
|
||||
class H5Dataclass(Dataclass):
|
||||
# settable with class params:
|
||||
_prefix : str = dataclasses.field(init=False, repr=False, default="")
|
||||
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
|
||||
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
|
||||
|
||||
def __init_subclass__(cls,
|
||||
prefix : typing.Optional[str] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
require_all : typing.Optional[bool] = None,
|
||||
**kw,
|
||||
):
|
||||
super().__init_subclass__(**kw)
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
if prefix is not None: cls._prefix = prefix
|
||||
if n_pages is not None: cls._n_pages = n_pages
|
||||
if require_all is not None: cls._require_all = require_all
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls) -> typing.Iterable[DataclassField]:
|
||||
for field in dataclasses.fields(cls):
|
||||
if not field.init:
|
||||
continue
|
||||
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
|
||||
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
|
||||
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
|
||||
)
|
||||
if isinstance(field.type, str):
|
||||
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
|
||||
|
||||
type_inner = strip_optional(field.type)
|
||||
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
|
||||
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
|
||||
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
|
||||
field_type = typing.Optional[field_type]
|
||||
|
||||
yield DataclassField(
|
||||
name = field.name,
|
||||
type = strip_optional(field_type),
|
||||
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
|
||||
is_array = is_array(field_type),
|
||||
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
|
||||
is_prefix = is_prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_h5_file(cls : type[T],
|
||||
fname : typing.Union[PathLike, str],
|
||||
*,
|
||||
page : typing.Optional[int] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
read_slice : slice = slice(None),
|
||||
require_even_pages : bool = True,
|
||||
) -> T:
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if n_pages is None:
|
||||
n_pages = cls._n_pages
|
||||
if not fname.exists():
|
||||
raise FileNotFoundError(str(fname))
|
||||
if not h5.is_hdf5(fname):
|
||||
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
|
||||
|
||||
# if this class has no fields, print a example class:
|
||||
if not any(field.init for field in dataclasses.fields(cls)):
|
||||
with h5.File(fname, "r") as f:
|
||||
klen = max(map(len, f.keys()))
|
||||
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
|
||||
f" {k.ljust(klen)} : "
|
||||
+ (
|
||||
"H5Array" if prod(v.shape, 1) > 1 else (
|
||||
"float" if issubclass(v.dtype.type, np.floating) else (
|
||||
"int" if issubclass(v.dtype.type, np.integer) else (
|
||||
"bool" if issubclass(v.dtype.type, np.bool_) else (
|
||||
"typing.Any"
|
||||
))))).ljust(14 + 1)
|
||||
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
|
||||
for k, v in f.items()
|
||||
)
|
||||
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
|
||||
|
||||
fields_consumed = set()
|
||||
|
||||
def make_kwarg(
|
||||
file : h5.File,
|
||||
keys : typing.KeysView,
|
||||
field : DataclassField,
|
||||
) -> tuple[str, typing.Any]:
|
||||
if field.is_optional:
|
||||
if field.name not in keys:
|
||||
return field.name, None
|
||||
if field.is_sliceable:
|
||||
if page is not None:
|
||||
n_items = int(f[cls._prefix + field.name].shape[0])
|
||||
page_len = n_items // n_pages
|
||||
modulus = n_items % n_pages
|
||||
if modulus: page_len += 1 # round up
|
||||
if require_even_pages and modulus:
|
||||
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
|
||||
this_slice = slice(
|
||||
start = page_len * page,
|
||||
stop = page_len * (page+1),
|
||||
step = read_slice.step, # inherit step
|
||||
)
|
||||
else:
|
||||
this_slice = read_slice
|
||||
else:
|
||||
this_slice = slice(None) # read all
|
||||
|
||||
# array or scalar?
|
||||
def read_dataset(var):
|
||||
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
|
||||
if field.is_array:
|
||||
return var[this_slice]
|
||||
if var.shape == (1,):
|
||||
return var[0]
|
||||
else:
|
||||
return var[()]
|
||||
|
||||
if field.is_prefix:
|
||||
fields_consumed.update(
|
||||
key
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
)
|
||||
return field.name, {
|
||||
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
}
|
||||
else:
|
||||
fields_consumed.add(cls._prefix + field.name)
|
||||
return field.name, read_dataset(file[cls._prefix + field.name])
|
||||
|
||||
with h5.File(fname, "r") as f:
|
||||
keys = f.keys()
|
||||
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
|
||||
|
||||
try:
|
||||
out = cls(**init_dict)
|
||||
except Exception as e:
|
||||
class_attrs = set(field.name for field in dataclasses.fields(cls))
|
||||
file_attr = set(init_dict.keys())
|
||||
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
|
||||
|
||||
if cls._require_all:
|
||||
fields_not_consumed = set(keys) - fields_consumed
|
||||
if fields_not_consumed:
|
||||
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
|
||||
|
||||
return out
|
||||
|
||||
def to_h5_file(self,
|
||||
fname : PathLike,
|
||||
mkdir : bool = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if not fname.parent.is_dir():
|
||||
if mkdir:
|
||||
fname.parent.mkdir(parents=True)
|
||||
else:
|
||||
raise NotADirectoryError(fname.parent)
|
||||
|
||||
with h5.File(fname, "w") as f:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
value = getattr(self, field.name)
|
||||
if field.is_array:
|
||||
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
|
||||
raise TypeError(
|
||||
"When dumping a H5Dataclass, make sure the array fields are "
|
||||
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
|
||||
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
def write_value(key: str, value: typing.Any):
|
||||
if field.is_array:
|
||||
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
|
||||
else:
|
||||
f.create_dataset(key, data=value)
|
||||
|
||||
if field.is_prefix:
|
||||
for k, v in value.items():
|
||||
write_value(self._prefix + field.name + "_" + k, v)
|
||||
else:
|
||||
write_value(self._prefix + field.name, value)
|
||||
|
||||
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
|
||||
if do_copy: # shallow
|
||||
self = self.copy(deep=False)
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
if field.is_prefix and field.is_array:
|
||||
setattr(self, field.name, {
|
||||
k : func(v)
|
||||
for k, v in getattr(self, field.name).items()
|
||||
})
|
||||
elif field.is_array:
|
||||
setattr(self, field.name, func(getattr(self, field.name)))
|
||||
|
||||
return self
|
||||
|
||||
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
|
||||
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
out = super().copy(deep=deep)
|
||||
if not deep:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_prefix:
|
||||
out[field.name] = copy.copy(field.name)
|
||||
return out
|
||||
|
||||
@property
|
||||
def shape(self) -> dict[str, tuple[int, ...]]:
|
||||
return {
|
||||
key: value.shape
|
||||
for key, value in self.items()
|
||||
if hasattr(value, "shape")
|
||||
}
|
||||
|
||||
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
|
||||
...
|
||||
|
||||
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
|
||||
mtx = self.transforms[name]
|
||||
out = self.transform(mtx, inplace=inplace)
|
||||
out.transforms.pop(name) # consumed
|
||||
|
||||
inv = np.linalg.inv(mtx)
|
||||
for key in list(out.transforms.keys()): # maintain the other transforms
|
||||
out.transforms[key] = out.transforms[key] @ inv
|
||||
if inverse_name is not None: # store inverse
|
||||
out.transforms[inverse_name] = inv
|
||||
|
||||
return out
|
48
ifield/data/common/mesh.py
Normal file
48
ifield/data/common/mesh.py
Normal file
@ -0,0 +1,48 @@
|
||||
from math import pi
|
||||
from trimesh import Trimesh
|
||||
import numpy as np
|
||||
import os
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def rotate_to_closest_axis_aligned_bounds(
|
||||
mesh : Trimesh,
|
||||
order_axes : bool = True,
|
||||
fail_ok : bool = True,
|
||||
) -> np.ndarray:
|
||||
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
|
||||
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
|
||||
|
||||
if not order_axes:
|
||||
return to_aabb_rot_mat4
|
||||
|
||||
v = pi / 4 * 1.01 # tolerance
|
||||
v2 = pi / 2
|
||||
|
||||
faces = (
|
||||
(0, 0),
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(3, 0),
|
||||
(0, 1),
|
||||
(0,-1),
|
||||
)
|
||||
orientations = [ # 6 faces x 4 rotations per face
|
||||
(f[0] * v2, f[1] * v2, i * v2)
|
||||
for i in range(4)
|
||||
for f in faces]
|
||||
|
||||
for x, y, z in orientations:
|
||||
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
|
||||
ai, aj, ak = T.euler_from_matrix(mat4)
|
||||
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
|
||||
return mat4
|
||||
|
||||
if fail_ok: return to_aabb_rot_mat4
|
||||
raise Exception("Unable to orient mesh")
|
297
ifield/data/common/points.py
Normal file
297
ifield/data/common/points.py
Normal file
@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
from ...utils.helpers import compose
|
||||
from functools import reduce, lru_cache
|
||||
from math import ceil
|
||||
from typing import Iterable
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
|
||||
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
|
||||
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
|
||||
|
||||
# pad the image
|
||||
img_pad = np.zeros((
|
||||
n_channels,
|
||||
int(ceil(1.0 * rows / psize) * psize),
|
||||
int(ceil(1.0 * cols / psize) * psize),
|
||||
))
|
||||
img_pad[:, 0:rows, 0:cols] = img
|
||||
|
||||
# allocate output buffer
|
||||
final = np.zeros((
|
||||
img_pad.shape[1],
|
||||
img_pad.shape[2],
|
||||
n_channels,
|
||||
psize,
|
||||
psize,
|
||||
))
|
||||
|
||||
for c in range(n_channels):
|
||||
for x in range(psize):
|
||||
for y in range(psize):
|
||||
img_shift = np.vstack((
|
||||
img_pad[c, x:],
|
||||
img_pad[c, :x]))
|
||||
img_shift = np.column_stack((
|
||||
img_shift[:, y:],
|
||||
img_shift[:, :y]))
|
||||
final[x::psize, y::psize, c] = np.swapaxes(
|
||||
img_shift.reshape(
|
||||
int(img_pad.shape[1] / psize), psize,
|
||||
int(img_pad.shape[2] / psize), psize),
|
||||
1,
|
||||
2)
|
||||
|
||||
# crop output and unwrap axes with size==1
|
||||
return np.squeeze(final[
|
||||
0:rows - psize + 1,
|
||||
0:cols - psize + 1])
|
||||
|
||||
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
|
||||
"""
|
||||
Removes data close to discontinuities, with size filt_size.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
# Ensure that filter sizes are okay
|
||||
assert filt_size % 2, "Can only use odd filter sizes."
|
||||
|
||||
# Compute discontinuities
|
||||
offset = int(filt_size - 1) // 2
|
||||
patches = 1.0 * img2col(depth_map, filt_size)
|
||||
mids = patches[:, :, offset, offset]
|
||||
mins = np.min(patches, axis=(2, 3))
|
||||
maxes = np.max(patches, axis=(2, 3))
|
||||
|
||||
discont = np.maximum(
|
||||
np.abs(mins - mids),
|
||||
np.abs(maxes - mids))
|
||||
mark = discont > thresh
|
||||
|
||||
# Account for offsets
|
||||
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
|
||||
final_mark[offset:offset + mark.shape[0],
|
||||
offset:offset + mark.shape[1]] = mark
|
||||
|
||||
return depth_map * (1 - final_mark)
|
||||
|
||||
def reorient_depth_map(
|
||||
depth_map : np.ndarray,
|
||||
rgb_map : np.ndarray,
|
||||
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
|
||||
rgb_mask_map : np.ndarray = None,
|
||||
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
|
||||
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
|
||||
) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
|
||||
If _output_points is True, the points returned are in the rgb camera space.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
# now faster AND more easy on the GIL
|
||||
|
||||
height_old, width_old, *_ = depth_map.shape
|
||||
height, width, *_ = rgb_map.shape
|
||||
|
||||
|
||||
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
|
||||
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
|
||||
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
|
||||
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
|
||||
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
|
||||
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
|
||||
|
||||
# make a UV grid over depth_map
|
||||
u, v = np.meshgrid(
|
||||
np.arange(width_old),
|
||||
np.arange(height_old),
|
||||
)
|
||||
|
||||
# compute xyz coordinates for all depths
|
||||
xyz_depth = np.stack((
|
||||
(u - d_cx) / d_fx,
|
||||
(v - d_cy) / d_fy,
|
||||
depth_map,
|
||||
np.ones(depth_map.shape)
|
||||
)).reshape((4, -1))
|
||||
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
|
||||
|
||||
# undistort depth coordinates
|
||||
d_x, d_y = xyz_depth[:2]
|
||||
r = np.linalg.norm(xyz_depth[:2], axis=0)
|
||||
xyz_depth[0, :] \
|
||||
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
|
||||
xyz_depth[1, :] \
|
||||
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
|
||||
|
||||
# unproject x and y
|
||||
xyz_depth[0, :] *= xyz_depth[2, :]
|
||||
xyz_depth[1, :] *= xyz_depth[2, :]
|
||||
|
||||
# convert depths to RGB camera viewpoint
|
||||
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
|
||||
|
||||
# project depths to RGB canvas
|
||||
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
|
||||
rgb_uv = np.stack((
|
||||
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
|
||||
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
|
||||
)).astype(np.int)
|
||||
|
||||
# mask of the rgb_xyz values within view of rgb_map
|
||||
mask = reduce(operator.and_, [
|
||||
rgb_uv[0] >= 0,
|
||||
rgb_uv[1] >= 0,
|
||||
rgb_uv[0] < width,
|
||||
rgb_uv[1] < height,
|
||||
])
|
||||
if rgb_mask_map is not None:
|
||||
mask[mask] &= rgb_mask_map[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
|
||||
if not _output_points: # output image
|
||||
output = np.zeros((height, width), dtype=depth_map.dtype)
|
||||
output[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
] = xyz_rgb[2, mask]
|
||||
|
||||
else: # output pointcloud
|
||||
rgbs = rgb_map[ # lookup rgb values using rgb_uv
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
output = np.stack((
|
||||
xyz_rgb[0, mask], # x
|
||||
xyz_rgb[1, mask], # y
|
||||
xyz_rgb[2, mask], # z
|
||||
rgbs[:, 0], # r
|
||||
rgbs[:, 1], # g
|
||||
rgbs[:, 2], # b
|
||||
)).T
|
||||
|
||||
# output for realsies
|
||||
if not _output_hits_uvs: #raw
|
||||
return output
|
||||
else: # with hit mask
|
||||
uv = np.zeros((height, width), dtype=bool)
|
||||
# filter points overlapping in the depth map
|
||||
uv_indices = (
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
)
|
||||
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
|
||||
output = output[chosen, :]
|
||||
uv[uv_indices] = True
|
||||
return output, uv
|
||||
|
||||
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
|
||||
return reorient_depth_map(*a, _output_points=True, **kw)
|
||||
|
||||
@compose(np.array) # block lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
@compose(list)
|
||||
def generate_equidistant_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False,
|
||||
) -> Iterable[tuple[float, ...]]:
|
||||
# Deserno M. How to generate equidistributed points on the surface of a sphere
|
||||
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
|
||||
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
n_count = 0
|
||||
a = 4 * np.pi / n
|
||||
d = np.sqrt(a)
|
||||
n_theta = round(np.pi / d)
|
||||
d_theta = np.pi / n_theta
|
||||
d_phi = a / d_theta
|
||||
|
||||
for i in range(0, n_theta):
|
||||
theta = np.pi * (i + 0.5) / n_theta
|
||||
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
|
||||
|
||||
for j in range(0, n_phi):
|
||||
phi = 2 * np.pi * j / n_phi
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
yield (
|
||||
theta if shift_theta else theta - 0.5*np.pi,
|
||||
phi,
|
||||
)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
np.sin(theta) * np.cos(phi),
|
||||
np.sin(theta) * np.sin(phi),
|
||||
np.cos(theta),
|
||||
)
|
||||
else: # (x, y, z)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
)
|
||||
n_count += 1
|
||||
|
||||
|
||||
def generate_random_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False, # depends on convention
|
||||
) -> np.ndarray:
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
|
||||
phi = np.random.uniform(0, 2*np.pi, n)
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
return np.stack((
|
||||
theta if not shift_theta else 0.5*np.pi + theta,
|
||||
phi,
|
||||
), axis=1)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
np.cos(theta) * np.cos(phi),
|
||||
np.cos(theta) * np.sin(phi),
|
||||
np.sin(theta),
|
||||
), axis=1)
|
||||
else: # (x, y, z)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
), axis=1)
|
85
ifield/data/common/processing.py
Normal file
85
ifield/data/common/processing.py
Normal file
@ -0,0 +1,85 @@
|
||||
from .h5_dataclasses import H5Dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Hashable, Optional, Callable
|
||||
import os
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
|
||||
# paralelize it in the shell instead
|
||||
|
||||
def precompute_data(
|
||||
computer : Callable[[Hashable], Optional[H5Dataclass]],
|
||||
identifiers : list[Hashable],
|
||||
output_paths : list[Path],
|
||||
page : tuple[int, int] = (0, 1),
|
||||
*,
|
||||
force : bool = False,
|
||||
debug : bool = False,
|
||||
):
|
||||
"""
|
||||
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
|
||||
"""
|
||||
|
||||
page, n_pages = page
|
||||
assert len(identifiers) == len(output_paths)
|
||||
|
||||
total = len(identifiers)
|
||||
identifier_max_len = max(map(len, map(str, identifiers)))
|
||||
t_epoch = None
|
||||
def log(state: str, is_start = False):
|
||||
nonlocal t_epoch
|
||||
if is_start: t_epoch = datetime.now()
|
||||
td = timedelta(0) if is_start else datetime.now() - t_epoch
|
||||
print(" - "
|
||||
f"{str(index+1).rjust(len(str(total)))}/{total}: "
|
||||
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
|
||||
)
|
||||
|
||||
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
|
||||
t_begin = datetime.now()
|
||||
failed = []
|
||||
|
||||
# pagination
|
||||
page_size = total // n_pages + bool(total % n_pages)
|
||||
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
|
||||
|
||||
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
|
||||
if not force and output_path.exists() and output_path.stat().st_size > 0:
|
||||
continue
|
||||
|
||||
log("compute", is_start=True)
|
||||
|
||||
# compute
|
||||
try:
|
||||
res = computer(identifier)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed compute: {e.__class__.__name__}: {e}")
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
if res is None:
|
||||
failed.append(identifier)
|
||||
log("no result")
|
||||
continue
|
||||
|
||||
# write to file
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
res.to_h5_file(output_path)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed write: {e.__class__.__name__}: {e}")
|
||||
if output_path.is_file(): output_path.unlink() # cleanup
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
|
||||
log("done")
|
||||
|
||||
print("precompute_data finished in", datetime.now() - t_begin)
|
||||
print("failed:", failed or None)
|
768
ifield/data/common/scan.py
Normal file
768
ifield/data/common/scan.py
Normal file
@ -0,0 +1,768 @@
|
||||
from ...utils.helpers import compose
|
||||
from . import points
|
||||
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
|
||||
from methodtools import lru_cache
|
||||
from sklearn.neighbors import BallTree
|
||||
import faiss
|
||||
from trimesh import Trimesh
|
||||
from typing import Iterable
|
||||
from typing import Optional, TypeVar
|
||||
import mesh_to_sdf
|
||||
import mesh_to_sdf.scan as sdf_scan
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
import warnings
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper types for data.
|
||||
"""
|
||||
|
||||
_T = TypeVar("T")
|
||||
|
||||
class InvalidateLRUOnWriteMixin:
|
||||
def __setattr__(self, key, value):
|
||||
if not key.startswith("__wire|"):
|
||||
for attr in dir(self):
|
||||
if attr.startswith("__wire|"):
|
||||
getattr(self, attr).cache_clear()
|
||||
return super().__setattr__(key, value)
|
||||
def lru_property(func):
|
||||
return lru_cache(maxsize=1)(property(func))
|
||||
|
||||
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
points_hit : H5ArrayNoSlice # (N, 3)
|
||||
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
points_miss : H5ArrayNoSlice # (M, 3)
|
||||
distances_miss : Optional[H5ArrayNoSlice] # (M)
|
||||
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
|
||||
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
|
||||
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
|
||||
cam_pos : H5ArrayNoSlice # (3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points_hit = T.transform_points(self.points_hit, mat4)
|
||||
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
|
||||
out.points_miss = T.transform_points(self.points_miss, mat4)
|
||||
out.distances_miss = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
if not self.is_hitting:
|
||||
raise ValueError("No hits to compute the ray distance towards")
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances_miss \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = out.points_cam,
|
||||
ray_dirs = out.ray_dirs_miss,
|
||||
points = out.points_hit,
|
||||
).astype(out.points_cam.dtype)
|
||||
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
return np.concatenate((
|
||||
self.points_hit,
|
||||
self.points_miss,
|
||||
self.points_cam,
|
||||
))
|
||||
|
||||
@lru_property
|
||||
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
|
||||
out[self.uv_hits, :] = self.points_hit
|
||||
out[self.uv_miss, :] = self.points_miss
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
|
||||
out[self.uv_hits, :] = self.normals_hit
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
|
||||
if self.cam_pos is None: return None
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def points_hit_centroid(self) -> np.ndarray:
|
||||
return self.points_hit.mean(axis=0)
|
||||
|
||||
@lru_property
|
||||
def points_hit_std(self) -> np.ndarray:
|
||||
return self.points_hit.std(axis=0)
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return len(self.points_hit) > 0
|
||||
|
||||
@lru_property
|
||||
def is_empty(self) -> bool:
|
||||
return not (len(self.points_hit) or len(self.points_miss))
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return self.colors_hit is not None or self.colors_miss is not None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return self.normals_hit is not None
|
||||
|
||||
@lru_property
|
||||
def has_uv(self) -> bool:
|
||||
return self.uv_hits is not None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return self.distances_miss is not None
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
|
||||
if self.colors_hit is None: raise ValueError
|
||||
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
|
||||
if self.colors_miss is None: raise ValueError
|
||||
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_hit - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_miss - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances and scan.is_hitting:
|
||||
scan.compute_miss_distances()
|
||||
return scan
|
||||
|
||||
def to_uv_scan(self) -> "SingleViewUVScan":
|
||||
return SingleViewUVScan.from_scan(self)
|
||||
|
||||
@classmethod
|
||||
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
|
||||
return uvscan.to_scan()
|
||||
|
||||
# The same, but with support for pagination (should have been this way since the start...)
|
||||
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
# B may be (N) or (H, W), the latter may be flattened
|
||||
hits : H5Array # (*B) dtype=bool
|
||||
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
|
||||
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
|
||||
normals : Optional[H5Array] # (*B, 3) NaN if not hit
|
||||
colors : Optional[H5Array] # (*B, 3)
|
||||
distances : Optional[H5Array] # (*B) NaN if not miss
|
||||
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
@classmethod
|
||||
def from_scan(cls, scan: SingleViewScan):
|
||||
if not scan.has_uv:
|
||||
raise ValueError("Scan cloud has no UV data")
|
||||
hits, miss = scan.uv_hits, scan.uv_miss
|
||||
dtype = scan.points_hit.dtype
|
||||
assert hits.ndim in (1, 2), hits.ndim
|
||||
assert hits.shape == miss.shape, (hits.shape, miss.shape)
|
||||
|
||||
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
points[hits, :] = scan.points_hit
|
||||
points[miss, :] = scan.points_miss
|
||||
|
||||
normals = None
|
||||
if scan.has_normals:
|
||||
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
normals[hits, :] = scan.normals_hit
|
||||
|
||||
distances = None
|
||||
if scan.has_miss_distances:
|
||||
distances = np.full(hits.shape, np.nan, dtype=dtype)
|
||||
distances[miss] = scan.distances_miss
|
||||
|
||||
colors = None
|
||||
if scan.has_colors:
|
||||
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
if scan.colors_hit is not None:
|
||||
colors[hits, :] = scan.colors_hit
|
||||
if scan.colors_miss is not None:
|
||||
colors[miss, :] = scan.colors_miss
|
||||
|
||||
return cls(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = points,
|
||||
normals = normals,
|
||||
colors = colors,
|
||||
distances = distances,
|
||||
cam_pos = scan.cam_pos,
|
||||
cam_mat4 = scan.cam_mat4,
|
||||
proj_mat4 = scan.proj_mat4,
|
||||
transforms = scan.transforms,
|
||||
)
|
||||
|
||||
def to_scan(self) -> "SingleViewScan":
|
||||
if not self.is_single_view: raise ValueError
|
||||
return SingleViewScan(
|
||||
points_hit = self.points [self.hits, :],
|
||||
points_miss = self.points [self.miss, :],
|
||||
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
|
||||
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
|
||||
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
|
||||
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
|
||||
uv_hits = self.hits,
|
||||
uv_miss = self.miss,
|
||||
cam_pos = self.cam_pos,
|
||||
cam_mat4 = self.cam_mat4,
|
||||
proj_mat4 = self.proj_mat4,
|
||||
transforms = self.transforms,
|
||||
)
|
||||
|
||||
def to_mesh(self) -> trimesh.Trimesh:
|
||||
faces: list[(tuple[int, int],)*3] = []
|
||||
for x in range(self.hits.shape[0]-1):
|
||||
for y in range(self.hits.shape[1]-1):
|
||||
c11 = x, y
|
||||
c12 = x, y+1
|
||||
c22 = x+1, y+1
|
||||
c21 = x+1, y
|
||||
|
||||
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
|
||||
if n == 3:
|
||||
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
|
||||
elif n == 4:
|
||||
faces.append((c11, c12, c22))
|
||||
faces.append((c11, c22, c21))
|
||||
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
|
||||
assert self.colors is not None
|
||||
return trimesh.Trimesh(
|
||||
vertices = [self.points[i] for i in xy2idx.keys()],
|
||||
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
|
||||
faces = [tuple(xy2idx[i] for i in face) for face in faces],
|
||||
)
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
unflat = self.hits.shape
|
||||
flat = np.product(unflat)
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
|
||||
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
|
||||
out.distances = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
|
||||
shape = self.hits.shape
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances = np.zeros(shape, dtype=self.points.dtype)
|
||||
if self.is_hitting:
|
||||
out.distances[self.miss] \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = self.cam_pos_unsqueezed_miss,
|
||||
ray_dirs = self.ray_dirs_miss,
|
||||
points = surface_points if surface_points is not None else self.points[self.hits],
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
"""
|
||||
Fill in missing points as hitting the far plane.
|
||||
"""
|
||||
if not self.is_2d:
|
||||
raise ValueError("Cannot fill missing points for non-2d scan!")
|
||||
if not self.is_single_view:
|
||||
raise ValueError("Cannot fill missing points for non-single-view scans!")
|
||||
if self.cam_mat4 is None:
|
||||
raise ValueError("cam_mat4 is None")
|
||||
if self.proj_mat4 is None:
|
||||
raise ValueError("proj_mat4 is None")
|
||||
|
||||
uv = np.argwhere(self.missing).astype(self.points.dtype)
|
||||
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
|
||||
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
|
||||
uv -= 1
|
||||
uv = np.stack((
|
||||
uv[:, 1],
|
||||
-uv[:, 0],
|
||||
np.ones(uv.shape[0]), # far clipping plane
|
||||
np.ones(uv.shape[0]), # homogeneous coordinate
|
||||
), axis=-1)
|
||||
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return np.any(self.hits)
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return not self.colors is None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return not self.normals is None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return not self.distances is None
|
||||
|
||||
@lru_property
|
||||
def any_missing(self) -> bool:
|
||||
return np.any(self.missing)
|
||||
|
||||
@lru_property
|
||||
def has_missing(self) -> bool:
|
||||
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos
|
||||
else:
|
||||
cam_pos = self.cam_pos
|
||||
for _ in range(self.hits.ndim):
|
||||
cam_pos = cam_pos[None, ...]
|
||||
return cam_pos
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_hit(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.hits, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_miss(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.miss, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def ray_dirs(self) -> H5Array:
|
||||
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> H5Array:
|
||||
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> H5Array:
|
||||
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def depths(self) -> H5Array:
|
||||
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
|
||||
|
||||
@lru_property
|
||||
def missing(self) -> H5Array:
|
||||
return ~(self.hits | self.miss)
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
assert scan.is_2d
|
||||
return scan
|
||||
|
||||
@classmethod
|
||||
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances:
|
||||
surface_points = None
|
||||
if scan.hits.sum() > mesh.vertices.shape[0]:
|
||||
surface_points = mesh.vertices.astype(scan.points.dtype)
|
||||
if not kw.get("no_unit_sphere", False):
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
|
||||
surface_points = (surface_points + translation) * scale
|
||||
scan.compute_miss_distances(surface_points=surface_points)
|
||||
assert scan.is_flat
|
||||
return scan
|
||||
|
||||
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
|
||||
n_items = np.product(self.hits.shape)
|
||||
permutation = np.random.permutation(n_items)
|
||||
|
||||
out = self.copy(deep=False) if copy else self
|
||||
out.hits = out.hits .reshape((n_items, ))[permutation]
|
||||
out.miss = out.miss .reshape((n_items, ))[permutation]
|
||||
out.points = out.points .reshape((n_items, 3))[permutation, :]
|
||||
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
|
||||
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
|
||||
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
|
||||
return out
|
||||
|
||||
@property
|
||||
def is_single_view(self) -> bool:
|
||||
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
|
||||
|
||||
@property
|
||||
def is_flat(self) -> bool:
|
||||
return len(self.hits.shape) == 1
|
||||
|
||||
@property
|
||||
def is_2d(self) -> bool:
|
||||
return len(self.hits.shape) == 2
|
||||
|
||||
|
||||
# transforms can be found in pytorch3d.transforms and in open3d
|
||||
# and in trimesh.transformations
|
||||
|
||||
def sample_single_view_scans_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
n_batches : int,
|
||||
scan_resolution : int = 400,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
) -> Iterable[SingleViewScan]:
|
||||
|
||||
normalized_mesh_cache = []
|
||||
|
||||
for _ in range(n_batches):
|
||||
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
|
||||
yield sample_single_view_scan_from_mesh(
|
||||
mesh = mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
_mesh_is_normalized = False,
|
||||
scan_resolution = scan_resolution,
|
||||
compute_normals = compute_normals,
|
||||
fov = fov,
|
||||
camera_distance = camera_distance,
|
||||
no_filter_backhits = no_filter_backhits,
|
||||
_mesh_cache = normalized_mesh_cache,
|
||||
)
|
||||
|
||||
def sample_single_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
scan_resolution : int = 200,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
dtype : type = np.float32,
|
||||
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
|
||||
) -> SingleViewScan:
|
||||
|
||||
# scale and center to unit sphere
|
||||
is_cache = isinstance(_mesh_cache, list)
|
||||
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
|
||||
_, mesh, translation, scale = _mesh_cache
|
||||
else:
|
||||
if is_cache:
|
||||
if _mesh_cache:
|
||||
_mesh_cache.clear()
|
||||
_mesh_cache.append(mesh)
|
||||
translation, scale = compute_unit_sphere_transform(mesh)
|
||||
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
if is_cache:
|
||||
_mesh_cache.extend((mesh, translation, scale))
|
||||
|
||||
z_near = 1
|
||||
z_far = 3
|
||||
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
|
||||
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
|
||||
|
||||
scan = sdf_scan.Scan(mesh,
|
||||
camera_transform = cam_mat4,
|
||||
resolution = scan_resolution,
|
||||
calculate_normals = compute_normals,
|
||||
fov = fov,
|
||||
z_near = z_near,
|
||||
z_far = z_far,
|
||||
no_flip_backfaced_normals = True
|
||||
)
|
||||
|
||||
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
|
||||
misses = np.argwhere(scan.depth_buffer == 0)
|
||||
points_miss = np.ones((misses.shape[0], 4))
|
||||
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
|
||||
points_miss[:, 1] *= -1
|
||||
points_miss[:, 2] = 1 # far plane in clipping space
|
||||
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
|
||||
points_miss /= points_miss[:, 3][:, np.newaxis]
|
||||
points_miss = points_miss[:, :3]
|
||||
|
||||
uv_hits = scan.depth_buffer != 0
|
||||
uv_miss = ~uv_hits
|
||||
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
# inner product
|
||||
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
|
||||
scan.points = scan.points [mask, :]
|
||||
scan.normals = scan.normals[mask, :]
|
||||
uv_hits[uv_hits] = mask
|
||||
|
||||
transforms = {}
|
||||
|
||||
# undo unit-sphere transform
|
||||
if no_unit_sphere:
|
||||
scan.points = scan.points * (1 / scale) - translation
|
||||
points_miss = points_miss * (1 / scale) - translation
|
||||
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
|
||||
cam_mat4[:3, :] *= 1 / scale
|
||||
cam_mat4[:3, 3] -= translation
|
||||
|
||||
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
|
||||
transforms["model"] = np.eye(4)
|
||||
else:
|
||||
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
|
||||
transforms["unit_sphere"] = np.eye(4)
|
||||
|
||||
return SingleViewScan(
|
||||
normals_hit = scan.normals .astype(dtype),
|
||||
points_hit = scan.points .astype(dtype),
|
||||
points_miss = points_miss .astype(dtype),
|
||||
distances_miss = None,
|
||||
colors_hit = None,
|
||||
colors_miss = None,
|
||||
uv_hits = uv_hits .astype(bool),
|
||||
uv_miss = uv_miss .astype(bool),
|
||||
cam_pos = cam_pos[:3] .astype(dtype),
|
||||
cam_mat4 = cam_mat4 .astype(dtype),
|
||||
proj_mat4 = scan.projection_matrix .astype(dtype),
|
||||
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
|
||||
)
|
||||
|
||||
def sample_sphere_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
sphere_points : int = 4000, # resulting rays are n*(n-1)
|
||||
compute_normals : bool = False,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
no_permute : bool = False,
|
||||
dtype : type = np.float32,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
|
||||
|
||||
# get unit-sphere points, then transform to model space
|
||||
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
|
||||
two_sphere = two_sphere / scale - translation # we transform after cache lookup
|
||||
|
||||
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
|
||||
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
|
||||
|
||||
(
|
||||
locations,
|
||||
index_ray,
|
||||
index_tri,
|
||||
) = mesh.ray.intersects_location(
|
||||
two_sphere[:, 0, :],
|
||||
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
|
||||
multiple_hits=False,
|
||||
)
|
||||
|
||||
|
||||
if compute_normals:
|
||||
location_normals = mesh.face_normals[index_tri]
|
||||
|
||||
batch = two_sphere.shape[:1]
|
||||
hits = np.zeros((*batch,), dtype=np.bool)
|
||||
miss = np.ones((*batch,), dtype=np.bool)
|
||||
cam_pos = two_sphere[:, 0, :]
|
||||
intersections = two_sphere[:, 1, :] # far-plane, effectively
|
||||
normals = np.zeros((*batch, 3), dtype=dtype)
|
||||
|
||||
index_ray_front = index_ray
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
|
||||
index_ray_front = index_ray[mask]
|
||||
|
||||
|
||||
hits[index_ray_front] = True
|
||||
miss[index_ray] = False
|
||||
intersections[index_ray] = locations
|
||||
normals[index_ray] = location_normals
|
||||
|
||||
|
||||
if not no_permute:
|
||||
assert len(batch) == 1, batch
|
||||
permutation = np.random.permutation(*batch)
|
||||
hits = hits [permutation]
|
||||
miss = miss [permutation]
|
||||
intersections = intersections[permutation, :]
|
||||
normals = normals [permutation, :]
|
||||
cam_pos = cam_pos [permutation, :]
|
||||
|
||||
# apply unit sphere transform
|
||||
if not no_unit_sphere:
|
||||
intersections = (intersections + translation) * scale
|
||||
cam_pos = (cam_pos + translation) * scale
|
||||
|
||||
return SingleViewUVScan(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = intersections,
|
||||
normals = normals,
|
||||
colors = None, # colors
|
||||
distances = None,
|
||||
cam_pos = cam_pos,
|
||||
cam_mat4 = None,
|
||||
proj_mat4 = None,
|
||||
transforms = {},
|
||||
)
|
||||
|
||||
def distance_from_rays_to_point_cloud(
|
||||
ray_origins : np.ndarray, # (*A, 3)
|
||||
ray_dirs : np.ndarray, # (*A, 3)
|
||||
points : np.ndarray, # (*B, 3)
|
||||
dirs_normalized : bool = False,
|
||||
n_steps : int = 40,
|
||||
) -> np.ndarray: # (A)
|
||||
|
||||
# anything outside of this volume will never constribute to the result
|
||||
max_norm = max(
|
||||
np.linalg.norm(ray_origins, axis=-1).max(),
|
||||
np.linalg.norm(points, axis=-1).max(),
|
||||
) * 1.02
|
||||
|
||||
if not dirs_normalized:
|
||||
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
|
||||
|
||||
|
||||
# deal with single-view clouds
|
||||
if ray_origins.shape != ray_dirs.shape:
|
||||
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
|
||||
|
||||
n_points = np.product(points.shape[:-1])
|
||||
use_faiss = n_points > 160000*4
|
||||
if not use_faiss:
|
||||
index = BallTree(points)
|
||||
else:
|
||||
# http://ann-benchmarks.com/index.html
|
||||
assert np.issubdtype(points.dtype, np.float32)
|
||||
assert np.issubdtype(ray_origins.dtype, np.float32)
|
||||
assert np.issubdtype(ray_dirs.dtype, np.float32)
|
||||
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
|
||||
|
||||
index.nprobe = 5 # 10 # default is 1
|
||||
index.train(points)
|
||||
index.add(points)
|
||||
|
||||
if not use_faiss:
|
||||
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
|
||||
else:
|
||||
min_d, min_n = index.search(ray_origins, k=1)
|
||||
min_d = np.sqrt(min_d)
|
||||
acc_d = min_d.copy()
|
||||
|
||||
for step in range(1, n_steps+1):
|
||||
query_points = ray_origins + acc_d * ray_dirs
|
||||
if max_norm is not None:
|
||||
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
|
||||
if not qmask.any(): break
|
||||
query_points = query_points[qmask]
|
||||
else:
|
||||
qmask = slice(None)
|
||||
if not use_faiss:
|
||||
current_d, current_n = index.query(query_points, k=1, return_distance=True)
|
||||
else:
|
||||
current_d, current_n = index.search(query_points, k=1)
|
||||
current_d = np.sqrt(current_d)
|
||||
if max_norm is not None:
|
||||
min_d[qmask] = np.minimum(current_d, min_d[qmask])
|
||||
new_min_mask = min_d[qmask] == current_d
|
||||
qmask2 = qmask.copy()
|
||||
qmask2[qmask2] = new_min_mask[..., 0]
|
||||
min_n[qmask2] = current_n[new_min_mask[..., 0]]
|
||||
acc_d[qmask] += current_d * 0.25
|
||||
else:
|
||||
np.minimum(current_d, min_d, out=min_d)
|
||||
new_min_mask = min_d == current_d
|
||||
min_n[new_min_mask] = current_n[new_min_mask]
|
||||
acc_d += current_d * 0.25
|
||||
|
||||
closest_points = points[min_n[:, 0], :] # k=1
|
||||
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
|
||||
return distances
|
||||
|
||||
# helpers
|
||||
|
||||
@compose(np.array) # make copy to avoid lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
|
||||
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
|
||||
|
||||
indices = np.indices((len(sphere_points),))[0] # (N)
|
||||
# cartesian product
|
||||
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
|
||||
# filter repeated combinations
|
||||
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
|
||||
# lookup sphere points
|
||||
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
|
||||
|
||||
return two_sphere
|
||||
|
||||
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
|
||||
"""
|
||||
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
|
||||
"""
|
||||
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
translation = -mesh.bounding_box.centroid
|
||||
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
|
||||
if dtype is not None:
|
||||
translation = translation.astype(dtype)
|
||||
scale = scale .astype(dtype)
|
||||
return translation, scale
|
6
ifield/data/common/types.py
Normal file
6
ifield/data/common/types.py
Normal file
@ -0,0 +1,6 @@
|
||||
__doc__ = """
|
||||
Some helper types.
|
||||
"""
|
||||
|
||||
class MalformedMesh(Exception):
|
||||
pass
|
28
ifield/data/config.py
Normal file
28
ifield/data/config.py
Normal file
@ -0,0 +1,28 @@
|
||||
from ..utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
import warnings
|
||||
|
||||
|
||||
def data_path_get(dataset_name: str, no_warn: bool = False) -> Path:
|
||||
dataset_envvar = f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"
|
||||
if dataset_envvar in os.environ:
|
||||
data_path = Path(os.environ[dataset_envvar])
|
||||
elif "IFIELD_DATA_MODELS" in os.environ:
|
||||
data_path = Path(os.environ["IFIELD_DATA_MODELS"]) / dataset_name
|
||||
else:
|
||||
data_path = Path(__file__).resolve().parent.parent.parent / "data" / "models" / dataset_name
|
||||
if not data_path.is_dir() and not no_warn:
|
||||
warnings.warn(f"{make_relative(data_path, Path.cwd()).__str__()!r} is not a directory!")
|
||||
return data_path
|
||||
|
||||
def data_path_persist(dataset_name: Optional[str], path: os.PathLike) -> os.PathLike:
|
||||
"Persist the datapath, ensuring subprocesses also will use it. The path passes through."
|
||||
|
||||
if dataset_name is None:
|
||||
os.environ["IFIELD_DATA_MODELS"] = str(path)
|
||||
else:
|
||||
os.environ[f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"] = str(path)
|
||||
|
||||
return path
|
56
ifield/data/coseg/__init__.py
Normal file
56
ifield/data/coseg/__init__.py
Normal file
@ -0,0 +1,56 @@
|
||||
from ..config import data_path_get, data_path_persist
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
# Data source:
|
||||
# http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
|
||||
|
||||
__ALL__ = ["config", "Model", "MODELS"]
|
||||
|
||||
Archive = namedtuple("Archive", "url fname download_size_str")
|
||||
|
||||
@(lambda x: x()) # singleton
|
||||
class config:
|
||||
DATA_PATH = property(
|
||||
doc = """
|
||||
Path to the dataset. The following envvars override it:
|
||||
${IFIELD_DATA_MODELS}/coseg
|
||||
${IFIELD_DATA_MODELS_COSEG}
|
||||
""",
|
||||
fget = lambda self: data_path_get ("coseg"),
|
||||
fset = lambda self, path: data_path_persist("coseg", path),
|
||||
)
|
||||
|
||||
@property
|
||||
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
|
||||
return [
|
||||
self.DATA_PATH / "downloaded.json",
|
||||
]
|
||||
|
||||
SHAPES: dict[str, Archive] = {
|
||||
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/shapes.zip", "candelabra-shapes.zip", "3,3M"),
|
||||
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/shapes.zip", "chair-shapes.zip", "3,2M"),
|
||||
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/shapes.zip", "four-legged-shapes.zip", "2,9M"),
|
||||
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/shapes.zip", "goblets-shapes.zip", "500K"),
|
||||
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip", "guitars-shapes.zip", "1,9M"),
|
||||
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/shapes.zip", "lampes-shapes.zip", "2,4M"),
|
||||
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/shapes.zip", "vases-shapes.zip", "5,5M"),
|
||||
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/shapes.zip", "irons-shapes.zip", "1,2M"),
|
||||
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/shapes.zip", "tele-aliens-shapes.zip", "15M"),
|
||||
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/shapes.zip", "large-vases-shapes.zip", "6,2M"),
|
||||
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/shapes.zip", "large-chairs-shapes.zip", "14M"),
|
||||
}
|
||||
GROUND_TRUTHS: dict[str, Archive] = {
|
||||
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/gt.zip", "candelabra-gt.zip", "68K"),
|
||||
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/gt.zip", "chair-gt.zip", "20K"),
|
||||
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/gt.zip", "four-legged-gt.zip", "24K"),
|
||||
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/gt.zip", "goblets-gt.zip", "4,0K"),
|
||||
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip", "guitars-gt.zip", "12K"),
|
||||
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/gt.zip", "lampes-gt.zip", "60K"),
|
||||
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/gt.zip", "vases-gt.zip", "40K"),
|
||||
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/gt.zip", "irons-gt.zip", "8,0K"),
|
||||
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/gt.zip", "tele-aliens-gt.zip", "72K"),
|
||||
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/gt.zip", "large-vases-gt.zip", "68K"),
|
||||
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/gt.zip", "large-chairs-gt.zip", "116K"),
|
||||
}
|
135
ifield/data/coseg/download.py
Normal file
135
ifield/data/coseg/download.py
Normal file
@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
from . import config
|
||||
from ...utils.helpers import make_relative
|
||||
from ..common import download
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
|
||||
|
||||
def is_downloaded(*a, **kw):
|
||||
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
|
||||
|
||||
def download_and_extract(target_dir: Path, url_dict: dict[str, str], *, force=False, silent=False) -> bool:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ret = False
|
||||
for url, fname in url_dict.items():
|
||||
if not force:
|
||||
if is_downloaded(target_dir, url): continue
|
||||
if not download.check_url(url):
|
||||
print("ERROR:", url)
|
||||
continue
|
||||
ret = True
|
||||
|
||||
if force or not (target_dir / "archives" / fname).is_file():
|
||||
|
||||
data = download.download_data(url, silent=silent, label=fname)
|
||||
assert url.endswith(".zip")
|
||||
|
||||
print("writing...")
|
||||
|
||||
(target_dir / "archives").mkdir(parents=True, exist_ok=True)
|
||||
with (target_dir / "archives" / fname).open("wb") as f:
|
||||
f.write(data)
|
||||
del data
|
||||
|
||||
print(f"extracting {fname}...")
|
||||
|
||||
with zipfile.ZipFile(target_dir / "archives" / fname, 'r') as f:
|
||||
f.extractall(target_dir / Path(fname).stem.removesuffix("-shapes").removesuffix("-gt"))
|
||||
|
||||
is_downloaded(target_dir, url, add=True)
|
||||
|
||||
return ret
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Download The COSEG Shape Dataset.
|
||||
More info: http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
|
||||
|
||||
Example:
|
||||
|
||||
download-coseg --shapes chairs
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument
|
||||
|
||||
arg("sets", nargs="*", default=[],
|
||||
help="Which set to download, defaults to none.")
|
||||
arg("--all", action="store_true",
|
||||
help="Download all sets")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
|
||||
arg("--shapes", action="store_true",
|
||||
help="Download the 3d shapes for each chosen set")
|
||||
arg("--gts", action="store_true",
|
||||
help="Download the ground-truth segmentation data for each chosen set")
|
||||
|
||||
arg("--list", action="store_true",
|
||||
help="Lists all the sets")
|
||||
arg("--list-urls", action="store_true",
|
||||
help="Lists the urls to download")
|
||||
arg("--list-sizes", action="store_true",
|
||||
help="Lists the download size of each set")
|
||||
arg("--silent", action="store_true",
|
||||
help="")
|
||||
arg("--force", action="store_true",
|
||||
help="Download again even if already downloaded")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert set(config.SHAPES.keys()) == set(config.GROUND_TRUTHS.keys())
|
||||
|
||||
set_names = sorted(set(args.sets))
|
||||
if args.all:
|
||||
assert not set_names, "--all is mutually exclusive from manually selected sets"
|
||||
set_names = sorted(config.SHAPES.keys())
|
||||
|
||||
if args.list:
|
||||
print(*config.SHAPES.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
if args.list_sizes:
|
||||
print(*(f"{set_name:<15}{config.SHAPES[set_name].download_size_str}" for set_name in (set_names or config.SHAPES.keys())), sep="\n")
|
||||
exit()
|
||||
|
||||
try:
|
||||
url_dict \
|
||||
= {config.SHAPES[set_name].url : config.SHAPES[set_name].fname for set_name in set_names if args.shapes} \
|
||||
| {config.GROUND_TRUTHS[set_name].url : config.GROUND_TRUTHS[set_name].fname for set_name in set_names if args.gts}
|
||||
except KeyError:
|
||||
print("Error: unrecognized object name:", *set(set_names).difference(config.SHAPES.keys()), sep="\n")
|
||||
exit(1)
|
||||
|
||||
if not url_dict:
|
||||
if set_names and not (args.shapes or args.gts):
|
||||
print("Error: Provide at least one of --shapes of --gts")
|
||||
else:
|
||||
print("Error: No object set was selected for download!")
|
||||
exit(1)
|
||||
|
||||
if args.list_urls:
|
||||
print(*url_dict.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
print("Download start")
|
||||
any_downloaded = download_and_extract(
|
||||
target_dir = Path(args.dir),
|
||||
url_dict = url_dict,
|
||||
force = args.force,
|
||||
silent = args.silent,
|
||||
)
|
||||
if not any_downloaded:
|
||||
print("Everything has already been downloaded, skipping.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
137
ifield/data/coseg/preprocess.py
Normal file
137
ifield/data/coseg/preprocess.py
Normal file
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
from . import config, read
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Preprocess the COSEG dataset. Depends on `download-coseg --shapes ...` having been run.
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument # brevity
|
||||
|
||||
arg("items", nargs="*", default=[],
|
||||
help="Which object-set[/model-id] to process, defaults to all downloaded. Format: OBJECT-SET[/MODEL-ID]")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
arg("--force", action="store_true",
|
||||
help="Overwrite existing files")
|
||||
arg("--list-models", action="store_true",
|
||||
help="List the downloaded models available for preprocessing")
|
||||
arg("--list-object-sets", action="store_true",
|
||||
help="List the downloaded object-sets available for preprocessing")
|
||||
arg("--list-pages", type=int, default=None,
|
||||
help="List the downloaded models available for preprocessing, paginated into N pages.")
|
||||
arg("--page", nargs=2, type=int, default=[0, 1],
|
||||
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
|
||||
|
||||
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
|
||||
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
|
||||
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
|
||||
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sphere-scan", action="store_true",
|
||||
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
|
||||
|
||||
arg3 = parser.add_argument_group("modifiers").add_argument # brevity
|
||||
arg3("--n-sphere-points", type=int, default=4000,
|
||||
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
|
||||
arg3("--compute-miss-distances", action="store_true",
|
||||
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
|
||||
arg3("--fill-missing-uv-points", action="store_true",
|
||||
help="TODO")
|
||||
arg3("--no-filter-backhits", action="store_true",
|
||||
help="Do not filter scan hits on backside of mesh faces.")
|
||||
arg3("--no-unit-sphere", action="store_true",
|
||||
help="Do not center the objects to the unit sphere.")
|
||||
arg3("--convert-ok", action="store_true",
|
||||
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
|
||||
arg3("--debug", action="store_true",
|
||||
help="Abort on failiure.")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list_models or args.list_object_sets or args.list_pages):
|
||||
parser.error("no preprocessing target selected") # exits
|
||||
|
||||
config.DATA_PATH = Path(args.dir)
|
||||
|
||||
object_sets = [i for i in args.items if "/" not in i]
|
||||
models = [i.split("/") for i in args.items if "/" in i]
|
||||
|
||||
# convert/expand synsets to models
|
||||
# they are mutually exclusive
|
||||
if object_sets: assert not models
|
||||
if models: assert not object_sets
|
||||
if not models:
|
||||
models = read.list_model_ids(tuple(object_sets) or None)
|
||||
|
||||
if args.list_models:
|
||||
try:
|
||||
print(*(f"{object_set_id}/{model_id}" for object_set_id, model_id in models), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.list_object_sets:
|
||||
try:
|
||||
print(*sorted(set(object_set_id for object_set_id, model_id in models)), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.list_pages is not None:
|
||||
try:
|
||||
print(*(
|
||||
f"--page {i} {args.list_pages} {object_set_id}/{model_id}"
|
||||
for object_set_id, model_id in models
|
||||
for i in range(args.list_pages)
|
||||
), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.precompute_mesh_sv_scan_clouds:
|
||||
read.precompute_mesh_scan_point_clouds(
|
||||
models,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sv_scan_uvs:
|
||||
read.precompute_mesh_scan_uvs(
|
||||
models,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
fill_missing_points = args.fill_missing_uv_points,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sphere_scan:
|
||||
read.precompute_mesh_sphere_scan(
|
||||
models,
|
||||
sphere_points = args.n_sphere_points,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
290
ifield/data/coseg/read.py
Normal file
290
ifield/data/coseg/read.py
Normal file
@ -0,0 +1,290 @@
|
||||
from . import config
|
||||
from ..common import points
|
||||
from ..common import processing
|
||||
from ..common.scan import SingleViewScan, SingleViewUVScan
|
||||
from ..common.types import MalformedMesh
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Iterable
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are functions for reading and preprocessing coseg benchmark data
|
||||
|
||||
There are essentially a few sets per object:
|
||||
"img" - meaning the RGBD images (none found in coseg)
|
||||
"mesh_scans" - meaning synthetic scans of a mesh
|
||||
"""
|
||||
|
||||
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0)) # rotate to be upright in pyrender
|
||||
MESH_POSE_CORRECTIONS = { # to gain a shared canonical orientation
|
||||
("four-legged", 381): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 382): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 383): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 384): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 385): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 386): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 387): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 388): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 389): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 390): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 391): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 392): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 393): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 394): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 395): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 396): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 397): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 398): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 399): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 400): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
}
|
||||
|
||||
|
||||
ModelUid = tuple[str, int]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_object_sets() -> list[str]:
|
||||
return sorted(
|
||||
object_set.name
|
||||
for object_set in config.DATA_PATH.iterdir()
|
||||
if (object_set / "shapes").is_dir() and object_set.name != "archive"
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_model_ids(object_sets: Optional[tuple[str]] = None) -> list[ModelUid]:
|
||||
return sorted(
|
||||
(object_set.name, int(model.stem))
|
||||
for object_set in config.DATA_PATH.iterdir()
|
||||
if (object_set / "shapes").is_dir() and object_set.name != "archive" and (object_sets is None or object_set.name in object_sets)
|
||||
for model in (object_set / "shapes").iterdir()
|
||||
if model.is_file() and model.suffix == ".off"
|
||||
)
|
||||
|
||||
def list_model_id_strings(object_sets: Optional[tuple[str]] = None) -> list[str]:
|
||||
return [model_uid_to_string(object_set_id, model_id) for object_set_id, model_id in list_model_ids(object_sets)]
|
||||
|
||||
def model_uid_to_string(object_set_id: str, model_id: int) -> str:
|
||||
return f"{object_set_id}-{model_id}"
|
||||
|
||||
def model_id_string_to_uid(model_string_uid: str) -> ModelUid:
|
||||
object_set, split, model = model_string_uid.rpartition("-")
|
||||
assert split == "-"
|
||||
return (object_set, int(model))
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
|
||||
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)
|
||||
|
||||
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
|
||||
return (
|
||||
f"{'np'[theta>=0]}{abs(theta):.2f}"
|
||||
f"{'np'[phi >=0]}{abs(phi) :.2f}"
|
||||
).replace(".", "d")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
|
||||
out = [
|
||||
mesh_scan_identifier(phi=phi, theta=theta)
|
||||
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
|
||||
]
|
||||
assert len(out) == len(set(out))
|
||||
return out
|
||||
|
||||
# ===
|
||||
|
||||
def read_mesh(object_set_id: str, model_id: int) -> trimesh.Trimesh:
|
||||
path = config.DATA_PATH / object_set_id / "shapes" / f"{model_id}.off"
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"{path = }")
|
||||
try:
|
||||
mesh = trimesh.load(path, force="mesh")
|
||||
except Exception as e:
|
||||
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
|
||||
|
||||
pose = MESH_POSE_CORRECTIONS.get((object_set_id, int(model_id)))
|
||||
mesh.apply_transform(pose @ MESH_TRANSFORM_SKYWARD if pose is not None else MESH_TRANSFORM_SKYWARD)
|
||||
return mesh
|
||||
|
||||
# === single-view scan clouds
|
||||
|
||||
def compute_mesh_scan_point_cloud(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
phi : float,
|
||||
theta : float,
|
||||
*,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_uv(object_set_id, model_id, phi=phi, theta=theta).to_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_point_clouds(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_point_cloud_h5_fnames(models, pose_identifiers, n_poses=n_poses)
|
||||
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
|
||||
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
|
||||
pretty_identifiers = [
|
||||
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for object_set_id, model_id in models
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_point_cloud(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_point_cloud(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
return SingleViewScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_point_cloud_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
for object_set_id, model_id in models
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
|
||||
# === single-view UV scan clouds
|
||||
|
||||
def compute_mesh_scan_uv(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
phi : float,
|
||||
theta : float,
|
||||
*,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_point_cloud(object_set_id, model_id, phi=phi, theta=theta).to_uv_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewUVScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_uvs(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_uv_h5_fnames(models, pose_identifiers, n_poses=n_poses)
|
||||
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
|
||||
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
|
||||
pretty_identifiers = [
|
||||
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for object_set_id, model_id in models
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
def computer(pretty_identifier: str) -> SingleViewUVScan:
|
||||
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_uv(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_uv(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_uv_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
for object_set_id, model_id in models
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
|
||||
# === sphere-view (UV) scan clouds
|
||||
|
||||
def compute_mesh_sphere_scan(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
*,
|
||||
compute_normals : bool = True,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
return scan
|
||||
|
||||
def precompute_mesh_sphere_scan(models: Iterable[ModelUid], *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
|
||||
"precomputes all sphere scan clouds and stores them as HDF5 datasets"
|
||||
paths = list_mesh_sphere_scan_h5_fnames(models)
|
||||
identifiers = [model_uid_to_string(*i) for i in models]
|
||||
def computer(identifier: str) -> SingleViewScan:
|
||||
object_set_id, model_id = model_id_string_to_uid(identifier)
|
||||
return compute_mesh_sphere_scan(object_set_id, model_id, **kw)
|
||||
return processing.precompute_data(computer, identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_mesh_sphere_scan(object_set_id: str, model_id: int) -> SingleViewUVScan:
|
||||
file = config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_sphere_scan_h5_fnames(models: Iterable[ModelUid]) -> list[str]:
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
|
||||
for object_set_id, model_id in models
|
||||
]
|
76
ifield/data/stanford/__init__.py
Normal file
76
ifield/data/stanford/__init__.py
Normal file
@ -0,0 +1,76 @@
|
||||
from ..config import data_path_get, data_path_persist
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
# Data source:
|
||||
# http://graphics.stanford.edu/data/3Dscanrep/
|
||||
|
||||
__ALL__ = ["config", "Model", "MODELS"]
|
||||
|
||||
@(lambda x: x()) # singleton
|
||||
class config:
|
||||
DATA_PATH = property(
|
||||
doc = """
|
||||
Path to the dataset. The following envvars override it:
|
||||
${IFIELD_DATA_MODELS}/stanford
|
||||
${IFIELD_DATA_MODELS_STANFORD}
|
||||
""",
|
||||
fget = lambda self: data_path_get ("stanford"),
|
||||
fset = lambda self, path: data_path_persist("stanford", path),
|
||||
)
|
||||
|
||||
@property
|
||||
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
|
||||
return [
|
||||
self.DATA_PATH / "downloaded.json",
|
||||
]
|
||||
|
||||
Model = namedtuple("Model", "url mesh_fname download_size_str")
|
||||
MODELS: dict[str, Model] = {
|
||||
"bunny": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/bunny.tar.gz",
|
||||
"bunny/reconstruction/bun_zipper.ply",
|
||||
"4.89M",
|
||||
),
|
||||
"drill_bit": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/drill.tar.gz",
|
||||
"drill/reconstruction/drill_shaft_vrip.ply",
|
||||
"555k",
|
||||
),
|
||||
"happy_buddha": Model(
|
||||
# religious symbol
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/happy/happy_recon.tar.gz",
|
||||
"happy_recon/happy_vrip.ply",
|
||||
"14.5M",
|
||||
),
|
||||
"dragon": Model(
|
||||
# symbol of Chinese culture
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/dragon/dragon_recon.tar.gz",
|
||||
"dragon_recon/dragon_vrip.ply",
|
||||
"11.2M",
|
||||
),
|
||||
"armadillo": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/armadillo/Armadillo.ply.gz",
|
||||
"armadillo.ply.gz",
|
||||
"3.87M",
|
||||
),
|
||||
"lucy": Model(
|
||||
# Christian angel
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/lucy.tar.gz",
|
||||
"lucy.ply",
|
||||
"322M",
|
||||
),
|
||||
"asian_dragon": Model(
|
||||
# symbol of Chinese culture
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_dragon.ply.gz",
|
||||
"xyzrgb_dragon.ply.gz",
|
||||
"70.5M",
|
||||
),
|
||||
"thai_statue": Model(
|
||||
# Hindu religious significance
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_statuette.ply.gz",
|
||||
"xyzrgb_statuette.ply.gz",
|
||||
"106M",
|
||||
),
|
||||
}
|
129
ifield/data/stanford/download.py
Normal file
129
ifield/data/stanford/download.py
Normal file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
from . import config
|
||||
from ...utils.helpers import make_relative
|
||||
from ..common import download
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Iterable
|
||||
import argparse
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
|
||||
def is_downloaded(*a, **kw):
|
||||
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
|
||||
|
||||
def download_and_extract(target_dir: Path, url_list: Iterable[str], *, force=False, silent=False) -> bool:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ret = False
|
||||
for url in url_list:
|
||||
if not force:
|
||||
if is_downloaded(target_dir, url): continue
|
||||
if not download.check_url(url):
|
||||
print("ERROR:", url)
|
||||
continue
|
||||
ret = True
|
||||
|
||||
data = download.download_data(url, silent=silent, label=str(Path(url).name))
|
||||
|
||||
print("extracting...")
|
||||
if url.endswith(".ply.gz"):
|
||||
fname = target_dir / "meshes" / url.split("/")[-1].lower()
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
with fname.open("wb") as f:
|
||||
f.write(data)
|
||||
elif url.endswith(".tar.gz"):
|
||||
with tarfile.open(fileobj=io.BytesIO(data)) as tar:
|
||||
for member in tar.getmembers():
|
||||
if not member.isfile(): continue
|
||||
if member.name.startswith("/"): continue
|
||||
if member.name.startswith("."): continue
|
||||
if Path(member.name).name.startswith("."): continue
|
||||
tar.extract(member, target_dir / "meshes")
|
||||
del tar
|
||||
else:
|
||||
raise NotImplementedError(f"Extraction for {str(Path(url).name)} unknown")
|
||||
|
||||
is_downloaded(target_dir, url, add=True)
|
||||
del data
|
||||
|
||||
return ret
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Download The Stanford 3D Scanning Repository models.
|
||||
More info: http://graphics.stanford.edu/data/3Dscanrep/
|
||||
|
||||
Example:
|
||||
|
||||
download-stanford bunny
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument
|
||||
|
||||
arg("objects", nargs="*", default=[],
|
||||
help="Which objects to download, defaults to none.")
|
||||
arg("--all", action="store_true",
|
||||
help="Download all objects")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
|
||||
arg("--list", action="store_true",
|
||||
help="Lists all the objects")
|
||||
arg("--list-urls", action="store_true",
|
||||
help="Lists the urls to download")
|
||||
arg("--list-sizes", action="store_true",
|
||||
help="Lists the download size of each model")
|
||||
arg("--silent", action="store_true",
|
||||
help="")
|
||||
arg("--force", action="store_true",
|
||||
help="Download again even if already downloaded")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
|
||||
obj_names = sorted(set(args.objects))
|
||||
if args.all:
|
||||
assert not obj_names
|
||||
obj_names = sorted(config.MODELS.keys())
|
||||
if not obj_names and args.list_urls: config.MODELS.keys()
|
||||
|
||||
if args.list:
|
||||
print(*config.MODELS.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
if args.list_sizes:
|
||||
print(*(f"{obj_name:<15}{config.MODELS[obj_name].download_size_str}" for obj_name in (obj_names or config.MODELS.keys())), sep="\n")
|
||||
exit()
|
||||
|
||||
try:
|
||||
url_list = [config.MODELS[obj_name].url for obj_name in obj_names]
|
||||
except KeyError:
|
||||
print("Error: unrecognized object name:", *set(obj_names).difference(config.MODELS.keys()), sep="\n")
|
||||
exit(1)
|
||||
|
||||
if not url_list:
|
||||
print("Error: No object set was selected for download!")
|
||||
exit(1)
|
||||
|
||||
if args.list_urls:
|
||||
print(*url_list, sep="\n")
|
||||
exit()
|
||||
|
||||
|
||||
print("Download start")
|
||||
any_downloaded = download_and_extract(
|
||||
target_dir = Path(args.dir),
|
||||
url_list = url_list,
|
||||
force = args.force,
|
||||
silent = args.silent,
|
||||
)
|
||||
if not any_downloaded:
|
||||
print("Everything has already been downloaded, skipping.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
118
ifield/data/stanford/preprocess.py
Normal file
118
ifield/data/stanford/preprocess.py
Normal file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
from . import config, read
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Preprocess the Stanford models. Depends on `download-stanford` having been run.
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument # brevity
|
||||
|
||||
arg("objects", nargs="*", default=[],
|
||||
help="Which objects to process, defaults to all downloaded")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
arg("--force", action="store_true",
|
||||
help="Overwrite existing files")
|
||||
arg("--list", action="store_true",
|
||||
help="List the downloaded models available for preprocessing")
|
||||
arg("--list-pages", type=int, default=None,
|
||||
help="List the downloaded models available for preprocessing, paginated into N pages.")
|
||||
arg("--page", nargs=2, type=int, default=[0, 1],
|
||||
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
|
||||
|
||||
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
|
||||
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
|
||||
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
|
||||
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sphere-scan", action="store_true",
|
||||
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
|
||||
|
||||
arg3 = parser.add_argument_group("ray-scan modifiers").add_argument # brevity
|
||||
arg3("--n-sphere-points", type=int, default=4000,
|
||||
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
|
||||
arg3("--compute-miss-distances", action="store_true",
|
||||
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
|
||||
arg3("--fill-missing-uv-points", action="store_true",
|
||||
help="TODO")
|
||||
arg3("--no-filter-backhits", action="store_true",
|
||||
help="Do not filter scan hits on backside of mesh faces.")
|
||||
arg3("--no-unit-sphere", action="store_true",
|
||||
help="Do not center the objects to the unit sphere.")
|
||||
arg3("--convert-ok", action="store_true",
|
||||
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
|
||||
arg3("--debug", action="store_true",
|
||||
help="Abort on failiure.")
|
||||
|
||||
arg5 = parser.add_argument_group("Shared modifiers").add_argument # brevity
|
||||
arg5("--scan-resolution", type=int, default=400,
|
||||
help="The resolution of the depth map rendered to sample points. Becomes x*x")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser: argparse.ArgumentParser = make_parser()):
|
||||
args = parser.parse_args()
|
||||
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list or args.list_pages):
|
||||
parser.error("no preprocessing target selected") # exits
|
||||
|
||||
config.DATA_PATH = Path(args.dir)
|
||||
obj_names = args.objects or read.list_object_names()
|
||||
|
||||
if args.list:
|
||||
print(*obj_names, sep="\n")
|
||||
parser.exit()
|
||||
|
||||
if args.list_pages is not None:
|
||||
print(*(
|
||||
f"--page {i} {args.list_pages} {obj_name}"
|
||||
for obj_name in obj_names
|
||||
for i in range(args.list_pages)
|
||||
), sep="\n")
|
||||
parser.exit()
|
||||
|
||||
if args.precompute_mesh_sv_scan_clouds:
|
||||
read.precompute_mesh_scan_point_clouds(
|
||||
obj_names,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sv_scan_uvs:
|
||||
read.precompute_mesh_scan_uvs(
|
||||
obj_names,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
fill_missing_points = args.fill_missing_uv_points,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sphere_scan:
|
||||
read.precompute_mesh_sphere_scan(
|
||||
obj_names,
|
||||
sphere_points = args.n_sphere_points,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
251
ifield/data/stanford/read.py
Normal file
251
ifield/data/stanford/read.py
Normal file
@ -0,0 +1,251 @@
|
||||
from . import config
|
||||
from ..common import points
|
||||
from ..common import processing
|
||||
from ..common.scan import SingleViewScan, SingleViewUVScan
|
||||
from ..common.types import MalformedMesh
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Optional, Iterable
|
||||
from pathlib import Path
|
||||
import gzip
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are functions for reading and preprocessing shapenet benchmark data
|
||||
|
||||
There are essentially a few sets per object:
|
||||
"img" - meaning the RGBD images (none found in stanford)
|
||||
"mesh_scans" - meaning synthetic scans of a mesh
|
||||
"""
|
||||
|
||||
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0))
|
||||
MESH_TRANSFORM_CANONICAL = { # to gain a shared canonical orientation
|
||||
"armadillo" : T.rotation_matrix(np.pi, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
|
||||
"asian_dragon" : T.rotation_matrix(-np.pi/2, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
|
||||
"bunny" : MESH_TRANSFORM_SKYWARD,
|
||||
"dragon" : MESH_TRANSFORM_SKYWARD,
|
||||
"drill_bit" : MESH_TRANSFORM_SKYWARD,
|
||||
"happy_buddha" : MESH_TRANSFORM_SKYWARD,
|
||||
"lucy" : T.rotation_matrix(np.pi, (0, 0, 1)),
|
||||
"thai_statue" : MESH_TRANSFORM_SKYWARD,
|
||||
}
|
||||
|
||||
def list_object_names() -> list[str]:
|
||||
# downloaded only:
|
||||
return [
|
||||
i for i, v in config.MODELS.items()
|
||||
if (config.DATA_PATH / "meshes" / v.mesh_fname).is_file()
|
||||
]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
|
||||
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)#, shift_theta=True
|
||||
|
||||
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
|
||||
return (
|
||||
f"{'np'[theta>=0]}{abs(theta):.2f}"
|
||||
f"{'np'[phi >=0]}{abs(phi) :.2f}"
|
||||
).replace(".", "d")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
|
||||
out = [
|
||||
mesh_scan_identifier(phi=phi, theta=theta)
|
||||
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
|
||||
]
|
||||
assert len(out) == len(set(out))
|
||||
return out
|
||||
|
||||
# ===
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def read_mesh(obj_name: str) -> trimesh.Trimesh:
|
||||
path = config.DATA_PATH / "meshes" / config.MODELS[obj_name].mesh_fname
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"{obj_name = } -> {str(path) = }")
|
||||
try:
|
||||
if path.suffixes[-1] == ".gz":
|
||||
with gzip.open(path, "r") as f:
|
||||
mesh = trimesh.load(f, file_type="".join(path.suffixes[:-1])[1:])
|
||||
else:
|
||||
mesh = trimesh.load(path)
|
||||
except Exception as e:
|
||||
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
|
||||
|
||||
# rotate to be upright in pyrender
|
||||
mesh.apply_transform(MESH_TRANSFORM_CANONICAL.get(obj_name, MESH_TRANSFORM_SKYWARD))
|
||||
|
||||
return mesh
|
||||
|
||||
# === single-view scan clouds
|
||||
|
||||
def compute_mesh_scan_point_cloud(
|
||||
obj_name : str,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
compute_miss_distances : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False, # this does not respect the other hparams
|
||||
**kw,
|
||||
) -> SingleViewScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_uv(obj_name, phi=phi, theta=theta).to_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(obj_name)
|
||||
return SingleViewScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
compute_miss_distances = compute_miss_distances,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def precompute_mesh_scan_point_clouds(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_point_cloud_h5_fnames(obj_names, pose_identifiers)
|
||||
mlen = max(map(len, config.MODELS.keys()))
|
||||
pretty_identifiers = [
|
||||
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for obj_name in obj_names
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
@wraps(compute_mesh_scan_point_cloud)
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_point_cloud(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_point_cloud_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
|
||||
for obj_name in obj_names
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
# === single-view UV scan clouds
|
||||
|
||||
def compute_mesh_scan_uv(
|
||||
obj_name : str,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta).to_uv_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(obj_name)
|
||||
scan = SingleViewUVScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_uvs(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_uv_h5_fnames(obj_names, pose_identifiers)
|
||||
mlen = max(map(len, config.MODELS.keys()))
|
||||
pretty_identifiers = [
|
||||
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for obj_name in obj_names
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
@wraps(compute_mesh_scan_uv)
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_uv(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_uv(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_uv_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
|
||||
for obj_name in obj_names
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
# === sphere-view (UV) scan clouds
|
||||
|
||||
def compute_mesh_sphere_scan(
|
||||
obj_name : str,
|
||||
*,
|
||||
compute_normals : bool = True,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
mesh = read_mesh(obj_name)
|
||||
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
return scan
|
||||
|
||||
def precompute_mesh_sphere_scan(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
paths = list_mesh_sphere_scan_h5_fnames(obj_names)
|
||||
@wraps(compute_mesh_sphere_scan)
|
||||
def computer(obj_name: str) -> SingleViewScan:
|
||||
return compute_mesh_sphere_scan(obj_name, **kw)
|
||||
return processing.precompute_data(computer, obj_names, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_mesh_sphere_scan(obj_name) -> SingleViewUVScan:
|
||||
file = config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_sphere_scan_h5_fnames(obj_names: Iterable[str]) -> list[Path]:
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
|
||||
for obj_name in obj_names
|
||||
]
|
3
ifield/datasets/__init__.py
Normal file
3
ifield/datasets/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Submodules defining various `torch.utils.data.Dataset`
|
||||
"""
|
196
ifield/datasets/common.py
Normal file
196
ifield/datasets/common.py
Normal file
@ -0,0 +1,196 @@
|
||||
from ..data.common.h5_dataclasses import H5Dataclass, PathLike
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from typing import Any, Iterable, Hashable, TypeVar, Iterator, Callable
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_H5 = TypeVar("T_H5", bound=H5Dataclass)
|
||||
|
||||
|
||||
class TransformableDatasetMixin:
|
||||
def __init_subclass__(cls):
|
||||
if getattr(cls, "_transformable_mixin_no_override_getitem", False):
|
||||
pass
|
||||
elif issubclass(cls, Dataset):
|
||||
if cls.__getitem__ is not cls._transformable_mixin_getitem_wrapper:
|
||||
cls._transformable_mixin_inner_getitem = cls.__getitem__
|
||||
cls.__getitem__ = cls._transformable_mixin_getitem_wrapper
|
||||
elif issubclass(cls, IterableDataset):
|
||||
if cls.__iter__ is not cls._transformable_mixin_iter_wrapper:
|
||||
cls._transformable_mixin_inner_iter = cls.__iter__
|
||||
cls.__iter__ = cls._transformable_mixin_iter_wrapper
|
||||
else:
|
||||
raise TypeError(f"{cls.__name__!r} is neither a Dataset nor a IterableDataset!")
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
self._transforms = []
|
||||
|
||||
# works as a decorator
|
||||
def map(self: T, func: callable = None, /, args=[], **kw) -> T:
|
||||
def wrapper(func) -> T:
|
||||
if args or kw:
|
||||
func = partial(func, *args, **kw)
|
||||
self._transforms.append(func)
|
||||
return self
|
||||
|
||||
if func is None:
|
||||
return wrapper
|
||||
else:
|
||||
return wrapper(func)
|
||||
|
||||
|
||||
def _transformable_mixin_getitem_wrapper(self, index: int):
|
||||
if not self._transforms:
|
||||
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, no transforms)
|
||||
else:
|
||||
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, has transforms)
|
||||
for f in self._transforms:
|
||||
out = f(out) # (TransformableDatasetMixin)
|
||||
return out
|
||||
|
||||
def _transformable_mixin_iter_wrapper(self):
|
||||
if not self._transforms:
|
||||
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, no transforms)
|
||||
else:
|
||||
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, has transforms)
|
||||
for f in self._transforms:
|
||||
out = map(f, out) # (TransformableDatasetMixin)
|
||||
return out
|
||||
|
||||
|
||||
class TransformedDataset(Dataset, TransformableDatasetMixin):
|
||||
# used to wrap an another dataset
|
||||
def __init__(self, dataset: Dataset, transforms: Iterable[callable]):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
for i in transforms:
|
||||
self.map(i)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.dataset[index] # (TransformedDataset)
|
||||
|
||||
|
||||
class TransformExtendedDataset(Dataset, TransformableDatasetMixin):
|
||||
_transformable_mixin_no_override_getitem = True
|
||||
def __init__(self, dataset: Dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset) * len(self._transforms)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
n = len(self._transforms)
|
||||
assert n > 0, f"{len(self._transforms) = }"
|
||||
|
||||
item = index // n
|
||||
transform = self._transforms[index % n]
|
||||
return transform(self.dataset[item])
|
||||
|
||||
|
||||
class CachedDataset(Dataset):
|
||||
# used to wrap an another dataset
|
||||
def __init__(self, dataset: Dataset, cache_size: int | None):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
if cache_size is not None and cache_size > 0:
|
||||
self.cached_getter = lru_cache(cache_size, self.dataset.__getitem__)
|
||||
else:
|
||||
self.cached_getter = self.dataset.__getitem__
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.cached_getter(index)
|
||||
|
||||
|
||||
class AutodecoderDataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
keys : Iterable[Hashable],
|
||||
dataset : Dataset,
|
||||
):
|
||||
super().__init__()
|
||||
self.ad_mapping = list(keys)
|
||||
self.dataset = dataset
|
||||
if len(self.ad_mapping) != len(dataset):
|
||||
raise ValueError(f"__len__ mismatch between keys and dataset: {len(self.ad_mapping)} != {len(dataset)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[Hashable, Any]:
|
||||
return self.ad_mapping[index], self.dataset[index] # (AutodecoderDataset)
|
||||
|
||||
def keys(self) -> list[Hashable]:
|
||||
return self.ad_mapping
|
||||
|
||||
def values(self) -> Iterator:
|
||||
return iter(self.dataset)
|
||||
|
||||
def items(self) -> Iterable[tuple[Hashable, Any]]:
|
||||
return zip(self.ad_mapping, self.dataset)
|
||||
|
||||
|
||||
class FunctionDataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
getter : Callable[[Hashable], T],
|
||||
keys : list[Hashable],
|
||||
cache_size : int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
if cache_size is not None and cache_size > 0:
|
||||
getter = lru_cache(cache_size)(getter)
|
||||
self.getter = getter
|
||||
self.keys = keys
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.keys)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
return self.getter(self.keys[index])
|
||||
|
||||
class H5Dataset(FunctionDataset):
|
||||
def __init__(self,
|
||||
h5_dataclass_cls : type[T_H5],
|
||||
fnames : list[PathLike],
|
||||
**kw,
|
||||
):
|
||||
super().__init__(
|
||||
getter = h5_dataclass_cls.from_h5_file,
|
||||
keys = fnames,
|
||||
**kw,
|
||||
)
|
||||
|
||||
class PaginatedH5Dataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
h5_dataclass_cls : type[T_H5],
|
||||
fnames : list[PathLike],
|
||||
n_pages : int = 10,
|
||||
require_even_pages : bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.h5_dataclass_cls = h5_dataclass_cls
|
||||
self.fnames = fnames
|
||||
self.n_pages = n_pages
|
||||
self.require_even_pages = require_even_pages
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.fnames) * self.n_pages
|
||||
|
||||
def __getitem__(self, index: int) -> T_H5:
|
||||
item = index // self.n_pages
|
||||
page = index % self.n_pages
|
||||
|
||||
return self.h5_dataclass_cls.from_h5_file( # (PaginatedH5Dataset)
|
||||
fname = self.fname[item],
|
||||
page = page,
|
||||
n_pages = self.n_pages,
|
||||
require_even_pages = self.require_even_pages,
|
||||
)
|
40
ifield/datasets/coseg.py
Normal file
40
ifield/datasets/coseg.py
Normal file
@ -0,0 +1,40 @@
|
||||
from . import common
|
||||
from ..data.coseg import config
|
||||
from ..data.coseg import read
|
||||
from ..data.common import scan
|
||||
from typing import Iterable, Optional, Union
|
||||
import os
|
||||
|
||||
|
||||
class SingleViewUVScanDataset(common.H5Dataset):
|
||||
def __init__(self,
|
||||
object_sets : tuple[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if not object_sets:
|
||||
raise ValueError("'object_sets' cannot be empty!")
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
if data_path is not None:
|
||||
config.DATA_PATH = data_path
|
||||
models = read.list_model_ids(object_sets)
|
||||
fnames = read.list_mesh_scan_uv_h5_fnames(models, identifiers)
|
||||
super().__init__(
|
||||
h5_dataclass_cls = scan.SingleViewUVScan,
|
||||
fnames = fnames,
|
||||
)
|
||||
|
||||
class AutodecoderSingleViewUVScanDataset(common.AutodecoderDataset):
|
||||
def __init__(self,
|
||||
object_sets : tuple[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
# here do this step first, such that all the duplicate strings reference the same object
|
||||
super().__init__(
|
||||
keys = [key for key in read.list_model_id_strings(object_sets) for _ in range(len(identifiers))],
|
||||
dataset = SingleViewUVScanDataset(object_sets, identifiers, data_path=data_path),
|
||||
)
|
64
ifield/datasets/stanford.py
Normal file
64
ifield/datasets/stanford.py
Normal file
@ -0,0 +1,64 @@
|
||||
from . import common
|
||||
from ..data.stanford import config
|
||||
from ..data.stanford import read
|
||||
from ..data.common import scan
|
||||
from typing import Iterable, Optional, Union
|
||||
import os
|
||||
|
||||
|
||||
class SingleViewUVScanDataset(common.H5Dataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if not obj_names:
|
||||
raise ValueError("'obj_names' cannot be empty!")
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
if data_path is not None:
|
||||
config.DATA_PATH = data_path
|
||||
fnames = read.list_mesh_scan_uv_h5_fnames(obj_names, identifiers)
|
||||
super().__init__(
|
||||
h5_dataclass_cls = scan.SingleViewUVScan,
|
||||
fnames = fnames,
|
||||
)
|
||||
|
||||
class AutodecoderSingleViewUVScanDataset(common.AutodecoderDataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
super().__init__(
|
||||
keys = [obj_name for obj_name in obj_names for _ in range(len(identifiers))],
|
||||
dataset = SingleViewUVScanDataset(obj_names, identifiers, data_path=data_path),
|
||||
)
|
||||
|
||||
|
||||
class SphereScanDataset(common.H5Dataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if not obj_names:
|
||||
raise ValueError("'obj_names' cannot be empty!")
|
||||
if data_path is not None:
|
||||
config.DATA_PATH = data_path
|
||||
fnames = read.list_mesh_sphere_scan_h5_fnames(obj_names)
|
||||
super().__init__(
|
||||
h5_dataclass_cls = scan.SingleViewUVScan,
|
||||
fnames = fnames,
|
||||
)
|
||||
|
||||
class AutodecoderSphereScanDataset(common.AutodecoderDataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
super().__init__(
|
||||
keys = obj_names,
|
||||
dataset = SphereScanDataset(obj_names, data_path=data_path),
|
||||
)
|
258
ifield/logging.py
Normal file
258
ifield/logging.py
Normal file
@ -0,0 +1,258 @@
|
||||
from . import param
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from typing import Union, Literal, Optional, TypeVar
|
||||
import concurrent.futures
|
||||
import psutil
|
||||
import pytorch_lightning as pl
|
||||
import statistics
|
||||
import threading
|
||||
import time
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# from https://github.com/yaml/pyyaml/issues/240#issuecomment-1018712495
|
||||
def str_presenter(dumper, data):
|
||||
"""configures yaml for dumping multiline strings
|
||||
Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data"""
|
||||
if len(data.splitlines()) > 1: # check for multiline string
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||
yaml.add_representer(str, str_presenter)
|
||||
|
||||
|
||||
LoggerStr = Literal[
|
||||
#"csv",
|
||||
"tensorboard",
|
||||
#"mlflow",
|
||||
#"comet",
|
||||
#"neptune",
|
||||
#"wandb",
|
||||
None]
|
||||
try:
|
||||
Logger = TypeVar("L", bound=pl.loggers.Logger)
|
||||
except AttributeError:
|
||||
Logger = TypeVar("L", bound=pl.loggers.base.LightningLoggerBase)
|
||||
|
||||
def make_logger(
|
||||
experiment_name : str,
|
||||
default_root_dir : Union[str, Path], # from pl.Trainer
|
||||
save_dir : Union[str, Path],
|
||||
type : LoggerStr = "tensorboard",
|
||||
project : str = "ifield",
|
||||
) -> Optional[Logger]:
|
||||
if type is None:
|
||||
return None
|
||||
elif type == "tensorboard":
|
||||
return pl.loggers.TensorBoardLogger(
|
||||
name = "tensorboard",
|
||||
save_dir = Path(default_root_dir) / save_dir,
|
||||
version = experiment_name,
|
||||
log_graph = True,
|
||||
)
|
||||
raise ValueError(f"make_logger({type=})")
|
||||
|
||||
def make_jinja_template(*, save_dir: Union[None, str, Path], **kw) -> str:
|
||||
return param.make_jinja_template(make_logger,
|
||||
defaults = dict(
|
||||
save_dir = save_dir,
|
||||
),
|
||||
exclude_list = {
|
||||
"experiment_name",
|
||||
"default_root_dir",
|
||||
},
|
||||
**({"name": "logging"} | kw),
|
||||
)
|
||||
|
||||
def get_checkpoints(experiment_name, default_root_dir, save_dir, type, project) -> list[Path]:
|
||||
if type is None:
|
||||
return None
|
||||
if type == "tensorboard":
|
||||
folder = Path(default_root_dir) / save_dir / "tensorboard" / experiment_name
|
||||
return folder.glob("*.ckpt")
|
||||
if type == "mlflow":
|
||||
raise NotImplementedError(f"{type=}")
|
||||
if type == "wandb":
|
||||
raise NotImplementedError(f"{type=}")
|
||||
raise ValueError(f"get_checkpoint({type=})")
|
||||
|
||||
|
||||
def log_config(_logger: Logger, **kwargs: Union[str, dict, list, int, float]):
|
||||
assert isinstance(_logger, pl.loggers.Logger) \
|
||||
or isinstance(_logger, pl.loggers.base.LightningLoggerBase), _logger
|
||||
|
||||
_logger: pl.loggers.TensorBoardLogger
|
||||
_logger.log_hyperparams(params=kwargs)
|
||||
|
||||
@dataclass
|
||||
class ModelOutputMonitor(pl.callbacks.Callback):
|
||||
log_training : bool = True
|
||||
log_validation : bool = True
|
||||
|
||||
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
|
||||
|
||||
@staticmethod
|
||||
def _log_outputs(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, fname: str):
|
||||
if outputs is None:
|
||||
return
|
||||
elif isinstance(outputs, list) or isinstance(outputs, tuple):
|
||||
outputs = {
|
||||
f"loss[{i}]": v
|
||||
for i, v in enumerate(outputs)
|
||||
}
|
||||
elif isinstance(outputs, torch.Tensor):
|
||||
outputs = {
|
||||
"loss": outputs,
|
||||
}
|
||||
elif isinstance(outputs, dict):
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
sep = trainer.logger.group_separator
|
||||
pl_module.log_dict({
|
||||
f"{pl_module.__class__.__qualname__}.{fname}{sep}{k}":
|
||||
float(v.item()) if isinstance(v, torch.Tensor) else float(v)
|
||||
for k, v in outputs.items()
|
||||
}, sync_dist=True)
|
||||
|
||||
def on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, unused=0):
|
||||
if self.log_training:
|
||||
self._log_outputs(trainer, pl_module, outputs, "training_step")
|
||||
|
||||
def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0):
|
||||
if self.log_validation:
|
||||
self._log_outputs(trainer, pl_module, outputs, "validation_step")
|
||||
|
||||
class EpochTimeMonitor(pl.callbacks.Callback):
|
||||
__slots__ = [
|
||||
"epoch_start",
|
||||
"epoch_start_train",
|
||||
"epoch_start_validation",
|
||||
"epoch_start_test",
|
||||
"epoch_start_predict",
|
||||
]
|
||||
|
||||
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_train = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_validation = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_test = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_predict_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_predict = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_train
|
||||
del self.epoch_start_train
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_train_time" : t}, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_validation
|
||||
del self.epoch_start_validation
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_validation_time" : t}, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_test
|
||||
del self.epoch_start_validation
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_test_time" : t}, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_predict
|
||||
del self.epoch_start_validation
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_predict_time" : t}, step=trainer.global_step)
|
||||
|
||||
@dataclass
|
||||
class PsutilMonitor(pl.callbacks.Callback):
|
||||
sample_rate : float = 0.2 # times per second
|
||||
|
||||
_should_stop = False
|
||||
|
||||
@rank_zero_only
|
||||
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
|
||||
assert not hasattr(self, "_thread")
|
||||
|
||||
self._should_stop = False
|
||||
self._thread = threading.Thread(
|
||||
target = self.thread_target,
|
||||
name = self.thread_target.__qualname__,
|
||||
args = [trainer],
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
@rank_zero_only
|
||||
def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
assert getattr(self, "_thread", None) is not None
|
||||
self._should_stop = True
|
||||
del self._thread
|
||||
|
||||
def thread_target(self, trainer: pl.Trainer):
|
||||
uses_gpu = isinstance(trainer.accelerator, (pl.accelerators.GPUAccelerator, pl.accelerators.CUDAAccelerator))
|
||||
gpu_ids = trainer.device_ids
|
||||
|
||||
prefix = f"{self.__class__.__qualname__}{trainer.logger.group_separator}"
|
||||
|
||||
while not self._should_stop:
|
||||
step = trainer.global_step
|
||||
p = psutil.Process()
|
||||
|
||||
meminfo = p.memory_info()
|
||||
rss_ram = meminfo.rss / 1024**2 # MB
|
||||
vms_ram = meminfo.vms / 1024**2 # MB
|
||||
|
||||
util_per_cpu = psutil.cpu_percent(percpu=True)
|
||||
|
||||
util_per_cpu = [util_per_cpu[i] for i in p.cpu_affinity()]
|
||||
util_total = statistics.mean(util_per_cpu)
|
||||
|
||||
if uses_gpu:
|
||||
with concurrent.futures.ThreadPoolExecutor() as e:
|
||||
if hasattr(pl.accelerators, "cuda"):
|
||||
gpu_stats = e.map(pl.accelerators.cuda.get_nvidia_gpu_stats, gpu_ids)
|
||||
else:
|
||||
gpu_stats = e.map(pl.accelerators.gpu.get_nvidia_gpu_stats, gpu_ids)
|
||||
trainer.logger.log_metrics({
|
||||
f"{prefix}ram.rss" : rss_ram,
|
||||
f"{prefix}ram.vms" : vms_ram,
|
||||
f"{prefix}cpu.total" : util_total,
|
||||
**{ f"{prefix}cpu.{i:03}.utilization" : stat for i, stat in enumerate(util_per_cpu) },
|
||||
**{
|
||||
f"{prefix}gpu.{gpu_idx:02}.{key.split(' ',1)[0]}" : stat
|
||||
for gpu_idx, stats in zip(gpu_ids, gpu_stats)
|
||||
for key, stat in stats.items()
|
||||
},
|
||||
}, step = step)
|
||||
else:
|
||||
trainer.logger.log_metrics({
|
||||
f"{prefix}cpu.total" : util_total,
|
||||
**{ f"{prefix}cpu.{i:03}.utilization" : stat for i, stat in enumerate(util_per_cpu) },
|
||||
}, step = step)
|
||||
|
||||
time.sleep(1 / self.sample_rate)
|
||||
print("DAEMON END")
|
3
ifield/models/__init__.py
Normal file
3
ifield/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Contains Pytorch Models
|
||||
"""
|
159
ifield/models/conditioning.py
Normal file
159
ifield/models/conditioning.py
Normal file
@ -0,0 +1,159 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import nn, Tensor
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
from typing import Hashable, Union, Optional, KeysView, ValuesView, ItemsView, Any, Sequence
|
||||
import torch
|
||||
|
||||
|
||||
class RequiresConditioner(nn.Module, ABC): # mixin
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def n_latent_features(self) -> int:
|
||||
"This should provide the width of the conditioning feature vector"
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def latent_embeddings_init_std(self) -> float:
|
||||
"This should provide the standard deviation to initialize the latent features with. DeepSDF uses 0.01."
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def latent_embeddings() -> Optional[Tensor]:
|
||||
"""This property should return a tensor cotnaining all stored embeddings, for use in computing auto-decoder losses"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, batch: Any, batch_idx: int, optimizer_idx: int) -> Tensor:
|
||||
"This should, given a training batch, return the encoded conditioning vector"
|
||||
...
|
||||
|
||||
|
||||
class AutoDecoderModuleMixin(RequiresConditioner, ABC):
|
||||
"""
|
||||
Populates dunder methods making it behave as a mapping.
|
||||
The mapping indexes into a stored set of learnable embedding vectors.
|
||||
|
||||
Based on the auto-decoder architecture of
|
||||
J.J. Park, P. Florence, J. Straub, R. Newcombe, S. Lovegrove, DeepSDF:
|
||||
Learning Continuous Signed Distance Functions for Shape Representation, in:
|
||||
2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
|
||||
IEEE, Long Beach, CA, USA, 2019: pp. 165–174.
|
||||
https://doi.org/10.1109/CVPR.2019.00025.
|
||||
"""
|
||||
|
||||
_autodecoder_mapping: dict[Hashable, int]
|
||||
autodecoder_embeddings: nn.Parameter
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
@self._register_load_state_dict_pre_hook
|
||||
def hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if f"{prefix}_autodecoder_mapping" in state_dict:
|
||||
state_dict[f"{prefix}{_EXTRA_STATE_KEY_SUFFIX}"] = state_dict.pop(f"{prefix}_autodecoder_mapping")
|
||||
|
||||
class ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing(nn.UninitializedParameter):
|
||||
def copy_(self, other):
|
||||
self.materialize(other.shape, other.device, other.dtype)
|
||||
return self.copy_(other)
|
||||
self.autodecoder_embeddings = ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing()
|
||||
|
||||
# nn.Module interface
|
||||
|
||||
def get_extra_state(self):
|
||||
return {
|
||||
"ad_uids": getattr(self, "_autodecoder_mapping", {}),
|
||||
}
|
||||
|
||||
def set_extra_state(self, obj):
|
||||
if "ad_uids" not in obj: # backward compat
|
||||
self._autodecoder_mapping = obj
|
||||
else:
|
||||
self._autodecoder_mapping = obj["ad_uids"]
|
||||
|
||||
# RequiresConditioner interface
|
||||
|
||||
@property
|
||||
def latent_embeddings(self) -> Tensor:
|
||||
return self.autodecoder_embeddings
|
||||
|
||||
# my interface
|
||||
|
||||
def set_observation_ids(self, z_uids: set[Hashable]):
|
||||
assert self.latent_embeddings_init_std is not None, f"{self.__module__}.{self.__class__.__qualname__}.latent_embeddings_init_std"
|
||||
assert self.n_latent_features is not None, f"{self.__module__}.{self.__class__.__qualname__}.n_latent_features"
|
||||
assert self.latent_embeddings_init_std > 0, self.latent_embeddings_init_std
|
||||
assert self.n_latent_features > 0, self.n_latent_features
|
||||
|
||||
self._autodecoder_mapping = {
|
||||
k: i
|
||||
for i, k in enumerate(sorted(set(z_uids)))
|
||||
}
|
||||
|
||||
if not len(z_uids) == len(self._autodecoder_mapping):
|
||||
raise ValueError(f"Observation identifiers are not unique! {z_uids = }")
|
||||
|
||||
self.autodecoder_embeddings = nn.Parameter(
|
||||
torch.Tensor(len(self._autodecoder_mapping), self.n_latent_features)
|
||||
.normal_(mean=0, std=self.latent_embeddings_init_std)
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
|
||||
def add_key(self, z_uid: Hashable, z: Optional[Tensor] = None):
|
||||
if z_uid in self._autodecoder_mapping:
|
||||
raise ValueError(f"Observation identifier {z_uid!r} not unique!")
|
||||
|
||||
self._autodecoder_mapping[z_uid] = len(self._autodecoder_mapping)
|
||||
self.autodecoder_embeddings
|
||||
raise NotImplementedError
|
||||
|
||||
def __delitem__(self, z_uid: Hashable):
|
||||
i = self._autodecoder_mapping.pop(z_uid)
|
||||
for k, v in list(self._autodecoder_mapping.items()):
|
||||
if v > i:
|
||||
self._autodecoder_mapping[k] -= 1
|
||||
|
||||
with torch.no_grad():
|
||||
self.autodecoder_embeddings = nn.Parameter(torch.cat((
|
||||
self.autodecoder_embeddings.detach()[:i, :],
|
||||
self.autodecoder_embeddings.detach()[i+1:, :],
|
||||
), dim=0))
|
||||
|
||||
def __contains__(self, z_uid: Hashable) -> bool:
|
||||
return z_uid in self._autodecoder_mapping
|
||||
|
||||
def __getitem__(self, z_uids: Union[Hashable, Sequence[Hashable]]) -> Tensor:
|
||||
if isinstance(z_uids, tuple) or isinstance(z_uids, list):
|
||||
key = tuple(map(self._autodecoder_mapping.__getitem__, z_uids))
|
||||
else:
|
||||
key = self._autodecoder_mapping[z_uids]
|
||||
return self.autodecoder_embeddings[key, :]
|
||||
|
||||
def __iter__(self):
|
||||
return self._autodecoder_mapping.keys()
|
||||
|
||||
def keys(self) -> KeysView[Hashable]:
|
||||
"""
|
||||
lists the identifiers of each code
|
||||
"""
|
||||
return self._autodecoder_mapping.keys()
|
||||
|
||||
def values(self) -> ValuesView[Tensor]:
|
||||
return list(self.autodecoder_embeddings)
|
||||
|
||||
def items(self) -> ItemsView[Hashable, Tensor]:
|
||||
"""
|
||||
lists all the learned codes / latent vectors with their identifiers as keys
|
||||
"""
|
||||
return {
|
||||
k : self.autodecoder_embeddings[i]
|
||||
for k, i in self._autodecoder_mapping.items()
|
||||
}.items()
|
||||
|
||||
class EncoderModuleMixin(RequiresConditioner, ABC):
|
||||
@property
|
||||
def latent_embeddings(self) -> None:
|
||||
return None
|
589
ifield/models/intersection_fields.py
Normal file
589
ifield/models/intersection_fields.py
Normal file
@ -0,0 +1,589 @@
|
||||
from .. import param
|
||||
from ..modules.dtype import DtypeMixin
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from ..utils.loss import Schedulable, ensure_schedulables, HParamSchedule, HParamScheduleBase, Linear
|
||||
from ..utils.operators import diff
|
||||
from .conditioning import RequiresConditioner, AutoDecoderModuleMixin
|
||||
from .medial_atoms import MedialAtomNet
|
||||
from .orthogonal_plane import OrthogonalPlaneNet
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import TypedDict, Literal, Union, Hashable, Optional
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import os
|
||||
|
||||
LOG_ALL_METRICS = bool(int(os.environ.get("IFIELD_LOG_ALL_METRICS", "1")))
|
||||
|
||||
if __debug__:
|
||||
def broadcast_tensors(*tensors: torch.Tensor) -> list[torch.Tensor]:
|
||||
try:
|
||||
return torch.broadcast_tensors(*tensors)
|
||||
except RuntimeError as e:
|
||||
shapes = ", ".join(f"{chr(c)}.size={tuple(t.shape)}" for c, t in enumerate(tensors, ord("a")))
|
||||
raise ValueError(f"Could not broadcast tensors {shapes}.\n{str(e)}")
|
||||
else:
|
||||
broadcast_tensors = torch.broadcast_tensors
|
||||
|
||||
|
||||
class ForwardDepthMapsBatch(TypedDict):
|
||||
cam2world : Tensor # (B, 4, 4)
|
||||
uv : Tensor # (B, H, W)
|
||||
intrinsics : Tensor # (B, 3, 3)
|
||||
|
||||
class ForwardScanRaysBatch(TypedDict):
|
||||
origins : Tensor # (B, H, W, 3) or (B, 3)
|
||||
dirs : Tensor # (B, H, W, 3)
|
||||
|
||||
class LossBatch(TypedDict):
|
||||
hits : Tensor # (B, H, W) dtype=bool
|
||||
miss : Tensor # (B, H, W) dtype=bool
|
||||
depths : Tensor # (B, H, W)
|
||||
normals : Tensor # (B, H, W, 3) NaN if not hit
|
||||
distances : Tensor # (B, H, W, 1) NaN if not miss
|
||||
|
||||
class LabeledBatch(TypedDict):
|
||||
z_uid : list[Hashable]
|
||||
|
||||
ForwardBatch = Union[ForwardDepthMapsBatch, ForwardScanRaysBatch]
|
||||
TrainingBatch = Union[ForwardBatch, LossBatch, LabeledBatch]
|
||||
|
||||
|
||||
IntersectionMode = Literal[
|
||||
"medial_sphere",
|
||||
"orthogonal_plane",
|
||||
]
|
||||
|
||||
class IntersectionFieldModel(pl.LightningModule, RequiresConditioner, DtypeMixin):
|
||||
net: Union[MedialAtomNet, OrthogonalPlaneNet]
|
||||
|
||||
@ensure_schedulables
|
||||
def __init__(self,
|
||||
# mode
|
||||
input_mode : geometry.RayEmbedding = "plucker",
|
||||
output_mode : IntersectionMode = "medial_sphere",
|
||||
|
||||
# network
|
||||
latent_features : int = 256,
|
||||
hidden_features : int = 512,
|
||||
hidden_layers : int = 8,
|
||||
improve_miss_grads: bool = True,
|
||||
normalize_ray_dirs: bool = False, # the dataset is usually already normalized, but this could still be important for backprop
|
||||
|
||||
# orthogonal plane
|
||||
loss_hit_cross_entropy : Schedulable = 1.0,
|
||||
|
||||
# medial atoms
|
||||
loss_intersection : Schedulable = 1,
|
||||
loss_intersection_l2 : Schedulable = 0,
|
||||
loss_intersection_proj : Schedulable = 0,
|
||||
loss_intersection_proj_l2 : Schedulable = 0,
|
||||
loss_normal_cossim : Schedulable = 0.25, # supervise target normal cosine similarity
|
||||
loss_normal_euclid : Schedulable = 0, # supervise target normal l2 distance
|
||||
loss_normal_cossim_proj : Schedulable = 0, # supervise target normal cosine similarity
|
||||
loss_normal_euclid_proj : Schedulable = 0, # supervise target normal l2 distance
|
||||
loss_hit_nodistance_l1 : Schedulable = 0, # constrain no miss distance for hits
|
||||
loss_hit_nodistance_l2 : Schedulable = 32, # constrain no miss distance for hits
|
||||
loss_miss_distance_l1 : Schedulable = 0, # supervise target miss distance for misses
|
||||
loss_miss_distance_l2 : Schedulable = 0, # supervise target miss distance for misses
|
||||
loss_inscription_hits : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_inscription_hits_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_inscription_miss : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_inscription_miss_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_sphere_grow_reg : Schedulable = 0, # maximialize sphere size
|
||||
loss_sphere_grow_reg_hit: Schedulable = 0, # maximialize sphere size
|
||||
loss_embedding_norm : Schedulable = "0.01**2 * Linear(15)", # DeepSDF schedules over 150 epochs. DeepSDF use 0.01**2, irobot uses 0.04**2
|
||||
loss_multi_view_reg : Schedulable = 0, # minimize gradient w.r.t. delta ray dir, when ray origin = intersection
|
||||
loss_atom_centroid_norm_std_reg : Schedulable = 0, # minimize per-atom centroid std
|
||||
|
||||
# optimization
|
||||
opt_learning_rate : Schedulable = 1e-5,
|
||||
opt_weight_decay : float = 0,
|
||||
opt_warmup : float = 0,
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
opt_warmup = Linear(opt_warmup)
|
||||
opt_warmup._param_name = "opt_warmup"
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
if "half" in input_mode:
|
||||
assert output_mode == "medial_sphere" and kw.get("n_atoms", 1) > 1
|
||||
|
||||
assert output_mode in ["medial_sphere", "orthogonal_plane"]
|
||||
assert opt_weight_decay >= 0, opt_weight_decay
|
||||
|
||||
if output_mode == "orthogonal_plane":
|
||||
self.net = OrthogonalPlaneNet(
|
||||
in_features = self.n_input_embedding_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
elif output_mode == "medial_sphere":
|
||||
self.net = MedialAtomNet(
|
||||
in_features = self.n_input_embedding_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def on_fit_start(self):
|
||||
if __debug__:
|
||||
for k, v in self.hparams.items():
|
||||
if isinstance(v, HParamScheduleBase):
|
||||
v.assert_positive(self.trainer.max_epochs)
|
||||
|
||||
@property
|
||||
def n_input_embedding_features(self) -> int:
|
||||
return geometry.ray_input_embedding_length(self.hparams.input_mode)
|
||||
|
||||
@property
|
||||
def n_latent_features(self) -> int:
|
||||
return self.hparams.latent_features
|
||||
|
||||
@property
|
||||
def latent_embeddings_init_std(self) -> float:
|
||||
return 0.01
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.net.is_conditioned
|
||||
|
||||
@property
|
||||
def is_double_backprop(self) -> bool:
|
||||
return self.is_double_backprop_origins or self.is_double_backprop_dirs
|
||||
|
||||
@property
|
||||
def is_double_backprop_origins(self) -> bool:
|
||||
prif = self.hparams.output_mode == "orthogonal_plane"
|
||||
return prif and self.hparams.loss_normal_cossim
|
||||
|
||||
@property
|
||||
def is_double_backprop_dirs(self) -> bool:
|
||||
return self.hparams.loss_multi_view_reg
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, **kw)
|
||||
yield MedialAtomNet.make_jinja_template(top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"latent_features",
|
||||
})
|
||||
|
||||
def batch2rays(self, batch: ForwardBatch) -> tuple[Tensor, Tensor]:
|
||||
if "uv" in batch:
|
||||
raise NotImplementedError
|
||||
assert not (self.hparams.loss_multi_view_reg and self.training)
|
||||
ray_origins, \
|
||||
ray_dirs, \
|
||||
= geometry.camera_uv_to_rays(
|
||||
cam2world = batch["cam2world"],
|
||||
uv = batch["uv"],
|
||||
intrinsics = batch["intrinsics"],
|
||||
)
|
||||
else:
|
||||
ray_origins = batch["points" if self.hparams.loss_multi_view_reg and self.training else "origins"]
|
||||
ray_dirs = batch["dirs"]
|
||||
return ray_origins, ray_dirs
|
||||
|
||||
def forward(self,
|
||||
batch : ForwardBatch,
|
||||
z : Optional[Tensor] = None, # latent code
|
||||
*,
|
||||
return_input : bool = False,
|
||||
allow_nans : bool = False, # in output
|
||||
**kw,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
(
|
||||
ray_origins, # (B, 3)
|
||||
ray_dirs, # (B, H, W, 3)
|
||||
) = self.batch2rays(batch)
|
||||
|
||||
# Ensure rays are normalized
|
||||
# NOTICE: this is slow, make sure to train with optimizations!
|
||||
assert ray_dirs.detach().norm(dim=-1).allclose(torch.ones(ray_dirs.shape[:-1], **self.device_and_dtype)),\
|
||||
ray_dirs.detach().norm(dim=-1)
|
||||
|
||||
if ray_origins.ndim + 2 == ray_dirs.ndim:
|
||||
ray_origins = ray_origins[..., None, None, :]
|
||||
|
||||
ray_origins, ray_dirs = broadcast_tensors(ray_origins, ray_dirs)
|
||||
|
||||
if self.is_double_backprop and self.training:
|
||||
if self.is_double_backprop_dirs:
|
||||
ray_dirs.requires_grad = True
|
||||
if self.is_double_backprop_origins:
|
||||
ray_origins.requires_grad = True
|
||||
assert ray_origins.requires_grad or ray_dirs.requires_grad
|
||||
|
||||
input = geometry.ray_input_embedding(
|
||||
ray_origins, ray_dirs,
|
||||
mode = self.hparams.input_mode,
|
||||
normalize_dirs = self.hparams.normalize_ray_dirs,
|
||||
is_training = self.training,
|
||||
)
|
||||
assert not input.detach().isnan().any()
|
||||
|
||||
predictions = self.net(input, z)
|
||||
|
||||
intersections = self.net.compute_intersections(
|
||||
ray_origins, ray_dirs, predictions,
|
||||
allow_nans = allow_nans and not self.training, **kw
|
||||
)
|
||||
if return_input:
|
||||
return ray_origins, ray_dirs, input, intersections
|
||||
else:
|
||||
return intersections
|
||||
|
||||
def training_step(self, batch: TrainingBatch, batch_idx: int, *, is_validation=False) -> Tensor:
|
||||
z = self.encode(batch) if self.is_conditioned else None
|
||||
assert self.is_conditioned or len(set(batch["z_uid"])) <= 1, \
|
||||
f"Network is unconditioned, but the batch has multiple uids: {set(batch['z_uid'])!r}"
|
||||
|
||||
# unpack
|
||||
target_hits = batch["hits"] # (B, H, W) dtype=bool
|
||||
target_miss = batch["miss"] # (B, H, W) dtype=bool
|
||||
target_points = batch["points"] # (B, H, W, 3)
|
||||
target_normals = batch["normals"] # (B, H, W, 3) NaN if not hit
|
||||
target_distances = batch["distances"] # (B, H, W) NaN if not miss
|
||||
assert not target_normals [target_hits].isnan().any()
|
||||
assert not target_distances[target_miss].isnan().any()
|
||||
target_normals[target_normals.isnan()] = 0
|
||||
assert not target_normals .isnan().any()
|
||||
|
||||
# make z fit batch scheme
|
||||
if z is not None:
|
||||
z = z[..., None, None, :]
|
||||
|
||||
losses = {}
|
||||
metrics = {}
|
||||
zeros = torch.zeros_like(target_distances)
|
||||
|
||||
if self.hparams.output_mode == "medial_sphere":
|
||||
assert isinstance(self.net, MedialAtomNet)
|
||||
ray_origins, ray_dirs, plucker, (
|
||||
depths, # (...) float, projection if not hit
|
||||
silhouettes, # (...) float
|
||||
intersections, # (..., 3) float, projection or NaN if not hit
|
||||
intersection_normals, # (..., 3) float, rejection or NaN if not hit
|
||||
is_intersecting, # (...) bool, true if hit
|
||||
sphere_centers, # (..., 3) network output
|
||||
sphere_radii, # (...) network output
|
||||
|
||||
atom_indices,
|
||||
all_intersections, # (..., N_ATOMS) float, projection or NaN if not hit
|
||||
all_intersection_normals, # (..., N_ATOMS, 3) float, rejection or NaN if not hit
|
||||
all_depths, # (..., N_ATOMS) float, projection if not hit
|
||||
all_silhouettes, # (..., N_ATOMS, 3) float, projection or NaN if not hit
|
||||
all_is_intersecting, # (..., N_ATOMS) bool, true if hit
|
||||
all_sphere_centers, # (..., N_ATOMS, 3) network output
|
||||
all_sphere_radii, # (..., N_ATOMS) network output
|
||||
) = self(batch, z,
|
||||
intersections_only = False,
|
||||
return_all_atoms = True,
|
||||
allow_nans = False,
|
||||
return_input = True,
|
||||
improve_miss_grads = True,
|
||||
)
|
||||
|
||||
# target hit supervision
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection: # scores true hits
|
||||
losses["loss_intersection"] = (
|
||||
(target_points - intersections).norm(dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2: # scores true hits
|
||||
losses["loss_intersection_l2"] = (
|
||||
(target_points - intersections).pow(2).sum(dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj: # scores misses as if they were hits, using the projection
|
||||
losses["loss_intersection_proj"] = (
|
||||
(target_points - intersections).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj_l2: # scores misses as if they were hits, using the projection
|
||||
losses["loss_intersection_proj_l2"] = (
|
||||
(target_points - intersections).pow(2).sum(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
# target hit normal supervision
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim: # scores true hits
|
||||
losses["loss_normal_cossim"] = (
|
||||
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid: # scores true hits
|
||||
losses["loss_normal_euclid"] = (
|
||||
(target_normals - intersection_normals).norm(dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim_proj: # scores misses as if they were hits
|
||||
losses["loss_normal_cossim_proj"] = (
|
||||
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid_proj: # scores misses as if they were hits
|
||||
losses["loss_normal_euclid_proj"] = (
|
||||
(target_normals - intersection_normals).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
# target sufficient hit radius
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l1: # ensures hits become hits, instead of relying on the projection being right
|
||||
losses["loss_hit_nodistance_l1"] = (
|
||||
silhouettes
|
||||
).where(target_hits & (silhouettes > 0), zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l2: # ensures hits become hits, instead of relying on the projection being right
|
||||
losses["loss_hit_nodistance_l2"] = (
|
||||
silhouettes
|
||||
).where(target_hits & (silhouettes > 0), zeros).pow(2).mean()
|
||||
|
||||
# target miss supervision
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l1: # only positive misses reinforcement
|
||||
losses["loss_miss_distance_l1"] = (
|
||||
target_distances - silhouettes
|
||||
).where(target_miss, zeros).abs().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l2: # only positive misses reinforcement
|
||||
losses["loss_miss_distance_l2"] = (
|
||||
target_distances - silhouettes
|
||||
).where(target_miss, zeros).pow(2).mean()
|
||||
|
||||
# incentivise maximal spheres
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg: # all atoms
|
||||
losses["loss_sphere_grow_reg"] = ((all_sphere_radii.detach() + 1) - all_sphere_radii).abs().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg_hit: # true hits only
|
||||
losses["loss_sphere_grow_reg_hit"] = ((sphere_radii.detach() + 1) - sphere_radii).where(target_hits & is_intersecting, zeros).abs().mean()
|
||||
|
||||
# spherical latent prior
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_embedding_norm:
|
||||
losses["loss_embedding_norm"] = self.latent_embeddings.norm(dim=-1).mean()
|
||||
|
||||
|
||||
is_grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
# multi-view regularization: atom should not change when view changes
|
||||
if self.hparams.loss_multi_view_reg and is_grad_enabled:
|
||||
assert ray_dirs.requires_grad, ray_dirs
|
||||
assert plucker.requires_grad, plucker
|
||||
assert intersections.grad_fn is not None
|
||||
assert intersection_normals.grad_fn is not None
|
||||
|
||||
*center_grads, radii_grads = diff.gradients(
|
||||
sphere_centers[..., 0],
|
||||
sphere_centers[..., 1],
|
||||
sphere_centers[..., 2],
|
||||
sphere_radii,
|
||||
wrt=ray_dirs,
|
||||
)
|
||||
|
||||
losses["loss_multi_view_reg"] = (
|
||||
sum(
|
||||
i.pow(2).sum(dim=-1)
|
||||
for i in center_grads
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
+
|
||||
radii_grads.pow(2).sum(dim=-1)
|
||||
.where(target_hits & is_intersecting, zeros).mean()
|
||||
)
|
||||
|
||||
# minimize the volume spanned by each atom
|
||||
if self.hparams.loss_atom_centroid_norm_std_reg and self.net.n_atoms > 1:
|
||||
assert len(all_sphere_centers.shape) == 5, all_sphere_centers.shape
|
||||
losses["loss_atom_centroid_norm_std_reg"] \
|
||||
= ((
|
||||
all_sphere_centers
|
||||
- all_sphere_centers
|
||||
.mean(dim=(1, 2), keepdim=True)
|
||||
).pow(2).sum(dim=-1) - 0.05**2).clamp(0, None).mean()
|
||||
|
||||
# prif is l1, LSMAT is l2
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits or self.hparams.loss_inscription_miss or self.hparams.loss_inscription_hits_l2 or self.hparams.loss_inscription_miss_l2:
|
||||
b = target_hits.shape[0] # number of objects
|
||||
n = target_hits.shape[1:].numel() # rays per object
|
||||
perm = torch.randperm(n, device=self.device) # ray2ray permutation
|
||||
flatten = dict(start_dim=1, end_dim=len(target_hits.shape) - 1)
|
||||
|
||||
(
|
||||
inscr_sphere_center_projs, # (b, n, n_atoms, 3)
|
||||
inscr_intersections_near, # (b, n, n_atoms, 3)
|
||||
inscr_intersections_far, # (b, n, n_atoms, 3)
|
||||
inscr_is_intersecting, # (b, n, n_atoms) dtype=bool
|
||||
) = geometry.ray_sphere_intersect(
|
||||
ray_origins.flatten(**flatten)[:, perm, None, :],
|
||||
ray_dirs .flatten(**flatten)[:, perm, None, :],
|
||||
all_sphere_centers.flatten(**flatten),
|
||||
all_sphere_radii .flatten(**flatten),
|
||||
return_parts = True,
|
||||
allow_nans = False,
|
||||
improve_miss_grads = self.hparams.improve_miss_grads,
|
||||
)
|
||||
assert inscr_sphere_center_projs.shape == (b, n, self.net.n_atoms, 3), \
|
||||
(inscr_sphere_center_projs.shape, (b, n, self.net.n_atoms, 3))
|
||||
inscr_silhouettes = (
|
||||
inscr_sphere_center_projs - all_sphere_centers.flatten(**flatten)
|
||||
).norm(dim=-1) - all_sphere_radii.flatten(**flatten)
|
||||
|
||||
loss_inscription_hits = (
|
||||
(
|
||||
(inscr_intersections_near - target_points.flatten(**flatten)[:, perm, None, :])
|
||||
* ray_dirs.flatten(**flatten)[:, perm, None, :]
|
||||
).sum(dim=-1)
|
||||
).where(target_hits.flatten(**flatten)[:, perm, None] & inscr_is_intersecting,
|
||||
torch.zeros(inscr_intersections_near.shape[:-1], **self.device_and_dtype),
|
||||
).clamp(None, 0)
|
||||
loss_inscription_miss = (
|
||||
inscr_silhouettes - target_distances.flatten(**flatten)[:, perm, None]
|
||||
).where(target_miss.flatten(**flatten)[:, perm, None],
|
||||
torch.zeros_like(inscr_silhouettes)
|
||||
).clamp(None, 0)
|
||||
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits:
|
||||
losses["loss_inscription_hits"] = loss_inscription_hits.neg().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss:
|
||||
losses["loss_inscription_miss"] = loss_inscription_miss.neg().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits_l2:
|
||||
losses["loss_inscription_hits_l2"] = loss_inscription_hits.pow(2).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss_l2:
|
||||
losses["loss_inscription_miss_l2"] = loss_inscription_miss.pow(2).mean()
|
||||
|
||||
# metrics
|
||||
metrics["iou"] = (
|
||||
((~target_miss) & is_intersecting.detach()).sum() /
|
||||
((~target_miss) | is_intersecting.detach()).sum()
|
||||
)
|
||||
metrics["radii"] = sphere_radii.detach().mean() # with the constant applied pressure, we need to measure it this way instead
|
||||
|
||||
elif self.hparams.output_mode == "orthogonal_plane":
|
||||
assert isinstance(self.net, OrthogonalPlaneNet)
|
||||
ray_origins, ray_dirs, input_embedding, (
|
||||
intersections, # (..., 3) dtype=float
|
||||
is_intersecting, # (...) dtype=float
|
||||
) = self(batch, z, return_input=True, normalize_origins=True)
|
||||
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection:
|
||||
losses["loss_intersection"] = (
|
||||
(intersections - target_points).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2:
|
||||
losses["loss_intersection_l2"] = (
|
||||
(intersections - target_points).pow(2).sum(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_cross_entropy:
|
||||
losses["loss_hit_cross_entropy"] = (
|
||||
F.binary_cross_entropy_with_logits(is_intersecting, (~target_miss).to(self.dtype))
|
||||
).mean()
|
||||
|
||||
if self.hparams.loss_normal_cossim and torch.is_grad_enabled():
|
||||
jac = diff.jacobian(intersections, ray_origins)
|
||||
intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs)
|
||||
losses["loss_normal_cossim"] = (
|
||||
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
if self.hparams.loss_normal_euclid and torch.is_grad_enabled():
|
||||
jac = diff.jacobian(intersections, ray_origins)
|
||||
intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs)
|
||||
losses["loss_normal_euclid"] = (
|
||||
(target_normals - intersection_normals).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
if self.hparams.loss_multi_view_reg and torch.is_grad_enabled():
|
||||
assert ray_dirs .requires_grad, ray_dirs
|
||||
assert intersections.grad_fn is not None
|
||||
grads = diff.gradients(
|
||||
intersections[..., 0],
|
||||
intersections[..., 1],
|
||||
intersections[..., 2],
|
||||
wrt=ray_dirs,
|
||||
)
|
||||
losses["loss_multi_view_reg"] = sum(
|
||||
i.pow(2).sum(dim=-1)
|
||||
for i in grads
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
metrics["iou"] = (
|
||||
((~target_miss) & (is_intersecting>0.5).detach()).sum() /
|
||||
((~target_miss) | (is_intersecting>0.5).detach()).sum()
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(self.hparams.output_mode)
|
||||
|
||||
# output losses and metrics
|
||||
|
||||
# apply scaling:
|
||||
losses_unscaled = losses.copy() # shallow copy
|
||||
for k in list(losses.keys()):
|
||||
assert losses[k].numel() == 1, f"losses[{k!r}] shape: {losses[k].shape}"
|
||||
val_schedule: HParamSchedule = self.hparams[k]
|
||||
val = val_schedule.get(self)
|
||||
if val == 0:
|
||||
if (__debug__ or LOG_ALL_METRICS) and val_schedule.is_const:
|
||||
del losses[k] # it was only added for unscaled logging, do not backprop
|
||||
else:
|
||||
losses[k] = 0
|
||||
elif val != 1:
|
||||
losses[k] = losses[k] * val
|
||||
|
||||
if not losses:
|
||||
raise MisconfigurationException("no loss was computed")
|
||||
|
||||
losses["loss"] = sum(losses.values()) * self.hparams.opt_warmup.get(self)
|
||||
losses.update({f"unscaled_{k}": v.detach() for k, v in losses_unscaled.items()})
|
||||
losses.update({f"metric_{k}": v.detach() for k, v in metrics.items()})
|
||||
return losses
|
||||
|
||||
|
||||
# used by pl.callbacks.EarlyStopping, via cli.py
|
||||
@property
|
||||
def metric_early_stop(self): return (
|
||||
"unscaled_loss_intersection_proj"
|
||||
if self.hparams.output_mode == "medial_sphere" else
|
||||
"unscaled_loss_intersection"
|
||||
)
|
||||
|
||||
def validation_step(self, batch: TrainingBatch, batch_idx: int) -> dict[str, Tensor]:
|
||||
losses = self.training_step(batch, batch_idx, is_validation=True)
|
||||
return losses
|
||||
|
||||
def configure_optimizers(self):
|
||||
adam = torch.optim.Adam(self.parameters(),
|
||||
lr=1 if not self.hparams.opt_learning_rate.is_const else self.hparams.opt_learning_rate.get_train_value(0),
|
||||
weight_decay=self.hparams.opt_weight_decay)
|
||||
schedules = []
|
||||
if not self.hparams.opt_learning_rate.is_const:
|
||||
schedules = [
|
||||
torch.optim.lr_scheduler.LambdaLR(adam,
|
||||
lambda epoch: self.hparams.opt_learning_rate.get_train_value(epoch),
|
||||
),
|
||||
]
|
||||
return [adam], schedules
|
||||
|
||||
@property
|
||||
def example_input_array(self) -> tuple[dict[str, Tensor], Tensor]:
|
||||
return (
|
||||
{ # see self.batch2rays
|
||||
"origins" : torch.zeros(1, 3), # most commonly used
|
||||
"points" : torch.zeros(1, 3), # used if self.training and self.hparams.loss_multi_view_reg
|
||||
"dirs" : torch.ones(1, 3) * torch.rsqrt(torch.tensor(3)),
|
||||
},
|
||||
torch.ones(1, self.hparams.latent_features),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_normals_from_intersection_origin_jacobian(origin_jac: Tensor, ray_dirs: Tensor) -> Tensor:
|
||||
normals = sum((
|
||||
torch.cross(origin_jac[..., 0], origin_jac[..., 1], dim=-1) * -ray_dirs[..., [2]],
|
||||
torch.cross(origin_jac[..., 1], origin_jac[..., 2], dim=-1) * -ray_dirs[..., [0]],
|
||||
torch.cross(origin_jac[..., 2], origin_jac[..., 0], dim=-1) * -ray_dirs[..., [1]],
|
||||
))
|
||||
return normals / normals.norm(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
class IntersectionFieldAutoDecoderModel(IntersectionFieldModel, AutoDecoderModuleMixin):
|
||||
def encode(self, batch: LabeledBatch) -> Tensor:
|
||||
assert not isinstance(self.trainer.strategy, pl.strategies.DataParallelStrategy)
|
||||
return self[batch["z_uid"]] # [N, Z_n]
|
186
ifield/models/medial_atoms.py
Normal file
186
ifield/models/medial_atoms.py
Normal file
@ -0,0 +1,186 @@
|
||||
from .. import param
|
||||
from ..modules import fc
|
||||
from ..data.common import points
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from textwrap import indent, dedent
|
||||
from torch import nn, Tensor
|
||||
from typing import Optional
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
# generalize this into a HypoHyperConcat net? ConditionedNet?
|
||||
class MedialAtomNet(nn.Module):
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
latent_features : int,
|
||||
hidden_features : int,
|
||||
hidden_layers : int,
|
||||
n_atoms : int = 1,
|
||||
final_init_wrr : tuple[float, float] | None = (0.05, 0.6, 0.1),
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
assert n_atoms >= 1, n_atoms
|
||||
self.n_atoms = n_atoms
|
||||
|
||||
self.fc = fc.FCBlock(
|
||||
in_features = in_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
out_features = n_atoms * 4, # n_atoms * (x, y, z, r)
|
||||
outermost_linear = True,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
if final_init_wrr is not None:
|
||||
with torch.no_grad():
|
||||
w, r1, r2 = final_init_wrr
|
||||
if w != 1: self.fc[-1].linear.weight *= w
|
||||
dtype = self.fc[-1].linear.bias.dtype
|
||||
self.fc[-1].linear.bias[..., [4*n+i for n in range(n_atoms) for i in range(3)]] = torch.tensor(points.generate_random_sphere_points(n_atoms, radius=r1), dtype=dtype).flatten()
|
||||
self.fc[-1].linear.bias[..., 3::4] = r2
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.fc.is_conditioned
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
|
||||
yield fc.FCBlock.make_jinja_template(top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"out_features",
|
||||
"outermost_linear",
|
||||
"latent_features",
|
||||
})
|
||||
|
||||
def forward(self, x: Tensor, z: Optional[Tensor] = None):
|
||||
if __debug__ and self.is_conditioned and z is None:
|
||||
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
|
||||
return self.fc(x, z)
|
||||
|
||||
def compute_intersections(self,
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
medial_atoms : Tensor, # (..., 4*self.n_atoms)
|
||||
*,
|
||||
intersections_only : bool = True,
|
||||
return_all_atoms : bool = False, # only applies if intersections_only=False
|
||||
allow_nans : bool = True,
|
||||
improve_miss_grads : bool = False,
|
||||
) -> tuple[(Tensor,)*5]:
|
||||
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == medial_atoms.shape[:-1], \
|
||||
(ray_origins.shape, ray_dirs.shape, medial_atoms.shape)
|
||||
assert medial_atoms.shape[-1] % 4 == 0, \
|
||||
medial_atoms.shape
|
||||
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
|
||||
(ray_origins.shape, ray_dirs.shape)
|
||||
|
||||
#n_atoms = medial_atoms.shape[-1] // 4
|
||||
n_atoms = medial_atoms.shape[-1] >> 2
|
||||
|
||||
# reshape (..., n_atoms * d) to (..., n_atoms, d)
|
||||
medial_atoms = medial_atoms.view(*medial_atoms.shape[:-1], n_atoms, 4)
|
||||
ray_origins = ray_origins.unsqueeze(-2).broadcast_to([*ray_origins.shape[:-1], n_atoms, 3])
|
||||
ray_dirs = ray_dirs .unsqueeze(-2).broadcast_to([*ray_dirs .shape[:-1], n_atoms, 3])
|
||||
|
||||
# unpack atoms
|
||||
sphere_centers = medial_atoms[..., :3]
|
||||
sphere_radii = medial_atoms[..., 3].abs()
|
||||
|
||||
assert not ray_origins .detach().isnan().any()
|
||||
assert not ray_dirs .detach().isnan().any()
|
||||
assert not sphere_centers.detach().isnan().any()
|
||||
assert not sphere_radii .detach().isnan().any()
|
||||
|
||||
# compute intersections
|
||||
(
|
||||
sphere_center_projs, # (..., 3)
|
||||
intersections_near, # (..., 3)
|
||||
intersections_far, # (..., 3)
|
||||
is_intersecting, # (...) bool
|
||||
) = geometry.ray_sphere_intersect(
|
||||
ray_origins,
|
||||
ray_dirs,
|
||||
sphere_centers,
|
||||
sphere_radii,
|
||||
return_parts = True,
|
||||
allow_nans = allow_nans,
|
||||
improve_miss_grads = improve_miss_grads,
|
||||
)
|
||||
|
||||
# early return
|
||||
if intersections_only and n_atoms == 1:
|
||||
return intersections_near.squeeze(-2), is_intersecting.squeeze(-1)
|
||||
|
||||
# compute how close each hit and miss are
|
||||
depths = ((intersections_near - ray_origins) * ray_dirs).sum(-1)
|
||||
silhouettes = torch.linalg.norm(sphere_center_projs - sphere_centers, dim=-1) - sphere_radii
|
||||
|
||||
if return_all_atoms:
|
||||
intersections_near_all = intersections_near
|
||||
depths_all = depths
|
||||
silhouettes_all = silhouettes
|
||||
is_intersecting_all = is_intersecting
|
||||
sphere_centers_all = sphere_centers
|
||||
sphere_radii_all = sphere_radii
|
||||
|
||||
# collapse n_atoms
|
||||
if n_atoms > 1:
|
||||
atom_indices = torch.where(is_intersecting.any(dim=-1, keepdim=True),
|
||||
torch.where(is_intersecting, depths.detach(), depths.detach()+100).argmin(dim=-1, keepdim=True),
|
||||
silhouettes.detach().argmin(dim=-1, keepdim=True),
|
||||
)
|
||||
|
||||
intersections_near = intersections_near.take_along_dim(atom_indices[..., None], -2).squeeze(-2)
|
||||
depths = depths .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
silhouettes = silhouettes .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
is_intersecting = is_intersecting .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
sphere_centers = sphere_centers .take_along_dim(atom_indices[..., None], -2).squeeze(-2)
|
||||
sphere_radii = sphere_radii .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
else:
|
||||
atom_indices = None
|
||||
intersections_near = intersections_near.squeeze(-2)
|
||||
depths = depths .squeeze(-1)
|
||||
silhouettes = silhouettes .squeeze(-1)
|
||||
is_intersecting = is_intersecting .squeeze(-1)
|
||||
sphere_centers = sphere_centers .squeeze(-2)
|
||||
sphere_radii = sphere_radii .squeeze(-1)
|
||||
|
||||
# early return
|
||||
if intersections_only:
|
||||
return intersections_near, is_intersecting
|
||||
|
||||
# compute sphere normals
|
||||
intersection_normals = intersections_near - sphere_centers
|
||||
intersection_normals = intersection_normals / (intersection_normals.norm(dim=-1)[..., None] + 1e-9)
|
||||
|
||||
if return_all_atoms:
|
||||
intersection_normals_all = intersections_near_all - sphere_centers_all
|
||||
intersection_normals_all = intersection_normals_all / (intersection_normals_all.norm(dim=-1)[..., None] + 1e-9)
|
||||
|
||||
|
||||
return (
|
||||
depths, # (...) valid if hit, based on 'intersections'
|
||||
silhouettes, # (...) always valid
|
||||
intersections_near, # (..., 3) valid if hit, projection if not
|
||||
intersection_normals, # (..., 3) valid if hit, rejection if not
|
||||
is_intersecting, # (...) dtype=bool
|
||||
sphere_centers, # (..., 3) network output
|
||||
sphere_radii, # (...) network output
|
||||
*(() if not return_all_atoms else (
|
||||
|
||||
atom_indices,
|
||||
intersections_near_all, # (..., N_ATOMS) valid if hit, based on 'intersections'
|
||||
intersection_normals_all, # (..., N_ATOMS, 3) valid if hit, rejection if not
|
||||
depths_all, # (..., N_ATOMS) always valid
|
||||
silhouettes_all, # (..., N_ATOMS, 3) valid if hit, projection if not
|
||||
is_intersecting_all, # (..., N_ATOMS) dtype=bool
|
||||
sphere_centers_all, # (..., N_ATOMS, 3) network output
|
||||
sphere_radii_all, # (..., N_ATOMS) network output
|
||||
)))
|
101
ifield/models/orthogonal_plane.py
Normal file
101
ifield/models/orthogonal_plane.py
Normal file
@ -0,0 +1,101 @@
|
||||
from .. import param
|
||||
from ..modules import fc
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from textwrap import indent, dedent
|
||||
from torch import nn, Tensor
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
class OrthogonalPlaneNet(nn.Module):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
latent_features : int,
|
||||
hidden_features : int,
|
||||
hidden_layers : int,
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc = fc.FCBlock(
|
||||
in_features = in_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
out_features = 2, # (plane_offset, is_intersecting)
|
||||
outermost_linear = True,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.fc.is_conditioned
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
|
||||
yield param.make_jinja_template(fc.FCBlock, top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"out_features",
|
||||
"outermost_linear",
|
||||
})
|
||||
|
||||
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
|
||||
if __debug__ and self.is_conditioned and z is None:
|
||||
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
|
||||
return self.fc(x, z)
|
||||
|
||||
@staticmethod
|
||||
def compute_intersections(
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
predictions : Tensor, # (..., 2)
|
||||
*,
|
||||
normalize_origins = True,
|
||||
return_signed_displacements = False,
|
||||
allow_nans = False, # MARF compat
|
||||
atom_random_prob = None, # MARF compat
|
||||
atom_dropout_prob = None, # MARF compat
|
||||
) -> tuple[(Tensor,)*5]:
|
||||
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == predictions.shape[:-1], \
|
||||
(ray_origins.shape, ray_dirs.shape, predictions.shape)
|
||||
assert predictions.shape[-1] == 2, \
|
||||
predictions.shape
|
||||
|
||||
assert not allow_nans
|
||||
|
||||
if normalize_origins:
|
||||
ray_origins = geometry.project_point_on_ray(0, ray_origins, ray_dirs)
|
||||
|
||||
# unpack predictions
|
||||
signed_displacements = predictions[..., 0]
|
||||
is_intersecting = predictions[..., 1]
|
||||
|
||||
# compute intersections
|
||||
intersections = ray_origins - signed_displacements[..., None] * ray_dirs
|
||||
|
||||
return (
|
||||
intersections,
|
||||
is_intersecting,
|
||||
*((signed_displacements,) if return_signed_displacements else ()),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
OrthogonalPlaneNet.__doc__ = __doc__ = f"""
|
||||
{dedent(OrthogonalPlaneNet.__doc__).strip()}
|
||||
|
||||
# Config template:
|
||||
|
||||
```yaml
|
||||
{OrthogonalPlaneNet.make_jinja_template()}
|
||||
```
|
||||
"""
|
3
ifield/modules/__init__.py
Normal file
3
ifield/modules/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Contains Pytorch Modules
|
||||
"""
|
22
ifield/modules/dtype.py
Normal file
22
ifield/modules/dtype.py
Normal file
@ -0,0 +1,22 @@
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class DtypeMixin:
|
||||
def __init_subclass__(cls):
|
||||
assert issubclass(cls, pl.LightningModule), \
|
||||
f"{cls.__name__!r} is not a subclass of 'pytorch_lightning.LightningModule'!"
|
||||
|
||||
@property
|
||||
def device_and_dtype(self) -> dict:
|
||||
"""
|
||||
Examples:
|
||||
```
|
||||
torch.tensor(1337, **self.device_and_dtype)
|
||||
some_tensor.to(**self.device_and_dtype)
|
||||
```
|
||||
"""
|
||||
|
||||
return {
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
424
ifield/modules/fc.py
Normal file
424
ifield/modules/fc.py
Normal file
@ -0,0 +1,424 @@
|
||||
from . import siren
|
||||
from .. import param
|
||||
from ..utils.helpers import compose, run_length_encode, MetaModuleProxy
|
||||
from collections import OrderedDict
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
from torch import nn, Tensor
|
||||
from torch.nn.utils.weight_norm import WeightNorm
|
||||
from torchmeta.modules import MetaModule, MetaSequential
|
||||
from typing import Iterable, Literal, Optional, Union, Callable
|
||||
import itertools
|
||||
import math
|
||||
import torch
|
||||
|
||||
__doc__ = """
|
||||
`fc` is short for "Fully Connected"
|
||||
"""
|
||||
|
||||
def broadcast_tensors_except(*tensors: Tensor, dim: int) -> list[Tensor]:
|
||||
if dim == -1:
|
||||
shapes = [ i.shape[:dim] for i in tensors ]
|
||||
else:
|
||||
shapes = [ (*i.shape[:dim], i.shape[dim+1:]) for i in tensors ]
|
||||
target_shape = list(torch.broadcast_shapes(*shapes))
|
||||
if dim == -1:
|
||||
target_shape.append(-1)
|
||||
elif dim < 0:
|
||||
target_shape.insert(dim+1, -1)
|
||||
else:
|
||||
target_shape.insert(dim, -1)
|
||||
|
||||
return [ i.broadcast_to(target_shape) for i in tensors ]
|
||||
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
Nonlinearity = Literal[
|
||||
None,
|
||||
"relu",
|
||||
"leaky_relu",
|
||||
"silu",
|
||||
"softplus",
|
||||
"elu",
|
||||
"selu",
|
||||
"sine",
|
||||
"sigmoid",
|
||||
"tanh",
|
||||
]
|
||||
|
||||
Normalization = Literal[
|
||||
None,
|
||||
"batchnorm",
|
||||
"batchnorm_na",
|
||||
"layernorm",
|
||||
"layernorm_na",
|
||||
"weightnorm",
|
||||
]
|
||||
|
||||
class ReprHyperparametersMixin(HyperparametersMixin):
|
||||
def extra_repr(self):
|
||||
this = ", ".join(f"{k}={v!r}" for k, v in self.hparams.items())
|
||||
rest = super().extra_repr()
|
||||
if rest:
|
||||
return f"{this}, {rest}"
|
||||
else:
|
||||
return this
|
||||
|
||||
class MultilineReprHyperparametersMixin(HyperparametersMixin):
|
||||
def extra_repr(self):
|
||||
items = [f"{k}={v!r}" for k, v in self.hparams.items()]
|
||||
this = "\n".join(
|
||||
", ".join(filter(bool, i)) + ","
|
||||
for i in itertools.zip_longest(items[0::3], items[1::3], items[2::3])
|
||||
)
|
||||
rest = super().extra_repr()
|
||||
if rest:
|
||||
return f"{this}, {rest}"
|
||||
else:
|
||||
return this
|
||||
|
||||
|
||||
class BatchLinear(nn.Linear):
|
||||
"""
|
||||
A linear (meta-)layer that can deal with batched weight matrices and biases,
|
||||
as for instance output by a hypernetwork.
|
||||
"""
|
||||
__doc__ = nn.Linear.__doc__
|
||||
_meta_forward_pre_hooks = None
|
||||
|
||||
def register_forward_pre_hook(self, hook: Callable) -> torch.utils.hooks.RemovableHandle:
|
||||
if not isinstance(hook, WeightNorm):
|
||||
return super().register_forward_pre_hook(hook)
|
||||
|
||||
if self._meta_forward_pre_hooks is None:
|
||||
self._meta_forward_pre_hooks = OrderedDict()
|
||||
|
||||
handle = torch.utils.hooks.RemovableHandle(self._meta_forward_pre_hooks)
|
||||
self._meta_forward_pre_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
def forward(self, input: Tensor, params: Optional[dict[str, Tensor]]=None):
|
||||
if params is None or not isinstance(self, MetaModule):
|
||||
params = OrderedDict(self.named_parameters())
|
||||
if self._meta_forward_pre_hooks is not None:
|
||||
proxy = MetaModuleProxy(self, params)
|
||||
for hook in self._meta_forward_pre_hooks.values():
|
||||
hook(proxy, [input])
|
||||
|
||||
weight = params["weight"]
|
||||
bias = params.get("bias", None)
|
||||
|
||||
# transpose weights
|
||||
weight = weight.permute(*range(len(weight.shape) - 2), -1, -2) # does not jit
|
||||
|
||||
output = input.unsqueeze(-2).matmul(weight).squeeze(-2)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MetaBatchLinear(BatchLinear, MetaModule):
|
||||
pass
|
||||
|
||||
|
||||
class CallbackConcatLayer(nn.Module):
|
||||
"A tricky way to enable skip connections in sequentials models"
|
||||
def __init__(self, tensor_getter: Callable[[], tuple[Tensor, ...]]):
|
||||
super().__init__()
|
||||
self.tensor_getter = tensor_getter
|
||||
|
||||
def forward(self, x):
|
||||
ys = self.tensor_getter()
|
||||
return torch.cat(broadcast_tensors_except(x, *ys, dim=-1), dim=-1)
|
||||
|
||||
|
||||
class ResidualSkipConnectionEndLayer(nn.Module):
|
||||
"""
|
||||
Residual skip connections that can be added to a nn.Sequential
|
||||
"""
|
||||
|
||||
class ResidualSkipConnectionStartLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._stored_tensor = None
|
||||
|
||||
def forward(self, x):
|
||||
assert self._stored_tensor is None
|
||||
self._stored_tensor = x
|
||||
return x
|
||||
|
||||
def get(self):
|
||||
assert self._stored_tensor is not None
|
||||
x = self._stored_tensor
|
||||
self._stored_tensor = None
|
||||
return x
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._stored_tensor = None
|
||||
self._start = self.ResidualSkipConnectionStartLayer()
|
||||
|
||||
def forward(self, x):
|
||||
skip = self._start.get()
|
||||
return x + skip
|
||||
|
||||
@property
|
||||
def start(self) -> ResidualSkipConnectionStartLayer:
|
||||
return self._start
|
||||
|
||||
@property
|
||||
def end(self) -> "ResidualSkipConnectionEndLayer":
|
||||
return self
|
||||
|
||||
|
||||
ResidualMode = Literal[
|
||||
None,
|
||||
"identity",
|
||||
]
|
||||
|
||||
class FCLayer(MultilineReprHyperparametersMixin, MetaSequential):
|
||||
"""
|
||||
A single fully connected (FC) layer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
out_features : int,
|
||||
*,
|
||||
nonlinearity : Nonlinearity = "relu",
|
||||
normalization : Normalization = None,
|
||||
is_first : bool = False, # used for SIREN initialization
|
||||
is_final : bool = False, # used for fan_out init
|
||||
dropout_prob : float = 0.0,
|
||||
negative_slope : float = 0.01, # only for nonlinearity="leaky_relu", default is normally 0.01
|
||||
omega_0 : float = 30, # only for nonlinearity="sine"
|
||||
residual_mode : ResidualMode = None,
|
||||
_no_meta : bool = False, # set to true in hypernetworks
|
||||
**_
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
# improve repr
|
||||
if nonlinearity != "leaky_relu":
|
||||
self.hparams.pop("negative_slope")
|
||||
if nonlinearity != "sine":
|
||||
self.hparams.pop("omega_0")
|
||||
|
||||
Linear = nn.Linear if _no_meta else MetaBatchLinear
|
||||
|
||||
def make_layer() -> Iterable[nn.Module]:
|
||||
# residual start
|
||||
if residual_mode is not None:
|
||||
residual_layer = ResidualSkipConnectionEndLayer()
|
||||
yield "res_a", residual_layer.start
|
||||
|
||||
linear = Linear(in_features, out_features)
|
||||
|
||||
# initialize
|
||||
if nonlinearity in {"relu", "leaky_relu", "silu", "softplus"}:
|
||||
nn.init.kaiming_uniform_(linear.weight, a=negative_slope, nonlinearity=nonlinearity, mode="fan_in" if not is_final else "fan_out")
|
||||
elif nonlinearity == "elu":
|
||||
nn.init.normal_(linear.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(linear.weight.size(-1)))
|
||||
elif nonlinearity == "selu":
|
||||
nn.init.normal_(linear.weight, std=1 / math.sqrt(linear.weight.size(-1)))
|
||||
elif nonlinearity == "sine":
|
||||
siren.init_weights_(linear, omega_0, is_first)
|
||||
elif nonlinearity in {"sigmoid", "tanh"}:
|
||||
nn.init.xavier_normal_(linear.weight)
|
||||
elif nonlinearity is None:
|
||||
pass # this is effectively uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
||||
else:
|
||||
raise NotImplementedError(nonlinearity)
|
||||
|
||||
# linear + normalize
|
||||
if normalization is None:
|
||||
yield "linear", linear
|
||||
elif normalization == "batchnorm":
|
||||
yield "linear", linear
|
||||
yield "norm", nn.BatchNorm1d(out_features, affine=True)
|
||||
elif normalization == "batchnorm_na":
|
||||
yield "linear", linear
|
||||
yield "norm", nn.BatchNorm1d(out_features, affine=False)
|
||||
elif normalization == "layernorm":
|
||||
yield "linear", linear
|
||||
yield "norm", nn.LayerNorm([out_features], elementwise_affine=True)
|
||||
elif normalization == "layernorm_na":
|
||||
yield "linear", linear
|
||||
yield "norm", nn.LayerNorm([out_features], elementwise_affine=False)
|
||||
elif normalization == "weightnorm":
|
||||
yield "linear", nn.utils.weight_norm(linear)
|
||||
else:
|
||||
raise NotImplementedError(normalization)
|
||||
|
||||
# activation
|
||||
inplace = False
|
||||
if nonlinearity is None : pass
|
||||
elif nonlinearity == "relu" : yield nonlinearity, nn.ReLU(inplace=inplace)
|
||||
elif nonlinearity == "leaky_relu" : yield nonlinearity, nn.LeakyReLU(negative_slope=negative_slope, inplace=inplace)
|
||||
elif nonlinearity == "silu" : yield nonlinearity, nn.SiLU(inplace=inplace)
|
||||
elif nonlinearity == "softplus" : yield nonlinearity, nn.Softplus()
|
||||
elif nonlinearity == "elu" : yield nonlinearity, nn.ELU(inplace=inplace)
|
||||
elif nonlinearity == "selu" : yield nonlinearity, nn.SELU(inplace=inplace)
|
||||
elif nonlinearity == "sine" : yield nonlinearity, siren.Sine(omega_0)
|
||||
elif nonlinearity == "sigmoid" : yield nonlinearity, nn.Sigmoid()
|
||||
elif nonlinearity == "tanh" : yield nonlinearity, nn.Tanh()
|
||||
else : raise NotImplementedError(f"{nonlinearity=}")
|
||||
|
||||
# dropout
|
||||
if dropout_prob > 0:
|
||||
if nonlinearity == "selu":
|
||||
yield "adropout", nn.AlphaDropout(p=dropout_prob)
|
||||
else:
|
||||
yield "dropout", nn.Dropout(p=dropout_prob)
|
||||
|
||||
# residual end
|
||||
if residual_mode is not None:
|
||||
yield "res_b", residual_layer.end
|
||||
|
||||
for name, module in make_layer():
|
||||
self.add_module(name.replace("-", "_"), module)
|
||||
|
||||
@property
|
||||
def nonlinearity(self) -> Optional[nn.Module]:
|
||||
"alias to the activation function submodule"
|
||||
if self.hparams.nonlinearity is None:
|
||||
return None
|
||||
return getattr(self, self.hparams.nonlinearity.replace("-", "_"))
|
||||
|
||||
def initialize_weights():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FCBlock(MultilineReprHyperparametersMixin, MetaSequential):
|
||||
"""
|
||||
A block of FC layers
|
||||
"""
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
hidden_features : int,
|
||||
hidden_layers : int,
|
||||
out_features : int,
|
||||
normalization : Normalization = None,
|
||||
nonlinearity : Nonlinearity = "relu",
|
||||
dropout_prob : float = 0.0,
|
||||
outermost_linear : bool = True, # whether last linear is nonlinear
|
||||
latent_features : Optional[int] = None,
|
||||
concat_skipped_layers : Union[list[int], bool] = [],
|
||||
concat_conditioned_layers : Union[list[int], bool] = [],
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
if isinstance(concat_skipped_layers, bool):
|
||||
concat_skipped_layers = list(range(hidden_layers+2)) if concat_skipped_layers else []
|
||||
if isinstance(concat_conditioned_layers, bool):
|
||||
concat_conditioned_layers = list(range(hidden_layers+2)) if concat_conditioned_layers else []
|
||||
if len(concat_conditioned_layers) != 0 and latent_features is None:
|
||||
raise ValueError("Layers marked to be conditioned without known number of latent features")
|
||||
concat_skipped_layers = [i if i >= 0 else hidden_layers+2-abs(i) for i in concat_skipped_layers]
|
||||
concat_conditioned_layers = [i if i >= 0 else hidden_layers+2-abs(i) for i in concat_conditioned_layers]
|
||||
self._concat_x_layers: frozenset[int] = frozenset(concat_skipped_layers)
|
||||
self._concat_z_layers: frozenset[int] = frozenset(concat_conditioned_layers)
|
||||
if len(self._concat_x_layers) != len(concat_skipped_layers):
|
||||
raise ValueError(f"Duplicates found in {concat_skipped_layers = }")
|
||||
if len(self._concat_z_layers) != len(concat_conditioned_layers):
|
||||
raise ValueError(f"Duplicates found in {concat_conditioned_layers = }")
|
||||
if not all(isinstance(i, int) for i in self._concat_x_layers):
|
||||
raise TypeError(f"Expected only integers in {concat_skipped_layers = }")
|
||||
if not all(isinstance(i, int) for i in self._concat_z_layers):
|
||||
raise TypeError(f"Expected only integers in {concat_conditioned_layers = }")
|
||||
|
||||
def make_layers() -> Iterable[nn.Module]:
|
||||
def make_concat_layer(*idxs: int) -> int:
|
||||
x_condition_this_layer = any(idx in self._concat_x_layers for idx in idxs)
|
||||
z_condition_this_layer = any(idx in self._concat_z_layers for idx in idxs)
|
||||
if x_condition_this_layer and z_condition_this_layer:
|
||||
yield CallbackConcatLayer(lambda: (self._current_x, self._current_z))
|
||||
elif x_condition_this_layer:
|
||||
yield CallbackConcatLayer(lambda: (self._current_x,))
|
||||
elif z_condition_this_layer:
|
||||
yield CallbackConcatLayer(lambda: (self._current_z,))
|
||||
|
||||
return in_features*x_condition_this_layer + (latent_features or 0)*z_condition_this_layer
|
||||
|
||||
added = yield from make_concat_layer(0)
|
||||
|
||||
yield FCLayer(
|
||||
in_features = in_features + added,
|
||||
out_features = hidden_features,
|
||||
nonlinearity = nonlinearity,
|
||||
normalization = normalization,
|
||||
dropout_prob = dropout_prob,
|
||||
is_first = True,
|
||||
is_final = False,
|
||||
**kw,
|
||||
)
|
||||
|
||||
for i in range(hidden_layers):
|
||||
added = yield from make_concat_layer(i+1)
|
||||
|
||||
yield FCLayer(
|
||||
in_features = hidden_features + added,
|
||||
out_features = hidden_features,
|
||||
nonlinearity = nonlinearity,
|
||||
normalization = normalization,
|
||||
dropout_prob = dropout_prob,
|
||||
is_first = False,
|
||||
is_final = False,
|
||||
**kw,
|
||||
)
|
||||
|
||||
added = yield from make_concat_layer(hidden_layers+1)
|
||||
|
||||
nl = nonlinearity
|
||||
|
||||
yield FCLayer(
|
||||
in_features = hidden_features + added,
|
||||
out_features = out_features,
|
||||
nonlinearity = None if outermost_linear else nl,
|
||||
normalization = None if outermost_linear else normalization,
|
||||
dropout_prob = 0.0 if outermost_linear else dropout_prob,
|
||||
is_first = False,
|
||||
is_final = True,
|
||||
**kw,
|
||||
)
|
||||
|
||||
for i, module in enumerate(make_layers()):
|
||||
self.add_module(str(i), module)
|
||||
|
||||
@property
|
||||
def is_conditioned(self) -> bool:
|
||||
"Whether z is used or not"
|
||||
return bool(self._concat_z_layers)
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
@compose(" ".join)
|
||||
def as_jexpr(values: Union[list[int]]):
|
||||
yield "{{"
|
||||
for val, count in run_length_encode(values):
|
||||
yield f"[{val!r}]*{count!r}"
|
||||
yield "}}"
|
||||
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list)
|
||||
yield param.make_jinja_template(FCLayer, top_level=False, exclude_list=exclude_list | {
|
||||
"in_features",
|
||||
"out_features",
|
||||
"nonlinearity",
|
||||
"normalization",
|
||||
"dropout_prob",
|
||||
"is_first",
|
||||
"is_final",
|
||||
})
|
||||
|
||||
def forward(self, input: Tensor, z: Optional[Tensor] = None, *, params: Optional[dict[str, Tensor]]=None):
|
||||
assert not self.is_conditioned or z is not None
|
||||
if z is not None and z.ndim < input.ndim:
|
||||
z = z[(*(None,)*(input.ndim - z.ndim), ...)]
|
||||
self._current_x = input
|
||||
self._current_z = z
|
||||
return super().forward(input, params=params)
|
25
ifield/modules/siren.py
Normal file
25
ifield/modules/siren.py
Normal file
@ -0,0 +1,25 @@
|
||||
from math import sqrt
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
class Sine(nn.Module):
|
||||
def __init__(self, omega_0: float):
|
||||
super().__init__()
|
||||
self.omega_0 = omega_0
|
||||
|
||||
def forward(self, input):
|
||||
if self.omega_0 == 1:
|
||||
return torch.sin(input)
|
||||
else:
|
||||
return torch.sin(input * self.omega_0)
|
||||
|
||||
|
||||
def init_weights_(module: nn.Linear, omega_0: float, is_first: bool = True):
|
||||
assert isinstance(module, nn.Linear), module
|
||||
with torch.no_grad():
|
||||
mag = (
|
||||
1 / module.in_features
|
||||
if is_first else
|
||||
sqrt(6 / module.in_features) / omega_0
|
||||
)
|
||||
module.weight.uniform_(-mag, mag)
|
231
ifield/param.py
Normal file
231
ifield/param.py
Normal file
@ -0,0 +1,231 @@
|
||||
from .utils.helpers import compose, elementwise_max
|
||||
from datetime import datetime
|
||||
from torch import nn
|
||||
from typing import Any, Literal, Iterable, Union, Callable, Optional
|
||||
import inspect
|
||||
import jinja2
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shlex
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
import yaml
|
||||
|
||||
_UNDEFINED = " I AM UNDEFINED "
|
||||
|
||||
def _yaml_encode_value(val) -> str:
|
||||
if isinstance(val, tuple):
|
||||
val = list(val)
|
||||
elif isinstance(val, set):
|
||||
val = list(val)
|
||||
if isinstance(val, list):
|
||||
return json.dumps(val)
|
||||
elif isinstance(val, dict):
|
||||
return json.dumps(val)
|
||||
else:
|
||||
return yaml.dump(val).removesuffix("\n...\n").rstrip("\n")
|
||||
|
||||
def _raise(val: Union[Exception, str]):
|
||||
if isinstance(val, str):
|
||||
val = jinja2.TemplateError(val)
|
||||
raise val
|
||||
|
||||
def make_jinja_globals(*, enable_require_defined: bool) -> dict:
|
||||
import builtins
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
import json
|
||||
|
||||
def require_defined(name, value, *defaults, failed: bool = False, strict: bool=False, exchaustive=False):
|
||||
if not defaults:
|
||||
raise ValueError("`require_defined` requires at least one valid value provided")
|
||||
if jinja2.is_undefined(value):
|
||||
assert value._undefined_name == name, \
|
||||
f"Name mismatch: {value._undefined_name=}, {name=}"
|
||||
if failed or jinja2.is_undefined(value):
|
||||
if enable_require_defined or strict:
|
||||
raise ValueError(
|
||||
f"Required variable {name!r} "
|
||||
f"is {'incorrect' if failed else 'undefined'}! "
|
||||
f"Try providing:\n" + "\n".join(
|
||||
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
|
||||
for default in defaults
|
||||
)
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Required variable {name!r} "
|
||||
f"is {'incorrect' if failed else 'undefined'}! "
|
||||
f"Try providing:\n" + "\n".join(
|
||||
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
|
||||
for default in defaults
|
||||
)
|
||||
)
|
||||
if exchaustive and not jinja2.is_undefined(value) and value not in defaults:
|
||||
raise ValueError(
|
||||
f"Variable {name!r} not in list of allowed values: {defaults!r}"
|
||||
)
|
||||
|
||||
def gen_run_uid(n: int, _choice = random.Random(time.time_ns()).choice):
|
||||
"""
|
||||
generates a UID for the experiment run, nice for regexes, grepping and timekeeping.
|
||||
"""
|
||||
# we have _choice, since most likely, pl.seed_everything has been run by this point
|
||||
# we store it as a default parameter to reuse it, on the off-chance of two calls to this function being run withion the same ns
|
||||
code = ''.join(_choice(string.ascii_lowercase) for _ in range(n))
|
||||
return f"{datetime.now():%Y-%m-%d-%H%M}-{code}"
|
||||
return f"{datetime.now():%Y%m%d-%H%M}-{code}"
|
||||
|
||||
def cartesian_hparams(_map=None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
|
||||
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
|
||||
if isinstance(_map, jinja2.utils.Namespace):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
elif isinstance(_map, dict):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
keys, vals = zip(*kw.items())
|
||||
for i in itertools.product(*vals):
|
||||
yield jinja2.utils.Namespace(zip(keys, i))
|
||||
|
||||
def ablation_hparams(_map=None, *, caartesian_keys: list[str] = None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
|
||||
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
|
||||
if isinstance(_map, jinja2.utils.Namespace):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
elif isinstance(_map, dict):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
keys = list(kw.keys())
|
||||
|
||||
caartesian_keys = [k for k in keys if k in caartesian_keys] if caartesian_keys else []
|
||||
ablation_keys = [k for k in keys if k not in caartesian_keys]
|
||||
caartesian_vals = list(map(kw.__getitem__, caartesian_keys))
|
||||
ablation_vals = list(map(kw.__getitem__, ablation_keys))
|
||||
|
||||
for base_vals in itertools.product(*caartesian_vals):
|
||||
base = list(itertools.chain(zip(caartesian_keys, base_vals), zip(ablation_keys, [i[0] for i in ablation_vals])))
|
||||
yield jinja2.utils.Namespace(base)
|
||||
for ablation_key, ablation_val in zip(ablation_keys, ablation_vals):
|
||||
for val in ablation_val[1:]:
|
||||
yield jinja2.utils.Namespace(base, **{ablation_key: val}) # ablation variation
|
||||
|
||||
return {
|
||||
**locals(),
|
||||
**vars(builtins),
|
||||
"argv": sys.argv,
|
||||
"raise": _raise,
|
||||
}
|
||||
|
||||
def make_jinja_env(globals = make_jinja_globals(enable_require_defined=True), allow_undef=False) -> jinja2.Environment:
|
||||
env = jinja2.Environment(
|
||||
loader = jinja2.FileSystemLoader([os.getcwd(), "/"], followlinks=True),
|
||||
autoescape = False,
|
||||
trim_blocks = True,
|
||||
lstrip_blocks = True,
|
||||
undefined = jinja2.Undefined if allow_undef else jinja2.StrictUndefined,
|
||||
extensions = [
|
||||
"jinja2.ext.do", # statements with side-effects
|
||||
"jinja2.ext.loopcontrols", # break and continue
|
||||
],
|
||||
)
|
||||
env.globals.update(globals)
|
||||
env.filters.update({
|
||||
"defined": lambda x: _raise(f"{x._undefined_name!r} is not defined!") if jinja2.is_undefined(x) else x,
|
||||
"repr": repr,
|
||||
"to_json": json.dumps,
|
||||
"bool": lambda x: json.dumps(bool(x)),
|
||||
"int": lambda x: json.dumps(int(x)),
|
||||
"float": lambda x: json.dumps(float(x)),
|
||||
"str": lambda x: json.dumps(str(x)),
|
||||
})
|
||||
return env
|
||||
|
||||
def list_func_params(func: callable, exclude_list: set[str], defaults: dict={}) -> Iterable[tuple[str, Any, str]]:
|
||||
signature = inspect.signature(func)
|
||||
for i, (k, v) in enumerate(signature.parameters.items()):
|
||||
if not i and k in {"self", "cls"}:
|
||||
continue
|
||||
if k in exclude_list:
|
||||
continue
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
if v.kind is v.VAR_POSITIONAL or v.kind is v.VAR_KEYWORD:
|
||||
continue
|
||||
has_default = not defaults.get(k, v.default) is v.empty
|
||||
has_annotation = not v.annotation is v.empty
|
||||
allowed_literals = f"{{{', '.join(map(_yaml_encode_value, typing.get_args(v.annotation)))}}}" \
|
||||
if typing.get_origin(v.annotation) is Literal else None
|
||||
|
||||
assert has_annotation, f"param {k!r} has no type annotation"
|
||||
yield (
|
||||
k,
|
||||
defaults.get(k, v.default) if has_default else _UNDEFINED,
|
||||
f"in {allowed_literals}" if allowed_literals else typing._type_repr(v.annotation),
|
||||
)
|
||||
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(
|
||||
network_cls: nn.Module,
|
||||
*,
|
||||
exclude_list: set[str] = set(),
|
||||
defaults: dict[str, Any]={},
|
||||
top_level: bool = True,
|
||||
commented: bool = False,
|
||||
name=None,
|
||||
comment: Optional[str] = None,
|
||||
special_encoders: dict[str, Callable[[Any], str]]={},
|
||||
) -> str:
|
||||
c = "#" if commented else ""
|
||||
if name is None:
|
||||
name = network_cls.__name__
|
||||
|
||||
if comment is not None:
|
||||
if "\n" in comment:
|
||||
raise ValueError("newline in jinja template comment is not allowed")
|
||||
|
||||
hparams = [*list_func_params(network_cls, exclude_list, defaults=defaults)]
|
||||
if not hparams:
|
||||
if top_level:
|
||||
yield f"{name}:"
|
||||
else:
|
||||
yield f" # {name}:"
|
||||
return
|
||||
|
||||
|
||||
encoded_hparams = [
|
||||
(key, _yaml_encode_value(value) if value is not _UNDEFINED else "", comment)
|
||||
if key not in special_encoders else
|
||||
(key, special_encoders[key](value) if value is not _UNDEFINED else "", comment)
|
||||
for key, value, comment in hparams
|
||||
]
|
||||
|
||||
ml_key, ml_value = elementwise_max(
|
||||
(
|
||||
len(key),
|
||||
len(value),
|
||||
)
|
||||
for key, value, comment in encoded_hparams
|
||||
)
|
||||
|
||||
if top_level:
|
||||
yield f"{name}:" if not comment else f"{name}: # {comment}"
|
||||
else:
|
||||
yield f" # {name}:" if not comment else f" # {name}: # {comment}"
|
||||
|
||||
for key, value, comment in encoded_hparams:
|
||||
if key in exclude_list:
|
||||
continue
|
||||
pad_key = ml_key - len(key)
|
||||
pad_value = ml_value - len(value)
|
||||
|
||||
yield f" {c}{key}{' '*pad_key} : {value}{' '*pad_value} # {comment}"
|
||||
|
||||
yield ""
|
||||
|
||||
# helpers:
|
||||
|
||||
def squash_newlines(data: str) -> str:
|
||||
return re.sub(r'\n\n\n+', '\n\n', data)
|
0
ifield/utils/__init__.py
Normal file
0
ifield/utils/__init__.py
Normal file
197
ifield/utils/geometry.py
Normal file
197
ifield/utils/geometry.py
Normal file
@ -0,0 +1,197 @@
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, Literal
|
||||
import torch
|
||||
from .helpers import compose
|
||||
|
||||
|
||||
def get_ray_origins(cam2world: Tensor):
|
||||
return cam2world[..., :3, 3]
|
||||
|
||||
def camera_uv_to_rays(
|
||||
cam2world : Tensor,
|
||||
uv : Tensor,
|
||||
intrinsics : Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Computes rays and origins from batched cam2world & intrinsics matrices, as well as pixel coordinates
|
||||
cam2world: (..., 4, 4)
|
||||
intrinsics: (..., 3, 3)
|
||||
uv: (..., n, 2)
|
||||
"""
|
||||
ray_dirs = get_ray_directions(uv, cam2world=cam2world, intrinsics=intrinsics)
|
||||
ray_origins = get_ray_origins(cam2world)
|
||||
ray_origins = ray_origins[..., None, :].expand([*uv.shape[:-1], 3])
|
||||
return ray_origins, ray_dirs
|
||||
|
||||
RayEmbedding = Literal[
|
||||
"plucker", # LFN
|
||||
"perp_foot", # PRIF
|
||||
"both",
|
||||
]
|
||||
|
||||
@compose(torch.cat, dim=-1)
|
||||
@compose(tuple)
|
||||
def ray_input_embedding(ray_origins: Tensor, ray_dirs: Tensor, mode: RayEmbedding = "plucker", normalize_dirs=False, is_training=False):
|
||||
"""
|
||||
Computes the plucker coordinates / perpendicular foot from ray origins and directions, appending it to direction
|
||||
"""
|
||||
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
|
||||
f"{ray_dirs.shape = }, {ray_origins.shape = }"
|
||||
|
||||
if normalize_dirs:
|
||||
ray_dirs = ray_dirs / ray_dirs.norm(dim=-1, keepdim=True)
|
||||
|
||||
yield ray_dirs
|
||||
|
||||
do_moment = mode in ("plucker", "both")
|
||||
do_perp_feet = mode in ("perp_foot", "both")
|
||||
assert do_moment or do_perp_feet
|
||||
|
||||
moment = torch.cross(ray_origins, ray_dirs, dim=-1)
|
||||
if do_moment:
|
||||
yield moment
|
||||
|
||||
if do_perp_feet:
|
||||
perp_feet = torch.cross(ray_dirs, moment, dim=-1)
|
||||
yield perp_feet
|
||||
|
||||
def ray_input_embedding_length(mode: RayEmbedding = "plucker") -> int:
|
||||
do_moment = mode in ("plucker", "both")
|
||||
do_perp_feet = mode in ("perp_foot", "both")
|
||||
assert do_moment or do_perp_feet
|
||||
|
||||
out = 3 # ray_dirs
|
||||
if do_moment:
|
||||
out += 3 # moment
|
||||
if do_perp_feet:
|
||||
out += 3 # perp foot
|
||||
return out
|
||||
|
||||
def parse_intrinsics(intrinsics, return_dict=False):
|
||||
fx = intrinsics[..., 0, 0:1]
|
||||
fy = intrinsics[..., 1, 1:2]
|
||||
cx = intrinsics[..., 0, 2:3]
|
||||
cy = intrinsics[..., 1, 2:3]
|
||||
if return_dict:
|
||||
return {"fx": fx, "fy": fy, "cx": cx, "cy": cy}
|
||||
else:
|
||||
return fx, fy, cx, cy
|
||||
|
||||
def expand_as(x, y):
|
||||
if len(x.shape) == len(y.shape):
|
||||
return x
|
||||
|
||||
for i in range(len(y.shape) - len(x.shape)):
|
||||
x = x.unsqueeze(-1)
|
||||
|
||||
return x
|
||||
|
||||
def lift(x, y, z, intrinsics, homogeneous=False):
|
||||
"""
|
||||
|
||||
:param self:
|
||||
:param x: Shape (batch_size, num_points)
|
||||
:param y:
|
||||
:param z:
|
||||
:param intrinsics:
|
||||
:return:
|
||||
"""
|
||||
fx, fy, cx, cy = parse_intrinsics(intrinsics)
|
||||
|
||||
x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
|
||||
y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z
|
||||
|
||||
if homogeneous:
|
||||
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
|
||||
else:
|
||||
return torch.stack((x_lift, y_lift, z), dim=-1)
|
||||
|
||||
def project(x, y, z, intrinsics):
|
||||
"""
|
||||
|
||||
:param self:
|
||||
:param x: Shape (batch_size, num_points)
|
||||
:param y:
|
||||
:param z:
|
||||
:param intrinsics:
|
||||
:return:
|
||||
"""
|
||||
fx, fy, cx, cy = parse_intrinsics(intrinsics)
|
||||
|
||||
x_proj = expand_as(fx, x) * x / z + expand_as(cx, x)
|
||||
y_proj = expand_as(fy, y) * y / z + expand_as(cy, y)
|
||||
|
||||
return torch.stack((x_proj, y_proj, z), dim=-1)
|
||||
|
||||
def world_from_xy_depth(xy, depth, cam2world, intrinsics):
|
||||
batch_size, *_ = cam2world.shape
|
||||
|
||||
x_cam = xy[..., 0]
|
||||
y_cam = xy[..., 1]
|
||||
z_cam = depth
|
||||
|
||||
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True)
|
||||
world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[..., :3]
|
||||
|
||||
return world_coords
|
||||
|
||||
def project_point_on_ray(projection_point, ray_origin, ray_dir):
|
||||
dot = torch.einsum("...j,...j", projection_point-ray_origin, ray_dir)
|
||||
return ray_origin + dot[..., None] * ray_dir
|
||||
|
||||
def get_ray_directions(
|
||||
xy : Tensor, # (..., N, 2)
|
||||
cam2world : Tensor, # (..., 4, 4)
|
||||
intrinsics : Tensor, # (..., 3, 3)
|
||||
):
|
||||
z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
|
||||
pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3)
|
||||
|
||||
cam_pos = cam2world[..., :3, 3]
|
||||
ray_dirs = pixel_points - cam_pos[..., None, :] # (batch, num_samples, 3)
|
||||
ray_dirs = F.normalize(ray_dirs, dim=-1)
|
||||
return ray_dirs
|
||||
|
||||
def ray_sphere_intersect(
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
sphere_centers : Optional[Tensor] = None, # (..., 3)
|
||||
sphere_radii : Optional[Tensor] = 1, # (...)
|
||||
*,
|
||||
return_parts : bool = False,
|
||||
allow_nans : bool = True,
|
||||
improve_miss_grads : bool = False,
|
||||
) -> tuple[Tensor, ...]:
|
||||
if improve_miss_grads: assert not allow_nans, "improve_miss_grads does not work with allow_nans"
|
||||
if sphere_centers is None:
|
||||
ray_origins_centered = ray_origins #- torch.zeros_like(ray_origins)
|
||||
else:
|
||||
ray_origins_centered = ray_origins - sphere_centers
|
||||
|
||||
ray_dir_dot_origins = (ray_dirs * ray_origins_centered).sum(dim=-1, keepdim=True)
|
||||
discriminants2 = ray_dir_dot_origins**2 - ((ray_origins_centered * ray_origins_centered).sum(dim=-1) - sphere_radii**2)[..., None]
|
||||
if not allow_nans or return_parts:
|
||||
is_intersecting = discriminants2 > 0
|
||||
if allow_nans:
|
||||
discriminants = torch.sqrt(discriminants2)
|
||||
else:
|
||||
discriminants = torch.sqrt(torch.where(is_intersecting, discriminants2,
|
||||
discriminants2 - discriminants2.detach() + 0.001
|
||||
if improve_miss_grads else
|
||||
torch.zeros_like(discriminants2)
|
||||
))
|
||||
assert not discriminants.detach().isnan().any() # slow, use optimizations!
|
||||
|
||||
if not return_parts:
|
||||
return (
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins),
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins - discriminants),
|
||||
ray_origins + ray_dirs * (- ray_dir_dot_origins + discriminants),
|
||||
is_intersecting.squeeze(-1),
|
||||
)
|
205
ifield/utils/helpers.py
Normal file
205
ifield/utils/helpers.py
Normal file
@ -0,0 +1,205 @@
|
||||
from functools import wraps, reduce, partial
|
||||
from itertools import zip_longest, groupby
|
||||
from pathlib import Path
|
||||
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
|
||||
import collections
|
||||
import operator
|
||||
import re
|
||||
|
||||
Numeric = Union[int, float, complex]
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
# decorator
|
||||
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
||||
def wrapper(inner_func: Callable[..., S]):
|
||||
@wraps(inner_func)
|
||||
def wrapped(*a, **kw):
|
||||
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
||||
def wrapper(inner_func: Callable[..., S]):
|
||||
@wraps(inner_func)
|
||||
def wrapped(*a, **kw):
|
||||
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
|
||||
# itertools
|
||||
|
||||
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
|
||||
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
|
||||
|
||||
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return reduce(operator.mul, numbers)
|
||||
|
||||
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
|
||||
return (
|
||||
(x, len(y))
|
||||
for x, y in groupby(data)
|
||||
)
|
||||
|
||||
|
||||
# text conversion
|
||||
|
||||
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
||||
parts = (
|
||||
part.lower()
|
||||
for part in re.split(r'(?=[A-Z])', text)
|
||||
if part
|
||||
)
|
||||
if join_abbreviations:
|
||||
parts = list(parts)
|
||||
if len(parts) > 1:
|
||||
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
||||
if len(a) == len(b) == 1:
|
||||
parts[i] = parts[i] + parts.pop(i+1)
|
||||
return sep.join(parts)
|
||||
|
||||
def snake_to_camel_case(text: str) -> str:
|
||||
return "".join(
|
||||
part.captialize()
|
||||
for part in text.split("_")
|
||||
if part
|
||||
)
|
||||
|
||||
|
||||
# textwrap
|
||||
|
||||
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
|
||||
sub = (len(data) + 1) // n_columns
|
||||
return reduce(partial(columnize, sep=sep),
|
||||
(
|
||||
columnize(
|
||||
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
|
||||
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
|
||||
)
|
||||
for n in range(n_columns)
|
||||
)
|
||||
)
|
||||
|
||||
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
|
||||
left = left .split("\n")
|
||||
right = right.split("\n")
|
||||
width = max(map(len, left)) if left else 0
|
||||
return "\n".join(
|
||||
f"{prefix}{a.ljust(width)}{sep}{b}"
|
||||
if b else
|
||||
f"{prefix}{a}"
|
||||
for a, b in zip_longest(left, right, fillvalue="")
|
||||
)
|
||||
|
||||
|
||||
# pathlib
|
||||
|
||||
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if parent is None:
|
||||
parent = Path.cwd()
|
||||
try:
|
||||
return path.relative_to(parent)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return ".." / path.relative_to(parent.parent)
|
||||
except ValueError:
|
||||
pass
|
||||
return path
|
||||
|
||||
|
||||
# dictionaries
|
||||
|
||||
def update_recursive(target: dict, source: dict):
|
||||
""" Update two config dictionaries recursively. """
|
||||
for k, v in source.items():
|
||||
if isinstance(v, dict):
|
||||
if k not in target:
|
||||
target[k] = type(target)()
|
||||
update_recursive(target[k], v)
|
||||
else:
|
||||
target[k] = v
|
||||
|
||||
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
|
||||
if isinstance(val, collections.abc.Mapping):
|
||||
return {
|
||||
k: map_tree(func, subval)
|
||||
for k, subval in val.items()
|
||||
}
|
||||
elif isinstance(val, tuple):
|
||||
return tuple(
|
||||
map_tree(func, subval)
|
||||
for subval in val
|
||||
)
|
||||
elif isinstance(val, list):
|
||||
return [
|
||||
map_tree(func, subval)
|
||||
for subval in val
|
||||
]
|
||||
else:
|
||||
return func(val)
|
||||
|
||||
def flatten_tree(val, *, sep=".", prefix=None):
|
||||
if isinstance(val, collections.abc.Mapping):
|
||||
return {
|
||||
k: v
|
||||
for subkey, subval in val.items()
|
||||
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
|
||||
}
|
||||
elif isinstance(val, tuple) or isinstance(val, list):
|
||||
return {
|
||||
k: v
|
||||
for index, subval in enumerate(val)
|
||||
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
|
||||
}
|
||||
elif prefix:
|
||||
return {prefix: val}
|
||||
else:
|
||||
return val
|
||||
|
||||
# conversions
|
||||
|
||||
def hex2tuple(data: str) -> tuple[int]:
|
||||
data = data.removeprefix("#")
|
||||
return (*(
|
||||
int(data[i:i+2], 16)
|
||||
for i in range(0, len(data), 2)
|
||||
),)
|
||||
|
||||
|
||||
# repr shims
|
||||
|
||||
class CustomRepr:
|
||||
def __init__(self, repr_str: str):
|
||||
self.repr_str = repr_str
|
||||
def __str__(self):
|
||||
return self.repr_str
|
||||
def __repr__(self):
|
||||
return self.repr_str
|
||||
|
||||
|
||||
# Meta Params Module proxy
|
||||
|
||||
class MetaModuleProxy:
|
||||
def __init__(self, module, params):
|
||||
self._module = module
|
||||
self._params = params
|
||||
|
||||
def __getattr__(self, name):
|
||||
params = super().__getattribute__("_params")
|
||||
if name in params:
|
||||
return params[name]
|
||||
else:
|
||||
return getattr(super().__getattribute__("_module"), name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name not in ("_params", "_module"):
|
||||
super().__getattribute__("_params")[name] = value
|
||||
else:
|
||||
super().__setattr__(name, value)
|
590
ifield/utils/loss.py
Normal file
590
ifield/utils/loss.py
Normal file
@ -0,0 +1,590 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
from functools import wraps
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.artist import Artist
|
||||
from tabulate import tabulate
|
||||
from torch import nn
|
||||
from typing import Optional, TypeVar, Union
|
||||
import inspect
|
||||
import math
|
||||
import pytorch_lightning as pl
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
|
||||
HParamSchedule = TypeVar("HParamSchedule", bound="HParamScheduleBase")
|
||||
Schedulable = Union[HParamSchedule, int, float, str]
|
||||
|
||||
class HParamScheduleBase(ABC):
|
||||
_subclasses = {} # shared reference intended
|
||||
def __init_subclass__(cls):
|
||||
if not cls.__name__.startswith("_"):
|
||||
cls._subclasses[cls.__name__] = cls
|
||||
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=None)
|
||||
_param_name : Optional[str] = field(init=False, repr=False, default=None)
|
||||
_expr : Optional[str] = field(init=False, repr=False, default=None)
|
||||
|
||||
def get(self, module: nn.Module, *, trainer: Optional[pl.Trainer] = None) -> float:
|
||||
if module.training:
|
||||
if trainer is None:
|
||||
trainer = module.trainer # this assumes `module` is a pl.LightningModule
|
||||
value = self.get_train_value(
|
||||
epoch = trainer.current_epoch + (trainer.fit_loop.epoch_loop.batch_progress.current.processed / trainer.num_training_batches),
|
||||
)
|
||||
if trainer.logger is not None and self._param_name is not None and self.__class__ is not Const and trainer.global_step % 15 == 0:
|
||||
trainer.logger.log_metrics({
|
||||
f"HParamSchedule/{self._param_name}": value,
|
||||
}, step=trainer.global_step)
|
||||
return value
|
||||
else:
|
||||
return self.get_eval_value()
|
||||
|
||||
def _gen_data(self, n_epochs, steps_per_epoch=1000):
|
||||
global_steps = 0
|
||||
for epoch in range(n_epochs):
|
||||
for step in range(steps_per_epoch):
|
||||
yield (
|
||||
epoch + step/steps_per_epoch,
|
||||
self.get_train_value(epoch + step/steps_per_epoch),
|
||||
)
|
||||
global_steps += steps_per_epoch
|
||||
|
||||
def plot(self, *a, ax: Optional[plt.Axes] = None, **kw) -> Artist:
|
||||
if ax is None: ax = plt.gca()
|
||||
out = ax.plot(*zip(*self._gen_data(*a, **kw)), label=self._expr)
|
||||
ax.set_title(self._param_name)
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Value")
|
||||
ax.legend()
|
||||
return out
|
||||
|
||||
def assert_positive(self, *a, **kw):
|
||||
for epoch, val in self._gen_data(*a, **kw):
|
||||
assert val >= 0, f"{epoch=}, {val=}"
|
||||
|
||||
@abstractmethod
|
||||
def get_eval_value(self) -> float:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
...
|
||||
|
||||
def __add__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "+":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __radd__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "+":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __sub__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rsub__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "*":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmul__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "*":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "@":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "@":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "/":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rtruediv__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "/":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "//":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rfloordiv__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "//":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __mod__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "%":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rmod__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "%":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __pow__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "**":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rpow__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "**":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __lshift__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<<":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rlshift__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<<":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __rshift__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">>":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rrshift__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">>":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __and__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "&":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rand__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "&":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __xor__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "^":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __rxor__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "^":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __or__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "|":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __ror__(self, lhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "|":
|
||||
return cls(lhs, self)
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">=":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == ">":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<=":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, rhs):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "<":
|
||||
return cls(self, rhs)
|
||||
return NotImplemented
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __neg__(self):
|
||||
for cls in self._subclasses.values():
|
||||
if cls._infix == "-":
|
||||
return cls(0, self)
|
||||
return NotImplemented
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def parse_dsl(config: Schedulable, name=None) -> HParamSchedule:
|
||||
if isinstance(config, HParamScheduleBase):
|
||||
return config
|
||||
elif isinstance(config, str):
|
||||
out = eval(config, {"__builtins__": {}, "lg": math.log10}, HParamScheduleBase._subclasses)
|
||||
if not isinstance(out, HParamScheduleBase):
|
||||
out = Const(out)
|
||||
else:
|
||||
out = Const(config)
|
||||
out._expr = config
|
||||
out._param_name = name
|
||||
return out
|
||||
|
||||
|
||||
# decorator
|
||||
def ensure_schedulables(func):
|
||||
signature = inspect.signature(func)
|
||||
module_name = func.__qualname__.removesuffix(".__init__")
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*a, **kw):
|
||||
bound_args = signature.bind(*a, **kw)
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
type_origin = typing.get_origin(param.annotation)
|
||||
type_args = typing.get_args (param.annotation)
|
||||
|
||||
if type_origin is HParamSchedule or (type_origin is Union and (HParamSchedule in type_args or HParamScheduleBase in type_args)):
|
||||
if param_name in bound_args.arguments:
|
||||
bound_args.arguments[param_name] = parse_dsl(bound_args.arguments[param_name], name=f"{module_name}.{param_name}")
|
||||
elif param.default is not param.empty:
|
||||
bound_args.arguments[param_name] = parse_dsl(param.default, name=f"{module_name}.{param_name}")
|
||||
|
||||
return func(
|
||||
*bound_args.args,
|
||||
**bound_args.kwargs,
|
||||
)
|
||||
return wrapper
|
||||
|
||||
# https://easings.net/
|
||||
|
||||
@dataclass
|
||||
class _InfixBase(HParamScheduleBase):
|
||||
l : Union[HParamSchedule, int, float]
|
||||
r : Union[HParamSchedule, int, float]
|
||||
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self._operation(
|
||||
self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) else self.l,
|
||||
self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) else self.r,
|
||||
)
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self._operation(
|
||||
self.l.get_train_value(epoch) if isinstance(self.l, HParamScheduleBase) else self.l,
|
||||
self.r.get_train_value(epoch) if isinstance(self.r, HParamScheduleBase) else self.r,
|
||||
)
|
||||
|
||||
def __bool__(self):
|
||||
if self.is_const:
|
||||
return bool(self.get_eval_value())
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return (self.l.is_const if isinstance(self.l, HParamScheduleBase) else True) \
|
||||
and (self.r.is_const if isinstance(self.r, HParamScheduleBase) else True)
|
||||
|
||||
@dataclass
|
||||
class Add(_InfixBase):
|
||||
""" adds the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="+")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l + r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sub(_InfixBase):
|
||||
""" subtracts the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="-")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l - r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Prod(_InfixBase):
|
||||
""" multiplies the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="*")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l * r
|
||||
@property
|
||||
def is_const(self) -> bool: # propagate identity
|
||||
l = self.l.get_eval_value() if isinstance(self.l, HParamScheduleBase) and self.l.is_const else self.l
|
||||
r = self.r.get_eval_value() if isinstance(self.r, HParamScheduleBase) and self.r.is_const else self.r
|
||||
return l == 0 or r == 0 or super().is_const
|
||||
|
||||
|
||||
@dataclass
|
||||
class Div(_InfixBase):
|
||||
""" divides the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="/")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l / r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pow(_InfixBase):
|
||||
""" raises the results of one schedule to the other """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="**")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l ** r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gt(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=">")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l > r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lt(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="<")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l < r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ge(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default=">=")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l >= r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Le(_InfixBase):
|
||||
""" compares the results of two schedules """
|
||||
_infix : Optional[str] = field(init=False, repr=False, default="<=")
|
||||
def _operation(self, l: float, r: float) -> float:
|
||||
return l <= r
|
||||
|
||||
|
||||
@dataclass
|
||||
class Const(HParamScheduleBase):
|
||||
""" A way to ensure .get(...) exists """
|
||||
|
||||
c : Union[int, float]
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.c
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.c
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.get_eval_value())
|
||||
|
||||
@property
|
||||
def is_const(self) -> bool:
|
||||
return True
|
||||
|
||||
@dataclass
|
||||
class Step(HParamScheduleBase):
|
||||
""" steps from 0 to 1 at `epoch` """
|
||||
|
||||
epoch : float
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return 1 if epoch >= self.epoch else 0
|
||||
|
||||
@dataclass
|
||||
class Linear(HParamScheduleBase):
|
||||
""" linear from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
if self.n_epochs <= 0: return 1
|
||||
return min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
|
||||
@dataclass
|
||||
class EaseSin(HParamScheduleBase): # effectively 1-CosineAnnealing
|
||||
""" sinusoidal ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
return -(math.cos(math.pi * x) - 1) / 2
|
||||
|
||||
@dataclass
|
||||
class EaseExp(HParamScheduleBase):
|
||||
""" exponential ease in-out from 0 to 1 over `n_epochs`, delayed by `offset` """
|
||||
|
||||
n_epochs : float
|
||||
offset : float = 0
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
if (epoch-self.offset) < 0:
|
||||
return 0
|
||||
if (epoch-self.offset) > self.n_epochs:
|
||||
return 1
|
||||
x = min(max(epoch - self.offset, 0) / self.n_epochs, 1)
|
||||
return (
|
||||
2**(20*x-10) / 2
|
||||
if x < 0.5 else
|
||||
(2 - 2**(-20*x+10)) / 2
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Steps(HParamScheduleBase):
|
||||
""" Starts at 1, multiply by gamma every n epochs. Models StepLR in pytorch """
|
||||
step_size: float
|
||||
gamma: float = 0.1
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.gamma**int(epoch / self.step_size)
|
||||
|
||||
@dataclass
|
||||
class MultiStep(HParamScheduleBase):
|
||||
""" Starts at 1, multiply by gamma every milstone epoch. Models MultiStepLR in pytorch """
|
||||
milestones: list[float]
|
||||
gamma: float = 0.1
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 1
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
for i, m in list(enumerate(self.milestones))[::-1]:
|
||||
if epoch >= m:
|
||||
return self.gamma**(i+1)
|
||||
return 1
|
||||
|
||||
@dataclass
|
||||
class Epoch(HParamScheduleBase):
|
||||
""" The current epoch, starting at 0 """
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return 0
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return epoch
|
||||
|
||||
@dataclass
|
||||
class Offset(HParamScheduleBase):
|
||||
""" Offsets the epoch for the subexpression, clamped above 0. Positive offsets makes it happen later """
|
||||
expr : Union[HParamSchedule, int, float]
|
||||
offset : float
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.expr.get_train_value(max(epoch - self.offset, 0)) if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
|
||||
@dataclass
|
||||
class Mod(HParamScheduleBase):
|
||||
""" The epoch in the subexptression is subject to a modulus. Use for warm restarts """
|
||||
|
||||
modulus : float
|
||||
expr : Union[HParamSchedule, int, float]
|
||||
|
||||
def get_eval_value(self) -> float:
|
||||
return self.expr.get_eval_value() if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
def get_train_value(self, epoch: float) -> float:
|
||||
return self.expr.get_train_value(epoch % self.modulus) if isinstance(self.expr, HParamScheduleBase) else self.expr
|
||||
|
||||
|
||||
def main():
|
||||
import sys, rich.pretty
|
||||
if not sys.argv[2:]:
|
||||
print(f"Usage: {sys.argv[0]} n_epochs 'expression'")
|
||||
print("Available operations:")
|
||||
def mk_ops():
|
||||
for name, cls in HParamScheduleBase._subclasses.items():
|
||||
if isinstance(cls._infix, str):
|
||||
yield (cls._infix, f"(infix) {cls.__doc__.strip()}")
|
||||
else:
|
||||
yield (
|
||||
f"""{name}({', '.join(
|
||||
i.name
|
||||
if i.default is MISSING else
|
||||
f"{i.name}={i.default!r}"
|
||||
for i in fields(cls)
|
||||
)})""",
|
||||
cls.__doc__.strip(),
|
||||
)
|
||||
rich.print(tabulate(sorted(mk_ops()), tablefmt="plain"))
|
||||
else:
|
||||
n_epochs = int(sys.argv[1])
|
||||
schedules = [parse_dsl(arg, name="cli arg") for arg in sys.argv[2:]]
|
||||
ax = plt.gca()
|
||||
print("[")
|
||||
for schedule in schedules:
|
||||
rich.print(f" {schedule}, # {schedule.is_const = }")
|
||||
schedule.plot(n_epochs, ax=ax)
|
||||
print("]")
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
ifield/utils/operators/__init__.py
Normal file
0
ifield/utils/operators/__init__.py
Normal file
96
ifield/utils/operators/diff.py
Normal file
96
ifield/utils/operators/diff.py
Normal file
@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from torch.autograd import grad
|
||||
|
||||
|
||||
def hessian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
|
||||
"""
|
||||
hessian of y wrt x
|
||||
y: shape (..., Y)
|
||||
x: shape (..., X)
|
||||
return: shape (..., Y, X, X)
|
||||
"""
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
|
||||
grad_y = torch.ones_like(y[..., 0]).to(y.device) # reuse -> less memory
|
||||
|
||||
hess = torch.stack([
|
||||
# calculate hessian on y for each x value
|
||||
torch.stack(
|
||||
gradients(
|
||||
*(dydx[..., j] for j in range(x.shape[-1])),
|
||||
wrt=x,
|
||||
grad_outputs=[grad_y]*x.shape[-1],
|
||||
detach=detach,
|
||||
),
|
||||
dim = -2,
|
||||
)
|
||||
# calculate dydx over batches for each feature value of y
|
||||
for dydx in gradients(*(y[..., i] for i in range(y.shape[-1])), wrt=x)
|
||||
], dim=-3)
|
||||
|
||||
if check:
|
||||
assert hess.isnan().any()
|
||||
return hess
|
||||
|
||||
def laplace(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
return divergence(*gradients(y, wrt=x), x)
|
||||
|
||||
def divergence(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
return sum(
|
||||
grad(
|
||||
y[..., i],
|
||||
x,
|
||||
torch.ones_like(y[..., i]),
|
||||
create_graph=True
|
||||
)[0][..., i:i+1]
|
||||
for i in range(y.shape[-1])
|
||||
)
|
||||
|
||||
def gradients(*ys, wrt, grad_outputs=None, detach=False) -> list[torch.Tensor]:
|
||||
assert wrt.requires_grad
|
||||
assert all(y.grad_fn for y in ys)
|
||||
if grad_outputs is None:
|
||||
grad_outputs = [torch.ones_like(y) for y in ys]
|
||||
|
||||
grads = (
|
||||
grad(
|
||||
[y],
|
||||
[wrt],
|
||||
grad_outputs=y_grad,
|
||||
create_graph=True,
|
||||
)[0]
|
||||
for y, y_grad in zip(ys, grad_outputs)
|
||||
)
|
||||
if detach:
|
||||
grads = map(torch.detach, grads)
|
||||
|
||||
return [*grads]
|
||||
|
||||
def jacobian(y: torch.Tensor, x: torch.Tensor, check=False, detach=False) -> torch.Tensor:
|
||||
"""
|
||||
jacobian of `y` w.r.t. `x`
|
||||
|
||||
y: shape (..., Y)
|
||||
x: shape (..., X)
|
||||
return: shape (..., Y, X)
|
||||
"""
|
||||
assert x.requires_grad
|
||||
assert y.grad_fn
|
||||
|
||||
y_grad = torch.ones_like(y[..., 0])
|
||||
jac = torch.stack(
|
||||
gradients(
|
||||
*(y[..., i] for i in range(y.shape[-1])),
|
||||
wrt=x,
|
||||
grad_outputs=[y_grad]*x.shape[-1],
|
||||
detach=detach,
|
||||
),
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
if check:
|
||||
assert jac.isnan().any()
|
||||
return jac
|
0
ifield/viewer/__init__.py
Normal file
0
ifield/viewer/__init__.py
Normal file
BIN
ifield/viewer/assets/texturify_pano-1-4.jpg
Normal file
BIN
ifield/viewer/assets/texturify_pano-1-4.jpg
Normal file
Binary file not shown.
After ![]() (image error) Size: 944 KiB |
430
ifield/viewer/common.py
Normal file
430
ifield/viewer/common.py
Normal file
@ -0,0 +1,430 @@
|
||||
from ..utils import geometry
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from pytorch3d.transforms import euler_angles_to_matrix
|
||||
from tqdm import tqdm
|
||||
from typing import Sequence, Callable, TypedDict
|
||||
import imageio
|
||||
import shlex
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
|
||||
import pygame
|
||||
|
||||
IVec2 = tuple[int, int]
|
||||
IVec3 = tuple[int, int, int]
|
||||
Vec2 = tuple[float|int, float|int]
|
||||
Vec3 = tuple[float|int, float|int, float|int]
|
||||
|
||||
class CamState(TypedDict, total=False):
|
||||
distance : float
|
||||
pos_x : float
|
||||
pos_y : float
|
||||
pos_z : float
|
||||
rot_x : float
|
||||
rot_y : float
|
||||
fov_y : float
|
||||
|
||||
|
||||
|
||||
class InteractiveViewer(ABC):
|
||||
constants = pygame.constants # saves an import
|
||||
|
||||
# realtime
|
||||
t : float # time since start
|
||||
td : float # time delta since last frame
|
||||
|
||||
# offline
|
||||
is_headless : bool
|
||||
fps : int
|
||||
frame_idx : int
|
||||
|
||||
fill_color = (255, 255, 255)
|
||||
|
||||
def __init__(self, name: str, res: IVec2 = (640, 480), scale: int= 1, screenshot_dir: Path = "."):
|
||||
self.name = name
|
||||
self.res = res
|
||||
self.scale = scale
|
||||
self.screenshot_dir = Path(screenshot_dir)
|
||||
|
||||
self.is_headless = False
|
||||
|
||||
self.cam_distance = 2.0
|
||||
self.cam_pos_x = 0.0 # look-at and rotation pivot
|
||||
self.cam_pos_y = 0.0 # look-at and rotation pivot
|
||||
self.cam_pos_z = 0.0 # look-at and rotation pivot
|
||||
self.cam_rot_x = 0.5 * torch.pi # radians
|
||||
self.cam_rot_y = -0.5 * torch.pi # radians
|
||||
self.cam_fov_y = 60.0 / 180.0 * 3.1415 # radians
|
||||
self.keep_rotating = False
|
||||
self.initial_camera_state = self.cam_state
|
||||
self.fps_cap = None
|
||||
|
||||
@property
|
||||
def cam_state(self) -> CamState:
|
||||
return dict(
|
||||
distance = self.cam_distance,
|
||||
pos_x = self.cam_pos_x,
|
||||
pos_y = self.cam_pos_y,
|
||||
pos_z = self.cam_pos_z,
|
||||
rot_x = self.cam_rot_x,
|
||||
rot_y = self.cam_rot_y,
|
||||
fov_y = self.cam_fov_y,
|
||||
)
|
||||
|
||||
@cam_state.setter
|
||||
def cam_state(self, new_state: CamState):
|
||||
self.cam_distance = new_state.get("distance", self.cam_distance)
|
||||
self.cam_pos_x = new_state.get("pos_x", self.cam_pos_x)
|
||||
self.cam_pos_y = new_state.get("pos_y", self.cam_pos_y)
|
||||
self.cam_pos_z = new_state.get("pos_z", self.cam_pos_z)
|
||||
self.cam_rot_x = new_state.get("rot_x", self.cam_rot_x)
|
||||
self.cam_rot_y = new_state.get("rot_y", self.cam_rot_y)
|
||||
self.cam_fov_y = new_state.get("fov_y", self.cam_fov_y)
|
||||
|
||||
@property
|
||||
def scaled_res(self) -> IVec2:
|
||||
return (
|
||||
self.res[0] * self.scale,
|
||||
self.res[1] * self.scale,
|
||||
)
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
def teardown(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_frame(self, pixel_view: np.ndarray): # (W, H, 3) dtype=uint8
|
||||
...
|
||||
|
||||
def handle_key_up(self, key: int, keys_pressed: Sequence[bool]):
|
||||
pass
|
||||
|
||||
def handle_key_down(self, key: int, keys_pressed: Sequence[bool]):
|
||||
mod = keys_pressed[pygame.K_LSHIFT] or keys_pressed[pygame.K_RSHIFT]
|
||||
mod2 = keys_pressed[pygame.K_LCTRL] or keys_pressed[pygame.K_RCTRL]
|
||||
if key == pygame.K_r:
|
||||
self.keep_rotating = True
|
||||
self.cam_rot_x += self.td
|
||||
if key == pygame.K_MINUS:
|
||||
self.scale += 1
|
||||
if __debug__: print()
|
||||
print(f"== Scale = {self.scale} ==")
|
||||
if key == pygame.K_PLUS and self.scale > 1:
|
||||
self.scale -= 1
|
||||
if __debug__: print()
|
||||
print(f"== Scale = {self.scale} ==")
|
||||
if key == pygame.K_RETURN:
|
||||
self.cam_state = self.initial_camera_state
|
||||
if key == pygame.K_h:
|
||||
if mod2:
|
||||
print(shlex.quote(json.dumps(self.cam_state)))
|
||||
elif mod:
|
||||
with (self.screenshot_dir / "camera.json").open("w") as f:
|
||||
json.dump(self.cam_state, f)
|
||||
print("Wrote", self.screenshot_dir / "camera.json")
|
||||
else:
|
||||
with (self.screenshot_dir / "camera.json").open("r") as f:
|
||||
self.cam_state = json.load(f)
|
||||
print("Read", self.screenshot_dir / "camera.json")
|
||||
|
||||
def handle_keys_pressed(self, pressed: Sequence[bool]) -> float:
|
||||
mod1 = pressed[pygame.K_LCTRL] or pressed[pygame.K_RCTRL]
|
||||
mod2 = pressed[pygame.K_LSHIFT] or pressed[pygame.K_RSHIFT]
|
||||
mod3 = pressed[pygame.K_LALT] or pressed[pygame.K_RALT]
|
||||
td = self.td * (0.5 if mod2 else (6 if mod1 else 2))
|
||||
|
||||
if pressed[pygame.K_UP]: self.cam_rot_y += td
|
||||
if pressed[pygame.K_DOWN]: self.cam_rot_y -= td
|
||||
if pressed[pygame.K_LEFT]: self.cam_rot_x += td
|
||||
if pressed[pygame.K_RIGHT]: self.cam_rot_x -= td
|
||||
if pressed[pygame.K_PAGEUP] and mod3: self.cam_distance -= td
|
||||
if pressed[pygame.K_PAGEDOWN] and mod3: self.cam_distance += td
|
||||
|
||||
if any(pressed[i] for i in [pygame.K_UP, pygame.K_DOWN, pygame.K_LEFT, pygame.K_RIGHT]):
|
||||
self.keep_rotating = False
|
||||
if self.keep_rotating: self.cam_rot_x += self.td * 0.25
|
||||
|
||||
if pressed[pygame.K_w]: self.cam_pos_x -= td * np.cos(-self.cam_rot_x)
|
||||
if pressed[pygame.K_w]: self.cam_pos_y += td * np.sin(-self.cam_rot_x)
|
||||
if pressed[pygame.K_s]: self.cam_pos_x += td * np.cos(-self.cam_rot_x)
|
||||
if pressed[pygame.K_s]: self.cam_pos_y -= td * np.sin(-self.cam_rot_x)
|
||||
if pressed[pygame.K_a]: self.cam_pos_x += td * np.sin(self.cam_rot_x)
|
||||
if pressed[pygame.K_a]: self.cam_pos_y -= td * np.cos(self.cam_rot_x)
|
||||
if pressed[pygame.K_d]: self.cam_pos_x -= td * np.sin(self.cam_rot_x)
|
||||
if pressed[pygame.K_d]: self.cam_pos_y += td * np.cos(self.cam_rot_x)
|
||||
if pressed[pygame.K_PAGEUP] and not mod3: self.cam_pos_z -= td
|
||||
if pressed[pygame.K_PAGEDOWN] and not mod3: self.cam_pos_z += td
|
||||
|
||||
return td
|
||||
|
||||
def handle_mouse_button_up(self, pos: IVec2, button: int, keys_pressed: Sequence[bool]):
|
||||
pass
|
||||
|
||||
def handle_mouse_button_down(self, pos: IVec2, button: int, keys_pressed: Sequence[bool]):
|
||||
pass
|
||||
|
||||
def handle_mouse_motion(self, pos: IVec2, rel: IVec2, buttons: Sequence[bool], keys_pressed: Sequence[bool]):
|
||||
pass
|
||||
|
||||
def handle_mousewheel(self, flipped: bool, x: int, y: int, keys_pressed: Sequence[bool]):
|
||||
if keys_pressed[pygame.K_LALT] or keys_pressed[pygame.K_RALT]:
|
||||
self.cam_fov_y -= y * 0.015
|
||||
else:
|
||||
self.cam_distance -= y * 0.2
|
||||
|
||||
_current_caption = None
|
||||
def set_caption(self, title: str, *a, **kw):
|
||||
if self._current_caption != title and not self.is_headless:
|
||||
print(f"set_caption: {title!r}")
|
||||
self._current_caption = title
|
||||
return pygame.display.set_caption(title, *a, **kw)
|
||||
|
||||
@property
|
||||
def mouse_position(self) -> IVec2:
|
||||
mx, my = pygame.mouse.get_pos() if not self.is_headless else (0, 0)
|
||||
return (
|
||||
mx // self.scale,
|
||||
my // self.scale,
|
||||
)
|
||||
|
||||
@property
|
||||
def uvs(self) -> torch.Tensor: # (w, h, 2) dtype=float32
|
||||
res = tuple(self.res)
|
||||
if not getattr(self, "_uvs_res", None) == res:
|
||||
U, V = torch.meshgrid(
|
||||
torch.arange(self.res[1]).to(torch.float32),
|
||||
torch.arange(self.res[0]).to(torch.float32),
|
||||
indexing="xy",
|
||||
)
|
||||
self._uvs_res, self._uvs = res, torch.stack((U, V), dim=-1)
|
||||
return self._uvs
|
||||
|
||||
@property
|
||||
def cam2world(self) -> torch.Tensor: # (4, 4) dtype=float32
|
||||
if getattr(self, "_cam2world_cam_rot_y", None) is not self.cam_rot_y \
|
||||
or getattr(self, "_cam2world_cam_rot_x", None) is not self.cam_rot_x \
|
||||
or getattr(self, "_cam2world_cam_pos_x", None) is not self.cam_pos_x \
|
||||
or getattr(self, "_cam2world_cam_pos_y", None) is not self.cam_pos_y \
|
||||
or getattr(self, "_cam2world_cam_pos_z", None) is not self.cam_pos_z \
|
||||
or getattr(self, "_cam2world_cam_distance", None) is not self.cam_distance:
|
||||
self._cam2world_cam_rot_y = self.cam_rot_y
|
||||
self._cam2world_cam_rot_x = self.cam_rot_x
|
||||
self._cam2world_cam_pos_x = self.cam_pos_x
|
||||
self._cam2world_cam_pos_y = self.cam_pos_y
|
||||
self._cam2world_cam_pos_z = self.cam_pos_z
|
||||
self._cam2world_cam_distance = self.cam_distance
|
||||
|
||||
a = torch.eye(4)
|
||||
a[2, 3] = self.cam_distance
|
||||
b = torch.eye(4)
|
||||
b[:3, :3] = euler_angles_to_matrix(torch.tensor((self.cam_rot_x, self.cam_rot_y, 0)), "ZYX")
|
||||
b[0:3, 3] -= torch.tensor(( self.cam_pos_x, self.cam_pos_y, self.cam_pos_z, ))
|
||||
self._cam2world = b @ a
|
||||
|
||||
self._cam2world_inv = None
|
||||
return self._cam2world
|
||||
|
||||
@property
|
||||
def cam2world_inv(self) -> torch.Tensor: # (4, 4) dtype=float32
|
||||
if getattr(self, "_cam2world_inv", None) is None:
|
||||
self._cam2world_inv = torch.linalg.inv(self._cam2world)
|
||||
return self._cam2world_inv
|
||||
|
||||
@property
|
||||
def intrinsics(self) -> torch.Tensor: # (3, 3) dtype=float32
|
||||
if getattr(self, "_intrinsics_res", None) is not self.res \
|
||||
or getattr(self, "_intrinsics_cam_fov_y", None) is not self.cam_fov_y:
|
||||
self._intrinsics_res = res = self.res
|
||||
self._intrinsics_cam_fov_y = cam_fov_y = self.cam_fov_y
|
||||
|
||||
self._intrinsics = torch.eye(3)
|
||||
p = torch.sin(torch.tensor(cam_fov_y / 2))
|
||||
s = (res[1] / 2)
|
||||
self._intrinsics[0, 0] = s/p # fx - focal length x
|
||||
self._intrinsics[1, 1] = s/p # fy - focal length y
|
||||
self._intrinsics[0, 2] = (res[1] - 1) / 2 # cx - optical center x
|
||||
self._intrinsics[1, 2] = (res[0] - 1) / 2 # cy - optical center y
|
||||
return self._intrinsics
|
||||
|
||||
@property
|
||||
def raydirs_and_cam(self) -> tuple[torch.Tensor, torch.Tensor]: # (w, h, 3) and (3) dtype=float32
|
||||
if getattr(self, "_raydirs_and_cam_cam2world", None) is not self.cam2world \
|
||||
or getattr(self, "_raydirs_and_cam_intrinsics", None) is not self.intrinsics \
|
||||
or getattr(self, "_raydirs_and_cam_uvs", None) is not self.uvs:
|
||||
self._raydirs_and_cam_cam2world = cam2world = self.cam2world
|
||||
self._raydirs_and_cam_intrinsics = intrinsics = self.intrinsics
|
||||
self._raydirs_and_cam_uvs = uvs = self.uvs
|
||||
|
||||
#cam_pos = (cam2world @ torch.tensor([0, 0, 0, 1], dtype=torch.float32))[:3]
|
||||
cam_pos = cam2world[:3, -1]
|
||||
|
||||
dirs = -geometry.get_ray_directions(uvs, cam2world[None, ...], intrinsics[None, ...]).squeeze(-1)
|
||||
|
||||
self._raydirs_and_cam = (dirs, cam_pos)
|
||||
return (
|
||||
self._raydirs_and_cam[0],
|
||||
self._raydirs_and_cam[1],
|
||||
)
|
||||
|
||||
def run(self):
|
||||
self.is_headless = False
|
||||
pygame.display.init() # we do not use the mixer, which often hangs on quit
|
||||
try:
|
||||
window = pygame.display.set_mode(self.scaled_res, flags=pygame.RESIZABLE)
|
||||
buffer = pygame.surface.Surface(self.res)
|
||||
|
||||
window.fill(self.fill_color)
|
||||
buffer.fill(self.fill_color)
|
||||
pygame.display.flip()
|
||||
|
||||
pixel_view = pygame.surfarray.pixels3d(buffer) # (W, H, 3)
|
||||
|
||||
current_scale = self.scale
|
||||
def remake_window_buffer(window_size: IVec2):
|
||||
nonlocal buffer, pixel_view, current_scale
|
||||
self.res = (
|
||||
window_size[0] // self.scale,
|
||||
window_size[1] // self.scale,
|
||||
)
|
||||
buffer = pygame.surface.Surface(self.res)
|
||||
pixel_view = pygame.surfarray.pixels3d(buffer)
|
||||
current_scale = self.scale
|
||||
|
||||
print()
|
||||
|
||||
self.setup()
|
||||
|
||||
is_running = True
|
||||
clock = pygame.time.Clock()
|
||||
epoch = t_prev = time.time()
|
||||
self.frame_idx = -1
|
||||
while is_running:
|
||||
self.frame_idx += 1
|
||||
if not self.fps_cap is None: clock.tick(self.fps_cap)
|
||||
t = time.time()
|
||||
self.td = t - t_prev
|
||||
t_prev = t
|
||||
self.t = t - epoch
|
||||
print("\rFPS:", 1/self.td, " "*10, end="")
|
||||
|
||||
self.render_frame(pixel_view)
|
||||
|
||||
pygame.transform.scale(buffer, window.get_size(), window)
|
||||
pygame.display.flip()
|
||||
|
||||
keys_pressed = pygame.key.get_pressed()
|
||||
self.handle_keys_pressed(keys_pressed)
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.VIDEORESIZE:
|
||||
print()
|
||||
print("== resize window ==")
|
||||
remake_window_buffer(event.size)
|
||||
elif event.type == pygame.QUIT:
|
||||
is_running = False
|
||||
elif event.type == pygame.KEYUP:
|
||||
self.handle_key_up(event.key, keys_pressed)
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
self.handle_key_down(event.key, keys_pressed)
|
||||
if event.key == pygame.K_q:
|
||||
is_running = False
|
||||
elif event.key == pygame.K_y:
|
||||
fname = self.mk_dump_fname("png")
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
pygame.image.save(buffer.copy(), fname)
|
||||
print()
|
||||
print("Saved", fname)
|
||||
elif event.type == pygame.MOUSEBUTTONUP:
|
||||
self.handle_mouse_button_up(event.pos, event.button, keys_pressed)
|
||||
elif event.type == pygame.MOUSEBUTTONDOWN:
|
||||
self.handle_mouse_button_down(event.pos, event.button, keys_pressed)
|
||||
elif event.type == pygame.MOUSEMOTION:
|
||||
self.handle_mouse_motion(event.pos, event.rel, event.buttons, keys_pressed)
|
||||
elif event.type == pygame.MOUSEWHEEL:
|
||||
self.handle_mousewheel(event.flipped, event.x, event.y, keys_pressed)
|
||||
|
||||
if current_scale != self.scale:
|
||||
remake_window_buffer(window.get_size())
|
||||
|
||||
finally:
|
||||
self.teardown()
|
||||
print()
|
||||
pygame.quit()
|
||||
|
||||
def render_headless(self, output_path: str, *, n_frames: int, fps: int, state_callback: Callable[["InteractiveViewer", int], None] | None, resolution=None, bitrate=None, **kw):
|
||||
self.is_headless = True
|
||||
self.fps = fps
|
||||
|
||||
buffer = pygame.surface.Surface(self.res if resolution is None else resolution)
|
||||
pixel_view = pygame.surfarray.pixels3d(buffer) # (W, H, 3)
|
||||
|
||||
def do():
|
||||
try:
|
||||
self.setup()
|
||||
for frame in tqdm(range(n_frames), **kw, disable=n_frames==1):
|
||||
self.frame_idx = frame
|
||||
if state_callback is not None:
|
||||
state_callback(self, frame)
|
||||
|
||||
self.render_frame(pixel_view)
|
||||
|
||||
yield pixel_view.copy().swapaxes(0,1)
|
||||
finally:
|
||||
self.teardown()
|
||||
|
||||
output_path = Path(output_path)
|
||||
if output_path.suffix == ".png":
|
||||
if n_frames > 1 and "%" not in output_path.name: raise ValueError
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
for i, framebuffer in enumerate(do()):
|
||||
with imageio.get_writer(output_path.parent / output_path.name.replace("%", f"{i:04}")) as writer:
|
||||
writer.append_data(framebuffer)
|
||||
else: # ffmpeg - https://imageio.readthedocs.io/en/v2.9.0/format_ffmpeg.html#ffmpeg
|
||||
with imageio.get_writer(output_path, fps=fps, bitrate=bitrate) as writer:
|
||||
for framebuffer in do():
|
||||
writer.append_data(framebuffer)
|
||||
|
||||
def load_sphere_map(self, fname):
|
||||
self._sphere_surf = pygame.image.load(fname)
|
||||
self._sphere_map = pygame.surfarray.pixels3d(self._sphere_surf)
|
||||
|
||||
def lookup_sphere_map_dirs(self, dirs, origins):
|
||||
near, far = geometry.ray_sphere_intersect(
|
||||
torch.tensor(origins),
|
||||
torch.tensor(dirs),
|
||||
sphere_radii = torch.tensor(origins).norm(dim=-1) * 2,
|
||||
)
|
||||
hits = far.detach()
|
||||
|
||||
x = hits[..., 0]
|
||||
y = hits[..., 1]
|
||||
z = hits[..., 2]
|
||||
theta = (z / hits.norm(dim=-1)).acos()
|
||||
phi = (y/x).atan()
|
||||
phi[(x<0) & (y>=0)] += 3.14
|
||||
phi[(x<0) & (y< 0)] -= 3.14
|
||||
|
||||
w, h = self._sphere_map.shape[:2]
|
||||
|
||||
return self._sphere_map[
|
||||
((phi / (2*torch.pi) * w).int() % w).cpu(),
|
||||
((theta / (1*torch.pi) * h).int() % h).cpu(),
|
||||
]
|
||||
|
||||
def blit_sphere_map_mask(self, pixel_view, mask=None):
|
||||
dirs, origin = self.raydirs_and_cam
|
||||
if mask is None: mask = (slice(None), slice(None))
|
||||
pixel_view[mask] \
|
||||
= self.lookup_sphere_map_dirs(dirs, origin[None, None, :])
|
||||
|
||||
def mk_dump_fname(self, suffix: str, uid=None) -> Path:
|
||||
name = self.name.split("-")[-1] if len(self.name) > 160 else self.name
|
||||
if uid is not None: name = f"{name}-{uid}"
|
||||
return self.screenshot_dir / f"pygame-viewer-{datetime.now():%Y%m%d-%H%M%S}-{name}.{suffix}"
|
792
ifield/viewer/ray_field.py
Normal file
792
ifield/viewer/ray_field.py
Normal file
@ -0,0 +1,792 @@
|
||||
from ..data.common.scan import SingleViewUVScan
|
||||
import mesh_to_sdf.scan as sdf_scan
|
||||
from ..models import intersection_fields
|
||||
from ..utils import geometry, helpers
|
||||
from ..utils.operators import diff
|
||||
from .common import InteractiveViewer
|
||||
from matplotlib import cm
|
||||
import matplotlib.colors as mcolors
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from textwrap import dedent
|
||||
from typing import Hashable, Optional, Callable
|
||||
from munch import Munch
|
||||
import functools
|
||||
import itertools
|
||||
import numpy as np
|
||||
import random
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import torch
|
||||
from trimesh import Trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
|
||||
class ModelViewer(InteractiveViewer):
|
||||
lambertian_color = (1.0, 1.0, 1.0)
|
||||
max_cols = 200
|
||||
max_cols = 32
|
||||
|
||||
def __init__(self,
|
||||
model : intersection_fields.IntersectionFieldAutoDecoderModel,
|
||||
start_uid : Hashable,
|
||||
skyward : str = "+Z",
|
||||
mesh_gt_getter: Callable[[Hashable], Trimesh] | None = None,
|
||||
*a, **kw):
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.current_uid = self._prev_uid = start_uid
|
||||
self.all_uids = list(model.keys())
|
||||
|
||||
self.mesh_gt_getter = mesh_gt_getter
|
||||
self.current_gt_mesh: tuple[Hashable, Trimesh] = (None, None)
|
||||
|
||||
self.display_mode_normals = self.vizmodes_normals .index("medial" if self.model.hparams.output_mode == "medial_sphere" else "analytical")
|
||||
self.display_mode_shading = self.vizmodes_shading .index("lambertian")
|
||||
self.display_mode_centroid = self.vizmodes_centroids.index("best-centroids-colored")
|
||||
self.display_mode_spheres = self.vizmodes_spheres .index(None)
|
||||
self.display_mode_variation = 0
|
||||
|
||||
self.display_sphere_map_bg = True
|
||||
self.atom_radius_offset = 0
|
||||
self.atom_index_solo = None
|
||||
self.export_medial_surface_mesh = False
|
||||
|
||||
self.light_angle1 = 0
|
||||
self.light_angle2 = 0
|
||||
|
||||
self.obj_rot = {
|
||||
"-X": torch.tensor(T.rotation_matrix(angle= np.pi/2, direction=(0, 1, 0))[:3, :3], **model.device_and_dtype).T,
|
||||
"+X": torch.tensor(T.rotation_matrix(angle=-np.pi/2, direction=(0, 1, 0))[:3, :3], **model.device_and_dtype).T,
|
||||
"-Y": torch.tensor(T.rotation_matrix(angle= np.pi/2, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
|
||||
"+Y": torch.tensor(T.rotation_matrix(angle=-np.pi/2, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
|
||||
"-Z": torch.tensor(T.rotation_matrix(angle= np.pi, direction=(1, 0, 0))[:3, :3], **model.device_and_dtype).T,
|
||||
"+Z": torch.eye(3, **model.device_and_dtype),
|
||||
}[str(skyward).upper()]
|
||||
self.obj_rot_inv = torch.linalg.inv(self.obj_rot)
|
||||
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
vizmodes_normals = (
|
||||
"medial",
|
||||
"analytical",
|
||||
"ground_truth",
|
||||
)
|
||||
vizmodes_shading = (
|
||||
None, # just atoms or medial axis
|
||||
"colored-lambertian",
|
||||
"lambertian",
|
||||
"shade-best-radii",
|
||||
"shade-all-radii",
|
||||
"translucent",
|
||||
"normal",
|
||||
"centroid-grad-norm", # backprop
|
||||
"anisotropic", # backprop
|
||||
"curvature", # backprop
|
||||
"glass",
|
||||
"double-glass",
|
||||
)
|
||||
vizmodes_centroids = (
|
||||
None,
|
||||
"best-centroids",
|
||||
"all-centroids",
|
||||
"best-centroids-colored",
|
||||
"all-centroids-colored",
|
||||
"miss-centroids-colored",
|
||||
"all-miss-centroids-colored",
|
||||
)
|
||||
vizmodes_spheres = (
|
||||
None,
|
||||
"intersecting-sphere",
|
||||
"intersecting-sphere-colored",
|
||||
"best-sphere",
|
||||
"best-sphere-colored",
|
||||
"all-spheres-colored",
|
||||
)
|
||||
|
||||
def get_display_mode(self) -> tuple[str, str, Optional[str], Optional[str]]:
|
||||
MARF = self.model.hparams.output_mode == "medial_sphere"
|
||||
if isinstance(self.display_mode_normals, str): self.display_mode_normals = self.vizmodes_shading .index(self.display_mode_normals)
|
||||
if isinstance(self.display_mode_shading, str): self.display_mode_shading = self.vizmodes_shading .index(self.display_mode_shading)
|
||||
if isinstance(self.display_mode_centroid, str): self.display_mode_centroid = self.vizmodes_centroids.index(self.display_mode_centroid)
|
||||
if isinstance(self.display_mode_spheres, str): self.display_mode_spheres = self.vizmodes_spheres .index(self.display_mode_spheres)
|
||||
out = (
|
||||
self.vizmodes_normals [self.display_mode_normals % len(self.vizmodes_normals)],
|
||||
self.vizmodes_shading [self.display_mode_shading % len(self.vizmodes_shading)],
|
||||
self.vizmodes_centroids[self.display_mode_centroid % len(self.vizmodes_centroids)] if MARF else None,
|
||||
self.vizmodes_spheres [self.display_mode_spheres % len(self.vizmodes_spheres)] if MARF else None,
|
||||
)
|
||||
self.set_caption(" & ".join(i for i in out if i is not None))
|
||||
return out
|
||||
|
||||
@property
|
||||
def cam_state(self):
|
||||
return super().cam_state | {
|
||||
"light_angle1" : self.light_angle1,
|
||||
"light_angle2" : self.light_angle2,
|
||||
}
|
||||
|
||||
@cam_state.setter
|
||||
def cam_state(self, new_state):
|
||||
InteractiveViewer.cam_state.fset(self, new_state)
|
||||
self.light_angle1 = new_state.get("light_angle1", self.light_angle1)
|
||||
self.light_angle2 = new_state.get("light_angle2", self.light_angle2)
|
||||
|
||||
def get_current_conditioning(self) -> Optional[torch.Tensor]:
|
||||
if not self.model.is_conditioned:
|
||||
return None
|
||||
|
||||
prev_uid = self._prev_uid # to determine if target has changed
|
||||
next_z = self.model[prev_uid].detach() # interpolation target
|
||||
prev_z = getattr(self, "_prev_z", next_z) # interpolation source
|
||||
epoch = getattr(self, "_prev_epoch", 0) # interpolation factor
|
||||
|
||||
if not self.is_headless:
|
||||
now = self.t
|
||||
t = (now - epoch) / 1 # 1 second
|
||||
else:
|
||||
now = self.frame_idx
|
||||
t = (now - epoch) / self.fps # 1 second
|
||||
assert t >= 0
|
||||
|
||||
if t < 1:
|
||||
next_z = next_z*t + prev_z*(1-t)
|
||||
|
||||
if prev_uid != self.current_uid:
|
||||
self._prev_uid = self.current_uid
|
||||
self._prev_z = next_z
|
||||
self._prev_epoch = now
|
||||
|
||||
return next_z
|
||||
|
||||
def get_current_ground_truth(self) -> Trimesh | None:
|
||||
if self.mesh_gt_getter is None:
|
||||
return None
|
||||
uid, mesh = self.current_gt_mesh
|
||||
try:
|
||||
if uid != self.current_uid:
|
||||
print("Loading ground truth mesh...")
|
||||
mesh = self.mesh_gt_getter(self.current_uid)
|
||||
self.current_gt_mesh = self.current_uid, mesh
|
||||
except NotImplementedError:
|
||||
self.current_gt_mesh = self.current_uid, None
|
||||
return None
|
||||
return mesh
|
||||
|
||||
def handle_keys_pressed(self, pressed):
|
||||
td = super().handle_keys_pressed(pressed)
|
||||
mod = pressed[self.constants.K_LALT] or pressed[self.constants.K_RALT]
|
||||
if not mod and pressed[self.constants.K_f]: self.light_angle1 -= td * 0.5
|
||||
if not mod and pressed[self.constants.K_g]: self.light_angle1 += td * 0.5
|
||||
if mod and pressed[self.constants.K_f]: self.light_angle2 += td * 0.5
|
||||
if mod and pressed[self.constants.K_g]: self.light_angle2 -= td * 0.5
|
||||
return td
|
||||
|
||||
def handle_key_down(self, key, keys_pressed):
|
||||
super().handle_key_down(key, keys_pressed)
|
||||
shift = keys_pressed[self.constants.K_LSHIFT] or keys_pressed[self.constants.K_RSHIFT]
|
||||
if key == self.constants.K_o:
|
||||
i = self.all_uids.index(self.current_uid)
|
||||
i = (i - 1) % len(self.all_uids)
|
||||
self.current_uid = self.all_uids[i]
|
||||
print(self.current_uid)
|
||||
if key == self.constants.K_p:
|
||||
i = self.all_uids.index(self.current_uid)
|
||||
i = (i + 1) % len(self.all_uids)
|
||||
self.current_uid = self.all_uids[i]
|
||||
print(self.current_uid)
|
||||
if key == self.constants.K_SPACE:
|
||||
self.display_sphere_map_bg = {
|
||||
True : 255,
|
||||
255 : 0,
|
||||
0 : True,
|
||||
}.get(self.display_sphere_map_bg, True)
|
||||
if key == self.constants.K_u: self.export_medial_surface_mesh = True
|
||||
if key == self.constants.K_x: self.display_mode_normals += -1 if shift else 1
|
||||
if key == self.constants.K_c: self.display_mode_shading += -1 if shift else 1
|
||||
if key == self.constants.K_v: self.display_mode_centroid += -1 if shift else 1
|
||||
if key == self.constants.K_b: self.display_mode_spheres += -1 if shift else 1
|
||||
if key == self.constants.K_e: self.display_mode_variation+= -1 if shift else 1
|
||||
if key == self.constants.K_c: self.display_mode_variation = 0
|
||||
if key == self.constants.K_0: self.atom_index_solo = None
|
||||
if key == self.constants.K_1: self.atom_index_solo = 0 if self.atom_index_solo != 0 else None
|
||||
if key == self.constants.K_2: self.atom_index_solo = 1 if self.atom_index_solo != 1 else None
|
||||
if key == self.constants.K_3: self.atom_index_solo = 2 if self.atom_index_solo != 2 else None
|
||||
if key == self.constants.K_4: self.atom_index_solo = 3 if self.atom_index_solo != 3 else None
|
||||
if key == self.constants.K_5: self.atom_index_solo = 4 if self.atom_index_solo != 4 else None
|
||||
if key == self.constants.K_6: self.atom_index_solo = 5 if self.atom_index_solo != 5 else None
|
||||
if key == self.constants.K_7: self.atom_index_solo = 6 if self.atom_index_solo != 6 else None
|
||||
if key == self.constants.K_8: self.atom_index_solo = 7 if self.atom_index_solo != 7 else None
|
||||
if key == self.constants.K_9: self.atom_index_solo = self.atom_index_solo + (-1 if shift else 1) if self.atom_index_solo is not None else 0
|
||||
|
||||
def handle_mouse_button_down(self, pos, button, keys_pressed):
|
||||
super().handle_mouse_button_down(pos, button, keys_pressed)
|
||||
if button in (1, 3):
|
||||
self.display_mode_spheres += 1 if button == 1 else -1
|
||||
|
||||
def handle_mousewheel(self, flipped, x, y, keys_pressed):
|
||||
shift = keys_pressed[self.constants.K_LSHIFT] or keys_pressed[self.constants.K_RSHIFT]
|
||||
if not shift:
|
||||
super().handle_mousewheel(flipped, x, y, keys_pressed)
|
||||
else:
|
||||
self.atom_radius_offset += 0.005 * y
|
||||
print()
|
||||
print("atom_radius_offset:", self.atom_radius_offset)
|
||||
|
||||
def setup(self):
|
||||
if not self.is_headless:
|
||||
print(dedent("""
|
||||
WASD + PG Up/Down - translate
|
||||
ARROWS - rotate
|
||||
|
||||
(SHIFT+) C - Next/(Prev) shading mode
|
||||
(SHIFT+) V - Next/(Prev) centroids mode
|
||||
(SHIFT+) B - Next/(Prev) sphere mode
|
||||
Mouse L/ R - Next/ Prev sphere mode
|
||||
(SHIFT+) E - Next/(Prev) variation (for quick experimentation within a shading mode)
|
||||
SHIFT + Scroll - Offset atom radius
|
||||
ALT + Scroll - Modify FoV (_true_ zoom)
|
||||
Mouse Scroll - Translate in/out ("zoom", moves camera to/from to point of focus)
|
||||
Alt+PG Up/Down - Translate in/out ("zoom", moves camera to/from to point of focus)
|
||||
|
||||
F / G - rotate light left / right
|
||||
ALT+ F / G - rotate light up / down
|
||||
CTRL / SHIFT - faster/slower rotation
|
||||
O / P - prev/next object
|
||||
1-9 - solo atom
|
||||
0 - show all atoms
|
||||
+ / - - decrease/increase pixel scale
|
||||
R - rotate continuously
|
||||
H / SHIFT+H / CTRL+H - load/save/print camera state
|
||||
Enter - reset camera state
|
||||
Y - save screenshot
|
||||
U - save mesh of centroids
|
||||
Space - cycle sphere map background
|
||||
Q - quit
|
||||
""").strip())
|
||||
|
||||
fname = Path(__file__).parent.resolve() / "assets/texturify_pano-1-4.jpg"
|
||||
self.load_sphere_map(fname)
|
||||
|
||||
if self.model.hparams.output_mode == "medial_sphere":
|
||||
@self.model.net.register_forward_hook
|
||||
def atom_offset_radius_and_solo(model, input, output):
|
||||
slice = (..., [i+3 for i in range(0, output.shape[-1], 4)])
|
||||
output[slice] += self.atom_radius_offset * output[slice].sign()
|
||||
if self.atom_index_solo is not None:
|
||||
x = self.atom_index_solo * 4
|
||||
x = x % output.shape[-1]
|
||||
output = output[..., list(range(x, x+4))]
|
||||
return output
|
||||
self._atom_offset_radius_and_solo_hook = atom_offset_radius_and_solo
|
||||
|
||||
def teardown(self):
|
||||
if hasattr(self, "_atom_offset_radius_and_solo_hook"):
|
||||
self._atom_offset_radius_and_solo_hook.remove()
|
||||
del self._atom_offset_radius_and_solo_hook
|
||||
|
||||
@torch.no_grad()
|
||||
def render_frame(self, pixel_view: np.ndarray): # (W, H, 3) dtype=uint8
|
||||
MARF = self.model.hparams.output_mode == "medial_sphere"
|
||||
PRIF = self.model.hparams.output_mode == "orthogonal_plane"
|
||||
assert (MARF or PRIF) and MARF != PRIF
|
||||
device_and_dtype = self.model.device_and_dtype
|
||||
device = self.model.device
|
||||
dtype = self.model.dtype
|
||||
|
||||
(
|
||||
vizmode_normals,
|
||||
vizmode_shading,
|
||||
vizmode_centroids,
|
||||
vizmode_spheres,
|
||||
) = self.get_display_mode()
|
||||
|
||||
dirs, origins = self.raydirs_and_cam
|
||||
origins = origins.detach().clone().to(**device_and_dtype)
|
||||
dirs = dirs .detach().clone().to(**device_and_dtype)
|
||||
|
||||
if vizmode_normals != "ground_truth" or self.get_current_ground_truth() is None:
|
||||
|
||||
# enable grad or not
|
||||
do_jac = PRIF or vizmode_normals == "analytical"
|
||||
do_jac_medial = MARF and "centroid-grad-norm" in (vizmode_shading or "")
|
||||
do_shape_operator = "anisotropic" in (vizmode_shading or "") or "curvature" in (vizmode_shading or "")
|
||||
do_grad = do_jac or do_jac_medial or do_shape_operator
|
||||
if do_grad:
|
||||
origins = origins.broadcast_to(dirs.shape)
|
||||
|
||||
self.model.eval()
|
||||
latent = self.get_current_conditioning()
|
||||
if self.max_cols is None or self.max_cols > dirs.shape[0]:
|
||||
chunks = [slice(None)]
|
||||
else:
|
||||
chunks = [slice(col, col+self.max_cols) for col in range(0, dirs.shape[0], self.max_cols)]
|
||||
forward_chunks = []
|
||||
for chunk in chunks:
|
||||
self.model.zero_grad()
|
||||
origins_chunk = origins[chunk if origins.ndim != 1 else slice(None)] @ self.obj_rot
|
||||
dirs_chunk = dirs [chunk] @ self.obj_rot
|
||||
if do_grad:
|
||||
origins_chunk.requires_grad = dirs_chunk.requires_grad = True
|
||||
|
||||
@forward_chunks.append
|
||||
@(lambda f: f(origins_chunk, dirs_chunk))
|
||||
@torch.set_grad_enabled(do_grad)
|
||||
def forward_chunk(origins, dirs) -> Munch:
|
||||
if PRIF:
|
||||
intersections, is_intersecting = self.model(dict(origins=origins, dirs=dirs), z=latent, normalize_origins=True)
|
||||
is_intersecting = is_intersecting > 0.5
|
||||
elif MARF:
|
||||
(
|
||||
depths, silhouettes, intersections,
|
||||
intersection_normals, is_intersecting,
|
||||
sphere_centers, sphere_radii,
|
||||
|
||||
atom_indices,
|
||||
all_intersections, all_intersection_normals, all_depths, all_silhouettes, all_is_intersecting,
|
||||
all_sphere_centers, all_sphere_radii,
|
||||
) = self.model.forward(dict(origins=origins, dirs=dirs), z=latent,
|
||||
intersections_only = False,
|
||||
return_all_atoms = True,
|
||||
)
|
||||
|
||||
if do_jac:
|
||||
jac = diff.jacobian(intersections, origins, detach=not do_shape_operator)
|
||||
intersection_normals = self.model.compute_normals_from_intersection_origin_jacobian(jac, dirs.detach())
|
||||
|
||||
if do_jac_medial:
|
||||
sphere_centers_jac = diff.jacobian(sphere_centers, origins, detach=True)
|
||||
|
||||
if do_shape_operator:
|
||||
hess = diff.jacobian(intersection_normals, origins, detach=True)[is_intersecting, :, :]
|
||||
N = intersection_normals.detach()[is_intersecting, :]
|
||||
TM = (torch.eye(3, device=device) - N[..., None, :]*N[..., :, None]) # projection onto tangent plane
|
||||
# shape operator, i.e. total derivative of the surface normal w.r.t. the tangent space
|
||||
shape_operator = hess @ TM
|
||||
|
||||
return Munch((k, v.detach()) for k, v in locals().items() if isinstance(v, torch.Tensor))
|
||||
|
||||
intersections = torch.cat([chunk.intersections for chunk in forward_chunks], dim=0)
|
||||
is_intersecting = torch.cat([chunk.is_intersecting for chunk in forward_chunks], dim=0)
|
||||
intersection_normals = torch.cat([chunk.intersection_normals for chunk in forward_chunks], dim=0)
|
||||
if MARF:
|
||||
all_sphere_centers = torch.cat([chunk.all_sphere_centers for chunk in forward_chunks], dim=0)
|
||||
all_sphere_radii = torch.cat([chunk.all_sphere_radii for chunk in forward_chunks], dim=0)
|
||||
atom_indices = torch.cat([chunk.atom_indices for chunk in forward_chunks], dim=0)
|
||||
silhouettes = torch.cat([chunk.silhouettes for chunk in forward_chunks], dim=0)
|
||||
sphere_centers = torch.cat([chunk.sphere_centers for chunk in forward_chunks], dim=0)
|
||||
sphere_radii = torch.cat([chunk.sphere_radii for chunk in forward_chunks], dim=0)
|
||||
if do_jac_medial:
|
||||
sphere_centers_jac = torch.cat([chunk.sphere_centers_jac for chunk in forward_chunks], dim=0)
|
||||
if do_shape_operator:
|
||||
shape_operator = torch.cat([chunk.shape_operator for chunk in forward_chunks], dim=0)
|
||||
|
||||
n_atoms = all_sphere_centers.shape[-2] if MARF else 1
|
||||
|
||||
intersections = intersections @ self.obj_rot_inv
|
||||
intersection_normals = intersection_normals @ self.obj_rot_inv
|
||||
sphere_centers = sphere_centers @ self.obj_rot_inv if sphere_centers is not None else None
|
||||
all_sphere_centers = all_sphere_centers @ self.obj_rot_inv if all_sphere_centers is not None else None
|
||||
|
||||
else: # render ground truth mesh
|
||||
# HACK: we use a thread to not break the pygame opengl context
|
||||
with ThreadPoolExecutor(max_workers=1) as p:
|
||||
scan = p.submit(sdf_scan.Scan, self.get_current_ground_truth(),
|
||||
camera_transform = self.cam2world.numpy(),
|
||||
resolution = self.res[1],
|
||||
calculate_normals = True,
|
||||
fov = self.cam_fov_y,
|
||||
z_near = 0.001,
|
||||
z_far = 50,
|
||||
no_flip_backfaced_normals = True
|
||||
).result()
|
||||
n_atoms, MARF, PRIF = 1, False, True
|
||||
is_intersecting = torch.zeros(self.res, dtype=bool)
|
||||
is_intersecting[ (self.res[0]-self.res[1]) // 2 : (self.res[0]-self.res[1]) // 2 + self.res[1], : ] = torch.tensor(scan.depth_buffer != 0, dtype=bool)
|
||||
intersections = torch.zeros((*is_intersecting.shape, 3), dtype=dtype)
|
||||
intersection_normals = torch.zeros((*is_intersecting.shape, 3), dtype=dtype)
|
||||
intersections [is_intersecting] = torch.tensor(scan.points, dtype=dtype)
|
||||
intersection_normals[is_intersecting] = torch.tensor(scan.normals, dtype=dtype)
|
||||
is_intersecting = is_intersecting .flip(1).to(device)
|
||||
intersections = intersections .flip(1).to(device)
|
||||
intersection_normals = intersection_normals.flip(1).to(device)
|
||||
|
||||
mask = is_intersecting.cpu()
|
||||
|
||||
mx, my = self.mouse_position
|
||||
w, h = dirs.shape[:2]
|
||||
|
||||
# fill white
|
||||
if self.display_sphere_map_bg == True:
|
||||
self.blit_sphere_map_mask(pixel_view)
|
||||
else:
|
||||
pixel_view[:] = self.display_sphere_map_bg
|
||||
|
||||
# draw to buffer
|
||||
|
||||
to_cam = -dirs.detach()
|
||||
|
||||
# light direction
|
||||
extra = np.pi if vizmode_shading == "translucent" else 0
|
||||
LM = torch.tensor(T.rotation_matrix(angle=self.light_angle2, direction=(0, 1, 0))[:3, :3], dtype=dtype)
|
||||
LM = torch.tensor(T.rotation_matrix(angle=self.light_angle1 + extra, direction=(1, 0, 0))[:3, :3], dtype=dtype) @ LM
|
||||
to_light = (self.cam2world[:3, :3] @ LM @ torch.tensor((1, 1, 3), dtype=dtype)).to(device)[None, :]
|
||||
to_light = to_light / to_light.norm(dim=-1, keepdim=True)
|
||||
|
||||
# used to color different atom candidates
|
||||
color_set = tuple(map(helpers.hex2tuple,
|
||||
itertools.chain(
|
||||
mcolors.TABLEAU_COLORS.values(),
|
||||
#list(mcolors.TABLEAU_COLORS.values())[::-1],
|
||||
#['#f8481c', '#c20078', '#35530a', '#010844', '#a8ff04'],
|
||||
mcolors.XKCD_COLORS.values(),
|
||||
)
|
||||
))
|
||||
color_per_atom = (*zip(*zip(range(n_atoms), itertools.cycle(color_set))),)[1]
|
||||
|
||||
|
||||
# shade hits
|
||||
|
||||
if vizmode_shading is None:
|
||||
pass
|
||||
elif vizmode_shading == "colored-lambertian":
|
||||
if n_atoms > 1:
|
||||
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[is_intersecting].T,)]
|
||||
else:
|
||||
color = torch.tensor(color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)], device=device)
|
||||
lambertian = torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_light,
|
||||
)[..., None]
|
||||
|
||||
pixel_view[mask, :] = (color *
|
||||
torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_cam[is_intersecting, :],
|
||||
)[..., None]).int().cpu()
|
||||
pixel_view[mask, :] = (
|
||||
255 * lambertian.clamp(0, 1).pow(32) +
|
||||
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.clamp(0, 1).pow(32))
|
||||
).cpu()
|
||||
elif vizmode_shading == "lambertian":
|
||||
lambertian = torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_light,
|
||||
)[..., None].clamp(0, 1)
|
||||
|
||||
if self.lambertian_color == (1.0, 1.0, 1.0):
|
||||
pixel_view[mask, :] = (255 * lambertian).cpu()
|
||||
else:
|
||||
color = 255*torch.tensor(self.lambertian_color, device=device)
|
||||
pixel_view[mask, :] = (color *
|
||||
torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_cam[is_intersecting, :],
|
||||
)[..., None]).int().cpu()
|
||||
pixel_view[mask, :] = (
|
||||
255 * lambertian.clamp(0, 1).pow(32) +
|
||||
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.clamp(0, 1).pow(32))
|
||||
).cpu()
|
||||
elif vizmode_shading == "translucent" and MARF:
|
||||
lambertian = torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_light,
|
||||
)[..., None].abs().clamp(0, 1)
|
||||
|
||||
distortion = 0.08
|
||||
power = 16
|
||||
ambient = 0
|
||||
thickness = sphere_radii[is_intersecting].detach()
|
||||
if self.display_mode_variation % 2:
|
||||
thickness = thickness.mean()
|
||||
|
||||
color1 = torch.tensor((1, 0.5, 0.5), **device_and_dtype) # subsurface
|
||||
color2 = torch.tensor((0, 1, 1), **device_and_dtype) # diffuse
|
||||
|
||||
l = to_light + intersection_normals[is_intersecting, :] * distortion
|
||||
d = (to_cam[is_intersecting, :] * -l).sum(dim=-1).clamp(0, None).pow(power)
|
||||
f = (d + ambient) * (1/(0.05 + thickness))
|
||||
|
||||
pixel_view[((dirs * to_light).sum(dim=-1) > 0.99).cpu(), :] = 255 # draw light source
|
||||
|
||||
pixel_view[mask, :] = (255 * (
|
||||
color2 * (0.05 + lambertian*0.15) +
|
||||
color1 * 0.3 * f[..., None]
|
||||
).clamp(0, 1)).cpu()
|
||||
elif vizmode_shading == "anisotropic" and vizmode_normals != "ground_truth":
|
||||
eigvals, eigvecs = torch.linalg.eig(shape_operator.mT) # slow, complex output, not sorted
|
||||
eigvals, indices = eigvals.abs().sort(dim=-1)
|
||||
eigvecs = (eigvecs.abs() * eigvecs.real.sign()).take_along_dim(indices[..., None, :], dim=-1)
|
||||
eigvecs = eigvecs.mT
|
||||
|
||||
s = self.display_mode_variation % 5
|
||||
if s in (0, 1):
|
||||
# try to keep these below 0.2:
|
||||
if s == 0: a1, a2 = 0.05, 0.3
|
||||
if s == 1: a1, a2 = 0.3, 0.05
|
||||
|
||||
# == Ward anisotropic specular reflectance ==
|
||||
|
||||
# G.J. Ward, Measuring and modeling anisotropic reflection, in:
|
||||
# Proceedings of the 19th Annual Conference on Computer Graphics and
|
||||
# Interactive Techniques, 1992: pp. 265–272.
|
||||
|
||||
eigvecs /= eigvecs.norm(dim=-1, keepdim=True)
|
||||
|
||||
N = intersection_normals[is_intersecting, :]
|
||||
H = to_cam[is_intersecting, :] + to_light
|
||||
H = H / H.norm(dim=-1, keepdim=True)
|
||||
specular = (1/(4*torch.pi * a1*a2 * torch.sqrt((
|
||||
(N * to_cam[is_intersecting, :]).sum(dim=-1) *
|
||||
(N * to_light ).sum(dim=-1)
|
||||
)))) * torch.exp(
|
||||
-2 * (
|
||||
((H * eigvecs[..., 2, :]).sum(dim=-1) / a1).pow(2)
|
||||
+
|
||||
((H * eigvecs[..., 1, :]).sum(dim=-1) / a2).pow(2)
|
||||
) / (
|
||||
1 + (N * H).sum(dim=-1)
|
||||
)
|
||||
)
|
||||
specular = specular.clamp(0, None).nan_to_num(0, 0, 0)
|
||||
lambertian = torch.einsum("id,id->i", N, to_light ).clamp(0, None)
|
||||
|
||||
color1 = 0.4 * torch.tensor((1, 1, 1), **device_and_dtype) # specular
|
||||
color2 = 0.4 * torch.tensor((0, 1, 1), **device_and_dtype) # diffuse
|
||||
pixel_view[mask, :] = (255 * (
|
||||
color1 * specular [..., None] +
|
||||
color2 * lambertian[..., None]
|
||||
).clamp(0, 1)).int().cpu()
|
||||
if s == 2:
|
||||
pixel_view[mask, :] = (255 * (
|
||||
eigvecs[..., 2, :].abs().clamp(0, 1) # orientation only
|
||||
)).int().cpu()
|
||||
elif s == 3:
|
||||
pixel_view[mask, :] = (255 * (
|
||||
eigvecs[..., 1, :].abs().clamp(0, 1) # orientation only
|
||||
)).int().cpu()
|
||||
elif s == 4:
|
||||
pixel_view[mask, :] = (255 * (
|
||||
eigvecs[..., 0, :].abs().clamp(0, 1) # orientation only
|
||||
)).int().cpu()
|
||||
elif vizmode_shading == "shade-best-radii" and MARF:
|
||||
lambertian = torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_light,
|
||||
)[..., None]
|
||||
|
||||
radii = sphere_radii[is_intersecting]
|
||||
radii = radii - 0.04
|
||||
radii = radii / 0.4
|
||||
|
||||
colors = cm.plasma(radii.clamp(0, 1).cpu())[..., :3]
|
||||
pixel_view[mask, :] = 255 * (
|
||||
lambertian.pow(32).clamp(0, 1).cpu().numpy() +
|
||||
colors * (lambertian + 0.25).clamp(0, 1).cpu().numpy() * (1-lambertian.pow(32).clamp(0, 1)).cpu().numpy()
|
||||
)
|
||||
elif vizmode_shading == "shade-all-radii" and MARF:
|
||||
radii = sphere_radii[is_intersecting][..., None]
|
||||
radii /= radii.max()
|
||||
if n_atoms > 1:
|
||||
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[is_intersecting].T,)]
|
||||
else:
|
||||
color = torch.tensor(color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)], device=device)
|
||||
pixel_view[mask, :] = (color * radii).int().cpu()
|
||||
elif vizmode_shading == "normal":
|
||||
normal = intersection_normals[is_intersecting, :]
|
||||
pixel_view[mask, :] = (255 * (normal * 0.5 + 0.5) ).int().cpu()
|
||||
elif vizmode_shading == "curvature" and vizmode_normals != "ground_truth":
|
||||
eigvals = torch.linalg.eigvals(shape_operator.mT) # complex output, not sorted
|
||||
|
||||
# we sort them by absolute magnitude, not the real component
|
||||
_, indices = (eigvals.abs() * eigvals.real.sign()).sort(dim=-1)
|
||||
eigvals = eigvals.real.take_along_dim(indices, dim=-1)
|
||||
|
||||
s = self.display_mode_variation % (6 if MARF else 5)
|
||||
if s==0: out = (eigvals[..., [0, 2]].mean(dim=-1, keepdim=True) / 25).tanh() # mean curvature
|
||||
if s==1: out = (eigvals[..., [0, 2]].prod(dim=-1, keepdim=True) / 25).tanh() # gaussian curvature
|
||||
if s==2: out = (eigvals[..., [2]] / 25).tanh() # maximum principal curvature - k1
|
||||
if s==3: out = (eigvals[..., [1]] / 25).tanh() # some curvature
|
||||
if s==4: out = (eigvals[..., [0]] / 25).tanh() # minimum principal curvature - k2
|
||||
if s==5: out = ((sphere_radii[is_intersecting][..., None].detach() - 1 / eigvals[..., [2]].clamp(1e-8, None)) * 5).tanh().clamp(0, None)
|
||||
|
||||
lambertian = torch.einsum("id,id->i",
|
||||
intersection_normals[is_intersecting, :],
|
||||
to_light,
|
||||
)[..., None]
|
||||
|
||||
pixel_view[mask, :] = (255 * (lambertian+0.5).clamp(0, 1) * torch.cat((
|
||||
1+out.clamp(-1, 0),
|
||||
1-out.abs(),
|
||||
1-out.clamp(0, 1),
|
||||
), dim=-1)).int().cpu()
|
||||
elif vizmode_shading == "centroid-grad-norm" and MARF:
|
||||
asd = sphere_centers_jac[is_intersecting, :, :].norm(dim=-2).mean(dim=-1, keepdim=True)
|
||||
asd -= asd.min()
|
||||
asd /= asd.max()
|
||||
pixel_view[mask, :] = (255 * asd).cpu()
|
||||
elif "glass" in vizmode_shading:
|
||||
normals = intersection_normals[is_intersecting, :]
|
||||
to_cam_ = to_cam [is_intersecting, :]
|
||||
# "Empiricial Approximation" of fresnel
|
||||
# https://developer.download.nvidia.com/CgTutorial/cg_tutorial_chapter07.html via
|
||||
# http://kylehalladay.com/blog/tutorial/2014/02/18/Fresnel-Shaders-From-The-Ground-Up.html
|
||||
cos = torch.einsum("id,id->i", normals, to_cam_ )[..., None]
|
||||
bias, scale, power = 0, 4, 3
|
||||
fresnel = (bias + scale*(1-cos)**power).clamp(0, 1)
|
||||
|
||||
#reflection
|
||||
reflection = -to_cam_ - 2*(-cos)*normals
|
||||
|
||||
#refraction
|
||||
r = 1 / 1.5 # refractive index, air -> glass
|
||||
refraction = -r*to_cam_ + (r*cos - (1-r**2*(1-cos**2)).sqrt()) * normals
|
||||
exit_point = intersections[is_intersecting, :]
|
||||
|
||||
# reflect the refraction over the plane defined by the refraction direction and the sphere center, resulting in the second refraction
|
||||
if vizmode_shading == "double-glass" and MARF:
|
||||
cos2 = torch.einsum("id,id->i", refraction, -to_cam_ )[..., None]
|
||||
pn = -to_cam_ - cos2*refraction
|
||||
pn /= pn.norm(dim=-1, keepdim=True)
|
||||
|
||||
refraction = -to_cam_ - 2*torch.einsum("id,id->i", pn, -to_cam_ )[..., None]*pn
|
||||
|
||||
exit_point -= sphere_centers[is_intersecting, :]
|
||||
exit_point = exit_point - 2*torch.einsum("id,id->i", pn, exit_point )[..., None]*pn
|
||||
exit_point += sphere_centers[is_intersecting, :]
|
||||
|
||||
fresnel = np.asanyarray(fresnel.cpu())
|
||||
pixel_view[mask, :] \
|
||||
= self.lookup_sphere_map_dirs(reflection, intersections[is_intersecting, :]) * fresnel \
|
||||
+ self.lookup_sphere_map_dirs(refraction, exit_point) * (1-fresnel)
|
||||
else: # flat
|
||||
pixel_view[mask, :] = 80
|
||||
|
||||
if not MARF: return
|
||||
|
||||
# overlay medial atoms
|
||||
|
||||
if vizmode_spheres is not None:
|
||||
# show miss distance in red
|
||||
s = silhouettes.detach()[~is_intersecting].clamp(0, 1)
|
||||
s /= s.max()
|
||||
pixel_view[~mask, 1] = (s * 255).cpu()
|
||||
pixel_view[:, 2] = pixel_view[:, 1]
|
||||
|
||||
mouse_hits = 0 <= mx < w and 0 <= my < h and mask[mx, my]
|
||||
draw_intersecting = "intersecting-sphere" in vizmode_spheres
|
||||
draw_best = "best-sphere" in vizmode_spheres
|
||||
draw_color = "-sphere-colored" in vizmode_spheres
|
||||
draw_all = "all-spheres-colored" in vizmode_spheres
|
||||
|
||||
def get_nears():
|
||||
if draw_all:
|
||||
projected, near, far, is_intersecting = geometry.ray_sphere_intersect(
|
||||
torch.tensor(origins),
|
||||
torch.tensor(dirs[..., None, :]),
|
||||
sphere_centers = all_sphere_centers[mx, my][None, None, ...],
|
||||
sphere_radii = all_sphere_radii [mx, my][None, None, ...],
|
||||
allow_nans = False,
|
||||
return_parts = True,
|
||||
)
|
||||
|
||||
depths = (near - origins).norm(dim=-1)
|
||||
atom_indices_ = torch.where(is_intersecting, depths.detach(), depths.detach()+100).argmin(dim=-1, keepdim=True)
|
||||
is_intersecting = is_intersecting.any(dim=-1)
|
||||
projected = None
|
||||
near = near.take_along_dim(atom_indices_[..., None], -2).squeeze(-2)
|
||||
far = None
|
||||
sphere_centers_ = all_sphere_centers[mx, my][None, None, ...].take_along_dim(atom_indices_[..., None], -2).squeeze(-2)
|
||||
|
||||
normals = near[is_intersecting, :] - sphere_centers_[is_intersecting, :]
|
||||
normals /= torch.linalg.norm(normals, dim=-1)[..., None]
|
||||
|
||||
color = torch.tensor(color_per_atom, device=device)[(*atom_indices_[is_intersecting].T,)]
|
||||
yield color, projected, near, far, is_intersecting, normals
|
||||
|
||||
if (mouse_hits and draw_intersecting) or draw_best:
|
||||
projected, near, far, is_intersecting = geometry.ray_sphere_intersect(
|
||||
torch.tensor(origins),
|
||||
torch.tensor(dirs),
|
||||
# unit-sphere by default
|
||||
sphere_centers = sphere_centers[mx, my][None, None, ...],
|
||||
sphere_radii = sphere_radii [mx, my][None, None, ...],
|
||||
return_parts = True,
|
||||
)
|
||||
|
||||
normals = near[is_intersecting, :] - sphere_centers[mx, my][None, ...]
|
||||
normals /= torch.linalg.norm(normals, dim=-1)[..., None]
|
||||
color = (255, 255, 255) if not draw_color else color_per_atom[atom_indices[mx, my]]
|
||||
yield torch.tensor(color, device=device), projected, near, far, is_intersecting, normals
|
||||
|
||||
# draw sphere with lambertian shading
|
||||
for color, projected, near, far, is_intersecting_2, normals in get_nears():
|
||||
lambertian = torch.einsum("...id,...id->...i", normals, to_light )[..., None]
|
||||
pixel_view[is_intersecting_2.cpu(), :] = (
|
||||
255*lambertian.pow(32).clamp(0, 1) +
|
||||
color * (lambertian + 0.25).clamp(0, 1) * (1-lambertian.pow(32).clamp(0, 1))
|
||||
).cpu()
|
||||
|
||||
# overlay points / sphere centers
|
||||
|
||||
if vizmode_centroids is not None:
|
||||
cam2world_inv = torch.tensor(self.cam2world_inv, **device_and_dtype)
|
||||
intrinsics = torch.tensor(self.intrinsics, **device_and_dtype)
|
||||
|
||||
def get_coords():
|
||||
miss_centroid = "miss-centroids" in vizmode_centroids
|
||||
mask = is_intersecting if not miss_centroid else ~is_intersecting
|
||||
if vizmode_centroids in ("all-centroids-colored", "all-miss-centroids-colored"):
|
||||
# we use temporal dithering to the show all overlapping centers
|
||||
for color, atom_index in sorted(zip(itertools.chain(color_set), range(n_atoms)), key=lambda x: random.random()):
|
||||
yield color, all_sphere_centers[..., atom_index, :][mask], mask
|
||||
elif "all-centroids" in vizmode_centroids:
|
||||
yield (80, 150, 80), all_sphere_centers[mask].reshape(-1, 3), mask # [:, 3]
|
||||
|
||||
if "centroids-colored" in vizmode_centroids:
|
||||
if n_atoms == 1:
|
||||
color = color_set[(0 if self.atom_index_solo is None else self.atom_index_solo) % len(color_set)]
|
||||
else:
|
||||
color = torch.tensor(color_per_atom, device=device)[(*atom_indices[mask].T,)].cpu()
|
||||
else:
|
||||
color = (0, 0, 0)
|
||||
yield color, sphere_centers[mask], mask
|
||||
|
||||
for i, (color, coords, coord_mask) in enumerate(get_coords()):
|
||||
if self.export_medial_surface_mesh:
|
||||
fname = self.mk_dump_fname("ply", uid=i)
|
||||
p = torch.zeros_like(sphere_centers)
|
||||
c = torch.zeros_like(sphere_centers)
|
||||
p[coord_mask, :] = coords
|
||||
c[coord_mask, :] = torch.tensor(color, device=p.device) / 255
|
||||
SingleViewUVScan(
|
||||
hits = ( mask).numpy(),
|
||||
miss = (~mask).numpy(),
|
||||
points = p.cpu().numpy(),
|
||||
colors = c.cpu().numpy(),
|
||||
normals=None, distances=None, cam_pos=None,
|
||||
cam_mat4=None, proj_mat4=None, transforms=None,
|
||||
).to_mesh().export(str(fname), file_type="ply")
|
||||
print("dumped", fname)
|
||||
if shutil.which("f3d"):
|
||||
subprocess.Popen(["f3d", "-gsy", "--up=+z", "--bg-color=1,1,1", fname], close_fds=True)
|
||||
|
||||
coords = torch.cat((coords, torch.ones((*coords.shape[:-1], 1), **device_and_dtype)), dim=-1)
|
||||
|
||||
coords = torch.einsum("...ij,...kj->...ki", cam2world_inv, coords)[..., :3]
|
||||
coords = geometry.project(coords[..., 0], coords[..., 1], coords[..., 2], intrinsics)
|
||||
|
||||
in_view = functools.reduce(torch.mul, (
|
||||
coords[:, 0] < pixel_view.shape[1],
|
||||
coords[:, 0] >= 0,
|
||||
coords[:, 1] < pixel_view.shape[0],
|
||||
coords[:, 1] >= 0,
|
||||
)).cpu()
|
||||
|
||||
coords = coords[in_view, :]
|
||||
if not isinstance(color, tuple):
|
||||
color = color[in_view, :]
|
||||
|
||||
pixel_view[(*coords[..., [1, 0]].int().T.cpu(),)] = color
|
||||
|
||||
self.export_medial_surface_mesh = False
|
6369
poetry.lock
generated
Normal file
6369
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
80
pyproject.toml
Normal file
80
pyproject.toml
Normal file
@ -0,0 +1,80 @@
|
||||
[tool.poetry]
|
||||
name = "ifield"
|
||||
version = "0.2.0"
|
||||
description = ""
|
||||
authors = ["Peder Bergebakken Sundt <pbsds@hotmail.com>"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0", "setuptools>=60"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.11"
|
||||
faiss-cpu = "^1.7.3"
|
||||
geomloss = "0.2.4" # 0.2.5 has no bdist on pypi
|
||||
h5py = "^3.7.0"
|
||||
hdf5plugin = "^4.0.1"
|
||||
imageio = "^2.23.0"
|
||||
jinja2 = "^3.1.2"
|
||||
matplotlib = "^3.6.2"
|
||||
mesh-to-sdf = {git = "https://github.com/pbsds/mesh_to_sdf", rev = "no_flip_normals"}
|
||||
methodtools = "^0.4.5"
|
||||
more-itertools = "^9.1.0"
|
||||
munch = "^2.5.0"
|
||||
numpy = "^1.23.0"
|
||||
pyembree = {url = "https://folk.ntnu.no/pederbs/pypy/pep503/pyembree/pyembree-0.2.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"}
|
||||
pygame = "^2.1.2"
|
||||
pykeops = "^2.1.1"
|
||||
pytorch3d = [
|
||||
{url = "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu116_pyt1130/pytorch3d-0.7.2-cp310-cp310-linux_x86_64.whl"},
|
||||
]
|
||||
pyqt5 = "^5.15.7"
|
||||
pyrender = "^0.1.45"
|
||||
pytorch-lightning = "^1.8.6"
|
||||
pyyaml = "^6.0"
|
||||
rich = "^13.3.2"
|
||||
rtree = "^1.0.1"
|
||||
scikit-image = "^0.19.3"
|
||||
scikit-learn = "^1.2.0"
|
||||
seaborn = "^0.12.1"
|
||||
serve-me-once = "^0.1.2"
|
||||
torch = "^1.13.0"
|
||||
torchmeta = {git = "https://github.com/pbsds/pytorch-meta", rev = "upgrade"}
|
||||
torchviz = "^0.0.2"
|
||||
tqdm = "^4.64.1"
|
||||
trimesh = "^3.17.1"
|
||||
typer = "^0.7.0"
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
python-lsp-server = {extras = ["all"], version = "^1.6.0"}
|
||||
fix-my-functions = "^0.1.3"
|
||||
imageio-ffmpeg = "^0.4.7"
|
||||
jupyter = "^1.0.0"
|
||||
jupyter-contrib-nbextensions = "^0.7.0"
|
||||
jupyterlab = "^3.5.2"
|
||||
jupyterthemes = "^0.20.0"
|
||||
llvmlite = "^0.39.1" # only to make poetry install the python3.10 wheels instead of building them
|
||||
nbconvert = "<=6.5.0" # https://github.com/jupyter/nbconvert/issues/1894
|
||||
numba = "^0.56.4" # only to make poetry install the python3.10 wheels instead of building them
|
||||
papermill = "^2.4.0"
|
||||
pdoc = "^12.3.0"
|
||||
pdoc3 = "^0.10.0"
|
||||
ptpython = "^3.0.22"
|
||||
pudb = "^2022.1.3"
|
||||
remote-exec = {git = "https://github.com/pbsds/remote", rev = "whitespace-push"} # https://github.com/remote-cli/remote/pull/52
|
||||
shapely = "^2.0.0"
|
||||
sympy = "^1.11.1"
|
||||
tensorboard = "^2.11.0"
|
||||
visidata = "^2.11"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
show-schedule = 'ifield.utils.loss:main'
|
||||
show-h5-items = 'ifield.cli_utils:show_h5_items'
|
||||
show-h5-img = 'ifield.cli_utils:show_h5_img'
|
||||
show-h5-scan-cloud = 'ifield.cli_utils:show_h5_scan_cloud'
|
||||
show-model = 'ifield.cli_utils:show_model'
|
||||
download-stanford = 'ifield.data.stanford.download:cli'
|
||||
download-coseg = 'ifield.data.coseg.download:cli'
|
||||
preprocess-stanford = 'ifield.data.stanford.preprocess:cli'
|
||||
preprocess-coseg = 'ifield.data.coseg.preprocess:cli'
|
Reference in New Issue
Block a user