Add code
This commit is contained in:
6
.envrc
Normal file
6
.envrc
Normal file
@@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This file is automatically loaded with `direnv` if allowed.
|
||||
# It enters you into the venv.
|
||||
|
||||
source .localenv
|
||||
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
__pycache__
|
||||
/data/models/
|
||||
/data/archives/
|
||||
/experiments/logdir/
|
||||
/.env/
|
||||
/.direnv/
|
||||
*.zip
|
||||
*.sh
|
||||
default.yaml # pandoc preview enhanced
|
||||
71
.localenv
Normal file
71
.localenv
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# =======================
|
||||
# bootstrap a venv
|
||||
# =======================
|
||||
|
||||
LOCAL_ENV_NAME="py310-$(basename $(pwd))"
|
||||
LOCAL_ENV_DIR="$(pwd)/.env/$LOCAL_ENV_NAME"
|
||||
mkdir -p "$LOCAL_ENV_DIR"
|
||||
|
||||
# make configs and caches a part of venv
|
||||
export POETRY_CACHE_DIR="$LOCAL_ENV_DIR/xdg/cache/poetry"
|
||||
export PIP_CACHE_DIR="$LOCAL_ENV_DIR/xdg/cache/pip"
|
||||
mkdir -p "$POETRY_CACHE_DIR" "$PIP_CACHE_DIR"
|
||||
|
||||
#export POETRY_VIRTUALENVS_IN_PROJECT=true # store venv in ./.venv/
|
||||
#export POETRY_VIRTUALENVS_CREATE=false # install globally
|
||||
export SETUPTOOLS_USE_DISTUTILS=stdlib # https://github.com/pre-commit/pre-commit/issues/2178#issuecomment-1002163763
|
||||
export IFIELD_PRETTY_TRACEBACK=1
|
||||
#export SHOW_LOCALS=1 # locals in tracebacks
|
||||
export PYTHON_KEYRING_BACKEND="keyring.backends.null.Keyring"
|
||||
|
||||
# ensure we have the correct python and poetry. Bootstrap via conda if missing
|
||||
if ! command -v python310 >/dev/null || ! command -v poetry >/dev/null; then
|
||||
source .localenv-bootstrap-conda
|
||||
|
||||
if command -v mamba >/dev/null; then
|
||||
CONDA=mamba
|
||||
elif command -v conda >/dev/null; then
|
||||
CONDA=conda
|
||||
else
|
||||
>&2 echo "ERROR: 'poetry' nor 'conda'/'mamba' could be found!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
function verbose {
|
||||
echo +"$(printf " %q" "$@")"
|
||||
"$@"
|
||||
}
|
||||
|
||||
if ! ($CONDA env list | grep -q "^$LOCAL_ENV_NAME "); then
|
||||
verbose $CONDA create --yes --name "$LOCAL_ENV_NAME" -c conda-forge \
|
||||
python==3.10.8 poetry==1.3.1 #python-lsp-server
|
||||
true
|
||||
fi
|
||||
|
||||
verbose conda activate "$LOCAL_ENV_NAME" || exit $?
|
||||
#verbose $CONDA activate "$LOCAL_ENV_NAME" || exit $?
|
||||
|
||||
unset -f verbose
|
||||
fi
|
||||
|
||||
|
||||
# enter poetry venv
|
||||
# source .envrc
|
||||
poetry run true # ensure venv exists
|
||||
#source "$(poetry env info -p)/bin/activate"
|
||||
export VIRTUAL_ENV=$(poetry env info --path)
|
||||
export POETRY_ACTIVE=1
|
||||
export PATH="$VIRTUAL_ENV/bin":"$PATH"
|
||||
# NOTE: poetry currently reuses and populates the conda venv.
|
||||
# See: https://github.com/python-poetry/poetry/issues/1724
|
||||
|
||||
|
||||
# ensure output dirs exist
|
||||
mkdir -p experiments/logdir
|
||||
|
||||
# first-time-setup poetry
|
||||
if ! command -v fix-my-functions >/dev/null; then
|
||||
poetry install
|
||||
fi
|
||||
53
.localenv-bootstrap-conda
Normal file
53
.localenv-bootstrap-conda
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# =======================
|
||||
# bootstrap a conda venv
|
||||
# =======================
|
||||
|
||||
CONDA_ENV_DIR="${LOCAL_ENV_DIR:-$(pwd)/.conda310}"
|
||||
mkdir -p "$CONDA_ENV_DIR"
|
||||
#touch "$HOME/.Xauthority"
|
||||
|
||||
MINICONDA_PY310_URL="https://repo.anaconda.com/miniconda/Miniconda3-py310_22.11.1-1-Linux-x86_64.sh"
|
||||
MINICONDA_PY310_HASH="00938c3534750a0e4069499baf8f4e6dc1c2e471c86a59caa0dd03f4a9269db6"
|
||||
|
||||
# Check if conda is available
|
||||
if ! command -v conda >/dev/null; then
|
||||
export PATH="$CONDA_ENV_DIR/conda/bin:$PATH"
|
||||
fi
|
||||
|
||||
# Check again if conda is available, install miniconda if not
|
||||
if ! command -v conda >/dev/null; then
|
||||
(set -e #x
|
||||
function verbose {
|
||||
echo +"$(printf " %q" "$@")"
|
||||
"$@"
|
||||
}
|
||||
|
||||
if command -v curl >/dev/null; then
|
||||
verbose curl -sLo "$CONDA_ENV_DIR/miniconda_py310.sh" "$MINICONDA_PY310_URL"
|
||||
elif command -v wget >/dev/null; then
|
||||
verbose wget -O "$CONDA_ENV_DIR/miniconda_py310.sh" "$MINICONDA_PY310_URL"
|
||||
else
|
||||
echo "ERROR: unable to download miniconda!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
verbose test "$(sha256sum "$CONDA_ENV_DIR/miniconda_py310.sh")" = "$MINICONDA_PY310_HASH"
|
||||
verbose chmod +x "$CONDA_ENV_DIR/miniconda_py310.sh"
|
||||
|
||||
verbose "$CONDA_ENV_DIR/miniconda_py310.sh" -b -u -p "$CONDA_ENV_DIR/conda"
|
||||
verbose rm "$CONDA_ENV_DIR/miniconda_py310.sh"
|
||||
|
||||
eval "$(conda shell.bash hook)" # basically `conda init`, without modifying .bashrc
|
||||
verbose conda install --yes --name base mamba -c conda-forge
|
||||
|
||||
) || exit $?
|
||||
fi
|
||||
|
||||
unset CONDA_ENV_DIR
|
||||
unset MINICONDA_PY310_URL
|
||||
unset MINICONDA_PY310_HASH
|
||||
|
||||
# Enter conda environment
|
||||
eval "$(conda shell.bash hook)" # basically `conda init`, without modifying .bashrc
|
||||
29
.remoteenv
Normal file
29
.remoteenv
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
# this file is used by remote-cli
|
||||
|
||||
# Assumes repo is put in a "remotes/name-hash" folder,
|
||||
# the default behaviour of remote-exec
|
||||
REMOTES_DIR="$(dirname $(pwd))"
|
||||
LOCAL_ENV_NAME="py310-$(basename $(pwd))"
|
||||
LOCAL_ENV_DIR="$REMOTES_DIR/envs/$REMOTE_ENV_NAME"
|
||||
|
||||
#export XDG_CACHE_HOME="$LOCAL_ENV_DIR/xdg/cache"
|
||||
#export XDG_DATA_HOME="$LOCAL_ENV_DIR/xdg/share"
|
||||
#export XDG_STATE_HOME="$LOCAL_ENV_DIR/xdg/state"
|
||||
#mkdir -p "$XDG_CACHE_HOME" "$XDG_DATA_HOME" "$XDG_STATE_HOME"
|
||||
export XDG_CONFIG_HOME="$LOCAL_ENV_DIR/xdg/config"
|
||||
mkdir -p "$XDG_CONFIG_HOME"
|
||||
|
||||
|
||||
export PYOPENGL_PLATFORM=egl # makes pyrender work headless
|
||||
#export PYOPENGL_PLATFORM=osmesa # makes pyrender work headless
|
||||
export SDL_VIDEODRIVER=dummy # pygame
|
||||
|
||||
source .localenv
|
||||
|
||||
# SLURM logs output dir
|
||||
if command -v sbatch >/dev/null; then
|
||||
mkdir -p slurm_logs
|
||||
test -L experiments/logdir/slurm_logs ||
|
||||
ln -s ../../slurm_logs experiments/logdir/
|
||||
fi
|
||||
30
.remoteignore.toml
Normal file
30
.remoteignore.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[push]
|
||||
exclude = [
|
||||
"*.egg-info",
|
||||
"*.pyc",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".remote.toml",
|
||||
".remoteignore.toml",
|
||||
".venv",
|
||||
".wandb",
|
||||
"__pycache__",
|
||||
"data/models",
|
||||
"docs",
|
||||
"experiments/logdir",
|
||||
"poetry.toml",
|
||||
"slurm_logs",
|
||||
"tmp",
|
||||
]
|
||||
include = []
|
||||
|
||||
[pull]
|
||||
exclude = [
|
||||
"*",
|
||||
]
|
||||
include = []
|
||||
|
||||
[both]
|
||||
exclude = [
|
||||
]
|
||||
include = []
|
||||
128
README.md
128
README.md
@@ -1 +1,127 @@
|
||||
This is where the code for the paper _"MARF: The Medial Atom Ray Field Object Representation"_ will be published.
|
||||
# MARF: The Medial Atom Ray Field Object Representation
|
||||
|
||||
<center>
|
||||
|
||||

|
||||
|
||||
[Publication](https://doi.org/10.1016/j.cag.2023.06.032) | [Arxiv](https://arxiv.org/abs/2307.00037) | [Training data](https://mega.nz/file/9tsz3SbA#V6SIXpCFC4hbqWaFFvKmmS8BKir7rltXuhsqpEpE9wo) | [Network weights](https://mega.nz/file/t01AyTLK#7ZNMNgbqT9x2mhq5dxLuKeKyP7G0slfQX1RaZxifayw)
|
||||
|
||||
</center>
|
||||
|
||||
**TL;DR:** We achieve _fast_ surface rendering by predicting _n_ maximally inscribed spherical intersection candidates for each camera ray.
|
||||
|
||||
---
|
||||
|
||||
## Entering the Virtual Environment
|
||||
|
||||
The environment is defined in `pyproject.toml` using [Poetry](https://github.com/python-poetry/poetry) and reproducibly locked in `poetry.lock`.
|
||||
We propose three ways to enter the venv:
|
||||
|
||||
```shell
|
||||
# Requires Python 3.10 and Poetry
|
||||
poetry install
|
||||
poetry shell
|
||||
|
||||
# Will bootstrap a Miniconda 3.10 environment into .env/ if needed, then run poetry
|
||||
source .localenv
|
||||
```
|
||||
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Pretrained models
|
||||
|
||||
You can download our pre-trained models` from <https://mega.nz/file/t01AyTLK#7ZNMNgbqT9x2mhq5dxLuKeKyP7G0slfQX1RaZxifayw>.
|
||||
It should be unpacked into the root directory, such that the `experiment` folder gets merged.
|
||||
|
||||
### The interactive renderer
|
||||
|
||||
We automatically create experiment names with a schema of `{{model}}-{{experiment-name}}-{{hparams-summary}}-{{date}}-{{random-uid}}`.
|
||||
You can load experiment weights using either the full path, or just the `random-uid` bit.
|
||||
|
||||
From the `experiments` directory:
|
||||
|
||||
```shell
|
||||
./marf.py model {{experiment}} viewer
|
||||
```
|
||||
|
||||
If you have downloaded our pre-trained network weights, consider trying:
|
||||
|
||||
```shell
|
||||
./marf.py model nqzh viewer # Stanford Bunny (single-shape)
|
||||
./marf.py model wznx viewer # Stanford Buddha (single-shape)
|
||||
./marf.py model mxwd viewer # Stanford Armadillo (single-shape)
|
||||
./marf.py model camo viewer # Stanford Dragon (single-shape)
|
||||
./marf.py model ksul viewer # Stanford Lucy (single-shape)
|
||||
./marf.py model oxrf viewer # COSEG four-legged (multi-shape)
|
||||
```
|
||||
|
||||
## Training and Evaluation Data
|
||||
|
||||
You can download a pre-computed archive from <https://mega.nz/file/9tsz3SbA#V6SIXpCFC4hbqWaFFvKmmS8BKir7rltXuhsqpEpE9wo>.
|
||||
It should be extracted into the root directory such that a `data` directory is added to the root directory.
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
Optionally, you may compute the data yourself.
|
||||
</summary>
|
||||
|
||||
Single-shape training data:
|
||||
|
||||
```shell
|
||||
# takes takes about 23 minutes, mainly due to lucy
|
||||
download-stanford bunny happy_buddha dragon armadillo lucy
|
||||
preprocess-stanford bunny happy_buddha dragon armadillo lucy \
|
||||
--precompute-mesh-sv-scan-uv \
|
||||
--compute-miss-distances \
|
||||
--fill-missing-uv-points
|
||||
```
|
||||
|
||||
Multi-shape training data:
|
||||
|
||||
```shell
|
||||
# takes takes about 29 minutes
|
||||
download-coseg four-legged --shapes
|
||||
preprocess-coseg four-legged \
|
||||
--precompute-mesh-sv-scan-uv \
|
||||
--compute-miss-distances \
|
||||
--fill-missing-uv-points
|
||||
```
|
||||
|
||||
Evaluation data:
|
||||
|
||||
```shell
|
||||
# takes takes about 2 hour 20 minutes, mainly due to lucy
|
||||
preprocess-stanford bunny happy_buddha dragon armadillo lucy \
|
||||
--precompute-mesh-sphere-scan \
|
||||
--compute-miss-distances
|
||||
```
|
||||
|
||||
```shell
|
||||
# takes takes about 4 hours
|
||||
preprocess-coseg four-legged \
|
||||
--precompute-mesh-sphere-scan \
|
||||
--compute-miss-distances
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
Our experiments are defined using YAML config files, optionally templated using Jinja2 as a preprocessor.
|
||||
These templates accept additional input from the command line in the form of `-Okey=value` options.
|
||||
Our whole experiment matrix is defined in `marf.yaml.j12`. We select between different experiment groups using `-Omode={single,ablation,multi}`, and which experiment using `-Oselect={{integer}}`
|
||||
|
||||
From the `experiments` directory:
|
||||
|
||||
CPU mode:
|
||||
|
||||
```shell
|
||||
./marf.py model marf.yaml.j2 -Oexperiment_name=cpu_test -Omode=single -Oselect=0 fit
|
||||
```
|
||||
|
||||
GPU mode:
|
||||
|
||||
```shell
|
||||
./marf.py model marf.yaml.j2 -Oexperiment_name=cpu_test -Omode=single -Oselect=0 fit --accelerator gpu --devices 1
|
||||
```
|
||||
|
||||
139
ablation.md
Normal file
139
ablation.md
Normal file
@@ -0,0 +1,139 @@
|
||||
### MARF
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0010-nqzh`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0312-wznx`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-1944-mxwd`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0529-camo`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0743-ksul`
|
||||
|
||||
### LFN encoding
|
||||
- `experiment-stanfordv12-dragon-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0539-xjte`
|
||||
- `experiment-stanfordv12-lucy-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0753-ayvt`
|
||||
- `experiment-stanfordv12-bunny-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0022-axft`
|
||||
- `experiment-stanfordv12-happy_buddha-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0322-xfoc`
|
||||
- `experiment-stanfordv12-armadillo-plkr2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2039-vbks`
|
||||
|
||||
### PRIF encoding
|
||||
- `experiment-stanfordv12-armadillo-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2033-nkxm`
|
||||
- `experiment-stanfordv12-happy_buddha-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0313-huci`
|
||||
- `experiment-stanfordv12-dragon-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0537-dxsb`
|
||||
- `experiment-stanfordv12-bunny-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0011-tzic`
|
||||
- `experiment-stanfordv12-lucy-prpft2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0744-hzvw`
|
||||
|
||||
### No init scheme.
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0444-uohy`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2307-wjcf`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0707-eanc`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0225-kcfw`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-nogeom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0852-lkfh`
|
||||
|
||||
### 1 atom candidate
|
||||
- `experiment-stanfordv12-lucy-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0755-qzth`
|
||||
- `experiment-stanfordv12-bunny-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0027-ycnl`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2121-fwvo`
|
||||
- `experiment-stanfordv12-dragon-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0541-nvhs`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-1atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0324-cuyw`
|
||||
|
||||
### 4 atom candidates
|
||||
- `experiment-stanfordv12-armadillo-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2122-qiwg`
|
||||
- `experiment-stanfordv12-dragon-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0544-ihkx`
|
||||
- `experiment-stanfordv12-lucy-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0757-jwxm`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0328-chhs`
|
||||
- `experiment-stanfordv12-bunny-both2marf-4atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0038-zymb`
|
||||
|
||||
### 8 atom candidates
|
||||
- `experiment-stanfordv12-bunny-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0055-ogpd`
|
||||
- `experiment-stanfordv12-lucy-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0757-frxb`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0337-twys`
|
||||
- `experiment-stanfordv12-dragon-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0551-bubw`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-8atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2137-nnlj`
|
||||
|
||||
### 32 atom candidates
|
||||
- `experiment-stanfordv12-bunny-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0056-ourc`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2141-byaj`
|
||||
- `experiment-stanfordv12-dragon-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0554-zobg`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0337-rmyq`
|
||||
- `experiment-stanfordv12-lucy-both2marf-32atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0800-lqen`
|
||||
|
||||
### 64 atom candidates
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0339-whcx`
|
||||
- `experiment-stanfordv12-bunny-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0058-seen`
|
||||
- `experiment-stanfordv12-lucy-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0806-ycxj`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2153-wnfq`
|
||||
- `experiment-stanfordv12-dragon-both2marf-64atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0555-zgcb`
|
||||
|
||||
### No intersection loss
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1053-ydnh`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1111-fawl`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1045-umwl`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1103-lwmb`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-0chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1041-lhcc`
|
||||
|
||||
### No silhouette loss
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1042-fsuw`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1046-nszw`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1111-mlal`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1055-cvkg`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-0dmiss-geom-20chit-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-1114-pdyh`
|
||||
|
||||
### More silhouette loss
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0157-yekm`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2243-nlrv`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0639-yros`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0842-xktg`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-50dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0423-ibxs`
|
||||
|
||||
### No normal loss
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0614-ttta`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0106-bnke`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2154-bxwl`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0811-qqgu`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-nocnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0357-gwca`
|
||||
|
||||
### No inscription loss
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0227-xrqt`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2312-cgzv`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0452-rerr`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0709-tfgg`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-noxinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0856-ctvc`
|
||||
|
||||
### More inscription loss
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0459-kyyh`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0243-qqqj`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2336-yclo`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0913-mulv`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-250xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0714-zugg`
|
||||
|
||||
### No maximality reg.
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0842-cvln`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0425-vpen`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0207-qpdb`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2251-zqvi`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-0sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0641-ucdo`
|
||||
|
||||
### More maximality reg.
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0659-bqvf`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2256-escz`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0208-wmvs`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0442-gdah`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-5000sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0845-halc`
|
||||
|
||||
### No specialization reg.
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0913-odyn`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0251-xzig`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0722-gxps`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-30-2342-zybo`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-nominatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-10dmv-nocond-100cwu500clr70tvs-2023-05-31-0507-tvlt`
|
||||
|
||||
### No multi-view loss
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0310-wbqj`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-30-2357-qnct`
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0527-psnk`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0927-wxcq`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-nogradreg-nocond-100cwu500clr70tvs-2023-05-31-0743-pdbc`
|
||||
|
||||
### More multi-view loss
|
||||
- `experiment-stanfordv12-happy_buddha-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0510-caah`
|
||||
- `experiment-stanfordv12-dragon-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0726-zkyg`
|
||||
- `experiment-stanfordv12-bunny-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0254-akbq`
|
||||
- `experiment-stanfordv12-lucy-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-31-0924-aahb`
|
||||
- `experiment-stanfordv12-armadillo-both2marf-16atom-50xinscr-10dmiss-geom-25cnrml-8x512fc-leaky_relu-hit-0minatomstdngxp-500sphgrow-10mdrop-layernorm-multi_view-20dmv-nocond-100cwu500clr70tvs-2023-05-30-2352-xlrn`
|
||||
624
experiments/marf.py
Executable file
624
experiments/marf.py
Executable file
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from ifield import logging
|
||||
from ifield.cli import CliInterface
|
||||
from ifield.data.common.scan import SingleViewUVScan
|
||||
from ifield.data.coseg import read as coseg_read
|
||||
from ifield.data.stanford import read as stanford_read
|
||||
from ifield.datasets import stanford, coseg, common
|
||||
from ifield.models import intersection_fields
|
||||
from ifield.utils.operators import diff
|
||||
from ifield.viewer.ray_field import ModelViewer
|
||||
from munch import Munch
|
||||
from pathlib import Path
|
||||
from pytorch3d.loss.chamfer import chamfer_distance
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from tqdm import tqdm
|
||||
from trimesh import Trimesh
|
||||
from typing import Union
|
||||
import builtins
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rich
|
||||
import rich.pretty
|
||||
import statistics
|
||||
import torch
|
||||
pl.seed_everything(31337)
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
|
||||
|
||||
IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity
|
||||
|
||||
|
||||
class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def observation_ids(self) -> list[str]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_trimesh_from_uid(uid) -> Trimesh:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
|
||||
...
|
||||
|
||||
def setup(self, stage=None):
|
||||
assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM
|
||||
|
||||
if not self.hparams.data_dir is None:
|
||||
coseg.config.DATA_PATH = self.hparams.data_dir
|
||||
step = self.hparams.step # brevity
|
||||
|
||||
dataset = self.mk_ad_dataset()
|
||||
n_items_pre_step_mapping = len(dataset)
|
||||
|
||||
if step > 1:
|
||||
dataset = common.TransformExtendedDataset(dataset)
|
||||
|
||||
for sx in range(step):
|
||||
for sy in range(step):
|
||||
def make_slicer(sx, sy, step) -> callable: # the closure is required
|
||||
if step > 1:
|
||||
return lambda t: t[sx::step, sy::step]
|
||||
else:
|
||||
return lambda t: t
|
||||
@dataset.map(slicer=make_slicer(sx, sy, step))
|
||||
def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
|
||||
scan: SingleViewUVScan = sample[1]
|
||||
assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
|
||||
assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"
|
||||
|
||||
return {
|
||||
"z_uid" : sample[0],
|
||||
"origins" : scan.cam_pos,
|
||||
"dirs" : slicer(scan.ray_dirs),
|
||||
"points" : slicer(scan.points),
|
||||
"hits" : slicer(scan.hits),
|
||||
"miss" : slicer(scan.miss),
|
||||
"normals" : slicer(scan.normals),
|
||||
"distances" : slicer(scan.distances),
|
||||
}
|
||||
|
||||
# Split each object into train/val with SampleSplit
|
||||
n_items = len(dataset)
|
||||
n_val = int(n_items * self.hparams.val_fraction)
|
||||
n_train = n_items - n_val
|
||||
self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)
|
||||
|
||||
# split the dataset such that all steps are in same part
|
||||
assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
|
||||
indices = [
|
||||
i*step*step + sx*step + sy
|
||||
for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
|
||||
for sx in range(step)
|
||||
for sy in range(step)
|
||||
]
|
||||
self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
|
||||
self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
|
||||
|
||||
assert len(self.dataset_train) % self.hparams.batch_size == 0
|
||||
assert len(self.dataset_val) % self.hparams.batch_size == 0
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.dataset_train,
|
||||
batch_size = self.hparams.batch_size,
|
||||
drop_last = self.hparams.drop_last,
|
||||
num_workers = self.hparams.num_workers,
|
||||
persistent_workers = self.hparams.persistent_workers,
|
||||
pin_memory = self.hparams.pin_memory,
|
||||
prefetch_factor = self.hparams.prefetch_factor,
|
||||
shuffle = self.hparams.shuffle,
|
||||
generator = self.generator,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.dataset_val,
|
||||
batch_size = self.hparams.batch_size,
|
||||
drop_last = self.hparams.drop_last,
|
||||
num_workers = self.hparams.num_workers,
|
||||
persistent_workers = self.hparams.persistent_workers,
|
||||
pin_memory = self.hparams.pin_memory,
|
||||
prefetch_factor = self.hparams.prefetch_factor,
|
||||
generator = self.generator,
|
||||
)
|
||||
|
||||
|
||||
class StanfordUVDataModule(RayFieldAdDataModuleBase):
|
||||
skyward = "+Z"
|
||||
def __init__(self,
|
||||
data_dir : Union[str, Path, None] = None,
|
||||
obj_names : list[str] = ["bunny"], # empty means all
|
||||
|
||||
prng_seed : int = 1337,
|
||||
step : int = 2,
|
||||
batch_size : int = 5,
|
||||
drop_last : bool = False,
|
||||
num_workers : int = 8,
|
||||
persistent_workers : bool = True,
|
||||
pin_memory : int = True,
|
||||
prefetch_factor : int = 2,
|
||||
shuffle : bool = True,
|
||||
val_fraction : float = 0.30,
|
||||
):
|
||||
super().__init__()
|
||||
if not obj_names:
|
||||
obj_names = stanford_read.list_object_names()
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def observation_ids(self) -> list[str]:
|
||||
return self.hparams.obj_names
|
||||
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
return stanford.AutodecoderSingleViewUVScanDataset(
|
||||
obj_names = self.hparams.obj_names,
|
||||
data_path = self.hparams.data_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_trimesh_from_uid(obj_name) -> Trimesh:
|
||||
import mesh_to_sdf
|
||||
mesh = stanford_read.read_mesh(obj_name)
|
||||
return mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
|
||||
@staticmethod
|
||||
def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
|
||||
return stanford_read.read_mesh_mesh_sphere_scan(obj_name)
|
||||
|
||||
|
||||
class CosegUVDataModule(RayFieldAdDataModuleBase):
|
||||
skyward = "+Y"
|
||||
def __init__(self,
|
||||
data_dir : Union[str, Path, None] = None,
|
||||
object_sets : tuple[str] = ["tele-aliens"], # empty means all
|
||||
|
||||
prng_seed : int = 1337,
|
||||
step : int = 2,
|
||||
batch_size : int = 5,
|
||||
drop_last : bool = False,
|
||||
num_workers : int = 8,
|
||||
persistent_workers : bool = True,
|
||||
pin_memory : int = True,
|
||||
prefetch_factor : int = 2,
|
||||
shuffle : bool = True,
|
||||
val_fraction : float = 0.30,
|
||||
):
|
||||
super().__init__()
|
||||
if not object_sets:
|
||||
object_sets = coseg_read.list_object_sets()
|
||||
object_sets = tuple(object_sets)
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def observation_ids(self) -> list[str]:
|
||||
return coseg_read.list_model_id_strings(self.hparams.object_sets)
|
||||
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
return coseg.AutodecoderSingleViewUVScanDataset(
|
||||
object_sets = self.hparams.object_sets,
|
||||
data_path = self.hparams.data_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_trimesh_from_uid(string_uid):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
|
||||
uid = coseg_read.model_id_string_to_uid(string_uid)
|
||||
return coseg_read.read_mesh_mesh_sphere_scan(*uid)
|
||||
|
||||
|
||||
def mk_cli(args=None) -> CliInterface:
|
||||
cli = CliInterface(
|
||||
module_cls = IField,
|
||||
datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
|
||||
workdir = Path(__file__).parent.resolve(),
|
||||
experiment_name_prefix = "ifield",
|
||||
)
|
||||
cli.trainer_defaults.update(dict(
|
||||
precision = 16,
|
||||
min_epochs = 5,
|
||||
))
|
||||
|
||||
@cli.register_pre_training_callback
|
||||
def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
|
||||
module.set_observation_ids(datamodule.observation_ids)
|
||||
rank = getattr(rank_zero_only, "rank", 0)
|
||||
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }")
|
||||
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
|
||||
rich.print(f"[rank {rank}] {module.is_conditioned = }")
|
||||
|
||||
@cli.register_action(help="Interactive window with direct renderings from the model", args=[
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--ground-truth", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--scale", dict(type=int, default=3, help="Rendering scale")),
|
||||
("--fps", dict(type=int, default=None, help="FPS upper limit")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
("--write", dict(type=Path, default=None, help="Where to write a screenshot.")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def viewer(args: Namespace, config: Munch, model: IField):
|
||||
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
scale = args.scale,
|
||||
mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
|
||||
)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.fps_cap = args.fps
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
if args.write is None:
|
||||
viewer.run()
|
||||
else:
|
||||
assert args.write.suffix == ".png", args.write.name
|
||||
viewer.render_headless(args.write,
|
||||
n_frames = 1,
|
||||
fps = 1,
|
||||
state_callback = None,
|
||||
)
|
||||
|
||||
@cli.register_action(help="Prerender direct renderings from the model", args=[
|
||||
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
|
||||
("uids", dict(type=str, nargs="*")),
|
||||
("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
|
||||
("--fps", dict(type=int, default=60, help="Default is 60")),
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
uids = args.uids or list(model.keys())
|
||||
assert len(uids) > 1
|
||||
if not args.uids: uids.append(uids[0])
|
||||
viewer = ModelViewer(model, uids[0],
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
)
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
def state_callback(self: ModelViewer, frame: int):
|
||||
if frame % args.frames:
|
||||
self.lambertian_color = (0.8, 0.8, 1.0)
|
||||
else:
|
||||
self.lambertian_color = (1.0, 1.0, 1.0)
|
||||
self.fps = args.frames
|
||||
idx = frame // args.frames + 1
|
||||
if idx != len(uids):
|
||||
self.current_uid = uids[idx]
|
||||
print(f"Writing video to {str(args.output_path)!r}...")
|
||||
viewer.render_headless(args.output_path,
|
||||
n_frames = args.frames * (len(uids)-1) + 1,
|
||||
fps = args.fps,
|
||||
state_callback = state_callback,
|
||||
bitrate = args.bitrate,
|
||||
)
|
||||
|
||||
@cli.register_action(help="Prerender direct renderings from the model", args=[
|
||||
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
|
||||
("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")),
|
||||
("--fps", dict(type=int, default=60, help="Default is 60")),
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
)
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
cam_rot_x_init = viewer.cam_rot_x
|
||||
def state_callback(self: ModelViewer, frame: int):
|
||||
self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
|
||||
print(f"Writing video to {str(args.output_path)!r}...")
|
||||
viewer.render_headless(args.output_path,
|
||||
n_frames = args.frames,
|
||||
fps = args.fps,
|
||||
state_callback = state_callback,
|
||||
bitrate = args.bitrate,
|
||||
)
|
||||
|
||||
@cli.register_action(help="foo", args=[
|
||||
("fname", dict(type=Path, help="where to write json")),
|
||||
("-t", "--transpose", dict(action="store_true", help="transpose the output")),
|
||||
("--single-shape", dict(action="store_true", help="break after first shape")),
|
||||
("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
|
||||
("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
|
||||
("--filter-outliers", dict(action="store_true", help="like in PRIF")),
|
||||
])
|
||||
@torch.enable_grad()
|
||||
def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
|
||||
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
|
||||
model.eval()
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
|
||||
def T(array: np.ndarray, **kw) -> torch.Tensor:
|
||||
if isinstance(array, torch.Tensor): return array
|
||||
return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)
|
||||
|
||||
MEDIAL = model.hparams.output_mode == "medial_sphere"
|
||||
if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"
|
||||
|
||||
|
||||
uids = sorted(model.keys())
|
||||
if args.single_shape: uids = [uids[0]]
|
||||
rich.print(f"{datamodule_cls.__name__ = }")
|
||||
rich.print(f"{len(uids) = }")
|
||||
|
||||
# accumulators for IoU and F-Score, CD and COS
|
||||
|
||||
# sum reduction:
|
||||
n = defaultdict(int)
|
||||
n_gt_hits = defaultdict(int)
|
||||
n_gt_miss = defaultdict(int)
|
||||
n_gt_missing = defaultdict(int)
|
||||
n_outliers = defaultdict(int)
|
||||
p_mse = defaultdict(int)
|
||||
s_mse = defaultdict(int)
|
||||
cossim_med = defaultdict(int) # medial normals
|
||||
cossim_jac = defaultdict(int) # jacovian normals
|
||||
TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score
|
||||
# mean reduction:
|
||||
cd_dist = {} # chamfer distance
|
||||
cd_cos_med = {} # chamfer medial normals
|
||||
cd_cos_jac = {} # chamfer jacovian normals
|
||||
all_metrics = dict(
|
||||
n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
|
||||
cossim_jac=cossim_jac,
|
||||
TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
|
||||
cd_cos_jac=cd_cos_jac,
|
||||
)
|
||||
if MEDIAL:
|
||||
all_metrics["s_mse"] = s_mse
|
||||
all_metrics["cossim_med"] = cossim_med
|
||||
all_metrics["cd_cos_med"] = cd_cos_med
|
||||
if args.filter_outliers:
|
||||
all_metrics["n_outliers"] = n_outliers
|
||||
|
||||
t = datetime.now()
|
||||
for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
|
||||
sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)
|
||||
|
||||
z = model[uid].detach()
|
||||
|
||||
all_intersections = []
|
||||
all_medial_normals = []
|
||||
all_jacobian_normals = []
|
||||
|
||||
step = args.batch_size
|
||||
for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
|
||||
# prepare batch and gt
|
||||
origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True)
|
||||
dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :])
|
||||
gt_hits = T(sphere_scan_gt.hits [i:i+step])
|
||||
gt_miss = T(sphere_scan_gt.miss [i:i+step])
|
||||
gt_missing = T(sphere_scan_gt.missing [i:i+step])
|
||||
gt_points = T(sphere_scan_gt.points [i:i+step, :])
|
||||
gt_normals = T(sphere_scan_gt.normals [i:i+step, :])
|
||||
gt_distances = T(sphere_scan_gt.distances[i:i+step])
|
||||
|
||||
# forward
|
||||
if MEDIAL:
|
||||
(
|
||||
depths,
|
||||
silhouettes,
|
||||
intersections,
|
||||
medial_normals,
|
||||
is_intersecting,
|
||||
sphere_centers,
|
||||
sphere_radii,
|
||||
) = model({
|
||||
"origins" : origins,
|
||||
"dirs" : dirs,
|
||||
}, z, intersections_only=False, allow_nans=False)
|
||||
else:
|
||||
silhouettes = medial_normals = None
|
||||
intersections, is_intersecting = model({
|
||||
"origins" : origins,
|
||||
"dirs" : dirs,
|
||||
}, z, normalize_origins = True)
|
||||
is_intersecting = is_intersecting > 0.5
|
||||
jac = diff.jacobian(intersections, origins, detach=True)
|
||||
|
||||
# outlier removal (PRIF)
|
||||
if args.filter_outliers:
|
||||
outliers = jac.norm(dim=-2).norm(dim=-1) > 5
|
||||
n_outliers[uid] += outliers[is_intersecting].sum().item()
|
||||
# We count filtered points as misses
|
||||
is_intersecting &= ~outliers
|
||||
|
||||
model.zero_grad()
|
||||
jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)
|
||||
|
||||
all_intersections .append(intersections .detach()[is_intersecting.detach(), :])
|
||||
all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None
|
||||
all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])
|
||||
|
||||
# accumulate metrics
|
||||
with torch.no_grad():
|
||||
n [uid] += dirs.shape[0]
|
||||
n_gt_hits [uid] += gt_hits.sum().item()
|
||||
n_gt_miss [uid] += gt_miss.sum().item()
|
||||
n_gt_missing [uid] += gt_missing.sum().item()
|
||||
p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
|
||||
if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item()
|
||||
if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
|
||||
cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
|
||||
not_intersecting = ~is_intersecting
|
||||
TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive
|
||||
FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
|
||||
FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive
|
||||
TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative
|
||||
|
||||
all_intersections = torch.cat(all_intersections, dim=0)
|
||||
all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None
|
||||
all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)
|
||||
|
||||
hits = sphere_scan_gt.hits # brevity
|
||||
print()
|
||||
|
||||
assert all_intersections.shape[0] >= args.n_cd
|
||||
idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd]
|
||||
idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd]
|
||||
|
||||
print("cd... ", end="")
|
||||
tt = datetime.now()
|
||||
loss_cd, loss_cos_jac = chamfer_distance(
|
||||
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
|
||||
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
|
||||
batch_reduction = "sum", point_reduction = "sum",
|
||||
)
|
||||
if MEDIAL: _, loss_cos_med = chamfer_distance(
|
||||
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
|
||||
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
|
||||
batch_reduction = "sum", point_reduction = "sum",
|
||||
)
|
||||
print(datetime.now() - tt)
|
||||
|
||||
cd_dist [uid] = loss_cd.item()
|
||||
cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
|
||||
cd_cos_jac [uid] = loss_cos_jac.item()
|
||||
|
||||
print()
|
||||
model.zero_grad(set_to_none=True)
|
||||
print("Total time:", datetime.now() - t)
|
||||
print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None
|
||||
|
||||
sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs)))
|
||||
mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
|
||||
stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
|
||||
n_cd = args.n_cd
|
||||
P = sum(TP)/(sum(TP, FP))
|
||||
R = sum(TP)/(sum(TP, FN))
|
||||
print(f"{mean(n) = :11.1f} (rays per object)")
|
||||
print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)")
|
||||
print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)")
|
||||
print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)")
|
||||
print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None
|
||||
print(f"{n_cd = :11.0f} (cd rays per object)")
|
||||
print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)")
|
||||
print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)")
|
||||
print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)")
|
||||
print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None
|
||||
print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)")
|
||||
print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)")
|
||||
print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)")
|
||||
print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)")
|
||||
print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)")
|
||||
print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)")
|
||||
print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)")
|
||||
print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)")
|
||||
print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)")
|
||||
print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)")
|
||||
print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)")
|
||||
print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None
|
||||
print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)")
|
||||
print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)")
|
||||
print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None
|
||||
print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)")
|
||||
print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None
|
||||
print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None
|
||||
print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None
|
||||
|
||||
if args.transpose:
|
||||
all_metrics, old_metrics = defaultdict(dict), all_metrics
|
||||
for m, table in old_metrics.items():
|
||||
for uid, vals in table.items():
|
||||
all_metrics[uid][m] = vals
|
||||
all_metrics["_hparams"] = dict(n_cd=args.n_cd)
|
||||
else:
|
||||
all_metrics["n_cd"] = args.n_cd
|
||||
|
||||
if str(args.fname) == "-":
|
||||
print("{", ',\n'.join(
|
||||
f" {json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in all_metrics.items()
|
||||
), "}", sep="\n")
|
||||
else:
|
||||
args.fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.fname.open("w") as f:
|
||||
json.dump(all_metrics, f, indent=2)
|
||||
|
||||
return cli
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mk_cli().run()
|
||||
263
experiments/marf.yaml.j2
Executable file
263
experiments/marf.yaml.j2
Executable file
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env -S python ./marf.py module
|
||||
{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}
|
||||
{% do require_defined("mode", mode, "single", "ablation", "multi", strict=true, exchaustive=true) %}{# requires jinja2.ext.do #}
|
||||
{% set counter = itertools.count(start=0, step=1) %}
|
||||
{% set do_condition = mode == "multi" %}
|
||||
{% set do_ablation = mode == "ablation" %}
|
||||
|
||||
{% set hp_matrix = namespace() %}{# hyper parameter matrix #}
|
||||
|
||||
{% set hp_matrix.input_mode = [
|
||||
"both",
|
||||
"perp_foot",
|
||||
"plucker",
|
||||
] if do_ablation else [ "both" ] %}
|
||||
{% set hp_matrix.output_mode = ["medial_sphere", "orthogonal_plane"] %}{##}
|
||||
{% set hp_matrix.output_mode = ["medial_sphere"] %}{##}
|
||||
{% set hp_matrix.n_atoms = [16, 1, 4, 8, 32, 64] if do_ablation else [16] %}{##}
|
||||
{% set hp_matrix.normal_coeff = [0.25, 0] if do_ablation else [0.25] %}{##}
|
||||
{% set hp_matrix.dataset_item = [objname] if objname is defined else (["armadillo", "bunny", "happy_buddha", "dragon", "lucy"] if not do_condition else ["four-legged"]) %}{##}
|
||||
{% set hp_matrix.test_val_split_frac = [0.7] %}{##}
|
||||
{% set hp_matrix.lr_coeff = [5] %}{##}
|
||||
{% set hp_matrix.warmup_epochs = [1] if not do_condition else [0.1] %}{##}
|
||||
{% set hp_matrix.improve_miss_grads = [True] %}{##}
|
||||
{% set hp_matrix.normalize_ray_dirs = [True] %}{##}
|
||||
{% set hp_matrix.intersection_coeff = [2, 0] if do_ablation else [2] %}{##}
|
||||
{% set hp_matrix.miss_distance_coeff = [1, 0, 5] if do_ablation else [1] %}{##}
|
||||
{% set hp_matrix.relative_out = [False] %}{##}
|
||||
{% set hp_matrix.hidden_features = [512] %}{# like deepsdf and prif #}
|
||||
{% set hp_matrix.hidden_layers = [8] %}{# like deepsdf, nerf, prif #}
|
||||
{% set hp_matrix.nonlinearity = ["leaky_relu"] %}{##}
|
||||
{% set hp_matrix.omega = [30] %}{##}
|
||||
{% set hp_matrix.normalization = ["layernorm"] %}{##}
|
||||
{% set hp_matrix.dropout_percent = [1] %}{##}
|
||||
{% set hp_matrix.sphere_grow_reg_coeff = [500, 0, 5000] if do_ablation else [500] %}{##}
|
||||
{% set hp_matrix.geom_init = [True, False] if do_ablation else [True] %}{##}
|
||||
{% set hp_matrix.loss_inscription = [50, 0, 250] if do_ablation else [50] %}{##}
|
||||
{% set hp_matrix.atom_centroid_norm_std_reg_negexp = [0, None] if do_ablation else [0] %}{##}
|
||||
{% set hp_matrix.curvature_reg_coeff = [0.2] %}{##}
|
||||
{% set hp_matrix.multi_view_reg_coeff = [1, 2] if do_ablation else [1] %}{##}
|
||||
{% set hp_matrix.grad_reg = [ "multi_view", "nogradreg" ] if do_ablation else [ "multi_view" ] %}
|
||||
|
||||
{#% for hp in cartesian_hparams(hp_matrix) %}{##}
|
||||
{% for hp in ablation_hparams(hp_matrix, caartesian_keys=["output_mode", "dataset_item", "nonlinearity", "test_val_split_frac"]) %}
|
||||
|
||||
{% if hp.output_mode == "orthogonal_plane"%}
|
||||
{% if hp.normal_coeff == 0 %}{% set hp.normal_coeff = 0.25 %}
|
||||
{% elif hp.normal_coeff == 0.25 %}{% set hp.normal_coeff = 0 %}
|
||||
{% endif %}
|
||||
{% if hp.grad_reg == "multi_view" %}{% set hp.grad_reg = "nogradreg" %}
|
||||
{% elif hp.grad_reg == "nogradreg" %}{% set hp.grad_reg = "multi_view" %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{# filter bad/uninteresting hparam combos #}
|
||||
{% if ( hp.nonlinearity != "sine" and hp.omega != 30 )
|
||||
or ( hp.nonlinearity == "sine" and hp.normalization in ("layernorm", "layernorm_na") )
|
||||
or ( hp.multi_view_reg_coeff != 1 and "multi_view" not in hp.grad_reg )
|
||||
or ( "curvature" not in hp.grad_reg and hp.curvature_reg_coeff != 0.2 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.input_mode != "both" )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.atom_centroid_norm_std_reg_negexp != 0 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.n_atoms != 16 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.sphere_grow_reg_coeff != 500 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.loss_inscription != 50 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.miss_distance_coeff != 1 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.test_val_split_frac != 0.7 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.lr_coeff != 5 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and not hp.geom_init )
|
||||
or ( hp.output_mode == "orthogonal_plane" and not hp.intersection_coeff )
|
||||
%}
|
||||
{% continue %}{# requires jinja2.ext.loopcontrols #}
|
||||
{% endif %}
|
||||
|
||||
{% set index = next(counter) %}
|
||||
{% if select is not defined and index > 0 %}---{% endif %}
|
||||
{% if select is not defined or int(select) == index %}
|
||||
|
||||
trainer:
|
||||
gradient_clip_val : 1.0
|
||||
max_epochs : 200
|
||||
min_epochs : 200
|
||||
log_every_n_steps : 20
|
||||
|
||||
{% if not do_condition %}
|
||||
|
||||
StanfordUVDataModule:
|
||||
obj_names : ["{{ hp.dataset_item }}"]
|
||||
step : 4
|
||||
batch_size : 8
|
||||
val_fraction : {{ 1-hp.test_val_split_frac }}
|
||||
|
||||
{% else %}{# if do_condition #}
|
||||
|
||||
CosegUVDataModule:
|
||||
object_sets : ["{{ hp.dataset_item }}"]
|
||||
step : 4
|
||||
batch_size : 8
|
||||
val_fraction : {{ 1-hp.test_val_split_frac }}
|
||||
|
||||
{% endif %}{# if do_condition #}
|
||||
|
||||
logging:
|
||||
save_dir : logdir
|
||||
type : tensorboard
|
||||
project : ifield
|
||||
|
||||
{% autoescape false %}
|
||||
{% do require_defined("experiment_name", experiment_name, "single-shape" if do_condition else "multi-shape", strict=true) %}
|
||||
{% set input_mode_abbr = hp.input_mode
|
||||
.replace("plucker", "plkr")
|
||||
.replace("perp_foot", "prpft")
|
||||
%}
|
||||
{% set output_mode_abbr = hp.output_mode
|
||||
.replace("medial_sphere", "marf")
|
||||
.replace("orthogonal_plane", "prif")
|
||||
%}
|
||||
experiment_name: experiment-{{ "" if experiment_name is not defined else experiment_name }}
|
||||
{#--#}-{{ hp.dataset_item }}
|
||||
{#--#}-{{ input_mode_abbr }}2{{ output_mode_abbr }}
|
||||
{#--#}
|
||||
{%- if hp.output_mode == "medial_sphere" -%}
|
||||
{#--#}-{{ hp.n_atoms }}atom
|
||||
{#--# }-{{ "rel" if hp.relative_out else "norel" }}
|
||||
{#--# }-{{ "e" if hp.improve_miss_grads else "0" }}sqrt
|
||||
{#--#}-{{ int(hp.loss_inscription) if hp.loss_inscription else "no" }}xinscr
|
||||
{#--#}-{{ int(hp.miss_distance_coeff * 10) }}dmiss
|
||||
{#--#}-{{ "geom" if hp.geom_init else "nogeom" }}
|
||||
{#--#}{% if "curvature" in hp.grad_reg %}
|
||||
{#- -#}-{{ int(hp.curvature_reg_coeff*10) }}crv
|
||||
{#--#}{%- endif -%}
|
||||
{%- elif hp.output_mode == "orthogonal_plane" -%}
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{#--#}-{{ int(hp.intersection_coeff*10) }}chit
|
||||
{#--#}-{{ int(hp.normal_coeff*100) or "no" }}cnrml
|
||||
{#--# }-{{ "do" if hp.normalize_ray_dirs else "no" }}raynorm
|
||||
{#--#}-{{ hp.hidden_layers }}x{{ hp.hidden_features }}fc
|
||||
{#--#}-{{ hp.nonlinearity or "linear" }}
|
||||
{#--#}
|
||||
{%- if hp.nonlinearity == "sine" -%}
|
||||
{#--#}-{{ hp.omega }}omega
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{%- if hp.output_mode == "medial_sphere" -%}
|
||||
{#--#}-{{ str(hp.atom_centroid_norm_std_reg_negexp).replace(*"-n") if hp.atom_centroid_norm_std_reg_negexp is not none else 'no' }}minatomstdngxp
|
||||
{#--#}-{{ hp.sphere_grow_reg_coeff }}sphgrow
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{#--#}-{{ int(hp.dropout_percent*10) }}mdrop
|
||||
{#--#}-{{ hp.normalization or "nonorm" }}
|
||||
{#--#}-{{ hp.grad_reg }}
|
||||
{#--#}{% if "multi_view" in hp.grad_reg %}
|
||||
{#- -#}-{{ int(hp.multi_view_reg_coeff*10) }}dmv
|
||||
{#--#}{%- endif -%}
|
||||
{#--#}-{{ "concat" if do_condition else "nocond" }}
|
||||
{#--#}-{{ int(hp.warmup_epochs*100) }}cwu{{ int(hp.lr_coeff*100) }}clr{{ int(hp.test_val_split_frac*100) }}tvs
|
||||
{#--#}-{{ gen_run_uid(4) }} # select with --Oselect={{ index }}
|
||||
{#--#}
|
||||
{##}
|
||||
|
||||
{% endautoescape %}
|
||||
IntersectionFieldAutoDecoderModel:
|
||||
_extra: # used for easier introspection with jq
|
||||
dataset_item: {{ hp.dataset_item | to_json}}
|
||||
dataset_test_val_frac: {{ hp.test_val_split_frac }}
|
||||
select: {{ index }}
|
||||
|
||||
input_mode : {{ hp.input_mode }} # in {plucker, perp_foot, both}
|
||||
output_mode : {{ hp.output_mode }} # in {medial_sphere, orthogonal_plane}
|
||||
#latent_features : 256 # int
|
||||
#latent_features : 128 # int
|
||||
latent_features : 16 # int
|
||||
hidden_features : {{ hp.hidden_features }} # int
|
||||
hidden_layers : {{ hp.hidden_layers }} # int
|
||||
|
||||
improve_miss_grads : {{ bool(hp.improve_miss_grads) | to_json }}
|
||||
normalize_ray_dirs : {{ bool(hp.normalize_ray_dirs) | to_json }}
|
||||
|
||||
loss_intersection : {{ hp.intersection_coeff }}
|
||||
loss_intersection_l2 : 0
|
||||
loss_intersection_proj : 0
|
||||
loss_intersection_proj_l2 : 0
|
||||
|
||||
loss_normal_cossim : {{ hp.normal_coeff }} * EaseSin(85, 15)
|
||||
loss_normal_euclid : 0
|
||||
loss_normal_cossim_proj : 0
|
||||
loss_normal_euclid_proj : 0
|
||||
|
||||
{% if "multi_view" in hp.grad_reg %}
|
||||
loss_multi_view_reg : 0.1 * {{ hp.multi_view_reg_coeff }} * Linear(50)
|
||||
{% else %}
|
||||
loss_multi_view_reg : 0
|
||||
{% endif %}
|
||||
|
||||
{% if hp.output_mode == "orthogonal_plane" %}
|
||||
|
||||
loss_hit_cross_entropy : 1
|
||||
|
||||
{% elif hp.output_mode == "medial_sphere" %}
|
||||
|
||||
loss_hit_nodistance_l1 : 0
|
||||
loss_hit_nodistance_l2 : 100 * {{ hp.miss_distance_coeff }}
|
||||
loss_miss_distance_l1 : 0
|
||||
loss_miss_distance_l2 : 10 * {{ hp.miss_distance_coeff }}
|
||||
|
||||
loss_inscription_hits : {{ 0.4 * hp.loss_inscription }}
|
||||
loss_inscription_miss : 0
|
||||
loss_inscription_hits_l2 : 0
|
||||
loss_inscription_miss_l2 : {{ 6 * hp.loss_inscription }}
|
||||
|
||||
loss_sphere_grow_reg : 1e-6 * {{ hp.sphere_grow_reg_coeff }} # constant
|
||||
loss_atom_centroid_norm_std_reg: (0.09*(1-Linear(40)) + 0.01) * {{ 10**(-hp.atom_centroid_norm_std_reg_negexp) if hp.atom_centroid_norm_std_reg_negexp is not none else 0 }}
|
||||
|
||||
{% else %}{#endif hp.output_mode == "medial_sphere" #}
|
||||
THIS IS INVALID YAML
|
||||
{% endif %}
|
||||
|
||||
loss_embedding_norm : 0.01**2 * Linear(30, 0.1)
|
||||
|
||||
opt_learning_rate : {{ hp.lr_coeff }} * 10**(-4-0.5*EaseSin(170, 30)) # layernorm
|
||||
opt_warmup : {{ hp.warmup_epochs }}
|
||||
opt_weight_decay : 5e-6 # float
|
||||
|
||||
{% if hp.output_mode == "medial_sphere" %}
|
||||
|
||||
# MedialAtomNet:
|
||||
n_atoms : {{ hp.n_atoms }} # int
|
||||
{% if hp.geom_init %}
|
||||
final_init_wrr: [0.05, 0.6, 0.1]
|
||||
{% else %}
|
||||
final_init_wrr: null
|
||||
{% endif %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
|
||||
# FCBlock:
|
||||
normalization : {{ hp.normalization or "null" }} # in {null, layernorm, layernorm_na, weightnorm}
|
||||
nonlinearity : {{ hp.nonlinearity or "null" }} # in {null, relu, leaky_relu, silu, softplus, elu, selu, sine, sigmoid, tanh }
|
||||
{% set middle = 1 + hp.hidden_layers // 2 + (hp.hidden_layers % 2) %}{##}
|
||||
concat_skipped_layers : [{{ middle }}, -1]
|
||||
{% if do_condition %}
|
||||
concat_conditioned_layers : [0, {{ middle }}]
|
||||
{% else %}
|
||||
concat_conditioned_layers : []
|
||||
{% endif %}
|
||||
|
||||
# FCLayer:
|
||||
negative_slope : 0.01 # float
|
||||
omega_0 : {{ hp.omega }} # float
|
||||
residual_mode : null # in {null, identity}
|
||||
|
||||
{% endif %}{# -Oselect #}
|
||||
|
||||
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% set index = next(counter) %}
|
||||
# number of possible -Oselect: {{ index }}, from 0 to {{ index-1 }}
|
||||
# local: for select in {0..{{ index-1 }}}; do python ... -Omode={{ mode }} -Oselect=$select ... ; done
|
||||
# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done
|
||||
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID ...
|
||||
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp
|
||||
849
experiments/summary.py
Executable file
849
experiments/summary.py
Executable file
@@ -0,0 +1,849 @@
|
||||
#!/usr/bin/env python
|
||||
from concurrent.futures import ThreadPoolExecutor, Future, ProcessPoolExecutor
|
||||
from functools import partial
|
||||
from more_itertools import first, last, tail
|
||||
from munch import Munch, DefaultMunch, munchify, unmunchify
|
||||
from pathlib import Path
|
||||
from statistics import mean, StatisticsError
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from typing import Iterable, Optional, Literal
|
||||
from math import isnan
|
||||
import json
|
||||
import stat
|
||||
import matplotlib
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import os, os.path
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
import itertools
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import typer
|
||||
import warnings
|
||||
import yaml
|
||||
import tempfile
|
||||
|
||||
EXPERIMENTS = Path(__file__).resolve()
|
||||
LOGDIR = EXPERIMENTS / "logdir"
|
||||
TENSORBOARD = LOGDIR / "tensorboard"
|
||||
SLURM_LOGS = LOGDIR / "slurm_logs"
|
||||
CACHED_SUMMARIES = LOGDIR / "cached_summaries"
|
||||
COMPUTED_SCORES = LOGDIR / "computed_scores"
|
||||
|
||||
MISSING = object()
|
||||
|
||||
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
|
||||
def ignore_unknown(self, node):
|
||||
return None
|
||||
SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)
|
||||
|
||||
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
||||
parts = (
|
||||
part.lower()
|
||||
for part in re.split(r'(?=[A-Z])', text)
|
||||
if part
|
||||
)
|
||||
if join_abbreviations: # this operation is not reversible
|
||||
parts = list(parts)
|
||||
if len(parts) > 1:
|
||||
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
||||
if len(a) == len(b) == 1:
|
||||
parts[i] = parts[i] + parts.pop(i+1)
|
||||
return sep.join(parts)
|
||||
|
||||
def flatten_dict(data: dict, key_mapper: callable = lambda x: x) -> dict:
|
||||
if not any(isinstance(val, dict) for val in data.values()):
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if not isinstance(v, dict)
|
||||
} | {
|
||||
f"{key_mapper(p)}/{k}":v
|
||||
for p,d in data.items()
|
||||
if isinstance(d, dict)
|
||||
for k,v in d.items()
|
||||
}
|
||||
|
||||
def parse_jsonl(data: str) -> Iterable[dict]:
|
||||
yield from map(json.loads, (line for line in data.splitlines() if line.strip()))
|
||||
|
||||
def read_jsonl(path: Path) -> Iterable[dict]:
|
||||
with path.open("r") as f:
|
||||
data = f.read()
|
||||
yield from parse_jsonl(data)
|
||||
|
||||
def get_experiment_paths(filter: str | None, assert_dumped = False) -> Iterable[Path]:
|
||||
for path in TENSORBOARD.iterdir():
|
||||
if filter is not None and not re.search(filter, path.name): continue
|
||||
if not path.is_dir(): continue
|
||||
|
||||
if not (path / "hparams.yaml").is_file():
|
||||
warnings.warn(f"Missing hparams: {path}")
|
||||
continue
|
||||
if not any(path.glob("events.out.tfevents.*")):
|
||||
warnings.warn(f"Missing tfevents: {path}")
|
||||
continue
|
||||
|
||||
if __debug__ and assert_dumped:
|
||||
assert (path / "scalars/epoch.json").is_file(), path
|
||||
assert (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json").is_file(), path
|
||||
assert (path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json").is_file(), path
|
||||
|
||||
yield path
|
||||
|
||||
def dump_pl_tensorboard_hparams(experiment: Path):
|
||||
with (experiment / "hparams.yaml").open() as f:
|
||||
hparams = yaml.load(f, Loader=SafeLoaderIgnoreUnknown)
|
||||
|
||||
shebang = None
|
||||
with (experiment / "config.yaml").open("w") as f:
|
||||
raw_yaml = hparams.get('_pickled_cli_args', {}).get('_raw_yaml', "").replace("\n\r", "\n")
|
||||
if raw_yaml.startswith("#!"): # preserve shebang
|
||||
shebang, _, raw_yaml = raw_yaml.partition("\n")
|
||||
f.write(f"{shebang}\n")
|
||||
f.write(f"# {' '.join(map(shlex.quote, hparams.get('_pickled_cli_args', {}).get('sys_argv', ['None'])))}\n\n")
|
||||
f.write(raw_yaml)
|
||||
if shebang is not None:
|
||||
os.chmod(experiment / "config.yaml", (experiment / "config.yaml").stat().st_mode | stat.S_IXUSR)
|
||||
print(experiment / "config.yaml", "written!", file=sys.stderr)
|
||||
|
||||
with (experiment / "environ.yaml").open("w") as f:
|
||||
yaml.safe_dump(hparams.get('_pickled_cli_args', {}).get('host', {}).get('environ'), f)
|
||||
print(experiment / "environ.yaml", "written!", file=sys.stderr)
|
||||
|
||||
with (experiment / "repo.patch").open("w") as f:
|
||||
f.write(hparams.get('_pickled_cli_args', {}).get('host', {}).get('vcs', "None"))
|
||||
print(experiment / "repo.patch", "written!", file=sys.stderr)
|
||||
|
||||
def dump_simple_tf_events_to_jsonl(output_dir: Path, *tf_files: Path):
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import tensorboard.backend.event_processing.event_accumulator
|
||||
s, l = {}, [] # reused sentinels
|
||||
|
||||
#resource.setrlimit(resource.RLIMIT_NOFILE, (2**16,-1))
|
||||
file_handles = {}
|
||||
try:
|
||||
for tffile in tf_files:
|
||||
loader = tensorboard.backend.event_processing.event_file_loader.LegacyEventFileLoader(str(tffile))
|
||||
for event in loader.Load():
|
||||
for summary in MessageToDict(event).get("summary", s).get("value", l):
|
||||
if "simpleValue" in summary:
|
||||
tag = summary["tag"]
|
||||
if tag not in file_handles:
|
||||
fname = output_dir / f"{tag}.json"
|
||||
print(f"Opening {str(fname)!r}...", file=sys.stderr)
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_handles[tag] = fname.open("w") # ("a")
|
||||
val = summary["simpleValue"]
|
||||
data = json.dumps({
|
||||
"step" : event.step,
|
||||
"value" : float(val) if isinstance(val, str) else val,
|
||||
"wall_time" : event.wall_time,
|
||||
})
|
||||
file_handles[tag].write(f"{data}\n")
|
||||
finally:
|
||||
if file_handles:
|
||||
print("Closing json files...", file=sys.stderr)
|
||||
for k, v in file_handles.items():
|
||||
v.close()
|
||||
|
||||
|
||||
NO_FILTER = {
|
||||
"__uid",
|
||||
"_minutes",
|
||||
"_epochs",
|
||||
"_hp_nonlinearity",
|
||||
"_val_uloss_intersection",
|
||||
"_val_uloss_normal_cossim",
|
||||
"_val_uloss_intersection",
|
||||
}
|
||||
def filter_jsonl_columns(data: Iterable[dict | None], no_filter=NO_FILTER) -> list[dict]:
|
||||
def merge_siren_omega(data: dict) -> dict:
|
||||
return {
|
||||
key: (
|
||||
f"{val}-{data.get('hp_omega_0', 'ERROR')}"
|
||||
if (key.removeprefix("_"), val) == ("hp_nonlinearity", "sine") else
|
||||
val
|
||||
)
|
||||
for key, val in data.items()
|
||||
if key != "hp_omega_0"
|
||||
}
|
||||
|
||||
def remove_uninteresting_cols(rows: list[dict]) -> Iterable[dict]:
|
||||
unique_vals = {}
|
||||
def register_val(key, val):
|
||||
unique_vals.setdefault(key, set()).add(repr(val))
|
||||
return val
|
||||
|
||||
whitelisted = {
|
||||
key
|
||||
for row in rows
|
||||
for key, val in row.items()
|
||||
if register_val(key, val) and val not in ("None", "0", "0.0")
|
||||
}
|
||||
for key in unique_vals:
|
||||
for row in rows:
|
||||
if key not in row:
|
||||
unique_vals[key].add(MISSING)
|
||||
for key, vals in unique_vals.items():
|
||||
if key not in whitelisted: continue
|
||||
if len(vals) == 1:
|
||||
whitelisted.remove(key)
|
||||
|
||||
whitelisted.update(no_filter)
|
||||
|
||||
yield from (
|
||||
{
|
||||
key: val
|
||||
for key, val in row.items()
|
||||
if key in whitelisted
|
||||
}
|
||||
for row in rows
|
||||
)
|
||||
|
||||
def pessemize_types(rows: list[dict]) -> Iterable[dict]:
|
||||
types = {}
|
||||
order = (str, float, int, bool, tuple, type(None))
|
||||
for row in rows:
|
||||
for key, val in row.items():
|
||||
if isinstance(val, list): val = tuple(val)
|
||||
assert type(val) in order, (type(val), val)
|
||||
index = order.index(type(val))
|
||||
types[key] = min(types.get(key, 999), index)
|
||||
|
||||
yield from (
|
||||
{
|
||||
key: order[types[key]](val) if val is not None else None
|
||||
for key, val in row.items()
|
||||
}
|
||||
for row in rows
|
||||
)
|
||||
|
||||
data = (row for row in data if row is not None)
|
||||
data = map(partial(flatten_dict, key_mapper=camel_to_snake_case), data)
|
||||
data = map(merge_siren_omega, data)
|
||||
data = remove_uninteresting_cols(list(data))
|
||||
data = pessemize_types(list(data))
|
||||
|
||||
return data
|
||||
|
||||
PlotMode = Literal["stackplot", "lineplot"]
|
||||
|
||||
def plot_losses(experiments: list[Path], mode: PlotMode, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force=True):
|
||||
def get_losses(experiment: Path, training: bool = True, unscaled: bool = False) -> Iterable[Path]:
|
||||
if not training and unscaled:
|
||||
return experiment.glob("scalars/*.validation_step/unscaled_loss_*.json")
|
||||
elif not training and not unscaled:
|
||||
return experiment.glob("scalars/*.validation_step/loss_*.json")
|
||||
elif training and unscaled:
|
||||
return experiment.glob("scalars/*.training_step/unscaled_loss_*.json")
|
||||
elif training and not unscaled:
|
||||
return experiment.glob("scalars/*.training_step/loss_*.json")
|
||||
|
||||
print("Mapping colors...")
|
||||
configurations = [
|
||||
dict(unscaled=unscaled, training=training),
|
||||
] if not write else [
|
||||
dict(unscaled=False, training=False),
|
||||
dict(unscaled=False, training=True),
|
||||
dict(unscaled=True, training=False),
|
||||
dict(unscaled=True, training=True),
|
||||
]
|
||||
legends = set(
|
||||
f"""{
|
||||
loss.parent.name.split(".", 1)[0]
|
||||
}.{
|
||||
loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
|
||||
}"""
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
for loss in get_losses(experiment, **kw)
|
||||
)
|
||||
colormap = dict(zip(
|
||||
sorted(legends),
|
||||
itertools.cycle(mcolors.TABLEAU_COLORS),
|
||||
))
|
||||
|
||||
def mkplot(experiment: Path, training: bool = True, unscaled: bool = False) -> tuple[bool, str]:
|
||||
label = f"{'unscaled' if unscaled else 'scaled'} {'training' if training else 'validation'}"
|
||||
if write:
|
||||
old_savefig_fname = experiment / f"{label.replace(' ', '-')}-{mode}.png"
|
||||
savefig_fname = experiment / "plots" / f"{label.replace(' ', '-')}-{mode}.png"
|
||||
savefig_fname.parent.mkdir(exist_ok=True, parents=True)
|
||||
if old_savefig_fname.is_file():
|
||||
old_savefig_fname.rename(savefig_fname)
|
||||
if savefig_fname.is_file() and not force:
|
||||
return True, "savefig_fname already exists"
|
||||
|
||||
# Get and sort data
|
||||
losses = {}
|
||||
for loss in get_losses(experiment, training=training, unscaled=unscaled):
|
||||
model = loss.parent.name.split(".", 1)[0]
|
||||
name = loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
|
||||
losses[f"{model}.{name}"] = (loss, list(read_jsonl(loss)))
|
||||
losses = dict(sorted(losses.items())) # sort keys
|
||||
if not losses:
|
||||
return True, "no losses"
|
||||
|
||||
# unwrap
|
||||
steps = [i["step"] for i in first(losses.values())[1]]
|
||||
values = [
|
||||
[i["value"] if not isnan(i["value"]) else 0 for i in data]
|
||||
for name, (scalar, data) in losses.items()
|
||||
]
|
||||
|
||||
# normalize
|
||||
if mode == "stackplot":
|
||||
totals = list(map(sum, zip(*values)))
|
||||
values = [
|
||||
[i / t for i, t in zip(data, totals)]
|
||||
for data in values
|
||||
]
|
||||
|
||||
print(experiment.name, label)
|
||||
fig, ax = plt.subplots(figsize=(16, 12))
|
||||
|
||||
if mode == "stackplot":
|
||||
ax.stackplot(steps, values,
|
||||
colors = list(map(colormap.__getitem__, losses.keys())),
|
||||
labels = list(
|
||||
label.split(".", 1)[1].removeprefix("loss_")
|
||||
for label in losses.keys()
|
||||
),
|
||||
)
|
||||
ax.set_xlim(0, steps[-1])
|
||||
ax.set_ylim(0, 1)
|
||||
ax.invert_yaxis()
|
||||
|
||||
elif mode == "lineplot":
|
||||
for data, color, label in zip(
|
||||
values,
|
||||
map(colormap.__getitem__, losses.keys()),
|
||||
list(losses.keys()),
|
||||
):
|
||||
ax.plot(steps, data,
|
||||
color = color,
|
||||
label = label,
|
||||
)
|
||||
ax.set_xlim(0, steps[-1])
|
||||
|
||||
else:
|
||||
raise ValueError(f"{mode=}")
|
||||
|
||||
ax.legend()
|
||||
ax.set_title(f"{label} loss\n{experiment.name}")
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("loss%")
|
||||
|
||||
if mode == "stackplot":
|
||||
ax2 = make_axes_locatable(ax).append_axes("bottom", 0.8, pad=0.05, sharex=ax)
|
||||
ax2.stackplot( steps, totals )
|
||||
|
||||
for tl in ax.get_xticklabels(): tl.set_visible(False)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
if write:
|
||||
fig.savefig(savefig_fname, dpi=300)
|
||||
print(savefig_fname)
|
||||
plt.close(fig)
|
||||
|
||||
return False, None
|
||||
|
||||
print("Plotting...")
|
||||
if write:
|
||||
matplotlib.use('agg') # fixes "WARNING: QApplication was not created in the main() thread."
|
||||
any_error = False
|
||||
if write:
|
||||
with ThreadPoolExecutor(max_workers=None) as pool:
|
||||
futures = [
|
||||
(experiment, pool.submit(mkplot, experiment, **kw))
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
]
|
||||
else:
|
||||
def mkfuture(item):
|
||||
f = Future()
|
||||
f.set_result(item)
|
||||
return f
|
||||
futures = [
|
||||
(experiment, mkfuture(mkplot(experiment, **kw)))
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
]
|
||||
|
||||
for experiment, future in futures:
|
||||
try:
|
||||
err, msg = future.result()
|
||||
except Exception:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
any_error = True
|
||||
continue
|
||||
if err:
|
||||
print(f"{msg}: {experiment.name}")
|
||||
any_error = True
|
||||
continue
|
||||
|
||||
if not any_error and not write: # show in main thread
|
||||
plt.show()
|
||||
elif not write:
|
||||
print("There were errors, will not show figure...", file=sys.stderr)
|
||||
|
||||
|
||||
|
||||
# =========
|
||||
|
||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||
|
||||
@app.command(help="Dump simple tensorboard events to json and extract some pytorch lightning hparams")
|
||||
def tf_dump(tfevent_files: list[Path], j: int = typer.Option(1, "-j"), force: bool = False):
|
||||
# expand to all tfevents files (there may be more than one)
|
||||
tfevent_files = sorted(set([
|
||||
tffile
|
||||
for tffile in tfevent_files
|
||||
if tffile.name.startswith("events.out.tfevents.")
|
||||
] + [
|
||||
tffile
|
||||
for experiment_dir in tfevent_files
|
||||
if experiment_dir.is_dir()
|
||||
for tffile in experiment_dir.glob("events.out.tfevents.*")
|
||||
] + [
|
||||
tffile
|
||||
for hparam_file in tfevent_files
|
||||
if hparam_file.name in ("hparams.yaml", "config.yaml")
|
||||
for tffile in hparam_file.parent.glob("events.out.tfevents.*")
|
||||
]))
|
||||
|
||||
# filter already dumped
|
||||
if not force:
|
||||
tfevent_files = [
|
||||
tffile
|
||||
for tffile in tfevent_files
|
||||
if not (
|
||||
(tffile.parent / "scalars/epoch.json").is_file()
|
||||
and
|
||||
tffile.stat().st_mtime < (tffile.parent / "scalars/epoch.json").stat().st_mtime
|
||||
)
|
||||
]
|
||||
|
||||
if not tfevent_files:
|
||||
raise typer.BadParameter("Nothing to be done, consider --force")
|
||||
|
||||
jobs = {}
|
||||
for tffile in tfevent_files:
|
||||
if not tffile.is_file():
|
||||
print("ERROR: file not found:", tffile, file=sys.stderr)
|
||||
continue
|
||||
output_dir = tffile.parent / "scalars"
|
||||
jobs.setdefault(output_dir, []).append(tffile)
|
||||
with ProcessPoolExecutor() as p:
|
||||
for experiment in set(tffile.parent for tffile in tfevent_files):
|
||||
p.submit(dump_pl_tensorboard_hparams, experiment)
|
||||
for output_dir, tffiles in jobs.items():
|
||||
p.submit(dump_simple_tf_events_to_jsonl, output_dir, *tffiles)
|
||||
|
||||
@app.command(help="Propose experiment regexes")
|
||||
def propose(cmd: str = typer.Argument("summary"), null: bool = False):
|
||||
def get():
|
||||
for i in TENSORBOARD.iterdir():
|
||||
if not i.is_dir(): continue
|
||||
if not (i / "hparams.yaml").is_file(): continue
|
||||
prefix, name, *hparams, year, month, day, hhmm, uid = i.name.split("-")
|
||||
yield f"{name}.*-{year}-{month}-{day}"
|
||||
proposals = sorted(set(get()), key=lambda x: x.split(".*-", 1)[1])
|
||||
print("\n".join(
|
||||
f"{'>/dev/null ' if null else ''}{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
|
||||
for i in proposals
|
||||
))
|
||||
|
||||
@app.command("list", help="List used experiment regexes")
|
||||
def list_cached_summaries(cmd: str = typer.Argument("summary")):
|
||||
if not CACHED_SUMMARIES.is_dir():
|
||||
cached = []
|
||||
else:
|
||||
cached = [
|
||||
i.name.removesuffix(".jsonl")
|
||||
for i in CACHED_SUMMARIES.iterdir()
|
||||
if i.suffix == ".jsonl"
|
||||
if i.is_file() and i.stat().st_size
|
||||
]
|
||||
def order(key: str) -> list[str]:
|
||||
return re.sub(r'[^0-9\-]', '', key.split(".*")[-1]).strip("-").split("-") + [key]
|
||||
|
||||
print("\n".join(
|
||||
f"{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
|
||||
for i in sorted(cached, key=order)
|
||||
))
|
||||
|
||||
@app.command(help="Precompute the summary of a experiment regex")
|
||||
def compute_summary(filter: str, force: bool = False, dump: bool = False, no_cache: bool = False):
|
||||
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
|
||||
if cache.is_file() and cache.stat().st_size:
|
||||
if not force:
|
||||
raise FileExistsError(cache)
|
||||
|
||||
def mk_summary(path: Path) -> dict | None:
|
||||
cache = path / "train_summary.json"
|
||||
if cache.is_file() and cache.stat().st_size and cache.stat().st_mtime > (path/"scalars/epoch.json").stat().st_mtime:
|
||||
with cache.open() as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
with (path / "hparams.yaml").open() as f:
|
||||
hparams = munchify(yaml.load(f, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
|
||||
config = hparams._pickled_cli_args._raw_yaml
|
||||
config = munchify(yaml.load(config, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
|
||||
|
||||
try:
|
||||
train_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json"))
|
||||
val_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json"))
|
||||
except:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
return None
|
||||
|
||||
out = Munch()
|
||||
out.uid = path.name.rsplit("-", 1)[-1]
|
||||
out.name = path.name
|
||||
out.date = "-".join(path.name.split("-")[-5:-1])
|
||||
out.epochs = int(last(read_jsonl(path / "scalars/epoch.json"))["value"])
|
||||
out.steps = val_loss[-1]["step"]
|
||||
out.gpu = hparams._pickled_cli_args.host.gpus[1][1]
|
||||
|
||||
if val_loss[-1]["wall_time"] - val_loss[0]["wall_time"] > 0:
|
||||
out.batches_per_second = val_loss[-1]["step"] / (val_loss[-1]["wall_time"] - val_loss[0]["wall_time"])
|
||||
else:
|
||||
out.batches_per_second = 0
|
||||
|
||||
out.minutes = (val_loss[-1]["wall_time"] - train_loss[0]["wall_time"]) / 60
|
||||
|
||||
if (path / "scalars/PsutilMonitor/gpu.00.memory.used.json").is_file():
|
||||
max(i["value"] for i in read_jsonl(path / "scalars/PsutilMonitor/gpu.00.memory.used.json"))
|
||||
|
||||
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step").glob("*.json"):
|
||||
if not metric_path.is_file() or not metric_path.stat().st_size: continue
|
||||
|
||||
metric_name = metric_path.name.removesuffix(".json")
|
||||
metric_data = read_jsonl(metric_path)
|
||||
try:
|
||||
out[f"val_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
|
||||
except StatisticsError:
|
||||
out[f"val_{metric_name}"] = float('nan')
|
||||
|
||||
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.training_step").glob("*.json"):
|
||||
if not any(i in metric_path.name for i in ("miss_radius_grad", "sphere_center_grad", "loss_tangential_reg", "multi_view")): continue
|
||||
if not metric_path.is_file() or not metric_path.stat().st_size: continue
|
||||
|
||||
metric_name = metric_path.name.removesuffix(".json")
|
||||
metric_data = read_jsonl(metric_path)
|
||||
try:
|
||||
out[f"train_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
|
||||
except StatisticsError:
|
||||
out[f"train_{metric_name}"] = float('nan')
|
||||
|
||||
out.hostname = hparams._pickled_cli_args.host.hostname
|
||||
|
||||
for key, val in config.IntersectionFieldAutoDecoderModel.items():
|
||||
if isinstance(val, dict):
|
||||
out.update({f"hp_{key}_{k}": v for k, v in val.items()})
|
||||
elif isinstance(val, float | int | str | bool | None):
|
||||
out[f"hp_{key}"] = val
|
||||
|
||||
with cache.open("w") as f:
|
||||
json.dump(unmunchify(out), f)
|
||||
|
||||
return dict(out)
|
||||
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No matching experiment")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments) # force=force_dump)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
# does literally nothing, thanks GIL
|
||||
with ThreadPoolExecutor() as p:
|
||||
results = list(p.map(mk_summary, experiments))
|
||||
|
||||
if any(result is None for result in results):
|
||||
if all(result is None for result in results):
|
||||
print("No summary succeeded", file=sys.stderr)
|
||||
raise typer.Exit(exit_code=1)
|
||||
warnings.warn("Some summaries failed:\n" + "\n".join(
|
||||
str(experiment)
|
||||
for result, experiment in zip(results, experiments)
|
||||
if result is None
|
||||
))
|
||||
|
||||
summaries = "\n".join( map(json.dumps, results) )
|
||||
if not no_cache:
|
||||
cache.parent.mkdir(parents=True, exist_ok=True)
|
||||
with cache.open("w") as f:
|
||||
f.write(summaries)
|
||||
return summaries
|
||||
|
||||
@app.command(help="Show the summary of a experiment regex, precompute it if needed")
|
||||
def summary(filter: Optional[str] = typer.Argument(None), force: bool = False, dump: bool = False, all: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("summary")
|
||||
|
||||
def key_mangler(key: str) -> str:
|
||||
for pattern, sub in (
|
||||
(r'^val_unscaled_loss_', r'val_uloss_'),
|
||||
(r'^train_unscaled_loss_', r'train_uloss_'),
|
||||
(r'^val_loss_', r'val_sloss_'),
|
||||
(r'^train_loss_', r'train_sloss_'),
|
||||
):
|
||||
key = re.sub(pattern, sub, key)
|
||||
|
||||
return key
|
||||
|
||||
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
|
||||
if force or not (cache.is_file() and cache.stat().st_size):
|
||||
compute_summary(filter, force=force, dump=dump)
|
||||
assert cache.is_file() and cache.stat().st_size, (cache, cache.stat())
|
||||
|
||||
if os.isatty(0) and os.isatty(1) and shutil.which("vd"):
|
||||
rows = read_jsonl(cache)
|
||||
rows = ({key_mangler(k): v for k, v in row.items()} if row is not None else None for row in rows)
|
||||
if not all:
|
||||
rows = filter_jsonl_columns(rows)
|
||||
rows = ({k: v for k, v in row.items() if not k.startswith(("val_sloss_", "train_sloss_"))} for row in rows)
|
||||
data = "\n".join(map(json.dumps, rows))
|
||||
subprocess.run(["vd",
|
||||
#"--play", EXPERIMENTS / "set-key-columns.vd",
|
||||
"-f", "jsonl"
|
||||
], input=data, text=True, check=True)
|
||||
else:
|
||||
with cache.open() as f:
|
||||
print(f.read())
|
||||
|
||||
@app.command(help="Filter uninteresting keys from jsonl stdin")
|
||||
def filter_cols():
|
||||
rows = map(json.loads, (line for line in sys.stdin.readlines() if line.strip()))
|
||||
rows = filter_jsonl_columns(rows)
|
||||
print(*map(json.dumps, rows), sep="\n")
|
||||
|
||||
@app.command(help="Run a command for each experiment matched by experiment regex")
|
||||
def exec(filter: str, cmd: list[str], j: int = typer.Option(1, "-j"), dumped: bool = False, undumped: bool = False):
|
||||
# inspired by fd / gnu parallel
|
||||
def populate_cmd(experiment: Path, cmd: Iterable[str]) -> Iterable[str]:
|
||||
any = False
|
||||
for i in cmd:
|
||||
if i == "{}":
|
||||
any = True
|
||||
yield str(experiment / "hparams.yaml")
|
||||
elif i == "{//}":
|
||||
any = True
|
||||
yield str(experiment)
|
||||
else:
|
||||
yield i
|
||||
if not any:
|
||||
yield str(experiment / "hparams.yaml")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=j or None) as p:
|
||||
results = p.map(subprocess.run, (
|
||||
list(populate_cmd(experiment, cmd))
|
||||
for experiment in get_experiment_paths(filter)
|
||||
if not dumped or (experiment / "scalars/epoch.json").is_file()
|
||||
if not undumped or not (experiment / "scalars/epoch.json").is_file()
|
||||
))
|
||||
|
||||
if any(i.returncode for i in results):
|
||||
return typer.Exit(1)
|
||||
|
||||
@app.command(help="Show stackplot of experiment loss")
|
||||
def stackplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
plot_losses(experiments,
|
||||
mode = "stackplot",
|
||||
write = write,
|
||||
dump = dump,
|
||||
training = training,
|
||||
unscaled = unscaled,
|
||||
force = force,
|
||||
)
|
||||
|
||||
@app.command(help="Show stackplot of experiment loss")
|
||||
def lineplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
plot_losses(experiments,
|
||||
mode = "lineplot",
|
||||
write = write,
|
||||
dump = dump,
|
||||
training = training,
|
||||
unscaled = unscaled,
|
||||
force = force,
|
||||
)
|
||||
|
||||
@app.command(help="Open tensorboard for the experiments matching the regex")
|
||||
def tensorboard(filter: Optional[str] = typer.Argument(None), watch: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("tensorboard")
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=False))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
|
||||
with tempfile.TemporaryDirectory(suffix=f"ifield-{filter}") as d:
|
||||
treefarm = Path(d)
|
||||
with ThreadPoolExecutor(max_workers=2) as p:
|
||||
for experiment in experiments:
|
||||
(treefarm / experiment.name).symlink_to(experiment)
|
||||
|
||||
cmd = ["tensorboard", "--logdir", d]
|
||||
print("+", *map(shlex.quote, cmd), file=sys.stderr)
|
||||
tensorboard = p.submit(subprocess.run, cmd, check=True)
|
||||
if not watch:
|
||||
tensorboard.result()
|
||||
|
||||
else:
|
||||
all_experiments = set(get_experiment_paths(None, assert_dumped=False))
|
||||
while not tensorboard.done():
|
||||
time.sleep(10)
|
||||
new_experiments = set(get_experiment_paths(None, assert_dumped=False)) - all_experiments
|
||||
if new_experiments:
|
||||
for experiment in new_experiments:
|
||||
print(f"Adding {experiment.name!r}...", file=sys.stderr)
|
||||
(treefarm / experiment.name).symlink_to(experiment)
|
||||
all_experiments.update(new_experiments)
|
||||
|
||||
@app.command(help="Compute evaluation metrics")
|
||||
def metrics(filter: Optional[str] = typer.Argument(None), dump: bool = False, dry: bool = False, prefix: Optional[str] = typer.Option(None), derive: bool = False, each: bool = False, no_total: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("metrics --derive")
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=False))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
def run(*cmd):
|
||||
if prefix is not None:
|
||||
cmd = [*shlex.split(prefix), *cmd]
|
||||
if dry:
|
||||
print(*map(shlex.quote, map(str, cmd)))
|
||||
else:
|
||||
print("+", *map(shlex.quote, map(str, cmd)))
|
||||
subprocess.run(cmd)
|
||||
|
||||
for experiment in experiments:
|
||||
if no_total: continue
|
||||
if not (experiment / "compute-scores/metrics.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics.json",
|
||||
"--transpose",
|
||||
)
|
||||
if not (experiment / "compute-scores/metrics-last.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-last.json",
|
||||
"--transpose",
|
||||
)
|
||||
if "2prif-" not in experiment.name: continue
|
||||
if not (experiment / "compute-scores/metrics-sans_outliers.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-sans_outliers.json",
|
||||
"--transpose", "--filter-outliers"
|
||||
)
|
||||
if not (experiment / "compute-scores/metrics-last-sans_outliers.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-last-sans_outliers.json",
|
||||
"--transpose", "--filter-outliers"
|
||||
)
|
||||
|
||||
if dry: return
|
||||
if prefix is not None:
|
||||
print("prefix was used, assuming a job scheduler was used, will not print scores.", file=sys.stderr)
|
||||
return
|
||||
|
||||
metrics = [
|
||||
*(experiment / "compute-scores/metrics.json" for experiment in experiments),
|
||||
*(experiment / "compute-scores/metrics-last.json" for experiment in experiments),
|
||||
*(experiment / "compute-scores/metrics-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
|
||||
*(experiment / "compute-scores/metrics-last-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
|
||||
]
|
||||
if not no_total:
|
||||
assert all(metric.exists() for metric in metrics)
|
||||
else:
|
||||
metrics = (metric for metric in metrics if metric.exists())
|
||||
|
||||
out = []
|
||||
for metric in metrics:
|
||||
experiment = metric.parent.parent.name
|
||||
is_last = metric.name in ("metrics-last.json", "metrics-last-sans_outliers.json")
|
||||
with metric.open() as f:
|
||||
data = json.load(f)
|
||||
|
||||
if derive:
|
||||
derived = {}
|
||||
objs = [i for i in data.keys() if i != "_hparams"]
|
||||
for obj in (objs if each else []) + [None]:
|
||||
if obj is None:
|
||||
d = DefaultMunch(0)
|
||||
for obj in objs:
|
||||
for k, v in data[obj].items():
|
||||
d[k] += v
|
||||
obj = "_all_"
|
||||
n_cd = data["_hparams"]["n_cd"] * len(objs)
|
||||
n_emd = data["_hparams"]["n_emd"] * len(objs)
|
||||
else:
|
||||
d = munchify(data[obj])
|
||||
n_cd = data["_hparams"]["n_cd"]
|
||||
n_emd = data["_hparams"]["n_emd"]
|
||||
|
||||
precision = d.TP / (d.TP + d.FP)
|
||||
recall = d.TP / (d.TP + d.FN)
|
||||
derived[obj] = dict(
|
||||
filtered = d.n_outliers / d.n if "n_outliers" in d else None,
|
||||
iou = d.TP / (d.TP + d.FN + d.FP),
|
||||
precision = precision,
|
||||
recall = recall,
|
||||
f_score = 2 * (precision * recall) / (precision + recall),
|
||||
cd = d.cd_dist / n_cd,
|
||||
emd = d.emd / n_emd,
|
||||
cos_med = 1 - (d.cd_cos_med / n_cd) if "cd_cos_med" in d else None,
|
||||
cos_jac = 1 - (d.cd_cos_jac / n_cd),
|
||||
)
|
||||
data = derived if each else derived["_all_"]
|
||||
|
||||
data["uid"] = experiment.rsplit("-", 1)[-1]
|
||||
data["experiment_name"] = experiment
|
||||
data["is_last"] = is_last
|
||||
|
||||
out.append(json.dumps(data))
|
||||
|
||||
if derive and not each and os.isatty(0) and os.isatty(1) and shutil.which("vd"):
|
||||
subprocess.run(["vd", "-f", "jsonl"], input="\n".join(out), text=True, check=True)
|
||||
else:
|
||||
print("\n".join(out))
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
822
figures/nn-architecture.svg
Normal file
822
figures/nn-architecture.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 198 KiB |
57
ifield/__init__.py
Normal file
57
ifield/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
def setup_print_hooks():
|
||||
import os
|
||||
if not os.environ.get("IFIELD_PRETTY_TRACEBACK", None):
|
||||
return
|
||||
|
||||
from rich.traceback import install
|
||||
from rich.console import Console
|
||||
import warnings, sys
|
||||
|
||||
if not os.isatty(2):
|
||||
# https://github.com/Textualize/rich/issues/1809
|
||||
os.environ.setdefault("COLUMNS", "120")
|
||||
|
||||
install(
|
||||
show_locals = bool(os.environ.get("SHOW_LOCALS", "")),
|
||||
width = None,
|
||||
)
|
||||
|
||||
# custom warnings
|
||||
# https://github.com/Textualize/rich/issues/433
|
||||
|
||||
from rich.traceback import install
|
||||
from rich.console import Console
|
||||
import warnings, sys
|
||||
|
||||
|
||||
def showwarning(message, category, filename, lineno, file=None, line=None):
|
||||
msg = warnings.WarningMessage(message, category, filename, lineno, file, line)
|
||||
|
||||
if file is None:
|
||||
file = sys.stderr
|
||||
if file is None:
|
||||
# sys.stderr is None when run with pythonw.exe:
|
||||
# warnings get lost
|
||||
return
|
||||
text = warnings._formatwarnmsg(msg)
|
||||
if file.isatty():
|
||||
Console(file=file, stderr=True).print(text)
|
||||
else:
|
||||
try:
|
||||
file.write(text)
|
||||
except OSError:
|
||||
# the file (probably stderr) is invalid - this warning gets lost.
|
||||
pass
|
||||
warnings.showwarning = showwarning
|
||||
|
||||
def warning_no_src_line(message, category, filename, lineno, file=None, line=None):
|
||||
if (file or sys.stderr) is not None:
|
||||
if (file or sys.stderr).isatty():
|
||||
if file is None or file is sys.stderr:
|
||||
return f"[yellow]{category.__name__}[/yellow]: {message}\n ({filename}:{lineno})"
|
||||
return f"{category.__name__}: {message} ({filename}:{lineno})\n"
|
||||
warnings.formatwarning = warning_no_src_line
|
||||
|
||||
|
||||
setup_print_hooks()
|
||||
del setup_print_hooks
|
||||
1006
ifield/cli.py
Normal file
1006
ifield/cli.py
Normal file
File diff suppressed because it is too large
Load Diff
174
ifield/cli_utils.py
Normal file
174
ifield/cli_utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
from .data.common.scan import SingleViewScan, SingleViewUVScan
|
||||
from datetime import datetime
|
||||
import re
|
||||
import click
|
||||
import gzip
|
||||
import h5py as h5
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyrender
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are a bunch of helper scripts exposed as cli command by poetry
|
||||
"""
|
||||
|
||||
|
||||
# these entrypoints are exposed by poetry as shell commands
|
||||
|
||||
@click.command()
|
||||
@click.argument("h5file")
|
||||
@click.argument("key", default="")
|
||||
def show_h5_items(h5file: str, key: str):
|
||||
"Show contents of HDF5 dataset"
|
||||
f = h5.File(h5file, "r")
|
||||
if not key:
|
||||
mlen = max(map(len, f.keys()))
|
||||
for i in sorted(f.keys()):
|
||||
print(i.ljust(mlen), ":",
|
||||
str (f[i].dtype).ljust(10),
|
||||
repr(f[i].shape).ljust(16),
|
||||
"mean:", f[i][:].mean()
|
||||
)
|
||||
else:
|
||||
if not f[key].shape:
|
||||
print(f[key].value)
|
||||
else:
|
||||
print(f[key][:])
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("h5file")
|
||||
@click.argument("key", default="")
|
||||
def show_h5_img(h5file: str, key: str):
|
||||
"Show a 2D HDF5 dataset as an image"
|
||||
f = h5.File(h5file, "r")
|
||||
if not key:
|
||||
mlen = max(map(len, f.keys()))
|
||||
for i in sorted(f.keys()):
|
||||
print(i.ljust(mlen), ":", str(f[i].dtype).ljust(10), f[i].shape)
|
||||
else:
|
||||
plt.imshow(f[key])
|
||||
plt.show()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("h5file")
|
||||
@click.option("--force-distances", is_flag=True, help="Always show miss distances.")
|
||||
@click.option("--uv", is_flag=True, help="Load as UV scan cloud and convert it.")
|
||||
@click.option("--show-unit-sphere", is_flag=True, help="Show the unit sphere.")
|
||||
@click.option("--missing", is_flag=True, help="Show miss points that are not hits nor misses as purple.")
|
||||
def show_h5_scan_cloud(
|
||||
h5file : str,
|
||||
force_distances : bool = False,
|
||||
uv : bool = False,
|
||||
missing : bool = False,
|
||||
show_unit_sphere = False,
|
||||
):
|
||||
"Show a SingleViewScan HDF5 dataset"
|
||||
print("Reading data...")
|
||||
t = datetime.now()
|
||||
if uv:
|
||||
scan = SingleViewUVScan.from_h5_file(h5file)
|
||||
if missing and scan.any_missing:
|
||||
if not scan.has_missing:
|
||||
scan.fill_missing_points()
|
||||
points_missing = scan.points[scan.missing]
|
||||
else:
|
||||
missing = False
|
||||
if not scan.is_single_view:
|
||||
scan.cam_pos = None
|
||||
scan = scan.to_scan()
|
||||
else:
|
||||
scan = SingleViewScan.from_h5_file(h5file)
|
||||
if missing:
|
||||
uvscan = scan.to_uv_scan()
|
||||
if scan.any_missing:
|
||||
uvscan.fill_missing_points()
|
||||
points_missing = uvscan.points[uvscan.missing]
|
||||
else:
|
||||
missing = False
|
||||
print("loadtime: ", datetime.now() - t)
|
||||
|
||||
if force_distances and not scan.has_miss_distances:
|
||||
print("Computing miss distances...")
|
||||
scan.compute_miss_distances()
|
||||
use_miss_distances = force_distances
|
||||
print("Constructing scene...")
|
||||
if not scan.has_colors:
|
||||
scan.colors_hit = np.zeros_like(scan.points_hit)
|
||||
scan.colors_miss = np.zeros_like(scan.points_miss)
|
||||
scan.colors_hit [:] = ( 31/255, 119/255, 180/255)
|
||||
scan.colors_miss[:] = (243/255, 156/255, 18/255)
|
||||
use_miss_distances = True
|
||||
if scan.has_miss_distances and use_miss_distances:
|
||||
sdm = scan.distances_miss / scan.distances_miss.max()
|
||||
sdm = sdm[..., None]
|
||||
scan.colors_miss \
|
||||
= np.array([0.8, 0, 0])[None, :] * sdm \
|
||||
+ np.array([0, 1, 0.2])[None, :] * (1-sdm)
|
||||
|
||||
|
||||
scene = pyrender.Scene()
|
||||
|
||||
scene.add(pyrender.Mesh.from_points(scan.points_hit, colors=scan.colors_hit, normals=scan.normals_hit))
|
||||
scene.add(pyrender.Mesh.from_points(scan.points_miss, colors=scan.colors_miss))
|
||||
|
||||
if missing:
|
||||
scene.add(pyrender.Mesh.from_points(points_missing, colors=(np.array((0xff, 0x00, 0xff))/255)[None, :].repeat(points_missing.shape[0], axis=0)))
|
||||
|
||||
# camera:
|
||||
if not scan.points_cam is None:
|
||||
camera_mesh = trimesh.creation.uv_sphere(radius=scan.points_hit_std.max()*0.2)
|
||||
camera_mesh.visual.vertex_colors = [0.0, 0.8, 0.0]
|
||||
tfs = np.tile(np.eye(4), (len(scan.points_cam), 1, 1))
|
||||
tfs[:,:3,3] = scan.points_cam
|
||||
scene.add(pyrender.Mesh.from_trimesh(camera_mesh, poses=tfs))
|
||||
|
||||
# UV sphere:
|
||||
if show_unit_sphere:
|
||||
unit_sphere_mesh = trimesh.creation.uv_sphere(radius=1)
|
||||
unit_sphere_mesh.invert()
|
||||
unit_sphere_mesh.visual.vertex_colors = [0.8, 0.8, 0.0]
|
||||
scene.add(pyrender.Mesh.from_trimesh(unit_sphere_mesh, poses=np.eye(4)[None, ...]))
|
||||
|
||||
print("Launch!")
|
||||
viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("meshfile")
|
||||
@click.option('--aabb', is_flag=True)
|
||||
@click.option('--z-skyward', is_flag=True)
|
||||
def show_model(
|
||||
meshfile : str,
|
||||
aabb : bool,
|
||||
z_skyward : bool,
|
||||
):
|
||||
"Show a 3D model with pyrender, supports .gz suffix"
|
||||
if meshfile.endswith(".gz"):
|
||||
with gzip.open(meshfile, "r") as f:
|
||||
mesh = trimesh.load(f, file_type=meshfile.split(".", 1)[1].removesuffix(".gz"))
|
||||
else:
|
||||
mesh = trimesh.load(meshfile)
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
mesh = mesh.dump(concatenate=True)
|
||||
|
||||
if aabb:
|
||||
from .data.common.mesh import rotate_to_closest_axis_aligned_bounds
|
||||
mesh.apply_transform(rotate_to_closest_axis_aligned_bounds(mesh, fail_ok=True))
|
||||
|
||||
if z_skyward:
|
||||
mesh.apply_transform(T.rotation_matrix(np.pi/2, (1, 0, 0)))
|
||||
|
||||
print(
|
||||
*(i.strip() for i in pyrender.Viewer.__doc__.splitlines() if re.search(r"- ``[a-z0-9]``: ", i)),
|
||||
sep="\n"
|
||||
)
|
||||
|
||||
scene = pyrender.Scene()
|
||||
scene.add(pyrender.Mesh.from_trimesh(mesh))
|
||||
pyrender.Viewer(scene, use_raymond_lighting=True)
|
||||
3
ifield/data/__init__.py
Normal file
3
ifield/data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Submodules to read and process datasets
|
||||
"""
|
||||
0
ifield/data/common/__init__.py
Normal file
0
ifield/data/common/__init__.py
Normal file
90
ifield/data/common/download.py
Normal file
90
ifield/data/common/download.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Optional
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
PathLike = Union[os.PathLike, str]
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def check_url(url): # HTTP HEAD
|
||||
return requests.head(url).ok
|
||||
|
||||
def download_stream(
|
||||
url : str,
|
||||
file_object,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
):
|
||||
resp = requests.get(url, stream=True)
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
if not silent:
|
||||
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
|
||||
|
||||
for chunk in resp.iter_content(block_size):
|
||||
if not silent:
|
||||
progress_bar.update(len(chunk))
|
||||
file_object.write(chunk)
|
||||
|
||||
if not silent:
|
||||
progress_bar.close()
|
||||
if total_size != 0 and progress_bar.n != total_size:
|
||||
print("ERROR, something went wrong")
|
||||
|
||||
def download_data(
|
||||
url : str,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
) -> bytearray:
|
||||
f = io.BytesIO()
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=label)
|
||||
f.seek(0)
|
||||
return bytearray(f.read())
|
||||
|
||||
def download_file(
|
||||
url : str,
|
||||
fname : Union[Path, str],
|
||||
block_size : int = 1024,
|
||||
silent = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
with fname.open("wb") as f:
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
|
||||
|
||||
def is_downloaded(
|
||||
target_dir : PathLike,
|
||||
url : str,
|
||||
*,
|
||||
add : bool = False,
|
||||
dbfiles : Union[list[PathLike], PathLike],
|
||||
):
|
||||
if not isinstance(target_dir, os.PathLike):
|
||||
target_dir = Path(target_dir)
|
||||
if not isinstance(dbfiles, list):
|
||||
dbfiles = [dbfiles]
|
||||
if not dbfiles:
|
||||
raise ValueError("'dbfiles' empty")
|
||||
downloaded = set()
|
||||
for dbfile_fname in dbfiles:
|
||||
dbfile_fname = target_dir / dbfile_fname
|
||||
if dbfile_fname.is_file():
|
||||
with open(dbfile_fname, "r") as f:
|
||||
downloaded.update(json.load(f)["downloaded"])
|
||||
|
||||
if add and url not in downloaded:
|
||||
downloaded.add(url)
|
||||
with open(dbfiles[0], "w") as f:
|
||||
data = {"downloaded": sorted(downloaded)}
|
||||
json.dump(data, f, indent=2, sort_keys=True)
|
||||
return True
|
||||
|
||||
return url in downloaded
|
||||
370
ifield/data/common/h5_dataclasses.py
Normal file
370
ifield/data/common/h5_dataclasses.py
Normal file
@@ -0,0 +1,370 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import h5py as h5
|
||||
import hdf5plugin
|
||||
import numpy as np
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
|
||||
__all__ = [
|
||||
"DataclassMeta",
|
||||
"Dataclass",
|
||||
"H5Dataclass",
|
||||
"H5Array",
|
||||
"H5ArrayNoSlice",
|
||||
]
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
NoneType = type(None)
|
||||
PathLike = typing.Union[os.PathLike, str]
|
||||
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
|
||||
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
|
||||
|
||||
DataclassField = namedtuple("DataclassField", [
|
||||
"name",
|
||||
"type",
|
||||
"is_optional",
|
||||
"is_array",
|
||||
"is_sliceable",
|
||||
"is_prefix",
|
||||
])
|
||||
|
||||
def strip_optional(val: type) -> type:
|
||||
if typing.get_origin(val) is typing.Union:
|
||||
union = set(typing.get_args(val))
|
||||
if len(union - {NoneType}) == 1:
|
||||
val, = union - {NoneType}
|
||||
else:
|
||||
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
|
||||
return val
|
||||
|
||||
def is_array(val, *, _inner=False):
|
||||
"""
|
||||
Hacky way to check if a value or type is an array.
|
||||
The hack omits having to depend on large frameworks such as pytorch or pandas
|
||||
"""
|
||||
val = strip_optional(val)
|
||||
if val is H5Array or val is H5ArrayNoSlice:
|
||||
return True
|
||||
|
||||
if typing._type_repr(val) in (
|
||||
"numpy.ndarray",
|
||||
"torch.Tensor",
|
||||
):
|
||||
return True
|
||||
if not _inner:
|
||||
return is_array(type(val), _inner=True)
|
||||
return False
|
||||
|
||||
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return functools.reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return functools.reduce(operator.mul, numbers)
|
||||
|
||||
class DataclassMeta(type):
|
||||
def __new__(
|
||||
mcls,
|
||||
name : str,
|
||||
bases : tuple[type, ...],
|
||||
attrs : dict[str, typing.Any],
|
||||
**kwargs,
|
||||
):
|
||||
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
|
||||
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
|
||||
cls = dataclasses.dataclass(slots=True)(cls)
|
||||
else:
|
||||
cls = dataclasses.dataclass(cls)
|
||||
return cls
|
||||
|
||||
class DataclassABCMeta(DataclassMeta, ABCMeta):
|
||||
pass
|
||||
|
||||
class Dataclass(metaclass=DataclassMeta):
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
if key in self.keys():
|
||||
return getattr(self, key)
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: typing.Any):
|
||||
if key in self.keys():
|
||||
return setattr(self, key, value)
|
||||
raise KeyError(key)
|
||||
|
||||
def keys(self) -> typing.KeysView:
|
||||
return self.as_dict().keys()
|
||||
|
||||
def values(self) -> typing.ValuesView:
|
||||
return self.as_dict().values()
|
||||
|
||||
def items(self) -> typing.ItemsView:
|
||||
return self.as_dict().items()
|
||||
|
||||
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
|
||||
out = dataclasses.asdict(self, **kw)
|
||||
for name in (properties_to_include or []):
|
||||
out[name] = getattr(self, name)
|
||||
return out
|
||||
|
||||
def as_tuple(self, properties_to_include: list[str]) -> tuple:
|
||||
out = dataclasses.astuple(self)
|
||||
if not properties_to_include:
|
||||
return out
|
||||
else:
|
||||
return (
|
||||
*out,
|
||||
*(getattr(self, name) for name in properties_to_include),
|
||||
)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
return (copy.deepcopy if deep else copy.copy)(self)
|
||||
|
||||
class H5Dataclass(Dataclass):
|
||||
# settable with class params:
|
||||
_prefix : str = dataclasses.field(init=False, repr=False, default="")
|
||||
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
|
||||
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
|
||||
|
||||
def __init_subclass__(cls,
|
||||
prefix : typing.Optional[str] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
require_all : typing.Optional[bool] = None,
|
||||
**kw,
|
||||
):
|
||||
super().__init_subclass__(**kw)
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
if prefix is not None: cls._prefix = prefix
|
||||
if n_pages is not None: cls._n_pages = n_pages
|
||||
if require_all is not None: cls._require_all = require_all
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls) -> typing.Iterable[DataclassField]:
|
||||
for field in dataclasses.fields(cls):
|
||||
if not field.init:
|
||||
continue
|
||||
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
|
||||
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
|
||||
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
|
||||
)
|
||||
if isinstance(field.type, str):
|
||||
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
|
||||
|
||||
type_inner = strip_optional(field.type)
|
||||
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
|
||||
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
|
||||
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
|
||||
field_type = typing.Optional[field_type]
|
||||
|
||||
yield DataclassField(
|
||||
name = field.name,
|
||||
type = strip_optional(field_type),
|
||||
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
|
||||
is_array = is_array(field_type),
|
||||
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
|
||||
is_prefix = is_prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_h5_file(cls : type[T],
|
||||
fname : typing.Union[PathLike, str],
|
||||
*,
|
||||
page : typing.Optional[int] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
read_slice : slice = slice(None),
|
||||
require_even_pages : bool = True,
|
||||
) -> T:
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if n_pages is None:
|
||||
n_pages = cls._n_pages
|
||||
if not fname.exists():
|
||||
raise FileNotFoundError(str(fname))
|
||||
if not h5.is_hdf5(fname):
|
||||
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
|
||||
|
||||
# if this class has no fields, print a example class:
|
||||
if not any(field.init for field in dataclasses.fields(cls)):
|
||||
with h5.File(fname, "r") as f:
|
||||
klen = max(map(len, f.keys()))
|
||||
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
|
||||
f" {k.ljust(klen)} : "
|
||||
+ (
|
||||
"H5Array" if prod(v.shape, 1) > 1 else (
|
||||
"float" if issubclass(v.dtype.type, np.floating) else (
|
||||
"int" if issubclass(v.dtype.type, np.integer) else (
|
||||
"bool" if issubclass(v.dtype.type, np.bool_) else (
|
||||
"typing.Any"
|
||||
))))).ljust(14 + 1)
|
||||
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
|
||||
for k, v in f.items()
|
||||
)
|
||||
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
|
||||
|
||||
fields_consumed = set()
|
||||
|
||||
def make_kwarg(
|
||||
file : h5.File,
|
||||
keys : typing.KeysView,
|
||||
field : DataclassField,
|
||||
) -> tuple[str, typing.Any]:
|
||||
if field.is_optional:
|
||||
if field.name not in keys:
|
||||
return field.name, None
|
||||
if field.is_sliceable:
|
||||
if page is not None:
|
||||
n_items = int(f[cls._prefix + field.name].shape[0])
|
||||
page_len = n_items // n_pages
|
||||
modulus = n_items % n_pages
|
||||
if modulus: page_len += 1 # round up
|
||||
if require_even_pages and modulus:
|
||||
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
|
||||
this_slice = slice(
|
||||
start = page_len * page,
|
||||
stop = page_len * (page+1),
|
||||
step = read_slice.step, # inherit step
|
||||
)
|
||||
else:
|
||||
this_slice = read_slice
|
||||
else:
|
||||
this_slice = slice(None) # read all
|
||||
|
||||
# array or scalar?
|
||||
def read_dataset(var):
|
||||
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
|
||||
if field.is_array:
|
||||
return var[this_slice]
|
||||
if var.shape == (1,):
|
||||
return var[0]
|
||||
else:
|
||||
return var[()]
|
||||
|
||||
if field.is_prefix:
|
||||
fields_consumed.update(
|
||||
key
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
)
|
||||
return field.name, {
|
||||
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
}
|
||||
else:
|
||||
fields_consumed.add(cls._prefix + field.name)
|
||||
return field.name, read_dataset(file[cls._prefix + field.name])
|
||||
|
||||
with h5.File(fname, "r") as f:
|
||||
keys = f.keys()
|
||||
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
|
||||
|
||||
try:
|
||||
out = cls(**init_dict)
|
||||
except Exception as e:
|
||||
class_attrs = set(field.name for field in dataclasses.fields(cls))
|
||||
file_attr = set(init_dict.keys())
|
||||
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
|
||||
|
||||
if cls._require_all:
|
||||
fields_not_consumed = set(keys) - fields_consumed
|
||||
if fields_not_consumed:
|
||||
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
|
||||
|
||||
return out
|
||||
|
||||
def to_h5_file(self,
|
||||
fname : PathLike,
|
||||
mkdir : bool = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if not fname.parent.is_dir():
|
||||
if mkdir:
|
||||
fname.parent.mkdir(parents=True)
|
||||
else:
|
||||
raise NotADirectoryError(fname.parent)
|
||||
|
||||
with h5.File(fname, "w") as f:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
value = getattr(self, field.name)
|
||||
if field.is_array:
|
||||
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
|
||||
raise TypeError(
|
||||
"When dumping a H5Dataclass, make sure the array fields are "
|
||||
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
|
||||
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
def write_value(key: str, value: typing.Any):
|
||||
if field.is_array:
|
||||
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
|
||||
else:
|
||||
f.create_dataset(key, data=value)
|
||||
|
||||
if field.is_prefix:
|
||||
for k, v in value.items():
|
||||
write_value(self._prefix + field.name + "_" + k, v)
|
||||
else:
|
||||
write_value(self._prefix + field.name, value)
|
||||
|
||||
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
|
||||
if do_copy: # shallow
|
||||
self = self.copy(deep=False)
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
if field.is_prefix and field.is_array:
|
||||
setattr(self, field.name, {
|
||||
k : func(v)
|
||||
for k, v in getattr(self, field.name).items()
|
||||
})
|
||||
elif field.is_array:
|
||||
setattr(self, field.name, func(getattr(self, field.name)))
|
||||
|
||||
return self
|
||||
|
||||
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
|
||||
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
out = super().copy(deep=deep)
|
||||
if not deep:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_prefix:
|
||||
out[field.name] = copy.copy(field.name)
|
||||
return out
|
||||
|
||||
@property
|
||||
def shape(self) -> dict[str, tuple[int, ...]]:
|
||||
return {
|
||||
key: value.shape
|
||||
for key, value in self.items()
|
||||
if hasattr(value, "shape")
|
||||
}
|
||||
|
||||
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
|
||||
...
|
||||
|
||||
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
|
||||
mtx = self.transforms[name]
|
||||
out = self.transform(mtx, inplace=inplace)
|
||||
out.transforms.pop(name) # consumed
|
||||
|
||||
inv = np.linalg.inv(mtx)
|
||||
for key in list(out.transforms.keys()): # maintain the other transforms
|
||||
out.transforms[key] = out.transforms[key] @ inv
|
||||
if inverse_name is not None: # store inverse
|
||||
out.transforms[inverse_name] = inv
|
||||
|
||||
return out
|
||||
48
ifield/data/common/mesh.py
Normal file
48
ifield/data/common/mesh.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from math import pi
|
||||
from trimesh import Trimesh
|
||||
import numpy as np
|
||||
import os
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def rotate_to_closest_axis_aligned_bounds(
|
||||
mesh : Trimesh,
|
||||
order_axes : bool = True,
|
||||
fail_ok : bool = True,
|
||||
) -> np.ndarray:
|
||||
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
|
||||
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
|
||||
|
||||
if not order_axes:
|
||||
return to_aabb_rot_mat4
|
||||
|
||||
v = pi / 4 * 1.01 # tolerance
|
||||
v2 = pi / 2
|
||||
|
||||
faces = (
|
||||
(0, 0),
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(3, 0),
|
||||
(0, 1),
|
||||
(0,-1),
|
||||
)
|
||||
orientations = [ # 6 faces x 4 rotations per face
|
||||
(f[0] * v2, f[1] * v2, i * v2)
|
||||
for i in range(4)
|
||||
for f in faces]
|
||||
|
||||
for x, y, z in orientations:
|
||||
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
|
||||
ai, aj, ak = T.euler_from_matrix(mat4)
|
||||
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
|
||||
return mat4
|
||||
|
||||
if fail_ok: return to_aabb_rot_mat4
|
||||
raise Exception("Unable to orient mesh")
|
||||
297
ifield/data/common/points.py
Normal file
297
ifield/data/common/points.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
from ...utils.helpers import compose
|
||||
from functools import reduce, lru_cache
|
||||
from math import ceil
|
||||
from typing import Iterable
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
|
||||
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
|
||||
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
|
||||
|
||||
# pad the image
|
||||
img_pad = np.zeros((
|
||||
n_channels,
|
||||
int(ceil(1.0 * rows / psize) * psize),
|
||||
int(ceil(1.0 * cols / psize) * psize),
|
||||
))
|
||||
img_pad[:, 0:rows, 0:cols] = img
|
||||
|
||||
# allocate output buffer
|
||||
final = np.zeros((
|
||||
img_pad.shape[1],
|
||||
img_pad.shape[2],
|
||||
n_channels,
|
||||
psize,
|
||||
psize,
|
||||
))
|
||||
|
||||
for c in range(n_channels):
|
||||
for x in range(psize):
|
||||
for y in range(psize):
|
||||
img_shift = np.vstack((
|
||||
img_pad[c, x:],
|
||||
img_pad[c, :x]))
|
||||
img_shift = np.column_stack((
|
||||
img_shift[:, y:],
|
||||
img_shift[:, :y]))
|
||||
final[x::psize, y::psize, c] = np.swapaxes(
|
||||
img_shift.reshape(
|
||||
int(img_pad.shape[1] / psize), psize,
|
||||
int(img_pad.shape[2] / psize), psize),
|
||||
1,
|
||||
2)
|
||||
|
||||
# crop output and unwrap axes with size==1
|
||||
return np.squeeze(final[
|
||||
0:rows - psize + 1,
|
||||
0:cols - psize + 1])
|
||||
|
||||
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
|
||||
"""
|
||||
Removes data close to discontinuities, with size filt_size.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
# Ensure that filter sizes are okay
|
||||
assert filt_size % 2, "Can only use odd filter sizes."
|
||||
|
||||
# Compute discontinuities
|
||||
offset = int(filt_size - 1) // 2
|
||||
patches = 1.0 * img2col(depth_map, filt_size)
|
||||
mids = patches[:, :, offset, offset]
|
||||
mins = np.min(patches, axis=(2, 3))
|
||||
maxes = np.max(patches, axis=(2, 3))
|
||||
|
||||
discont = np.maximum(
|
||||
np.abs(mins - mids),
|
||||
np.abs(maxes - mids))
|
||||
mark = discont > thresh
|
||||
|
||||
# Account for offsets
|
||||
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
|
||||
final_mark[offset:offset + mark.shape[0],
|
||||
offset:offset + mark.shape[1]] = mark
|
||||
|
||||
return depth_map * (1 - final_mark)
|
||||
|
||||
def reorient_depth_map(
|
||||
depth_map : np.ndarray,
|
||||
rgb_map : np.ndarray,
|
||||
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
|
||||
rgb_mask_map : np.ndarray = None,
|
||||
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
|
||||
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
|
||||
) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
|
||||
If _output_points is True, the points returned are in the rgb camera space.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
# now faster AND more easy on the GIL
|
||||
|
||||
height_old, width_old, *_ = depth_map.shape
|
||||
height, width, *_ = rgb_map.shape
|
||||
|
||||
|
||||
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
|
||||
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
|
||||
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
|
||||
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
|
||||
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
|
||||
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
|
||||
|
||||
# make a UV grid over depth_map
|
||||
u, v = np.meshgrid(
|
||||
np.arange(width_old),
|
||||
np.arange(height_old),
|
||||
)
|
||||
|
||||
# compute xyz coordinates for all depths
|
||||
xyz_depth = np.stack((
|
||||
(u - d_cx) / d_fx,
|
||||
(v - d_cy) / d_fy,
|
||||
depth_map,
|
||||
np.ones(depth_map.shape)
|
||||
)).reshape((4, -1))
|
||||
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
|
||||
|
||||
# undistort depth coordinates
|
||||
d_x, d_y = xyz_depth[:2]
|
||||
r = np.linalg.norm(xyz_depth[:2], axis=0)
|
||||
xyz_depth[0, :] \
|
||||
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
|
||||
xyz_depth[1, :] \
|
||||
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
|
||||
|
||||
# unproject x and y
|
||||
xyz_depth[0, :] *= xyz_depth[2, :]
|
||||
xyz_depth[1, :] *= xyz_depth[2, :]
|
||||
|
||||
# convert depths to RGB camera viewpoint
|
||||
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
|
||||
|
||||
# project depths to RGB canvas
|
||||
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
|
||||
rgb_uv = np.stack((
|
||||
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
|
||||
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
|
||||
)).astype(np.int)
|
||||
|
||||
# mask of the rgb_xyz values within view of rgb_map
|
||||
mask = reduce(operator.and_, [
|
||||
rgb_uv[0] >= 0,
|
||||
rgb_uv[1] >= 0,
|
||||
rgb_uv[0] < width,
|
||||
rgb_uv[1] < height,
|
||||
])
|
||||
if rgb_mask_map is not None:
|
||||
mask[mask] &= rgb_mask_map[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
|
||||
if not _output_points: # output image
|
||||
output = np.zeros((height, width), dtype=depth_map.dtype)
|
||||
output[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
] = xyz_rgb[2, mask]
|
||||
|
||||
else: # output pointcloud
|
||||
rgbs = rgb_map[ # lookup rgb values using rgb_uv
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
output = np.stack((
|
||||
xyz_rgb[0, mask], # x
|
||||
xyz_rgb[1, mask], # y
|
||||
xyz_rgb[2, mask], # z
|
||||
rgbs[:, 0], # r
|
||||
rgbs[:, 1], # g
|
||||
rgbs[:, 2], # b
|
||||
)).T
|
||||
|
||||
# output for realsies
|
||||
if not _output_hits_uvs: #raw
|
||||
return output
|
||||
else: # with hit mask
|
||||
uv = np.zeros((height, width), dtype=bool)
|
||||
# filter points overlapping in the depth map
|
||||
uv_indices = (
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
)
|
||||
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
|
||||
output = output[chosen, :]
|
||||
uv[uv_indices] = True
|
||||
return output, uv
|
||||
|
||||
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
|
||||
return reorient_depth_map(*a, _output_points=True, **kw)
|
||||
|
||||
@compose(np.array) # block lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
@compose(list)
|
||||
def generate_equidistant_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False,
|
||||
) -> Iterable[tuple[float, ...]]:
|
||||
# Deserno M. How to generate equidistributed points on the surface of a sphere
|
||||
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
|
||||
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
n_count = 0
|
||||
a = 4 * np.pi / n
|
||||
d = np.sqrt(a)
|
||||
n_theta = round(np.pi / d)
|
||||
d_theta = np.pi / n_theta
|
||||
d_phi = a / d_theta
|
||||
|
||||
for i in range(0, n_theta):
|
||||
theta = np.pi * (i + 0.5) / n_theta
|
||||
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
|
||||
|
||||
for j in range(0, n_phi):
|
||||
phi = 2 * np.pi * j / n_phi
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
yield (
|
||||
theta if shift_theta else theta - 0.5*np.pi,
|
||||
phi,
|
||||
)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
np.sin(theta) * np.cos(phi),
|
||||
np.sin(theta) * np.sin(phi),
|
||||
np.cos(theta),
|
||||
)
|
||||
else: # (x, y, z)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
)
|
||||
n_count += 1
|
||||
|
||||
|
||||
def generate_random_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False, # depends on convention
|
||||
) -> np.ndarray:
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
|
||||
phi = np.random.uniform(0, 2*np.pi, n)
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
return np.stack((
|
||||
theta if not shift_theta else 0.5*np.pi + theta,
|
||||
phi,
|
||||
), axis=1)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
np.cos(theta) * np.cos(phi),
|
||||
np.cos(theta) * np.sin(phi),
|
||||
np.sin(theta),
|
||||
), axis=1)
|
||||
else: # (x, y, z)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
), axis=1)
|
||||
85
ifield/data/common/processing.py
Normal file
85
ifield/data/common/processing.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from .h5_dataclasses import H5Dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Hashable, Optional, Callable
|
||||
import os
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
|
||||
# paralelize it in the shell instead
|
||||
|
||||
def precompute_data(
|
||||
computer : Callable[[Hashable], Optional[H5Dataclass]],
|
||||
identifiers : list[Hashable],
|
||||
output_paths : list[Path],
|
||||
page : tuple[int, int] = (0, 1),
|
||||
*,
|
||||
force : bool = False,
|
||||
debug : bool = False,
|
||||
):
|
||||
"""
|
||||
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
|
||||
"""
|
||||
|
||||
page, n_pages = page
|
||||
assert len(identifiers) == len(output_paths)
|
||||
|
||||
total = len(identifiers)
|
||||
identifier_max_len = max(map(len, map(str, identifiers)))
|
||||
t_epoch = None
|
||||
def log(state: str, is_start = False):
|
||||
nonlocal t_epoch
|
||||
if is_start: t_epoch = datetime.now()
|
||||
td = timedelta(0) if is_start else datetime.now() - t_epoch
|
||||
print(" - "
|
||||
f"{str(index+1).rjust(len(str(total)))}/{total}: "
|
||||
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
|
||||
)
|
||||
|
||||
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
|
||||
t_begin = datetime.now()
|
||||
failed = []
|
||||
|
||||
# pagination
|
||||
page_size = total // n_pages + bool(total % n_pages)
|
||||
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
|
||||
|
||||
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
|
||||
if not force and output_path.exists() and output_path.stat().st_size > 0:
|
||||
continue
|
||||
|
||||
log("compute", is_start=True)
|
||||
|
||||
# compute
|
||||
try:
|
||||
res = computer(identifier)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed compute: {e.__class__.__name__}: {e}")
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
if res is None:
|
||||
failed.append(identifier)
|
||||
log("no result")
|
||||
continue
|
||||
|
||||
# write to file
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
res.to_h5_file(output_path)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed write: {e.__class__.__name__}: {e}")
|
||||
if output_path.is_file(): output_path.unlink() # cleanup
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
|
||||
log("done")
|
||||
|
||||
print("precompute_data finished in", datetime.now() - t_begin)
|
||||
print("failed:", failed or None)
|
||||
768
ifield/data/common/scan.py
Normal file
768
ifield/data/common/scan.py
Normal file
@@ -0,0 +1,768 @@
|
||||
from ...utils.helpers import compose
|
||||
from . import points
|
||||
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
|
||||
from methodtools import lru_cache
|
||||
from sklearn.neighbors import BallTree
|
||||
import faiss
|
||||
from trimesh import Trimesh
|
||||
from typing import Iterable
|
||||
from typing import Optional, TypeVar
|
||||
import mesh_to_sdf
|
||||
import mesh_to_sdf.scan as sdf_scan
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
import warnings
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper types for data.
|
||||
"""
|
||||
|
||||
_T = TypeVar("T")
|
||||
|
||||
class InvalidateLRUOnWriteMixin:
|
||||
def __setattr__(self, key, value):
|
||||
if not key.startswith("__wire|"):
|
||||
for attr in dir(self):
|
||||
if attr.startswith("__wire|"):
|
||||
getattr(self, attr).cache_clear()
|
||||
return super().__setattr__(key, value)
|
||||
def lru_property(func):
|
||||
return lru_cache(maxsize=1)(property(func))
|
||||
|
||||
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
points_hit : H5ArrayNoSlice # (N, 3)
|
||||
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
points_miss : H5ArrayNoSlice # (M, 3)
|
||||
distances_miss : Optional[H5ArrayNoSlice] # (M)
|
||||
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
|
||||
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
|
||||
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
|
||||
cam_pos : H5ArrayNoSlice # (3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points_hit = T.transform_points(self.points_hit, mat4)
|
||||
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
|
||||
out.points_miss = T.transform_points(self.points_miss, mat4)
|
||||
out.distances_miss = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
if not self.is_hitting:
|
||||
raise ValueError("No hits to compute the ray distance towards")
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances_miss \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = out.points_cam,
|
||||
ray_dirs = out.ray_dirs_miss,
|
||||
points = out.points_hit,
|
||||
).astype(out.points_cam.dtype)
|
||||
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
return np.concatenate((
|
||||
self.points_hit,
|
||||
self.points_miss,
|
||||
self.points_cam,
|
||||
))
|
||||
|
||||
@lru_property
|
||||
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
|
||||
out[self.uv_hits, :] = self.points_hit
|
||||
out[self.uv_miss, :] = self.points_miss
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
|
||||
out[self.uv_hits, :] = self.normals_hit
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
|
||||
if self.cam_pos is None: return None
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def points_hit_centroid(self) -> np.ndarray:
|
||||
return self.points_hit.mean(axis=0)
|
||||
|
||||
@lru_property
|
||||
def points_hit_std(self) -> np.ndarray:
|
||||
return self.points_hit.std(axis=0)
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return len(self.points_hit) > 0
|
||||
|
||||
@lru_property
|
||||
def is_empty(self) -> bool:
|
||||
return not (len(self.points_hit) or len(self.points_miss))
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return self.colors_hit is not None or self.colors_miss is not None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return self.normals_hit is not None
|
||||
|
||||
@lru_property
|
||||
def has_uv(self) -> bool:
|
||||
return self.uv_hits is not None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return self.distances_miss is not None
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
|
||||
if self.colors_hit is None: raise ValueError
|
||||
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
|
||||
if self.colors_miss is None: raise ValueError
|
||||
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_hit - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_miss - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances and scan.is_hitting:
|
||||
scan.compute_miss_distances()
|
||||
return scan
|
||||
|
||||
def to_uv_scan(self) -> "SingleViewUVScan":
|
||||
return SingleViewUVScan.from_scan(self)
|
||||
|
||||
@classmethod
|
||||
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
|
||||
return uvscan.to_scan()
|
||||
|
||||
# The same, but with support for pagination (should have been this way since the start...)
|
||||
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
# B may be (N) or (H, W), the latter may be flattened
|
||||
hits : H5Array # (*B) dtype=bool
|
||||
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
|
||||
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
|
||||
normals : Optional[H5Array] # (*B, 3) NaN if not hit
|
||||
colors : Optional[H5Array] # (*B, 3)
|
||||
distances : Optional[H5Array] # (*B) NaN if not miss
|
||||
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
@classmethod
|
||||
def from_scan(cls, scan: SingleViewScan):
|
||||
if not scan.has_uv:
|
||||
raise ValueError("Scan cloud has no UV data")
|
||||
hits, miss = scan.uv_hits, scan.uv_miss
|
||||
dtype = scan.points_hit.dtype
|
||||
assert hits.ndim in (1, 2), hits.ndim
|
||||
assert hits.shape == miss.shape, (hits.shape, miss.shape)
|
||||
|
||||
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
points[hits, :] = scan.points_hit
|
||||
points[miss, :] = scan.points_miss
|
||||
|
||||
normals = None
|
||||
if scan.has_normals:
|
||||
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
normals[hits, :] = scan.normals_hit
|
||||
|
||||
distances = None
|
||||
if scan.has_miss_distances:
|
||||
distances = np.full(hits.shape, np.nan, dtype=dtype)
|
||||
distances[miss] = scan.distances_miss
|
||||
|
||||
colors = None
|
||||
if scan.has_colors:
|
||||
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
if scan.colors_hit is not None:
|
||||
colors[hits, :] = scan.colors_hit
|
||||
if scan.colors_miss is not None:
|
||||
colors[miss, :] = scan.colors_miss
|
||||
|
||||
return cls(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = points,
|
||||
normals = normals,
|
||||
colors = colors,
|
||||
distances = distances,
|
||||
cam_pos = scan.cam_pos,
|
||||
cam_mat4 = scan.cam_mat4,
|
||||
proj_mat4 = scan.proj_mat4,
|
||||
transforms = scan.transforms,
|
||||
)
|
||||
|
||||
def to_scan(self) -> "SingleViewScan":
|
||||
if not self.is_single_view: raise ValueError
|
||||
return SingleViewScan(
|
||||
points_hit = self.points [self.hits, :],
|
||||
points_miss = self.points [self.miss, :],
|
||||
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
|
||||
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
|
||||
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
|
||||
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
|
||||
uv_hits = self.hits,
|
||||
uv_miss = self.miss,
|
||||
cam_pos = self.cam_pos,
|
||||
cam_mat4 = self.cam_mat4,
|
||||
proj_mat4 = self.proj_mat4,
|
||||
transforms = self.transforms,
|
||||
)
|
||||
|
||||
def to_mesh(self) -> trimesh.Trimesh:
|
||||
faces: list[(tuple[int, int],)*3] = []
|
||||
for x in range(self.hits.shape[0]-1):
|
||||
for y in range(self.hits.shape[1]-1):
|
||||
c11 = x, y
|
||||
c12 = x, y+1
|
||||
c22 = x+1, y+1
|
||||
c21 = x+1, y
|
||||
|
||||
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
|
||||
if n == 3:
|
||||
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
|
||||
elif n == 4:
|
||||
faces.append((c11, c12, c22))
|
||||
faces.append((c11, c22, c21))
|
||||
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
|
||||
assert self.colors is not None
|
||||
return trimesh.Trimesh(
|
||||
vertices = [self.points[i] for i in xy2idx.keys()],
|
||||
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
|
||||
faces = [tuple(xy2idx[i] for i in face) for face in faces],
|
||||
)
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
unflat = self.hits.shape
|
||||
flat = np.product(unflat)
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
|
||||
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
|
||||
out.distances = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
|
||||
shape = self.hits.shape
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances = np.zeros(shape, dtype=self.points.dtype)
|
||||
if self.is_hitting:
|
||||
out.distances[self.miss] \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = self.cam_pos_unsqueezed_miss,
|
||||
ray_dirs = self.ray_dirs_miss,
|
||||
points = surface_points if surface_points is not None else self.points[self.hits],
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
"""
|
||||
Fill in missing points as hitting the far plane.
|
||||
"""
|
||||
if not self.is_2d:
|
||||
raise ValueError("Cannot fill missing points for non-2d scan!")
|
||||
if not self.is_single_view:
|
||||
raise ValueError("Cannot fill missing points for non-single-view scans!")
|
||||
if self.cam_mat4 is None:
|
||||
raise ValueError("cam_mat4 is None")
|
||||
if self.proj_mat4 is None:
|
||||
raise ValueError("proj_mat4 is None")
|
||||
|
||||
uv = np.argwhere(self.missing).astype(self.points.dtype)
|
||||
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
|
||||
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
|
||||
uv -= 1
|
||||
uv = np.stack((
|
||||
uv[:, 1],
|
||||
-uv[:, 0],
|
||||
np.ones(uv.shape[0]), # far clipping plane
|
||||
np.ones(uv.shape[0]), # homogeneous coordinate
|
||||
), axis=-1)
|
||||
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return np.any(self.hits)
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return not self.colors is None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return not self.normals is None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return not self.distances is None
|
||||
|
||||
@lru_property
|
||||
def any_missing(self) -> bool:
|
||||
return np.any(self.missing)
|
||||
|
||||
@lru_property
|
||||
def has_missing(self) -> bool:
|
||||
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos
|
||||
else:
|
||||
cam_pos = self.cam_pos
|
||||
for _ in range(self.hits.ndim):
|
||||
cam_pos = cam_pos[None, ...]
|
||||
return cam_pos
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_hit(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.hits, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_miss(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.miss, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def ray_dirs(self) -> H5Array:
|
||||
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> H5Array:
|
||||
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> H5Array:
|
||||
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def depths(self) -> H5Array:
|
||||
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
|
||||
|
||||
@lru_property
|
||||
def missing(self) -> H5Array:
|
||||
return ~(self.hits | self.miss)
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
assert scan.is_2d
|
||||
return scan
|
||||
|
||||
@classmethod
|
||||
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances:
|
||||
surface_points = None
|
||||
if scan.hits.sum() > mesh.vertices.shape[0]:
|
||||
surface_points = mesh.vertices.astype(scan.points.dtype)
|
||||
if not kw.get("no_unit_sphere", False):
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
|
||||
surface_points = (surface_points + translation) * scale
|
||||
scan.compute_miss_distances(surface_points=surface_points)
|
||||
assert scan.is_flat
|
||||
return scan
|
||||
|
||||
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
|
||||
n_items = np.product(self.hits.shape)
|
||||
permutation = np.random.permutation(n_items)
|
||||
|
||||
out = self.copy(deep=False) if copy else self
|
||||
out.hits = out.hits .reshape((n_items, ))[permutation]
|
||||
out.miss = out.miss .reshape((n_items, ))[permutation]
|
||||
out.points = out.points .reshape((n_items, 3))[permutation, :]
|
||||
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
|
||||
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
|
||||
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
|
||||
return out
|
||||
|
||||
@property
|
||||
def is_single_view(self) -> bool:
|
||||
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
|
||||
|
||||
@property
|
||||
def is_flat(self) -> bool:
|
||||
return len(self.hits.shape) == 1
|
||||
|
||||
@property
|
||||
def is_2d(self) -> bool:
|
||||
return len(self.hits.shape) == 2
|
||||
|
||||
|
||||
# transforms can be found in pytorch3d.transforms and in open3d
|
||||
# and in trimesh.transformations
|
||||
|
||||
def sample_single_view_scans_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
n_batches : int,
|
||||
scan_resolution : int = 400,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
) -> Iterable[SingleViewScan]:
|
||||
|
||||
normalized_mesh_cache = []
|
||||
|
||||
for _ in range(n_batches):
|
||||
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
|
||||
yield sample_single_view_scan_from_mesh(
|
||||
mesh = mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
_mesh_is_normalized = False,
|
||||
scan_resolution = scan_resolution,
|
||||
compute_normals = compute_normals,
|
||||
fov = fov,
|
||||
camera_distance = camera_distance,
|
||||
no_filter_backhits = no_filter_backhits,
|
||||
_mesh_cache = normalized_mesh_cache,
|
||||
)
|
||||
|
||||
def sample_single_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
scan_resolution : int = 200,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
dtype : type = np.float32,
|
||||
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
|
||||
) -> SingleViewScan:
|
||||
|
||||
# scale and center to unit sphere
|
||||
is_cache = isinstance(_mesh_cache, list)
|
||||
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
|
||||
_, mesh, translation, scale = _mesh_cache
|
||||
else:
|
||||
if is_cache:
|
||||
if _mesh_cache:
|
||||
_mesh_cache.clear()
|
||||
_mesh_cache.append(mesh)
|
||||
translation, scale = compute_unit_sphere_transform(mesh)
|
||||
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
if is_cache:
|
||||
_mesh_cache.extend((mesh, translation, scale))
|
||||
|
||||
z_near = 1
|
||||
z_far = 3
|
||||
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
|
||||
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
|
||||
|
||||
scan = sdf_scan.Scan(mesh,
|
||||
camera_transform = cam_mat4,
|
||||
resolution = scan_resolution,
|
||||
calculate_normals = compute_normals,
|
||||
fov = fov,
|
||||
z_near = z_near,
|
||||
z_far = z_far,
|
||||
no_flip_backfaced_normals = True
|
||||
)
|
||||
|
||||
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
|
||||
misses = np.argwhere(scan.depth_buffer == 0)
|
||||
points_miss = np.ones((misses.shape[0], 4))
|
||||
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
|
||||
points_miss[:, 1] *= -1
|
||||
points_miss[:, 2] = 1 # far plane in clipping space
|
||||
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
|
||||
points_miss /= points_miss[:, 3][:, np.newaxis]
|
||||
points_miss = points_miss[:, :3]
|
||||
|
||||
uv_hits = scan.depth_buffer != 0
|
||||
uv_miss = ~uv_hits
|
||||
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
# inner product
|
||||
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
|
||||
scan.points = scan.points [mask, :]
|
||||
scan.normals = scan.normals[mask, :]
|
||||
uv_hits[uv_hits] = mask
|
||||
|
||||
transforms = {}
|
||||
|
||||
# undo unit-sphere transform
|
||||
if no_unit_sphere:
|
||||
scan.points = scan.points * (1 / scale) - translation
|
||||
points_miss = points_miss * (1 / scale) - translation
|
||||
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
|
||||
cam_mat4[:3, :] *= 1 / scale
|
||||
cam_mat4[:3, 3] -= translation
|
||||
|
||||
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
|
||||
transforms["model"] = np.eye(4)
|
||||
else:
|
||||
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
|
||||
transforms["unit_sphere"] = np.eye(4)
|
||||
|
||||
return SingleViewScan(
|
||||
normals_hit = scan.normals .astype(dtype),
|
||||
points_hit = scan.points .astype(dtype),
|
||||
points_miss = points_miss .astype(dtype),
|
||||
distances_miss = None,
|
||||
colors_hit = None,
|
||||
colors_miss = None,
|
||||
uv_hits = uv_hits .astype(bool),
|
||||
uv_miss = uv_miss .astype(bool),
|
||||
cam_pos = cam_pos[:3] .astype(dtype),
|
||||
cam_mat4 = cam_mat4 .astype(dtype),
|
||||
proj_mat4 = scan.projection_matrix .astype(dtype),
|
||||
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
|
||||
)
|
||||
|
||||
def sample_sphere_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
sphere_points : int = 4000, # resulting rays are n*(n-1)
|
||||
compute_normals : bool = False,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
no_permute : bool = False,
|
||||
dtype : type = np.float32,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
|
||||
|
||||
# get unit-sphere points, then transform to model space
|
||||
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
|
||||
two_sphere = two_sphere / scale - translation # we transform after cache lookup
|
||||
|
||||
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
|
||||
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
|
||||
|
||||
(
|
||||
locations,
|
||||
index_ray,
|
||||
index_tri,
|
||||
) = mesh.ray.intersects_location(
|
||||
two_sphere[:, 0, :],
|
||||
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
|
||||
multiple_hits=False,
|
||||
)
|
||||
|
||||
|
||||
if compute_normals:
|
||||
location_normals = mesh.face_normals[index_tri]
|
||||
|
||||
batch = two_sphere.shape[:1]
|
||||
hits = np.zeros((*batch,), dtype=np.bool)
|
||||
miss = np.ones((*batch,), dtype=np.bool)
|
||||
cam_pos = two_sphere[:, 0, :]
|
||||
intersections = two_sphere[:, 1, :] # far-plane, effectively
|
||||
normals = np.zeros((*batch, 3), dtype=dtype)
|
||||
|
||||
index_ray_front = index_ray
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
|
||||
index_ray_front = index_ray[mask]
|
||||
|
||||
|
||||
hits[index_ray_front] = True
|
||||
miss[index_ray] = False
|
||||
intersections[index_ray] = locations
|
||||
normals[index_ray] = location_normals
|
||||
|
||||
|
||||
if not no_permute:
|
||||
assert len(batch) == 1, batch
|
||||
permutation = np.random.permutation(*batch)
|
||||
hits = hits [permutation]
|
||||
miss = miss [permutation]
|
||||
intersections = intersections[permutation, :]
|
||||
normals = normals [permutation, :]
|
||||
cam_pos = cam_pos [permutation, :]
|
||||
|
||||
# apply unit sphere transform
|
||||
if not no_unit_sphere:
|
||||
intersections = (intersections + translation) * scale
|
||||
cam_pos = (cam_pos + translation) * scale
|
||||
|
||||
return SingleViewUVScan(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = intersections,
|
||||
normals = normals,
|
||||
colors = None, # colors
|
||||
distances = None,
|
||||
cam_pos = cam_pos,
|
||||
cam_mat4 = None,
|
||||
proj_mat4 = None,
|
||||
transforms = {},
|
||||
)
|
||||
|
||||
def distance_from_rays_to_point_cloud(
|
||||
ray_origins : np.ndarray, # (*A, 3)
|
||||
ray_dirs : np.ndarray, # (*A, 3)
|
||||
points : np.ndarray, # (*B, 3)
|
||||
dirs_normalized : bool = False,
|
||||
n_steps : int = 40,
|
||||
) -> np.ndarray: # (A)
|
||||
|
||||
# anything outside of this volume will never constribute to the result
|
||||
max_norm = max(
|
||||
np.linalg.norm(ray_origins, axis=-1).max(),
|
||||
np.linalg.norm(points, axis=-1).max(),
|
||||
) * 1.02
|
||||
|
||||
if not dirs_normalized:
|
||||
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
|
||||
|
||||
|
||||
# deal with single-view clouds
|
||||
if ray_origins.shape != ray_dirs.shape:
|
||||
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
|
||||
|
||||
n_points = np.product(points.shape[:-1])
|
||||
use_faiss = n_points > 160000*4
|
||||
if not use_faiss:
|
||||
index = BallTree(points)
|
||||
else:
|
||||
# http://ann-benchmarks.com/index.html
|
||||
assert np.issubdtype(points.dtype, np.float32)
|
||||
assert np.issubdtype(ray_origins.dtype, np.float32)
|
||||
assert np.issubdtype(ray_dirs.dtype, np.float32)
|
||||
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
|
||||
|
||||
index.nprobe = 5 # 10 # default is 1
|
||||
index.train(points)
|
||||
index.add(points)
|
||||
|
||||
if not use_faiss:
|
||||
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
|
||||
else:
|
||||
min_d, min_n = index.search(ray_origins, k=1)
|
||||
min_d = np.sqrt(min_d)
|
||||
acc_d = min_d.copy()
|
||||
|
||||
for step in range(1, n_steps+1):
|
||||
query_points = ray_origins + acc_d * ray_dirs
|
||||
if max_norm is not None:
|
||||
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
|
||||
if not qmask.any(): break
|
||||
query_points = query_points[qmask]
|
||||
else:
|
||||
qmask = slice(None)
|
||||
if not use_faiss:
|
||||
current_d, current_n = index.query(query_points, k=1, return_distance=True)
|
||||
else:
|
||||
current_d, current_n = index.search(query_points, k=1)
|
||||
current_d = np.sqrt(current_d)
|
||||
if max_norm is not None:
|
||||
min_d[qmask] = np.minimum(current_d, min_d[qmask])
|
||||
new_min_mask = min_d[qmask] == current_d
|
||||
qmask2 = qmask.copy()
|
||||
qmask2[qmask2] = new_min_mask[..., 0]
|
||||
min_n[qmask2] = current_n[new_min_mask[..., 0]]
|
||||
acc_d[qmask] += current_d * 0.25
|
||||
else:
|
||||
np.minimum(current_d, min_d, out=min_d)
|
||||
new_min_mask = min_d == current_d
|
||||
min_n[new_min_mask] = current_n[new_min_mask]
|
||||
acc_d += current_d * 0.25
|
||||
|
||||
closest_points = points[min_n[:, 0], :] # k=1
|
||||
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
|
||||
return distances
|
||||
|
||||
# helpers
|
||||
|
||||
@compose(np.array) # make copy to avoid lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
|
||||
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
|
||||
|
||||
indices = np.indices((len(sphere_points),))[0] # (N)
|
||||
# cartesian product
|
||||
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
|
||||
# filter repeated combinations
|
||||
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
|
||||
# lookup sphere points
|
||||
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
|
||||
|
||||
return two_sphere
|
||||
|
||||
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
|
||||
"""
|
||||
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
|
||||
"""
|
||||
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
translation = -mesh.bounding_box.centroid
|
||||
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
|
||||
if dtype is not None:
|
||||
translation = translation.astype(dtype)
|
||||
scale = scale .astype(dtype)
|
||||
return translation, scale
|
||||
6
ifield/data/common/types.py
Normal file
6
ifield/data/common/types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
__doc__ = """
|
||||
Some helper types.
|
||||
"""
|
||||
|
||||
class MalformedMesh(Exception):
|
||||
pass
|
||||
28
ifield/data/config.py
Normal file
28
ifield/data/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from ..utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
import warnings
|
||||
|
||||
|
||||
def data_path_get(dataset_name: str, no_warn: bool = False) -> Path:
|
||||
dataset_envvar = f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"
|
||||
if dataset_envvar in os.environ:
|
||||
data_path = Path(os.environ[dataset_envvar])
|
||||
elif "IFIELD_DATA_MODELS" in os.environ:
|
||||
data_path = Path(os.environ["IFIELD_DATA_MODELS"]) / dataset_name
|
||||
else:
|
||||
data_path = Path(__file__).resolve().parent.parent.parent / "data" / "models" / dataset_name
|
||||
if not data_path.is_dir() and not no_warn:
|
||||
warnings.warn(f"{make_relative(data_path, Path.cwd()).__str__()!r} is not a directory!")
|
||||
return data_path
|
||||
|
||||
def data_path_persist(dataset_name: Optional[str], path: os.PathLike) -> os.PathLike:
|
||||
"Persist the datapath, ensuring subprocesses also will use it. The path passes through."
|
||||
|
||||
if dataset_name is None:
|
||||
os.environ["IFIELD_DATA_MODELS"] = str(path)
|
||||
else:
|
||||
os.environ[f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"] = str(path)
|
||||
|
||||
return path
|
||||
56
ifield/data/coseg/__init__.py
Normal file
56
ifield/data/coseg/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from ..config import data_path_get, data_path_persist
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
# Data source:
|
||||
# http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
|
||||
|
||||
__ALL__ = ["config", "Model", "MODELS"]
|
||||
|
||||
Archive = namedtuple("Archive", "url fname download_size_str")
|
||||
|
||||
@(lambda x: x()) # singleton
|
||||
class config:
|
||||
DATA_PATH = property(
|
||||
doc = """
|
||||
Path to the dataset. The following envvars override it:
|
||||
${IFIELD_DATA_MODELS}/coseg
|
||||
${IFIELD_DATA_MODELS_COSEG}
|
||||
""",
|
||||
fget = lambda self: data_path_get ("coseg"),
|
||||
fset = lambda self, path: data_path_persist("coseg", path),
|
||||
)
|
||||
|
||||
@property
|
||||
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
|
||||
return [
|
||||
self.DATA_PATH / "downloaded.json",
|
||||
]
|
||||
|
||||
SHAPES: dict[str, Archive] = {
|
||||
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/shapes.zip", "candelabra-shapes.zip", "3,3M"),
|
||||
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/shapes.zip", "chair-shapes.zip", "3,2M"),
|
||||
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/shapes.zip", "four-legged-shapes.zip", "2,9M"),
|
||||
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/shapes.zip", "goblets-shapes.zip", "500K"),
|
||||
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip", "guitars-shapes.zip", "1,9M"),
|
||||
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/shapes.zip", "lampes-shapes.zip", "2,4M"),
|
||||
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/shapes.zip", "vases-shapes.zip", "5,5M"),
|
||||
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/shapes.zip", "irons-shapes.zip", "1,2M"),
|
||||
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/shapes.zip", "tele-aliens-shapes.zip", "15M"),
|
||||
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/shapes.zip", "large-vases-shapes.zip", "6,2M"),
|
||||
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/shapes.zip", "large-chairs-shapes.zip", "14M"),
|
||||
}
|
||||
GROUND_TRUTHS: dict[str, Archive] = {
|
||||
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/gt.zip", "candelabra-gt.zip", "68K"),
|
||||
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/gt.zip", "chair-gt.zip", "20K"),
|
||||
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/gt.zip", "four-legged-gt.zip", "24K"),
|
||||
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/gt.zip", "goblets-gt.zip", "4,0K"),
|
||||
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip", "guitars-gt.zip", "12K"),
|
||||
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/gt.zip", "lampes-gt.zip", "60K"),
|
||||
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/gt.zip", "vases-gt.zip", "40K"),
|
||||
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/gt.zip", "irons-gt.zip", "8,0K"),
|
||||
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/gt.zip", "tele-aliens-gt.zip", "72K"),
|
||||
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/gt.zip", "large-vases-gt.zip", "68K"),
|
||||
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/gt.zip", "large-chairs-gt.zip", "116K"),
|
||||
}
|
||||
135
ifield/data/coseg/download.py
Normal file
135
ifield/data/coseg/download.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
from . import config
|
||||
from ...utils.helpers import make_relative
|
||||
from ..common import download
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
|
||||
|
||||
def is_downloaded(*a, **kw):
|
||||
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
|
||||
|
||||
def download_and_extract(target_dir: Path, url_dict: dict[str, str], *, force=False, silent=False) -> bool:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ret = False
|
||||
for url, fname in url_dict.items():
|
||||
if not force:
|
||||
if is_downloaded(target_dir, url): continue
|
||||
if not download.check_url(url):
|
||||
print("ERROR:", url)
|
||||
continue
|
||||
ret = True
|
||||
|
||||
if force or not (target_dir / "archives" / fname).is_file():
|
||||
|
||||
data = download.download_data(url, silent=silent, label=fname)
|
||||
assert url.endswith(".zip")
|
||||
|
||||
print("writing...")
|
||||
|
||||
(target_dir / "archives").mkdir(parents=True, exist_ok=True)
|
||||
with (target_dir / "archives" / fname).open("wb") as f:
|
||||
f.write(data)
|
||||
del data
|
||||
|
||||
print(f"extracting {fname}...")
|
||||
|
||||
with zipfile.ZipFile(target_dir / "archives" / fname, 'r') as f:
|
||||
f.extractall(target_dir / Path(fname).stem.removesuffix("-shapes").removesuffix("-gt"))
|
||||
|
||||
is_downloaded(target_dir, url, add=True)
|
||||
|
||||
return ret
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Download The COSEG Shape Dataset.
|
||||
More info: http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
|
||||
|
||||
Example:
|
||||
|
||||
download-coseg --shapes chairs
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument
|
||||
|
||||
arg("sets", nargs="*", default=[],
|
||||
help="Which set to download, defaults to none.")
|
||||
arg("--all", action="store_true",
|
||||
help="Download all sets")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
|
||||
arg("--shapes", action="store_true",
|
||||
help="Download the 3d shapes for each chosen set")
|
||||
arg("--gts", action="store_true",
|
||||
help="Download the ground-truth segmentation data for each chosen set")
|
||||
|
||||
arg("--list", action="store_true",
|
||||
help="Lists all the sets")
|
||||
arg("--list-urls", action="store_true",
|
||||
help="Lists the urls to download")
|
||||
arg("--list-sizes", action="store_true",
|
||||
help="Lists the download size of each set")
|
||||
arg("--silent", action="store_true",
|
||||
help="")
|
||||
arg("--force", action="store_true",
|
||||
help="Download again even if already downloaded")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert set(config.SHAPES.keys()) == set(config.GROUND_TRUTHS.keys())
|
||||
|
||||
set_names = sorted(set(args.sets))
|
||||
if args.all:
|
||||
assert not set_names, "--all is mutually exclusive from manually selected sets"
|
||||
set_names = sorted(config.SHAPES.keys())
|
||||
|
||||
if args.list:
|
||||
print(*config.SHAPES.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
if args.list_sizes:
|
||||
print(*(f"{set_name:<15}{config.SHAPES[set_name].download_size_str}" for set_name in (set_names or config.SHAPES.keys())), sep="\n")
|
||||
exit()
|
||||
|
||||
try:
|
||||
url_dict \
|
||||
= {config.SHAPES[set_name].url : config.SHAPES[set_name].fname for set_name in set_names if args.shapes} \
|
||||
| {config.GROUND_TRUTHS[set_name].url : config.GROUND_TRUTHS[set_name].fname for set_name in set_names if args.gts}
|
||||
except KeyError:
|
||||
print("Error: unrecognized object name:", *set(set_names).difference(config.SHAPES.keys()), sep="\n")
|
||||
exit(1)
|
||||
|
||||
if not url_dict:
|
||||
if set_names and not (args.shapes or args.gts):
|
||||
print("Error: Provide at least one of --shapes of --gts")
|
||||
else:
|
||||
print("Error: No object set was selected for download!")
|
||||
exit(1)
|
||||
|
||||
if args.list_urls:
|
||||
print(*url_dict.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
print("Download start")
|
||||
any_downloaded = download_and_extract(
|
||||
target_dir = Path(args.dir),
|
||||
url_dict = url_dict,
|
||||
force = args.force,
|
||||
silent = args.silent,
|
||||
)
|
||||
if not any_downloaded:
|
||||
print("Everything has already been downloaded, skipping.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
137
ifield/data/coseg/preprocess.py
Normal file
137
ifield/data/coseg/preprocess.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
from . import config, read
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Preprocess the COSEG dataset. Depends on `download-coseg --shapes ...` having been run.
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument # brevity
|
||||
|
||||
arg("items", nargs="*", default=[],
|
||||
help="Which object-set[/model-id] to process, defaults to all downloaded. Format: OBJECT-SET[/MODEL-ID]")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
arg("--force", action="store_true",
|
||||
help="Overwrite existing files")
|
||||
arg("--list-models", action="store_true",
|
||||
help="List the downloaded models available for preprocessing")
|
||||
arg("--list-object-sets", action="store_true",
|
||||
help="List the downloaded object-sets available for preprocessing")
|
||||
arg("--list-pages", type=int, default=None,
|
||||
help="List the downloaded models available for preprocessing, paginated into N pages.")
|
||||
arg("--page", nargs=2, type=int, default=[0, 1],
|
||||
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
|
||||
|
||||
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
|
||||
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
|
||||
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
|
||||
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sphere-scan", action="store_true",
|
||||
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
|
||||
|
||||
arg3 = parser.add_argument_group("modifiers").add_argument # brevity
|
||||
arg3("--n-sphere-points", type=int, default=4000,
|
||||
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
|
||||
arg3("--compute-miss-distances", action="store_true",
|
||||
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
|
||||
arg3("--fill-missing-uv-points", action="store_true",
|
||||
help="TODO")
|
||||
arg3("--no-filter-backhits", action="store_true",
|
||||
help="Do not filter scan hits on backside of mesh faces.")
|
||||
arg3("--no-unit-sphere", action="store_true",
|
||||
help="Do not center the objects to the unit sphere.")
|
||||
arg3("--convert-ok", action="store_true",
|
||||
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
|
||||
arg3("--debug", action="store_true",
|
||||
help="Abort on failiure.")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list_models or args.list_object_sets or args.list_pages):
|
||||
parser.error("no preprocessing target selected") # exits
|
||||
|
||||
config.DATA_PATH = Path(args.dir)
|
||||
|
||||
object_sets = [i for i in args.items if "/" not in i]
|
||||
models = [i.split("/") for i in args.items if "/" in i]
|
||||
|
||||
# convert/expand synsets to models
|
||||
# they are mutually exclusive
|
||||
if object_sets: assert not models
|
||||
if models: assert not object_sets
|
||||
if not models:
|
||||
models = read.list_model_ids(tuple(object_sets) or None)
|
||||
|
||||
if args.list_models:
|
||||
try:
|
||||
print(*(f"{object_set_id}/{model_id}" for object_set_id, model_id in models), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.list_object_sets:
|
||||
try:
|
||||
print(*sorted(set(object_set_id for object_set_id, model_id in models)), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.list_pages is not None:
|
||||
try:
|
||||
print(*(
|
||||
f"--page {i} {args.list_pages} {object_set_id}/{model_id}"
|
||||
for object_set_id, model_id in models
|
||||
for i in range(args.list_pages)
|
||||
), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.precompute_mesh_sv_scan_clouds:
|
||||
read.precompute_mesh_scan_point_clouds(
|
||||
models,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sv_scan_uvs:
|
||||
read.precompute_mesh_scan_uvs(
|
||||
models,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
fill_missing_points = args.fill_missing_uv_points,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sphere_scan:
|
||||
read.precompute_mesh_sphere_scan(
|
||||
models,
|
||||
sphere_points = args.n_sphere_points,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
290
ifield/data/coseg/read.py
Normal file
290
ifield/data/coseg/read.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from . import config
|
||||
from ..common import points
|
||||
from ..common import processing
|
||||
from ..common.scan import SingleViewScan, SingleViewUVScan
|
||||
from ..common.types import MalformedMesh
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Iterable
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are functions for reading and preprocessing coseg benchmark data
|
||||
|
||||
There are essentially a few sets per object:
|
||||
"img" - meaning the RGBD images (none found in coseg)
|
||||
"mesh_scans" - meaning synthetic scans of a mesh
|
||||
"""
|
||||
|
||||
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0)) # rotate to be upright in pyrender
|
||||
MESH_POSE_CORRECTIONS = { # to gain a shared canonical orientation
|
||||
("four-legged", 381): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 382): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 383): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 384): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 385): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 386): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 387): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 388): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 389): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 390): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 391): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 392): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 393): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 394): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 395): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 396): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 397): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 398): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 399): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 400): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
}
|
||||
|
||||
|
||||
ModelUid = tuple[str, int]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_object_sets() -> list[str]:
|
||||
return sorted(
|
||||
object_set.name
|
||||
for object_set in config.DATA_PATH.iterdir()
|
||||
if (object_set / "shapes").is_dir() and object_set.name != "archive"
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_model_ids(object_sets: Optional[tuple[str]] = None) -> list[ModelUid]:
|
||||
return sorted(
|
||||
(object_set.name, int(model.stem))
|
||||
for object_set in config.DATA_PATH.iterdir()
|
||||
if (object_set / "shapes").is_dir() and object_set.name != "archive" and (object_sets is None or object_set.name in object_sets)
|
||||
for model in (object_set / "shapes").iterdir()
|
||||
if model.is_file() and model.suffix == ".off"
|
||||
)
|
||||
|
||||
def list_model_id_strings(object_sets: Optional[tuple[str]] = None) -> list[str]:
|
||||
return [model_uid_to_string(object_set_id, model_id) for object_set_id, model_id in list_model_ids(object_sets)]
|
||||
|
||||
def model_uid_to_string(object_set_id: str, model_id: int) -> str:
|
||||
return f"{object_set_id}-{model_id}"
|
||||
|
||||
def model_id_string_to_uid(model_string_uid: str) -> ModelUid:
|
||||
object_set, split, model = model_string_uid.rpartition("-")
|
||||
assert split == "-"
|
||||
return (object_set, int(model))
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
|
||||
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)
|
||||
|
||||
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
|
||||
return (
|
||||
f"{'np'[theta>=0]}{abs(theta):.2f}"
|
||||
f"{'np'[phi >=0]}{abs(phi) :.2f}"
|
||||
).replace(".", "d")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
|
||||
out = [
|
||||
mesh_scan_identifier(phi=phi, theta=theta)
|
||||
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
|
||||
]
|
||||
assert len(out) == len(set(out))
|
||||
return out
|
||||
|
||||
# ===
|
||||
|
||||
def read_mesh(object_set_id: str, model_id: int) -> trimesh.Trimesh:
|
||||
path = config.DATA_PATH / object_set_id / "shapes" / f"{model_id}.off"
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"{path = }")
|
||||
try:
|
||||
mesh = trimesh.load(path, force="mesh")
|
||||
except Exception as e:
|
||||
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
|
||||
|
||||
pose = MESH_POSE_CORRECTIONS.get((object_set_id, int(model_id)))
|
||||
mesh.apply_transform(pose @ MESH_TRANSFORM_SKYWARD if pose is not None else MESH_TRANSFORM_SKYWARD)
|
||||
return mesh
|
||||
|
||||
# === single-view scan clouds
|
||||
|
||||
def compute_mesh_scan_point_cloud(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
phi : float,
|
||||
theta : float,
|
||||
*,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_uv(object_set_id, model_id, phi=phi, theta=theta).to_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_point_clouds(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_point_cloud_h5_fnames(models, pose_identifiers, n_poses=n_poses)
|
||||
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
|
||||
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
|
||||
pretty_identifiers = [
|
||||
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for object_set_id, model_id in models
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_point_cloud(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_point_cloud(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
return SingleViewScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_point_cloud_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
for object_set_id, model_id in models
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
|
||||
# === single-view UV scan clouds
|
||||
|
||||
def compute_mesh_scan_uv(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
phi : float,
|
||||
theta : float,
|
||||
*,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_point_cloud(object_set_id, model_id, phi=phi, theta=theta).to_uv_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewUVScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_uvs(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_uv_h5_fnames(models, pose_identifiers, n_poses=n_poses)
|
||||
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
|
||||
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
|
||||
pretty_identifiers = [
|
||||
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for object_set_id, model_id in models
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
def computer(pretty_identifier: str) -> SingleViewUVScan:
|
||||
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_uv(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_uv(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_uv_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
for object_set_id, model_id in models
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
|
||||
# === sphere-view (UV) scan clouds
|
||||
|
||||
def compute_mesh_sphere_scan(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
*,
|
||||
compute_normals : bool = True,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
return scan
|
||||
|
||||
def precompute_mesh_sphere_scan(models: Iterable[ModelUid], *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
|
||||
"precomputes all sphere scan clouds and stores them as HDF5 datasets"
|
||||
paths = list_mesh_sphere_scan_h5_fnames(models)
|
||||
identifiers = [model_uid_to_string(*i) for i in models]
|
||||
def computer(identifier: str) -> SingleViewScan:
|
||||
object_set_id, model_id = model_id_string_to_uid(identifier)
|
||||
return compute_mesh_sphere_scan(object_set_id, model_id, **kw)
|
||||
return processing.precompute_data(computer, identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_mesh_sphere_scan(object_set_id: str, model_id: int) -> SingleViewUVScan:
|
||||
file = config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_sphere_scan_h5_fnames(models: Iterable[ModelUid]) -> list[str]:
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
|
||||
for object_set_id, model_id in models
|
||||
]
|
||||
76
ifield/data/stanford/__init__.py
Normal file
76
ifield/data/stanford/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from ..config import data_path_get, data_path_persist
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
# Data source:
|
||||
# http://graphics.stanford.edu/data/3Dscanrep/
|
||||
|
||||
__ALL__ = ["config", "Model", "MODELS"]
|
||||
|
||||
@(lambda x: x()) # singleton
|
||||
class config:
|
||||
DATA_PATH = property(
|
||||
doc = """
|
||||
Path to the dataset. The following envvars override it:
|
||||
${IFIELD_DATA_MODELS}/stanford
|
||||
${IFIELD_DATA_MODELS_STANFORD}
|
||||
""",
|
||||
fget = lambda self: data_path_get ("stanford"),
|
||||
fset = lambda self, path: data_path_persist("stanford", path),
|
||||
)
|
||||
|
||||
@property
|
||||
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
|
||||
return [
|
||||
self.DATA_PATH / "downloaded.json",
|
||||
]
|
||||
|
||||
Model = namedtuple("Model", "url mesh_fname download_size_str")
|
||||
MODELS: dict[str, Model] = {
|
||||
"bunny": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/bunny.tar.gz",
|
||||
"bunny/reconstruction/bun_zipper.ply",
|
||||
"4.89M",
|
||||
),
|
||||
"drill_bit": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/drill.tar.gz",
|
||||
"drill/reconstruction/drill_shaft_vrip.ply",
|
||||
"555k",
|
||||
),
|
||||
"happy_buddha": Model(
|
||||
# religious symbol
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/happy/happy_recon.tar.gz",
|
||||
"happy_recon/happy_vrip.ply",
|
||||
"14.5M",
|
||||
),
|
||||
"dragon": Model(
|
||||
# symbol of Chinese culture
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/dragon/dragon_recon.tar.gz",
|
||||
"dragon_recon/dragon_vrip.ply",
|
||||
"11.2M",
|
||||
),
|
||||
"armadillo": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/armadillo/Armadillo.ply.gz",
|
||||
"armadillo.ply.gz",
|
||||
"3.87M",
|
||||
),
|
||||
"lucy": Model(
|
||||
# Christian angel
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/lucy.tar.gz",
|
||||
"lucy.ply",
|
||||
"322M",
|
||||
),
|
||||
"asian_dragon": Model(
|
||||
# symbol of Chinese culture
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_dragon.ply.gz",
|
||||
"xyzrgb_dragon.ply.gz",
|
||||
"70.5M",
|
||||
),
|
||||
"thai_statue": Model(
|
||||
# Hindu religious significance
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_statuette.ply.gz",
|
||||
"xyzrgb_statuette.ply.gz",
|
||||
"106M",
|
||||
),
|
||||
}
|
||||
129
ifield/data/stanford/download.py
Normal file
129
ifield/data/stanford/download.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
from . import config
|
||||
from ...utils.helpers import make_relative
|
||||
from ..common import download
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Iterable
|
||||
import argparse
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
|
||||
def is_downloaded(*a, **kw):
|
||||
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
|
||||
|
||||
def download_and_extract(target_dir: Path, url_list: Iterable[str], *, force=False, silent=False) -> bool:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ret = False
|
||||
for url in url_list:
|
||||
if not force:
|
||||
if is_downloaded(target_dir, url): continue
|
||||
if not download.check_url(url):
|
||||
print("ERROR:", url)
|
||||
continue
|
||||
ret = True
|
||||
|
||||
data = download.download_data(url, silent=silent, label=str(Path(url).name))
|
||||
|
||||
print("extracting...")
|
||||
if url.endswith(".ply.gz"):
|
||||
fname = target_dir / "meshes" / url.split("/")[-1].lower()
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
with fname.open("wb") as f:
|
||||
f.write(data)
|
||||
elif url.endswith(".tar.gz"):
|
||||
with tarfile.open(fileobj=io.BytesIO(data)) as tar:
|
||||
for member in tar.getmembers():
|
||||
if not member.isfile(): continue
|
||||
if member.name.startswith("/"): continue
|
||||
if member.name.startswith("."): continue
|
||||
if Path(member.name).name.startswith("."): continue
|
||||
tar.extract(member, target_dir / "meshes")
|
||||
del tar
|
||||
else:
|
||||
raise NotImplementedError(f"Extraction for {str(Path(url).name)} unknown")
|
||||
|
||||
is_downloaded(target_dir, url, add=True)
|
||||
del data
|
||||
|
||||
return ret
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Download The Stanford 3D Scanning Repository models.
|
||||
More info: http://graphics.stanford.edu/data/3Dscanrep/
|
||||
|
||||
Example:
|
||||
|
||||
download-stanford bunny
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument
|
||||
|
||||
arg("objects", nargs="*", default=[],
|
||||
help="Which objects to download, defaults to none.")
|
||||
arg("--all", action="store_true",
|
||||
help="Download all objects")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
|
||||
arg("--list", action="store_true",
|
||||
help="Lists all the objects")
|
||||
arg("--list-urls", action="store_true",
|
||||
help="Lists the urls to download")
|
||||
arg("--list-sizes", action="store_true",
|
||||
help="Lists the download size of each model")
|
||||
arg("--silent", action="store_true",
|
||||
help="")
|
||||
arg("--force", action="store_true",
|
||||
help="Download again even if already downloaded")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
|
||||
obj_names = sorted(set(args.objects))
|
||||
if args.all:
|
||||
assert not obj_names
|
||||
obj_names = sorted(config.MODELS.keys())
|
||||
if not obj_names and args.list_urls: config.MODELS.keys()
|
||||
|
||||
if args.list:
|
||||
print(*config.MODELS.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
if args.list_sizes:
|
||||
print(*(f"{obj_name:<15}{config.MODELS[obj_name].download_size_str}" for obj_name in (obj_names or config.MODELS.keys())), sep="\n")
|
||||
exit()
|
||||
|
||||
try:
|
||||
url_list = [config.MODELS[obj_name].url for obj_name in obj_names]
|
||||
except KeyError:
|
||||
print("Error: unrecognized object name:", *set(obj_names).difference(config.MODELS.keys()), sep="\n")
|
||||
exit(1)
|
||||
|
||||
if not url_list:
|
||||
print("Error: No object set was selected for download!")
|
||||
exit(1)
|
||||
|
||||
if args.list_urls:
|
||||
print(*url_list, sep="\n")
|
||||
exit()
|
||||
|
||||
|
||||
print("Download start")
|
||||
any_downloaded = download_and_extract(
|
||||
target_dir = Path(args.dir),
|
||||
url_list = url_list,
|
||||
force = args.force,
|
||||
silent = args.silent,
|
||||
)
|
||||
if not any_downloaded:
|
||||
print("Everything has already been downloaded, skipping.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
118
ifield/data/stanford/preprocess.py
Normal file
118
ifield/data/stanford/preprocess.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
from . import config, read
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Preprocess the Stanford models. Depends on `download-stanford` having been run.
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument # brevity
|
||||
|
||||
arg("objects", nargs="*", default=[],
|
||||
help="Which objects to process, defaults to all downloaded")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
arg("--force", action="store_true",
|
||||
help="Overwrite existing files")
|
||||
arg("--list", action="store_true",
|
||||
help="List the downloaded models available for preprocessing")
|
||||
arg("--list-pages", type=int, default=None,
|
||||
help="List the downloaded models available for preprocessing, paginated into N pages.")
|
||||
arg("--page", nargs=2, type=int, default=[0, 1],
|
||||
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
|
||||
|
||||
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
|
||||
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
|
||||
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
|
||||
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sphere-scan", action="store_true",
|
||||
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
|
||||
|
||||
arg3 = parser.add_argument_group("ray-scan modifiers").add_argument # brevity
|
||||
arg3("--n-sphere-points", type=int, default=4000,
|
||||
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
|
||||
arg3("--compute-miss-distances", action="store_true",
|
||||
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
|
||||
arg3("--fill-missing-uv-points", action="store_true",
|
||||
help="TODO")
|
||||
arg3("--no-filter-backhits", action="store_true",
|
||||
help="Do not filter scan hits on backside of mesh faces.")
|
||||
arg3("--no-unit-sphere", action="store_true",
|
||||
help="Do not center the objects to the unit sphere.")
|
||||
arg3("--convert-ok", action="store_true",
|
||||
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
|
||||
arg3("--debug", action="store_true",
|
||||
help="Abort on failiure.")
|
||||
|
||||
arg5 = parser.add_argument_group("Shared modifiers").add_argument # brevity
|
||||
arg5("--scan-resolution", type=int, default=400,
|
||||
help="The resolution of the depth map rendered to sample points. Becomes x*x")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser: argparse.ArgumentParser = make_parser()):
|
||||
args = parser.parse_args()
|
||||
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list or args.list_pages):
|
||||
parser.error("no preprocessing target selected") # exits
|
||||
|
||||
config.DATA_PATH = Path(args.dir)
|
||||
obj_names = args.objects or read.list_object_names()
|
||||
|
||||
if args.list:
|
||||
print(*obj_names, sep="\n")
|
||||
parser.exit()
|
||||
|
||||
if args.list_pages is not None:
|
||||
print(*(
|
||||
f"--page {i} {args.list_pages} {obj_name}"
|
||||
for obj_name in obj_names
|
||||
for i in range(args.list_pages)
|
||||
), sep="\n")
|
||||
parser.exit()
|
||||
|
||||
if args.precompute_mesh_sv_scan_clouds:
|
||||
read.precompute_mesh_scan_point_clouds(
|
||||
obj_names,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sv_scan_uvs:
|
||||
read.precompute_mesh_scan_uvs(
|
||||
obj_names,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
fill_missing_points = args.fill_missing_uv_points,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sphere_scan:
|
||||
read.precompute_mesh_sphere_scan(
|
||||
obj_names,
|
||||
sphere_points = args.n_sphere_points,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
251
ifield/data/stanford/read.py
Normal file
251
ifield/data/stanford/read.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from . import config
|
||||
from ..common import points
|
||||
from ..common import processing
|
||||
from ..common.scan import SingleViewScan, SingleViewUVScan
|
||||
from ..common.types import MalformedMesh
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Optional, Iterable
|
||||
from pathlib import Path
|
||||
import gzip
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are functions for reading and preprocessing shapenet benchmark data
|
||||
|
||||
There are essentially a few sets per object:
|
||||
"img" - meaning the RGBD images (none found in stanford)
|
||||
"mesh_scans" - meaning synthetic scans of a mesh
|
||||
"""
|
||||
|
||||
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0))
|
||||
MESH_TRANSFORM_CANONICAL = { # to gain a shared canonical orientation
|
||||
"armadillo" : T.rotation_matrix(np.pi, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
|
||||
"asian_dragon" : T.rotation_matrix(-np.pi/2, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
|
||||
"bunny" : MESH_TRANSFORM_SKYWARD,
|
||||
"dragon" : MESH_TRANSFORM_SKYWARD,
|
||||
"drill_bit" : MESH_TRANSFORM_SKYWARD,
|
||||
"happy_buddha" : MESH_TRANSFORM_SKYWARD,
|
||||
"lucy" : T.rotation_matrix(np.pi, (0, 0, 1)),
|
||||
"thai_statue" : MESH_TRANSFORM_SKYWARD,
|
||||
}
|
||||
|
||||
def list_object_names() -> list[str]:
|
||||
# downloaded only:
|
||||
return [
|
||||
i for i, v in config.MODELS.items()
|
||||
if (config.DATA_PATH / "meshes" / v.mesh_fname).is_file()
|
||||
]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
|
||||
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)#, shift_theta=True
|
||||
|
||||
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
|
||||
return (
|
||||
f"{'np'[theta>=0]}{abs(theta):.2f}"
|
||||
f"{'np'[phi >=0]}{abs(phi) :.2f}"
|
||||
).replace(".", "d")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
|
||||
out = [
|
||||
mesh_scan_identifier(phi=phi, theta=theta)
|
||||
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
|
||||
]
|
||||
assert len(out) == len(set(out))
|
||||
return out
|
||||
|
||||
# ===
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def read_mesh(obj_name: str) -> trimesh.Trimesh:
|
||||
path = config.DATA_PATH / "meshes" / config.MODELS[obj_name].mesh_fname
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"{obj_name = } -> {str(path) = }")
|
||||
try:
|
||||
if path.suffixes[-1] == ".gz":
|
||||
with gzip.open(path, "r") as f:
|
||||
mesh = trimesh.load(f, file_type="".join(path.suffixes[:-1])[1:])
|
||||
else:
|
||||
mesh = trimesh.load(path)
|
||||
except Exception as e:
|
||||
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
|
||||
|
||||
# rotate to be upright in pyrender
|
||||
mesh.apply_transform(MESH_TRANSFORM_CANONICAL.get(obj_name, MESH_TRANSFORM_SKYWARD))
|
||||
|
||||
return mesh
|
||||
|
||||
# === single-view scan clouds
|
||||
|
||||
def compute_mesh_scan_point_cloud(
|
||||
obj_name : str,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
compute_miss_distances : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False, # this does not respect the other hparams
|
||||
**kw,
|
||||
) -> SingleViewScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_uv(obj_name, phi=phi, theta=theta).to_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(obj_name)
|
||||
return SingleViewScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
compute_miss_distances = compute_miss_distances,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def precompute_mesh_scan_point_clouds(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_point_cloud_h5_fnames(obj_names, pose_identifiers)
|
||||
mlen = max(map(len, config.MODELS.keys()))
|
||||
pretty_identifiers = [
|
||||
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for obj_name in obj_names
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
@wraps(compute_mesh_scan_point_cloud)
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_point_cloud(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_point_cloud_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
|
||||
for obj_name in obj_names
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
# === single-view UV scan clouds
|
||||
|
||||
def compute_mesh_scan_uv(
|
||||
obj_name : str,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta).to_uv_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(obj_name)
|
||||
scan = SingleViewUVScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_uvs(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_uv_h5_fnames(obj_names, pose_identifiers)
|
||||
mlen = max(map(len, config.MODELS.keys()))
|
||||
pretty_identifiers = [
|
||||
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for obj_name in obj_names
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
@wraps(compute_mesh_scan_uv)
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_uv(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_uv(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_uv_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
|
||||
for obj_name in obj_names
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
# === sphere-view (UV) scan clouds
|
||||
|
||||
def compute_mesh_sphere_scan(
|
||||
obj_name : str,
|
||||
*,
|
||||
compute_normals : bool = True,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
mesh = read_mesh(obj_name)
|
||||
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
return scan
|
||||
|
||||
def precompute_mesh_sphere_scan(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
paths = list_mesh_sphere_scan_h5_fnames(obj_names)
|
||||
@wraps(compute_mesh_sphere_scan)
|
||||
def computer(obj_name: str) -> SingleViewScan:
|
||||
return compute_mesh_sphere_scan(obj_name, **kw)
|
||||
return processing.precompute_data(computer, obj_names, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_mesh_sphere_scan(obj_name) -> SingleViewUVScan:
|
||||
file = config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_sphere_scan_h5_fnames(obj_names: Iterable[str]) -> list[Path]:
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
|
||||
for obj_name in obj_names
|
||||
]
|
||||
3
ifield/datasets/__init__.py
Normal file
3
ifield/datasets/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Submodules defining various `torch.utils.data.Dataset`
|
||||
"""
|
||||
196
ifield/datasets/common.py
Normal file
196
ifield/datasets/common.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from ..data.common.h5_dataclasses import H5Dataclass, PathLike
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from typing import Any, Iterable, Hashable, TypeVar, Iterator, Callable
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_H5 = TypeVar("T_H5", bound=H5Dataclass)
|
||||
|
||||
|
||||
class TransformableDatasetMixin:
|
||||
def __init_subclass__(cls):
|
||||
if getattr(cls, "_transformable_mixin_no_override_getitem", False):
|
||||
pass
|
||||
elif issubclass(cls, Dataset):
|
||||
if cls.__getitem__ is not cls._transformable_mixin_getitem_wrapper:
|
||||
cls._transformable_mixin_inner_getitem = cls.__getitem__
|
||||
cls.__getitem__ = cls._transformable_mixin_getitem_wrapper
|
||||
elif issubclass(cls, IterableDataset):
|
||||
if cls.__iter__ is not cls._transformable_mixin_iter_wrapper:
|
||||
cls._transformable_mixin_inner_iter = cls.__iter__
|
||||
cls.__iter__ = cls._transformable_mixin_iter_wrapper
|
||||
else:
|
||||
raise TypeError(f"{cls.__name__!r} is neither a Dataset nor a IterableDataset!")
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
self._transforms = []
|
||||
|
||||
# works as a decorator
|
||||
def map(self: T, func: callable = None, /, args=[], **kw) -> T:
|
||||
def wrapper(func) -> T:
|
||||
if args or kw:
|
||||
func = partial(func, *args, **kw)
|
||||
self._transforms.append(func)
|
||||
return self
|
||||
|
||||
if func is None:
|
||||
return wrapper
|
||||
else:
|
||||
return wrapper(func)
|
||||
|
||||
|
||||
def _transformable_mixin_getitem_wrapper(self, index: int):
|
||||
if not self._transforms:
|
||||
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, no transforms)
|
||||
else:
|
||||
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, has transforms)
|
||||
for f in self._transforms:
|
||||
out = f(out) # (TransformableDatasetMixin)
|
||||
return out
|
||||
|
||||
def _transformable_mixin_iter_wrapper(self):
|
||||
if not self._transforms:
|
||||
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, no transforms)
|
||||
else:
|
||||
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, has transforms)
|
||||
for f in self._transforms:
|
||||
out = map(f, out) # (TransformableDatasetMixin)
|
||||
return out
|
||||
|
||||
|
||||
class TransformedDataset(Dataset, TransformableDatasetMixin):
|
||||
# used to wrap an another dataset
|
||||
def __init__(self, dataset: Dataset, transforms: Iterable[callable]):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
for i in transforms:
|
||||
self.map(i)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.dataset[index] # (TransformedDataset)
|
||||
|
||||
|
||||
class TransformExtendedDataset(Dataset, TransformableDatasetMixin):
|
||||
_transformable_mixin_no_override_getitem = True
|
||||
def __init__(self, dataset: Dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset) * len(self._transforms)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
n = len(self._transforms)
|
||||
assert n > 0, f"{len(self._transforms) = }"
|
||||
|
||||
item = index // n
|
||||
transform = self._transforms[index % n]
|
||||
return transform(self.dataset[item])
|
||||
|
||||
|
||||
class CachedDataset(Dataset):
|
||||
# used to wrap an another dataset
|
||||
def __init__(self, dataset: Dataset, cache_size: int | None):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
if cache_size is not None and cache_size > 0:
|
||||
self.cached_getter = lru_cache(cache_size, self.dataset.__getitem__)
|
||||
else:
|
||||
self.cached_getter = self.dataset.__getitem__
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.cached_getter(index)
|
||||
|
||||
|
||||
class AutodecoderDataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
keys : Iterable[Hashable],
|
||||
dataset : Dataset,
|
||||
):
|
||||
super().__init__()
|
||||
self.ad_mapping = list(keys)
|
||||
self.dataset = dataset
|
||||
if len(self.ad_mapping) != len(dataset):
|
||||
raise ValueError(f"__len__ mismatch between keys and dataset: {len(self.ad_mapping)} != {len(dataset)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[Hashable, Any]:
|
||||
return self.ad_mapping[index], self.dataset[index] # (AutodecoderDataset)
|
||||
|
||||
def keys(self) -> list[Hashable]:
|
||||
return self.ad_mapping
|
||||
|
||||
def values(self) -> Iterator:
|
||||
return iter(self.dataset)
|
||||
|
||||
def items(self) -> Iterable[tuple[Hashable, Any]]:
|
||||
return zip(self.ad_mapping, self.dataset)
|
||||
|
||||
|
||||
class FunctionDataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
getter : Callable[[Hashable], T],
|
||||
keys : list[Hashable],
|
||||
cache_size : int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
if cache_size is not None and cache_size > 0:
|
||||
getter = lru_cache(cache_size)(getter)
|
||||
self.getter = getter
|
||||
self.keys = keys
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.keys)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
return self.getter(self.keys[index])
|
||||
|
||||
class H5Dataset(FunctionDataset):
|
||||
def __init__(self,
|
||||
h5_dataclass_cls : type[T_H5],
|
||||
fnames : list[PathLike],
|
||||
**kw,
|
||||
):
|
||||
super().__init__(
|
||||
getter = h5_dataclass_cls.from_h5_file,
|
||||
keys = fnames,
|
||||
**kw,
|
||||
)
|
||||
|
||||
class PaginatedH5Dataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
h5_dataclass_cls : type[T_H5],
|
||||
fnames : list[PathLike],
|
||||
n_pages : int = 10,
|
||||
require_even_pages : bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.h5_dataclass_cls = h5_dataclass_cls
|
||||
self.fnames = fnames
|
||||
self.n_pages = n_pages
|
||||
self.require_even_pages = require_even_pages
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.fnames) * self.n_pages
|
||||
|
||||
def __getitem__(self, index: int) -> T_H5:
|
||||
item = index // self.n_pages
|
||||
page = index % self.n_pages
|
||||
|
||||
return self.h5_dataclass_cls.from_h5_file( # (PaginatedH5Dataset)
|
||||
fname = self.fname[item],
|
||||
page = page,
|
||||
n_pages = self.n_pages,
|
||||
require_even_pages = self.require_even_pages,
|
||||
)
|
||||
40
ifield/datasets/coseg.py
Normal file
40
ifield/datasets/coseg.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from . import common
|
||||
from ..data.coseg import config
|
||||
from ..data.coseg import read
|
||||
from ..data.common import scan
|
||||
from typing import Iterable, Optional, Union
|
||||
import os
|
||||
|
||||
|
||||
class SingleViewUVScanDataset(common.H5Dataset):
|
||||
def __init__(self,
|
||||
object_sets : tuple[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if not object_sets:
|
||||
raise ValueError("'object_sets' cannot be empty!")
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
if data_path is not None:
|
||||
config.DATA_PATH = data_path
|
||||
models = read.list_model_ids(object_sets)
|
||||
fnames = read.list_mesh_scan_uv_h5_fnames(models, identifiers)
|
||||
super().__init__(
|
||||
h5_dataclass_cls = scan.SingleViewUVScan,
|
||||
fnames = fnames,
|
||||
)
|
||||
|
||||
class AutodecoderSingleViewUVScanDataset(common.AutodecoderDataset):
|
||||
def __init__(self,
|
||||
object_sets : tuple[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
# here do this step first, such that all the duplicate strings reference the same object
|
||||
super().__init__(
|
||||
keys = [key for key in read.list_model_id_strings(object_sets) for _ in range(len(identifiers))],
|
||||
dataset = SingleViewUVScanDataset(object_sets, identifiers, data_path=data_path),
|
||||
)
|
||||
64
ifield/datasets/stanford.py
Normal file
64
ifield/datasets/stanford.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from . import common
|
||||
from ..data.stanford import config
|
||||
from ..data.stanford import read
|
||||
from ..data.common import scan
|
||||
from typing import Iterable, Optional, Union
|
||||
import os
|
||||
|
||||
|
||||
class SingleViewUVScanDataset(common.H5Dataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if not obj_names:
|
||||
raise ValueError("'obj_names' cannot be empty!")
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
if data_path is not None:
|
||||
config.DATA_PATH = data_path
|
||||
fnames = read.list_mesh_scan_uv_h5_fnames(obj_names, identifiers)
|
||||
super().__init__(
|
||||
h5_dataclass_cls = scan.SingleViewUVScan,
|
||||
fnames = fnames,
|
||||
)
|
||||
|
||||
class AutodecoderSingleViewUVScanDataset(common.AutodecoderDataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
identifiers : Optional[Iterable[str]] = None,
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if identifiers is None:
|
||||
identifiers = read.list_mesh_scan_identifiers()
|
||||
super().__init__(
|
||||
keys = [obj_name for obj_name in obj_names for _ in range(len(identifiers))],
|
||||
dataset = SingleViewUVScanDataset(obj_names, identifiers, data_path=data_path),
|
||||
)
|
||||
|
||||
|
||||
class SphereScanDataset(common.H5Dataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
if not obj_names:
|
||||
raise ValueError("'obj_names' cannot be empty!")
|
||||
if data_path is not None:
|
||||
config.DATA_PATH = data_path
|
||||
fnames = read.list_mesh_sphere_scan_h5_fnames(obj_names)
|
||||
super().__init__(
|
||||
h5_dataclass_cls = scan.SingleViewUVScan,
|
||||
fnames = fnames,
|
||||
)
|
||||
|
||||
class AutodecoderSphereScanDataset(common.AutodecoderDataset):
|
||||
def __init__(self,
|
||||
obj_names : Iterable[str],
|
||||
data_path : Union[str, os.PathLike, None] = None,
|
||||
):
|
||||
super().__init__(
|
||||
keys = obj_names,
|
||||
dataset = SphereScanDataset(obj_names, data_path=data_path),
|
||||
)
|
||||
258
ifield/logging.py
Normal file
258
ifield/logging.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from . import param
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from typing import Union, Literal, Optional, TypeVar
|
||||
import concurrent.futures
|
||||
import psutil
|
||||
import pytorch_lightning as pl
|
||||
import statistics
|
||||
import threading
|
||||
import time
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# from https://github.com/yaml/pyyaml/issues/240#issuecomment-1018712495
|
||||
def str_presenter(dumper, data):
|
||||
"""configures yaml for dumping multiline strings
|
||||
Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data"""
|
||||
if len(data.splitlines()) > 1: # check for multiline string
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||
yaml.add_representer(str, str_presenter)
|
||||
|
||||
|
||||
LoggerStr = Literal[
|
||||
#"csv",
|
||||
"tensorboard",
|
||||
#"mlflow",
|
||||
#"comet",
|
||||
#"neptune",
|
||||
#"wandb",
|
||||
None]
|
||||
try:
|
||||
Logger = TypeVar("L", bound=pl.loggers.Logger)
|
||||
except AttributeError:
|
||||
Logger = TypeVar("L", bound=pl.loggers.base.LightningLoggerBase)
|
||||
|
||||
def make_logger(
|
||||
experiment_name : str,
|
||||
default_root_dir : Union[str, Path], # from pl.Trainer
|
||||
save_dir : Union[str, Path],
|
||||
type : LoggerStr = "tensorboard",
|
||||
project : str = "ifield",
|
||||
) -> Optional[Logger]:
|
||||
if type is None:
|
||||
return None
|
||||
elif type == "tensorboard":
|
||||
return pl.loggers.TensorBoardLogger(
|
||||
name = "tensorboard",
|
||||
save_dir = Path(default_root_dir) / save_dir,
|
||||
version = experiment_name,
|
||||
log_graph = True,
|
||||
)
|
||||
raise ValueError(f"make_logger({type=})")
|
||||
|
||||
def make_jinja_template(*, save_dir: Union[None, str, Path], **kw) -> str:
|
||||
return param.make_jinja_template(make_logger,
|
||||
defaults = dict(
|
||||
save_dir = save_dir,
|
||||
),
|
||||
exclude_list = {
|
||||
"experiment_name",
|
||||
"default_root_dir",
|
||||
},
|
||||
**({"name": "logging"} | kw),
|
||||
)
|
||||
|
||||
def get_checkpoints(experiment_name, default_root_dir, save_dir, type, project) -> list[Path]:
|
||||
if type is None:
|
||||
return None
|
||||
if type == "tensorboard":
|
||||
folder = Path(default_root_dir) / save_dir / "tensorboard" / experiment_name
|
||||
return folder.glob("*.ckpt")
|
||||
if type == "mlflow":
|
||||
raise NotImplementedError(f"{type=}")
|
||||
if type == "wandb":
|
||||
raise NotImplementedError(f"{type=}")
|
||||
raise ValueError(f"get_checkpoint({type=})")
|
||||
|
||||
|
||||
def log_config(_logger: Logger, **kwargs: Union[str, dict, list, int, float]):
|
||||
assert isinstance(_logger, pl.loggers.Logger) \
|
||||
or isinstance(_logger, pl.loggers.base.LightningLoggerBase), _logger
|
||||
|
||||
_logger: pl.loggers.TensorBoardLogger
|
||||
_logger.log_hyperparams(params=kwargs)
|
||||
|
||||
@dataclass
|
||||
class ModelOutputMonitor(pl.callbacks.Callback):
|
||||
log_training : bool = True
|
||||
log_validation : bool = True
|
||||
|
||||
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
|
||||
|
||||
@staticmethod
|
||||
def _log_outputs(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, fname: str):
|
||||
if outputs is None:
|
||||
return
|
||||
elif isinstance(outputs, list) or isinstance(outputs, tuple):
|
||||
outputs = {
|
||||
f"loss[{i}]": v
|
||||
for i, v in enumerate(outputs)
|
||||
}
|
||||
elif isinstance(outputs, torch.Tensor):
|
||||
outputs = {
|
||||
"loss": outputs,
|
||||
}
|
||||
elif isinstance(outputs, dict):
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
sep = trainer.logger.group_separator
|
||||
pl_module.log_dict({
|
||||
f"{pl_module.__class__.__qualname__}.{fname}{sep}{k}":
|
||||
float(v.item()) if isinstance(v, torch.Tensor) else float(v)
|
||||
for k, v in outputs.items()
|
||||
}, sync_dist=True)
|
||||
|
||||
def on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, unused=0):
|
||||
if self.log_training:
|
||||
self._log_outputs(trainer, pl_module, outputs, "training_step")
|
||||
|
||||
def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0):
|
||||
if self.log_validation:
|
||||
self._log_outputs(trainer, pl_module, outputs, "validation_step")
|
||||
|
||||
class EpochTimeMonitor(pl.callbacks.Callback):
|
||||
__slots__ = [
|
||||
"epoch_start",
|
||||
"epoch_start_train",
|
||||
"epoch_start_validation",
|
||||
"epoch_start_test",
|
||||
"epoch_start_predict",
|
||||
]
|
||||
|
||||
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_train = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_validation = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_test = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_predict_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
self.epoch_start_predict = time.time()
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_train
|
||||
del self.epoch_start_train
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_train_time" : t}, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_validation
|
||||
del self.epoch_start_validation
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_validation_time" : t}, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_test
|
||||
del self.epoch_start_validation
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_test_time" : t}, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
t = time.time() - self.epoch_start_predict
|
||||
del self.epoch_start_validation
|
||||
sep = trainer.logger.group_separator
|
||||
trainer.logger.log_metrics({f"{self.__class__.__qualname__}{sep}epoch_predict_time" : t}, step=trainer.global_step)
|
||||
|
||||
@dataclass
|
||||
class PsutilMonitor(pl.callbacks.Callback):
|
||||
sample_rate : float = 0.2 # times per second
|
||||
|
||||
_should_stop = False
|
||||
|
||||
@rank_zero_only
|
||||
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(f"Cannot use {self._class__.__name__} callback with Trainer that has no logger.")
|
||||
assert not hasattr(self, "_thread")
|
||||
|
||||
self._should_stop = False
|
||||
self._thread = threading.Thread(
|
||||
target = self.thread_target,
|
||||
name = self.thread_target.__qualname__,
|
||||
args = [trainer],
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
@rank_zero_only
|
||||
def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
assert getattr(self, "_thread", None) is not None
|
||||
self._should_stop = True
|
||||
del self._thread
|
||||
|
||||
def thread_target(self, trainer: pl.Trainer):
|
||||
uses_gpu = isinstance(trainer.accelerator, (pl.accelerators.GPUAccelerator, pl.accelerators.CUDAAccelerator))
|
||||
gpu_ids = trainer.device_ids
|
||||
|
||||
prefix = f"{self.__class__.__qualname__}{trainer.logger.group_separator}"
|
||||
|
||||
while not self._should_stop:
|
||||
step = trainer.global_step
|
||||
p = psutil.Process()
|
||||
|
||||
meminfo = p.memory_info()
|
||||
rss_ram = meminfo.rss / 1024**2 # MB
|
||||
vms_ram = meminfo.vms / 1024**2 # MB
|
||||
|
||||
util_per_cpu = psutil.cpu_percent(percpu=True)
|
||||
|
||||
util_per_cpu = [util_per_cpu[i] for i in p.cpu_affinity()]
|
||||
util_total = statistics.mean(util_per_cpu)
|
||||
|
||||
if uses_gpu:
|
||||
with concurrent.futures.ThreadPoolExecutor() as e:
|
||||
if hasattr(pl.accelerators, "cuda"):
|
||||
gpu_stats = e.map(pl.accelerators.cuda.get_nvidia_gpu_stats, gpu_ids)
|
||||
else:
|
||||
gpu_stats = e.map(pl.accelerators.gpu.get_nvidia_gpu_stats, gpu_ids)
|
||||
trainer.logger.log_metrics({
|
||||
f"{prefix}ram.rss" : rss_ram,
|
||||
f"{prefix}ram.vms" : vms_ram,
|
||||
f"{prefix}cpu.total" : util_total,
|
||||
**{ f"{prefix}cpu.{i:03}.utilization" : stat for i, stat in enumerate(util_per_cpu) },
|
||||
**{
|
||||
f"{prefix}gpu.{gpu_idx:02}.{key.split(' ',1)[0]}" : stat
|
||||
for gpu_idx, stats in zip(gpu_ids, gpu_stats)
|
||||
for key, stat in stats.items()
|
||||
},
|
||||
}, step = step)
|
||||
else:
|
||||
trainer.logger.log_metrics({
|
||||
f"{prefix}cpu.total" : util_total,
|
||||
**{ f"{prefix}cpu.{i:03}.utilization" : stat for i, stat in enumerate(util_per_cpu) },
|
||||
}, step = step)
|
||||
|
||||
time.sleep(1 / self.sample_rate)
|
||||
print("DAEMON END")
|
||||
3
ifield/models/__init__.py
Normal file
3
ifield/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Contains Pytorch Models
|
||||
"""
|
||||
159
ifield/models/conditioning.py
Normal file
159
ifield/models/conditioning.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import nn, Tensor
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
from typing import Hashable, Union, Optional, KeysView, ValuesView, ItemsView, Any, Sequence
|
||||
import torch
|
||||
|
||||
|
||||
class RequiresConditioner(nn.Module, ABC): # mixin
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def n_latent_features(self) -> int:
|
||||
"This should provide the width of the conditioning feature vector"
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def latent_embeddings_init_std(self) -> float:
|
||||
"This should provide the standard deviation to initialize the latent features with. DeepSDF uses 0.01."
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def latent_embeddings() -> Optional[Tensor]:
|
||||
"""This property should return a tensor cotnaining all stored embeddings, for use in computing auto-decoder losses"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, batch: Any, batch_idx: int, optimizer_idx: int) -> Tensor:
|
||||
"This should, given a training batch, return the encoded conditioning vector"
|
||||
...
|
||||
|
||||
|
||||
class AutoDecoderModuleMixin(RequiresConditioner, ABC):
|
||||
"""
|
||||
Populates dunder methods making it behave as a mapping.
|
||||
The mapping indexes into a stored set of learnable embedding vectors.
|
||||
|
||||
Based on the auto-decoder architecture of
|
||||
J.J. Park, P. Florence, J. Straub, R. Newcombe, S. Lovegrove, DeepSDF:
|
||||
Learning Continuous Signed Distance Functions for Shape Representation, in:
|
||||
2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
|
||||
IEEE, Long Beach, CA, USA, 2019: pp. 165–174.
|
||||
https://doi.org/10.1109/CVPR.2019.00025.
|
||||
"""
|
||||
|
||||
_autodecoder_mapping: dict[Hashable, int]
|
||||
autodecoder_embeddings: nn.Parameter
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
|
||||
@self._register_load_state_dict_pre_hook
|
||||
def hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if f"{prefix}_autodecoder_mapping" in state_dict:
|
||||
state_dict[f"{prefix}{_EXTRA_STATE_KEY_SUFFIX}"] = state_dict.pop(f"{prefix}_autodecoder_mapping")
|
||||
|
||||
class ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing(nn.UninitializedParameter):
|
||||
def copy_(self, other):
|
||||
self.materialize(other.shape, other.device, other.dtype)
|
||||
return self.copy_(other)
|
||||
self.autodecoder_embeddings = ICanBeLoadedFromCheckpointsAndChangeShapeStopBotheringMePyTorchAndSitInTheCornerIKnowWhatIAmDoing()
|
||||
|
||||
# nn.Module interface
|
||||
|
||||
def get_extra_state(self):
|
||||
return {
|
||||
"ad_uids": getattr(self, "_autodecoder_mapping", {}),
|
||||
}
|
||||
|
||||
def set_extra_state(self, obj):
|
||||
if "ad_uids" not in obj: # backward compat
|
||||
self._autodecoder_mapping = obj
|
||||
else:
|
||||
self._autodecoder_mapping = obj["ad_uids"]
|
||||
|
||||
# RequiresConditioner interface
|
||||
|
||||
@property
|
||||
def latent_embeddings(self) -> Tensor:
|
||||
return self.autodecoder_embeddings
|
||||
|
||||
# my interface
|
||||
|
||||
def set_observation_ids(self, z_uids: set[Hashable]):
|
||||
assert self.latent_embeddings_init_std is not None, f"{self.__module__}.{self.__class__.__qualname__}.latent_embeddings_init_std"
|
||||
assert self.n_latent_features is not None, f"{self.__module__}.{self.__class__.__qualname__}.n_latent_features"
|
||||
assert self.latent_embeddings_init_std > 0, self.latent_embeddings_init_std
|
||||
assert self.n_latent_features > 0, self.n_latent_features
|
||||
|
||||
self._autodecoder_mapping = {
|
||||
k: i
|
||||
for i, k in enumerate(sorted(set(z_uids)))
|
||||
}
|
||||
|
||||
if not len(z_uids) == len(self._autodecoder_mapping):
|
||||
raise ValueError(f"Observation identifiers are not unique! {z_uids = }")
|
||||
|
||||
self.autodecoder_embeddings = nn.Parameter(
|
||||
torch.Tensor(len(self._autodecoder_mapping), self.n_latent_features)
|
||||
.normal_(mean=0, std=self.latent_embeddings_init_std)
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
|
||||
def add_key(self, z_uid: Hashable, z: Optional[Tensor] = None):
|
||||
if z_uid in self._autodecoder_mapping:
|
||||
raise ValueError(f"Observation identifier {z_uid!r} not unique!")
|
||||
|
||||
self._autodecoder_mapping[z_uid] = len(self._autodecoder_mapping)
|
||||
self.autodecoder_embeddings
|
||||
raise NotImplementedError
|
||||
|
||||
def __delitem__(self, z_uid: Hashable):
|
||||
i = self._autodecoder_mapping.pop(z_uid)
|
||||
for k, v in list(self._autodecoder_mapping.items()):
|
||||
if v > i:
|
||||
self._autodecoder_mapping[k] -= 1
|
||||
|
||||
with torch.no_grad():
|
||||
self.autodecoder_embeddings = nn.Parameter(torch.cat((
|
||||
self.autodecoder_embeddings.detach()[:i, :],
|
||||
self.autodecoder_embeddings.detach()[i+1:, :],
|
||||
), dim=0))
|
||||
|
||||
def __contains__(self, z_uid: Hashable) -> bool:
|
||||
return z_uid in self._autodecoder_mapping
|
||||
|
||||
def __getitem__(self, z_uids: Union[Hashable, Sequence[Hashable]]) -> Tensor:
|
||||
if isinstance(z_uids, tuple) or isinstance(z_uids, list):
|
||||
key = tuple(map(self._autodecoder_mapping.__getitem__, z_uids))
|
||||
else:
|
||||
key = self._autodecoder_mapping[z_uids]
|
||||
return self.autodecoder_embeddings[key, :]
|
||||
|
||||
def __iter__(self):
|
||||
return self._autodecoder_mapping.keys()
|
||||
|
||||
def keys(self) -> KeysView[Hashable]:
|
||||
"""
|
||||
lists the identifiers of each code
|
||||
"""
|
||||
return self._autodecoder_mapping.keys()
|
||||
|
||||
def values(self) -> ValuesView[Tensor]:
|
||||
return list(self.autodecoder_embeddings)
|
||||
|
||||
def items(self) -> ItemsView[Hashable, Tensor]:
|
||||
"""
|
||||
lists all the learned codes / latent vectors with their identifiers as keys
|
||||
"""
|
||||
return {
|
||||
k : self.autodecoder_embeddings[i]
|
||||
for k, i in self._autodecoder_mapping.items()
|
||||
}.items()
|
||||
|
||||
class EncoderModuleMixin(RequiresConditioner, ABC):
|
||||
@property
|
||||
def latent_embeddings(self) -> None:
|
||||
return None
|
||||
589
ifield/models/intersection_fields.py
Normal file
589
ifield/models/intersection_fields.py
Normal file
@@ -0,0 +1,589 @@
|
||||
from .. import param
|
||||
from ..modules.dtype import DtypeMixin
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from ..utils.loss import Schedulable, ensure_schedulables, HParamSchedule, HParamScheduleBase, Linear
|
||||
from ..utils.operators import diff
|
||||
from .conditioning import RequiresConditioner, AutoDecoderModuleMixin
|
||||
from .medial_atoms import MedialAtomNet
|
||||
from .orthogonal_plane import OrthogonalPlaneNet
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import TypedDict, Literal, Union, Hashable, Optional
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import os
|
||||
|
||||
LOG_ALL_METRICS = bool(int(os.environ.get("IFIELD_LOG_ALL_METRICS", "1")))
|
||||
|
||||
if __debug__:
|
||||
def broadcast_tensors(*tensors: torch.Tensor) -> list[torch.Tensor]:
|
||||
try:
|
||||
return torch.broadcast_tensors(*tensors)
|
||||
except RuntimeError as e:
|
||||
shapes = ", ".join(f"{chr(c)}.size={tuple(t.shape)}" for c, t in enumerate(tensors, ord("a")))
|
||||
raise ValueError(f"Could not broadcast tensors {shapes}.\n{str(e)}")
|
||||
else:
|
||||
broadcast_tensors = torch.broadcast_tensors
|
||||
|
||||
|
||||
class ForwardDepthMapsBatch(TypedDict):
|
||||
cam2world : Tensor # (B, 4, 4)
|
||||
uv : Tensor # (B, H, W)
|
||||
intrinsics : Tensor # (B, 3, 3)
|
||||
|
||||
class ForwardScanRaysBatch(TypedDict):
|
||||
origins : Tensor # (B, H, W, 3) or (B, 3)
|
||||
dirs : Tensor # (B, H, W, 3)
|
||||
|
||||
class LossBatch(TypedDict):
|
||||
hits : Tensor # (B, H, W) dtype=bool
|
||||
miss : Tensor # (B, H, W) dtype=bool
|
||||
depths : Tensor # (B, H, W)
|
||||
normals : Tensor # (B, H, W, 3) NaN if not hit
|
||||
distances : Tensor # (B, H, W, 1) NaN if not miss
|
||||
|
||||
class LabeledBatch(TypedDict):
|
||||
z_uid : list[Hashable]
|
||||
|
||||
ForwardBatch = Union[ForwardDepthMapsBatch, ForwardScanRaysBatch]
|
||||
TrainingBatch = Union[ForwardBatch, LossBatch, LabeledBatch]
|
||||
|
||||
|
||||
IntersectionMode = Literal[
|
||||
"medial_sphere",
|
||||
"orthogonal_plane",
|
||||
]
|
||||
|
||||
class IntersectionFieldModel(pl.LightningModule, RequiresConditioner, DtypeMixin):
|
||||
net: Union[MedialAtomNet, OrthogonalPlaneNet]
|
||||
|
||||
@ensure_schedulables
|
||||
def __init__(self,
|
||||
# mode
|
||||
input_mode : geometry.RayEmbedding = "plucker",
|
||||
output_mode : IntersectionMode = "medial_sphere",
|
||||
|
||||
# network
|
||||
latent_features : int = 256,
|
||||
hidden_features : int = 512,
|
||||
hidden_layers : int = 8,
|
||||
improve_miss_grads: bool = True,
|
||||
normalize_ray_dirs: bool = False, # the dataset is usually already normalized, but this could still be important for backprop
|
||||
|
||||
# orthogonal plane
|
||||
loss_hit_cross_entropy : Schedulable = 1.0,
|
||||
|
||||
# medial atoms
|
||||
loss_intersection : Schedulable = 1,
|
||||
loss_intersection_l2 : Schedulable = 0,
|
||||
loss_intersection_proj : Schedulable = 0,
|
||||
loss_intersection_proj_l2 : Schedulable = 0,
|
||||
loss_normal_cossim : Schedulable = 0.25, # supervise target normal cosine similarity
|
||||
loss_normal_euclid : Schedulable = 0, # supervise target normal l2 distance
|
||||
loss_normal_cossim_proj : Schedulable = 0, # supervise target normal cosine similarity
|
||||
loss_normal_euclid_proj : Schedulable = 0, # supervise target normal l2 distance
|
||||
loss_hit_nodistance_l1 : Schedulable = 0, # constrain no miss distance for hits
|
||||
loss_hit_nodistance_l2 : Schedulable = 32, # constrain no miss distance for hits
|
||||
loss_miss_distance_l1 : Schedulable = 0, # supervise target miss distance for misses
|
||||
loss_miss_distance_l2 : Schedulable = 0, # supervise target miss distance for misses
|
||||
loss_inscription_hits : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_inscription_hits_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_inscription_miss : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_inscription_miss_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray
|
||||
loss_sphere_grow_reg : Schedulable = 0, # maximialize sphere size
|
||||
loss_sphere_grow_reg_hit: Schedulable = 0, # maximialize sphere size
|
||||
loss_embedding_norm : Schedulable = "0.01**2 * Linear(15)", # DeepSDF schedules over 150 epochs. DeepSDF use 0.01**2, irobot uses 0.04**2
|
||||
loss_multi_view_reg : Schedulable = 0, # minimize gradient w.r.t. delta ray dir, when ray origin = intersection
|
||||
loss_atom_centroid_norm_std_reg : Schedulable = 0, # minimize per-atom centroid std
|
||||
|
||||
# optimization
|
||||
opt_learning_rate : Schedulable = 1e-5,
|
||||
opt_weight_decay : float = 0,
|
||||
opt_warmup : float = 0,
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
opt_warmup = Linear(opt_warmup)
|
||||
opt_warmup._param_name = "opt_warmup"
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
if "half" in input_mode:
|
||||
assert output_mode == "medial_sphere" and kw.get("n_atoms", 1) > 1
|
||||
|
||||
assert output_mode in ["medial_sphere", "orthogonal_plane"]
|
||||
assert opt_weight_decay >= 0, opt_weight_decay
|
||||
|
||||
if output_mode == "orthogonal_plane":
|
||||
self.net = OrthogonalPlaneNet(
|
||||
in_features = self.n_input_embedding_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
elif output_mode == "medial_sphere":
|
||||
self.net = MedialAtomNet(
|
||||
in_features = self.n_input_embedding_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def on_fit_start(self):
|
||||
if __debug__:
|
||||
for k, v in self.hparams.items():
|
||||
if isinstance(v, HParamScheduleBase):
|
||||
v.assert_positive(self.trainer.max_epochs)
|
||||
|
||||
@property
|
||||
def n_input_embedding_features(self) -> int:
|
||||
return geometry.ray_input_embedding_length(self.hparams.input_mode)
|
||||
|
||||
@property
|
||||
def n_latent_features(self) -> int:
|
||||
return self.hparams.latent_features
|
||||
|
||||
@property
|
||||
def latent_embeddings_init_std(self) -> float:
|
||||
return 0.01
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.net.is_conditioned
|
||||
|
||||
@property
|
||||
def is_double_backprop(self) -> bool:
|
||||
return self.is_double_backprop_origins or self.is_double_backprop_dirs
|
||||
|
||||
@property
|
||||
def is_double_backprop_origins(self) -> bool:
|
||||
prif = self.hparams.output_mode == "orthogonal_plane"
|
||||
return prif and self.hparams.loss_normal_cossim
|
||||
|
||||
@property
|
||||
def is_double_backprop_dirs(self) -> bool:
|
||||
return self.hparams.loss_multi_view_reg
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, **kw)
|
||||
yield MedialAtomNet.make_jinja_template(top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"latent_features",
|
||||
})
|
||||
|
||||
def batch2rays(self, batch: ForwardBatch) -> tuple[Tensor, Tensor]:
|
||||
if "uv" in batch:
|
||||
raise NotImplementedError
|
||||
assert not (self.hparams.loss_multi_view_reg and self.training)
|
||||
ray_origins, \
|
||||
ray_dirs, \
|
||||
= geometry.camera_uv_to_rays(
|
||||
cam2world = batch["cam2world"],
|
||||
uv = batch["uv"],
|
||||
intrinsics = batch["intrinsics"],
|
||||
)
|
||||
else:
|
||||
ray_origins = batch["points" if self.hparams.loss_multi_view_reg and self.training else "origins"]
|
||||
ray_dirs = batch["dirs"]
|
||||
return ray_origins, ray_dirs
|
||||
|
||||
def forward(self,
|
||||
batch : ForwardBatch,
|
||||
z : Optional[Tensor] = None, # latent code
|
||||
*,
|
||||
return_input : bool = False,
|
||||
allow_nans : bool = False, # in output
|
||||
**kw,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
(
|
||||
ray_origins, # (B, 3)
|
||||
ray_dirs, # (B, H, W, 3)
|
||||
) = self.batch2rays(batch)
|
||||
|
||||
# Ensure rays are normalized
|
||||
# NOTICE: this is slow, make sure to train with optimizations!
|
||||
assert ray_dirs.detach().norm(dim=-1).allclose(torch.ones(ray_dirs.shape[:-1], **self.device_and_dtype)),\
|
||||
ray_dirs.detach().norm(dim=-1)
|
||||
|
||||
if ray_origins.ndim + 2 == ray_dirs.ndim:
|
||||
ray_origins = ray_origins[..., None, None, :]
|
||||
|
||||
ray_origins, ray_dirs = broadcast_tensors(ray_origins, ray_dirs)
|
||||
|
||||
if self.is_double_backprop and self.training:
|
||||
if self.is_double_backprop_dirs:
|
||||
ray_dirs.requires_grad = True
|
||||
if self.is_double_backprop_origins:
|
||||
ray_origins.requires_grad = True
|
||||
assert ray_origins.requires_grad or ray_dirs.requires_grad
|
||||
|
||||
input = geometry.ray_input_embedding(
|
||||
ray_origins, ray_dirs,
|
||||
mode = self.hparams.input_mode,
|
||||
normalize_dirs = self.hparams.normalize_ray_dirs,
|
||||
is_training = self.training,
|
||||
)
|
||||
assert not input.detach().isnan().any()
|
||||
|
||||
predictions = self.net(input, z)
|
||||
|
||||
intersections = self.net.compute_intersections(
|
||||
ray_origins, ray_dirs, predictions,
|
||||
allow_nans = allow_nans and not self.training, **kw
|
||||
)
|
||||
if return_input:
|
||||
return ray_origins, ray_dirs, input, intersections
|
||||
else:
|
||||
return intersections
|
||||
|
||||
def training_step(self, batch: TrainingBatch, batch_idx: int, *, is_validation=False) -> Tensor:
|
||||
z = self.encode(batch) if self.is_conditioned else None
|
||||
assert self.is_conditioned or len(set(batch["z_uid"])) <= 1, \
|
||||
f"Network is unconditioned, but the batch has multiple uids: {set(batch['z_uid'])!r}"
|
||||
|
||||
# unpack
|
||||
target_hits = batch["hits"] # (B, H, W) dtype=bool
|
||||
target_miss = batch["miss"] # (B, H, W) dtype=bool
|
||||
target_points = batch["points"] # (B, H, W, 3)
|
||||
target_normals = batch["normals"] # (B, H, W, 3) NaN if not hit
|
||||
target_distances = batch["distances"] # (B, H, W) NaN if not miss
|
||||
assert not target_normals [target_hits].isnan().any()
|
||||
assert not target_distances[target_miss].isnan().any()
|
||||
target_normals[target_normals.isnan()] = 0
|
||||
assert not target_normals .isnan().any()
|
||||
|
||||
# make z fit batch scheme
|
||||
if z is not None:
|
||||
z = z[..., None, None, :]
|
||||
|
||||
losses = {}
|
||||
metrics = {}
|
||||
zeros = torch.zeros_like(target_distances)
|
||||
|
||||
if self.hparams.output_mode == "medial_sphere":
|
||||
assert isinstance(self.net, MedialAtomNet)
|
||||
ray_origins, ray_dirs, plucker, (
|
||||
depths, # (...) float, projection if not hit
|
||||
silhouettes, # (...) float
|
||||
intersections, # (..., 3) float, projection or NaN if not hit
|
||||
intersection_normals, # (..., 3) float, rejection or NaN if not hit
|
||||
is_intersecting, # (...) bool, true if hit
|
||||
sphere_centers, # (..., 3) network output
|
||||
sphere_radii, # (...) network output
|
||||
|
||||
atom_indices,
|
||||
all_intersections, # (..., N_ATOMS) float, projection or NaN if not hit
|
||||
all_intersection_normals, # (..., N_ATOMS, 3) float, rejection or NaN if not hit
|
||||
all_depths, # (..., N_ATOMS) float, projection if not hit
|
||||
all_silhouettes, # (..., N_ATOMS, 3) float, projection or NaN if not hit
|
||||
all_is_intersecting, # (..., N_ATOMS) bool, true if hit
|
||||
all_sphere_centers, # (..., N_ATOMS, 3) network output
|
||||
all_sphere_radii, # (..., N_ATOMS) network output
|
||||
) = self(batch, z,
|
||||
intersections_only = False,
|
||||
return_all_atoms = True,
|
||||
allow_nans = False,
|
||||
return_input = True,
|
||||
improve_miss_grads = True,
|
||||
)
|
||||
|
||||
# target hit supervision
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection: # scores true hits
|
||||
losses["loss_intersection"] = (
|
||||
(target_points - intersections).norm(dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2: # scores true hits
|
||||
losses["loss_intersection_l2"] = (
|
||||
(target_points - intersections).pow(2).sum(dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj: # scores misses as if they were hits, using the projection
|
||||
losses["loss_intersection_proj"] = (
|
||||
(target_points - intersections).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj_l2: # scores misses as if they were hits, using the projection
|
||||
losses["loss_intersection_proj_l2"] = (
|
||||
(target_points - intersections).pow(2).sum(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
# target hit normal supervision
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim: # scores true hits
|
||||
losses["loss_normal_cossim"] = (
|
||||
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid: # scores true hits
|
||||
losses["loss_normal_euclid"] = (
|
||||
(target_normals - intersection_normals).norm(dim=-1)
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim_proj: # scores misses as if they were hits
|
||||
losses["loss_normal_cossim_proj"] = (
|
||||
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid_proj: # scores misses as if they were hits
|
||||
losses["loss_normal_euclid_proj"] = (
|
||||
(target_normals - intersection_normals).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
# target sufficient hit radius
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l1: # ensures hits become hits, instead of relying on the projection being right
|
||||
losses["loss_hit_nodistance_l1"] = (
|
||||
silhouettes
|
||||
).where(target_hits & (silhouettes > 0), zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l2: # ensures hits become hits, instead of relying on the projection being right
|
||||
losses["loss_hit_nodistance_l2"] = (
|
||||
silhouettes
|
||||
).where(target_hits & (silhouettes > 0), zeros).pow(2).mean()
|
||||
|
||||
# target miss supervision
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l1: # only positive misses reinforcement
|
||||
losses["loss_miss_distance_l1"] = (
|
||||
target_distances - silhouettes
|
||||
).where(target_miss, zeros).abs().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l2: # only positive misses reinforcement
|
||||
losses["loss_miss_distance_l2"] = (
|
||||
target_distances - silhouettes
|
||||
).where(target_miss, zeros).pow(2).mean()
|
||||
|
||||
# incentivise maximal spheres
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg: # all atoms
|
||||
losses["loss_sphere_grow_reg"] = ((all_sphere_radii.detach() + 1) - all_sphere_radii).abs().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg_hit: # true hits only
|
||||
losses["loss_sphere_grow_reg_hit"] = ((sphere_radii.detach() + 1) - sphere_radii).where(target_hits & is_intersecting, zeros).abs().mean()
|
||||
|
||||
# spherical latent prior
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_embedding_norm:
|
||||
losses["loss_embedding_norm"] = self.latent_embeddings.norm(dim=-1).mean()
|
||||
|
||||
|
||||
is_grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
# multi-view regularization: atom should not change when view changes
|
||||
if self.hparams.loss_multi_view_reg and is_grad_enabled:
|
||||
assert ray_dirs.requires_grad, ray_dirs
|
||||
assert plucker.requires_grad, plucker
|
||||
assert intersections.grad_fn is not None
|
||||
assert intersection_normals.grad_fn is not None
|
||||
|
||||
*center_grads, radii_grads = diff.gradients(
|
||||
sphere_centers[..., 0],
|
||||
sphere_centers[..., 1],
|
||||
sphere_centers[..., 2],
|
||||
sphere_radii,
|
||||
wrt=ray_dirs,
|
||||
)
|
||||
|
||||
losses["loss_multi_view_reg"] = (
|
||||
sum(
|
||||
i.pow(2).sum(dim=-1)
|
||||
for i in center_grads
|
||||
).where(target_hits & is_intersecting, zeros).mean()
|
||||
+
|
||||
radii_grads.pow(2).sum(dim=-1)
|
||||
.where(target_hits & is_intersecting, zeros).mean()
|
||||
)
|
||||
|
||||
# minimize the volume spanned by each atom
|
||||
if self.hparams.loss_atom_centroid_norm_std_reg and self.net.n_atoms > 1:
|
||||
assert len(all_sphere_centers.shape) == 5, all_sphere_centers.shape
|
||||
losses["loss_atom_centroid_norm_std_reg"] \
|
||||
= ((
|
||||
all_sphere_centers
|
||||
- all_sphere_centers
|
||||
.mean(dim=(1, 2), keepdim=True)
|
||||
).pow(2).sum(dim=-1) - 0.05**2).clamp(0, None).mean()
|
||||
|
||||
# prif is l1, LSMAT is l2
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits or self.hparams.loss_inscription_miss or self.hparams.loss_inscription_hits_l2 or self.hparams.loss_inscription_miss_l2:
|
||||
b = target_hits.shape[0] # number of objects
|
||||
n = target_hits.shape[1:].numel() # rays per object
|
||||
perm = torch.randperm(n, device=self.device) # ray2ray permutation
|
||||
flatten = dict(start_dim=1, end_dim=len(target_hits.shape) - 1)
|
||||
|
||||
(
|
||||
inscr_sphere_center_projs, # (b, n, n_atoms, 3)
|
||||
inscr_intersections_near, # (b, n, n_atoms, 3)
|
||||
inscr_intersections_far, # (b, n, n_atoms, 3)
|
||||
inscr_is_intersecting, # (b, n, n_atoms) dtype=bool
|
||||
) = geometry.ray_sphere_intersect(
|
||||
ray_origins.flatten(**flatten)[:, perm, None, :],
|
||||
ray_dirs .flatten(**flatten)[:, perm, None, :],
|
||||
all_sphere_centers.flatten(**flatten),
|
||||
all_sphere_radii .flatten(**flatten),
|
||||
return_parts = True,
|
||||
allow_nans = False,
|
||||
improve_miss_grads = self.hparams.improve_miss_grads,
|
||||
)
|
||||
assert inscr_sphere_center_projs.shape == (b, n, self.net.n_atoms, 3), \
|
||||
(inscr_sphere_center_projs.shape, (b, n, self.net.n_atoms, 3))
|
||||
inscr_silhouettes = (
|
||||
inscr_sphere_center_projs - all_sphere_centers.flatten(**flatten)
|
||||
).norm(dim=-1) - all_sphere_radii.flatten(**flatten)
|
||||
|
||||
loss_inscription_hits = (
|
||||
(
|
||||
(inscr_intersections_near - target_points.flatten(**flatten)[:, perm, None, :])
|
||||
* ray_dirs.flatten(**flatten)[:, perm, None, :]
|
||||
).sum(dim=-1)
|
||||
).where(target_hits.flatten(**flatten)[:, perm, None] & inscr_is_intersecting,
|
||||
torch.zeros(inscr_intersections_near.shape[:-1], **self.device_and_dtype),
|
||||
).clamp(None, 0)
|
||||
loss_inscription_miss = (
|
||||
inscr_silhouettes - target_distances.flatten(**flatten)[:, perm, None]
|
||||
).where(target_miss.flatten(**flatten)[:, perm, None],
|
||||
torch.zeros_like(inscr_silhouettes)
|
||||
).clamp(None, 0)
|
||||
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits:
|
||||
losses["loss_inscription_hits"] = loss_inscription_hits.neg().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss:
|
||||
losses["loss_inscription_miss"] = loss_inscription_miss.neg().mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits_l2:
|
||||
losses["loss_inscription_hits_l2"] = loss_inscription_hits.pow(2).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss_l2:
|
||||
losses["loss_inscription_miss_l2"] = loss_inscription_miss.pow(2).mean()
|
||||
|
||||
# metrics
|
||||
metrics["iou"] = (
|
||||
((~target_miss) & is_intersecting.detach()).sum() /
|
||||
((~target_miss) | is_intersecting.detach()).sum()
|
||||
)
|
||||
metrics["radii"] = sphere_radii.detach().mean() # with the constant applied pressure, we need to measure it this way instead
|
||||
|
||||
elif self.hparams.output_mode == "orthogonal_plane":
|
||||
assert isinstance(self.net, OrthogonalPlaneNet)
|
||||
ray_origins, ray_dirs, input_embedding, (
|
||||
intersections, # (..., 3) dtype=float
|
||||
is_intersecting, # (...) dtype=float
|
||||
) = self(batch, z, return_input=True, normalize_origins=True)
|
||||
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection:
|
||||
losses["loss_intersection"] = (
|
||||
(intersections - target_points).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2:
|
||||
losses["loss_intersection_l2"] = (
|
||||
(intersections - target_points).pow(2).sum(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_cross_entropy:
|
||||
losses["loss_hit_cross_entropy"] = (
|
||||
F.binary_cross_entropy_with_logits(is_intersecting, (~target_miss).to(self.dtype))
|
||||
).mean()
|
||||
|
||||
if self.hparams.loss_normal_cossim and torch.is_grad_enabled():
|
||||
jac = diff.jacobian(intersections, ray_origins)
|
||||
intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs)
|
||||
losses["loss_normal_cossim"] = (
|
||||
1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
if self.hparams.loss_normal_euclid and torch.is_grad_enabled():
|
||||
jac = diff.jacobian(intersections, ray_origins)
|
||||
intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs)
|
||||
losses["loss_normal_euclid"] = (
|
||||
(target_normals - intersection_normals).norm(dim=-1)
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
if self.hparams.loss_multi_view_reg and torch.is_grad_enabled():
|
||||
assert ray_dirs .requires_grad, ray_dirs
|
||||
assert intersections.grad_fn is not None
|
||||
grads = diff.gradients(
|
||||
intersections[..., 0],
|
||||
intersections[..., 1],
|
||||
intersections[..., 2],
|
||||
wrt=ray_dirs,
|
||||
)
|
||||
losses["loss_multi_view_reg"] = sum(
|
||||
i.pow(2).sum(dim=-1)
|
||||
for i in grads
|
||||
).where(target_hits, zeros).mean()
|
||||
|
||||
metrics["iou"] = (
|
||||
((~target_miss) & (is_intersecting>0.5).detach()).sum() /
|
||||
((~target_miss) | (is_intersecting>0.5).detach()).sum()
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(self.hparams.output_mode)
|
||||
|
||||
# output losses and metrics
|
||||
|
||||
# apply scaling:
|
||||
losses_unscaled = losses.copy() # shallow copy
|
||||
for k in list(losses.keys()):
|
||||
assert losses[k].numel() == 1, f"losses[{k!r}] shape: {losses[k].shape}"
|
||||
val_schedule: HParamSchedule = self.hparams[k]
|
||||
val = val_schedule.get(self)
|
||||
if val == 0:
|
||||
if (__debug__ or LOG_ALL_METRICS) and val_schedule.is_const:
|
||||
del losses[k] # it was only added for unscaled logging, do not backprop
|
||||
else:
|
||||
losses[k] = 0
|
||||
elif val != 1:
|
||||
losses[k] = losses[k] * val
|
||||
|
||||
if not losses:
|
||||
raise MisconfigurationException("no loss was computed")
|
||||
|
||||
losses["loss"] = sum(losses.values()) * self.hparams.opt_warmup.get(self)
|
||||
losses.update({f"unscaled_{k}": v.detach() for k, v in losses_unscaled.items()})
|
||||
losses.update({f"metric_{k}": v.detach() for k, v in metrics.items()})
|
||||
return losses
|
||||
|
||||
|
||||
# used by pl.callbacks.EarlyStopping, via cli.py
|
||||
@property
|
||||
def metric_early_stop(self): return (
|
||||
"unscaled_loss_intersection_proj"
|
||||
if self.hparams.output_mode == "medial_sphere" else
|
||||
"unscaled_loss_intersection"
|
||||
)
|
||||
|
||||
def validation_step(self, batch: TrainingBatch, batch_idx: int) -> dict[str, Tensor]:
|
||||
losses = self.training_step(batch, batch_idx, is_validation=True)
|
||||
return losses
|
||||
|
||||
def configure_optimizers(self):
|
||||
adam = torch.optim.Adam(self.parameters(),
|
||||
lr=1 if not self.hparams.opt_learning_rate.is_const else self.hparams.opt_learning_rate.get_train_value(0),
|
||||
weight_decay=self.hparams.opt_weight_decay)
|
||||
schedules = []
|
||||
if not self.hparams.opt_learning_rate.is_const:
|
||||
schedules = [
|
||||
torch.optim.lr_scheduler.LambdaLR(adam,
|
||||
lambda epoch: self.hparams.opt_learning_rate.get_train_value(epoch),
|
||||
),
|
||||
]
|
||||
return [adam], schedules
|
||||
|
||||
@property
|
||||
def example_input_array(self) -> tuple[dict[str, Tensor], Tensor]:
|
||||
return (
|
||||
{ # see self.batch2rays
|
||||
"origins" : torch.zeros(1, 3), # most commonly used
|
||||
"points" : torch.zeros(1, 3), # used if self.training and self.hparams.loss_multi_view_reg
|
||||
"dirs" : torch.ones(1, 3) * torch.rsqrt(torch.tensor(3)),
|
||||
},
|
||||
torch.ones(1, self.hparams.latent_features),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_normals_from_intersection_origin_jacobian(origin_jac: Tensor, ray_dirs: Tensor) -> Tensor:
|
||||
normals = sum((
|
||||
torch.cross(origin_jac[..., 0], origin_jac[..., 1], dim=-1) * -ray_dirs[..., [2]],
|
||||
torch.cross(origin_jac[..., 1], origin_jac[..., 2], dim=-1) * -ray_dirs[..., [0]],
|
||||
torch.cross(origin_jac[..., 2], origin_jac[..., 0], dim=-1) * -ray_dirs[..., [1]],
|
||||
))
|
||||
return normals / normals.norm(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
class IntersectionFieldAutoDecoderModel(IntersectionFieldModel, AutoDecoderModuleMixin):
|
||||
def encode(self, batch: LabeledBatch) -> Tensor:
|
||||
assert not isinstance(self.trainer.strategy, pl.strategies.DataParallelStrategy)
|
||||
return self[batch["z_uid"]] # [N, Z_n]
|
||||
186
ifield/models/medial_atoms.py
Normal file
186
ifield/models/medial_atoms.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from .. import param
|
||||
from ..modules import fc
|
||||
from ..data.common import points
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from textwrap import indent, dedent
|
||||
from torch import nn, Tensor
|
||||
from typing import Optional
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
# generalize this into a HypoHyperConcat net? ConditionedNet?
|
||||
class MedialAtomNet(nn.Module):
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
latent_features : int,
|
||||
hidden_features : int,
|
||||
hidden_layers : int,
|
||||
n_atoms : int = 1,
|
||||
final_init_wrr : tuple[float, float] | None = (0.05, 0.6, 0.1),
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
assert n_atoms >= 1, n_atoms
|
||||
self.n_atoms = n_atoms
|
||||
|
||||
self.fc = fc.FCBlock(
|
||||
in_features = in_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
out_features = n_atoms * 4, # n_atoms * (x, y, z, r)
|
||||
outermost_linear = True,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
if final_init_wrr is not None:
|
||||
with torch.no_grad():
|
||||
w, r1, r2 = final_init_wrr
|
||||
if w != 1: self.fc[-1].linear.weight *= w
|
||||
dtype = self.fc[-1].linear.bias.dtype
|
||||
self.fc[-1].linear.bias[..., [4*n+i for n in range(n_atoms) for i in range(3)]] = torch.tensor(points.generate_random_sphere_points(n_atoms, radius=r1), dtype=dtype).flatten()
|
||||
self.fc[-1].linear.bias[..., 3::4] = r2
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.fc.is_conditioned
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
|
||||
yield fc.FCBlock.make_jinja_template(top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"out_features",
|
||||
"outermost_linear",
|
||||
"latent_features",
|
||||
})
|
||||
|
||||
def forward(self, x: Tensor, z: Optional[Tensor] = None):
|
||||
if __debug__ and self.is_conditioned and z is None:
|
||||
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
|
||||
return self.fc(x, z)
|
||||
|
||||
def compute_intersections(self,
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
medial_atoms : Tensor, # (..., 4*self.n_atoms)
|
||||
*,
|
||||
intersections_only : bool = True,
|
||||
return_all_atoms : bool = False, # only applies if intersections_only=False
|
||||
allow_nans : bool = True,
|
||||
improve_miss_grads : bool = False,
|
||||
) -> tuple[(Tensor,)*5]:
|
||||
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == medial_atoms.shape[:-1], \
|
||||
(ray_origins.shape, ray_dirs.shape, medial_atoms.shape)
|
||||
assert medial_atoms.shape[-1] % 4 == 0, \
|
||||
medial_atoms.shape
|
||||
assert ray_origins.shape[-1] == ray_dirs.shape[-1] == 3, \
|
||||
(ray_origins.shape, ray_dirs.shape)
|
||||
|
||||
#n_atoms = medial_atoms.shape[-1] // 4
|
||||
n_atoms = medial_atoms.shape[-1] >> 2
|
||||
|
||||
# reshape (..., n_atoms * d) to (..., n_atoms, d)
|
||||
medial_atoms = medial_atoms.view(*medial_atoms.shape[:-1], n_atoms, 4)
|
||||
ray_origins = ray_origins.unsqueeze(-2).broadcast_to([*ray_origins.shape[:-1], n_atoms, 3])
|
||||
ray_dirs = ray_dirs .unsqueeze(-2).broadcast_to([*ray_dirs .shape[:-1], n_atoms, 3])
|
||||
|
||||
# unpack atoms
|
||||
sphere_centers = medial_atoms[..., :3]
|
||||
sphere_radii = medial_atoms[..., 3].abs()
|
||||
|
||||
assert not ray_origins .detach().isnan().any()
|
||||
assert not ray_dirs .detach().isnan().any()
|
||||
assert not sphere_centers.detach().isnan().any()
|
||||
assert not sphere_radii .detach().isnan().any()
|
||||
|
||||
# compute intersections
|
||||
(
|
||||
sphere_center_projs, # (..., 3)
|
||||
intersections_near, # (..., 3)
|
||||
intersections_far, # (..., 3)
|
||||
is_intersecting, # (...) bool
|
||||
) = geometry.ray_sphere_intersect(
|
||||
ray_origins,
|
||||
ray_dirs,
|
||||
sphere_centers,
|
||||
sphere_radii,
|
||||
return_parts = True,
|
||||
allow_nans = allow_nans,
|
||||
improve_miss_grads = improve_miss_grads,
|
||||
)
|
||||
|
||||
# early return
|
||||
if intersections_only and n_atoms == 1:
|
||||
return intersections_near.squeeze(-2), is_intersecting.squeeze(-1)
|
||||
|
||||
# compute how close each hit and miss are
|
||||
depths = ((intersections_near - ray_origins) * ray_dirs).sum(-1)
|
||||
silhouettes = torch.linalg.norm(sphere_center_projs - sphere_centers, dim=-1) - sphere_radii
|
||||
|
||||
if return_all_atoms:
|
||||
intersections_near_all = intersections_near
|
||||
depths_all = depths
|
||||
silhouettes_all = silhouettes
|
||||
is_intersecting_all = is_intersecting
|
||||
sphere_centers_all = sphere_centers
|
||||
sphere_radii_all = sphere_radii
|
||||
|
||||
# collapse n_atoms
|
||||
if n_atoms > 1:
|
||||
atom_indices = torch.where(is_intersecting.any(dim=-1, keepdim=True),
|
||||
torch.where(is_intersecting, depths.detach(), depths.detach()+100).argmin(dim=-1, keepdim=True),
|
||||
silhouettes.detach().argmin(dim=-1, keepdim=True),
|
||||
)
|
||||
|
||||
intersections_near = intersections_near.take_along_dim(atom_indices[..., None], -2).squeeze(-2)
|
||||
depths = depths .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
silhouettes = silhouettes .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
is_intersecting = is_intersecting .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
sphere_centers = sphere_centers .take_along_dim(atom_indices[..., None], -2).squeeze(-2)
|
||||
sphere_radii = sphere_radii .take_along_dim(atom_indices, -1).squeeze(-1)
|
||||
else:
|
||||
atom_indices = None
|
||||
intersections_near = intersections_near.squeeze(-2)
|
||||
depths = depths .squeeze(-1)
|
||||
silhouettes = silhouettes .squeeze(-1)
|
||||
is_intersecting = is_intersecting .squeeze(-1)
|
||||
sphere_centers = sphere_centers .squeeze(-2)
|
||||
sphere_radii = sphere_radii .squeeze(-1)
|
||||
|
||||
# early return
|
||||
if intersections_only:
|
||||
return intersections_near, is_intersecting
|
||||
|
||||
# compute sphere normals
|
||||
intersection_normals = intersections_near - sphere_centers
|
||||
intersection_normals = intersection_normals / (intersection_normals.norm(dim=-1)[..., None] + 1e-9)
|
||||
|
||||
if return_all_atoms:
|
||||
intersection_normals_all = intersections_near_all - sphere_centers_all
|
||||
intersection_normals_all = intersection_normals_all / (intersection_normals_all.norm(dim=-1)[..., None] + 1e-9)
|
||||
|
||||
|
||||
return (
|
||||
depths, # (...) valid if hit, based on 'intersections'
|
||||
silhouettes, # (...) always valid
|
||||
intersections_near, # (..., 3) valid if hit, projection if not
|
||||
intersection_normals, # (..., 3) valid if hit, rejection if not
|
||||
is_intersecting, # (...) dtype=bool
|
||||
sphere_centers, # (..., 3) network output
|
||||
sphere_radii, # (...) network output
|
||||
*(() if not return_all_atoms else (
|
||||
|
||||
atom_indices,
|
||||
intersections_near_all, # (..., N_ATOMS) valid if hit, based on 'intersections'
|
||||
intersection_normals_all, # (..., N_ATOMS, 3) valid if hit, rejection if not
|
||||
depths_all, # (..., N_ATOMS) always valid
|
||||
silhouettes_all, # (..., N_ATOMS, 3) valid if hit, projection if not
|
||||
is_intersecting_all, # (..., N_ATOMS) dtype=bool
|
||||
sphere_centers_all, # (..., N_ATOMS, 3) network output
|
||||
sphere_radii_all, # (..., N_ATOMS) network output
|
||||
)))
|
||||
101
ifield/models/orthogonal_plane.py
Normal file
101
ifield/models/orthogonal_plane.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from .. import param
|
||||
from ..modules import fc
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from textwrap import indent, dedent
|
||||
from torch import nn, Tensor
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
class OrthogonalPlaneNet(nn.Module):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
latent_features : int,
|
||||
hidden_features : int,
|
||||
hidden_layers : int,
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc = fc.FCBlock(
|
||||
in_features = in_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
out_features = 2, # (plane_offset, is_intersecting)
|
||||
outermost_linear = True,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.fc.is_conditioned
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
|
||||
yield param.make_jinja_template(fc.FCBlock, top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"out_features",
|
||||
"outermost_linear",
|
||||
})
|
||||
|
||||
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
|
||||
if __debug__ and self.is_conditioned and z is None:
|
||||
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
|
||||
return self.fc(x, z)
|
||||
|
||||
@staticmethod
|
||||
def compute_intersections(
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
predictions : Tensor, # (..., 2)
|
||||
*,
|
||||
normalize_origins = True,
|
||||
return_signed_displacements = False,
|
||||
allow_nans = False, # MARF compat
|
||||
atom_random_prob = None, # MARF compat
|
||||
atom_dropout_prob = None, # MARF compat
|
||||
) -> tuple[(Tensor,)*5]:
|
||||
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == predictions.shape[:-1], \
|
||||
(ray_origins.shape, ray_dirs.shape, predictions.shape)
|
||||
assert predictions.shape[-1] == 2, \
|
||||
predictions.shape
|
||||
|
||||
assert not allow_nans
|
||||
|
||||
if normalize_origins:
|
||||
ray_origins = geometry.project_point_on_ray(0, ray_origins, ray_dirs)
|
||||
|
||||
# unpack predictions
|
||||
signed_displacements = predictions[..., 0]
|
||||
is_intersecting = predictions[..., 1]
|
||||
|
||||
# compute intersections
|
||||
intersections = ray_origins - signed_displacements[..., None] * ray_dirs
|
||||
|
||||
return (
|
||||
intersections,
|
||||
is_intersecting,
|
||||
*((signed_displacements,) if return_signed_displacements else ()),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
OrthogonalPlaneNet.__doc__ = __doc__ = f"""
|
||||
{dedent(OrthogonalPlaneNet.__doc__).strip()}
|
||||
|
||||
# Config template:
|
||||
|
||||
```yaml
|
||||
{OrthogonalPlaneNet.make_jinja_template()}
|
||||
```
|
||||
"""
|
||||
3
ifield/modules/__init__.py
Normal file
3
ifield/modules/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Contains Pytorch Modules
|
||||
"""
|
||||
22
ifield/modules/dtype.py
Normal file
22
ifield/modules/dtype.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class DtypeMixin:
|
||||
def __init_subclass__(cls):
|
||||
assert issubclass(cls, pl.LightningModule), \
|
||||
| ||||