Skip to content

Commit 45afa6c

Browse files
Add (final?) cvpr code + plotting scripts
1 parent 070355c commit 45afa6c

20 files changed

+633
-325
lines changed

dart/fields/grid.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Basic Plenoxels-inspired Grid."""
22

3+
import os
4+
import numpy as np
35
from jax import numpy as jnp
46
import haiku as hk
57
import math
@@ -73,10 +75,10 @@ def closure():
7375
def to_parser(p: types.ParserLike) -> None:
7476
"""Create grid command line arguments."""
7577
p.add_argument(
76-
"--lower", default=[-4.0, -4.0, -1.0], nargs='+', type=float,
78+
"--lower", default=None, nargs='+', type=float,
7779
help="Lower coordinate (x, y, z).")
7880
p.add_argument(
79-
"--upper", default=[4.0, 4.0, 1.0], nargs='+', type=float,
81+
"--upper", default=None, nargs='+', type=float,
8082
help="Upper coordinate (x, y, z).")
8183
p.add_argument(
8284
"--resolution", default=25.0, type=float,
@@ -88,6 +90,12 @@ def to_parser(p: types.ParserLike) -> None:
8890
@classmethod
8991
def args_to_config(cls, args: types.Namespace) -> dict:
9092
"""Create configuration dictionary."""
93+
if args.lower is None or args.upper is None:
94+
datadir = os.path.dirname(args.path)
95+
npz = np.load(os.path.join(datadir, "map.npz"))
96+
args.lower = npz['lower'].tolist()
97+
args.upper = npz['upper'].tolist()
98+
9199
assert len(args.upper) == 3
92100
assert len(args.lower) == 3
93101
grid_size = [

dart/metrics.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def mse(
1919
Returns
2020
-------
2121
mse: MSE of the optimally-scaled `y_hat`.
22-
alpha: Optimal scale factor.
22+
xi: Optimal scale factor.
2323
"""
24-
alpha = jnp.sum(y_true * y_hat) / jnp.sum(y_hat**2)
25-
mse = jnp.sum(jnp.square(y_true - alpha * y_hat))
26-
return mse, alpha
24+
xi = jnp.sum(y_true * y_hat) / jnp.sum(y_hat**2)
25+
mse = jnp.sum(jnp.square(y_true - xi * y_hat))
26+
return mse, xi
2727

2828

2929
def ssim(

notebooks/map.ipynb

-179
This file was deleted.

notebooks/volshow.ipynb

-95
This file was deleted.

plot/_result.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Result convenience wrapper."""
2+
3+
import os
4+
import json
5+
import numpy as np
6+
import pandas as pd
7+
import h5py
8+
9+
from beartype.typing import Any
10+
11+
12+
def _json(p):
13+
with open(p) as f:
14+
return json.load(f)
15+
16+
17+
class DartResult:
18+
"""DART result/dataset convenience wrapper."""
19+
20+
def __init__(self, path: str) -> None:
21+
_meta = os.path.join(path, "metadata.json")
22+
if not os.path.exists(_meta):
23+
raise FileNotFoundError(
24+
"Result path does not exist (could not find {})".format(_meta))
25+
26+
self.metadata = _json(_meta)
27+
self.resdir = path
28+
self.datadir = os.path.dirname(self.metadata["dataset"]["path"])
29+
30+
def dart(self) -> "DART": # type: ignore
31+
"""Construct DART for results.
32+
33+
NOTE: will import DART (and load jax & other heavy dependencies) on
34+
first call.
35+
"""
36+
from dart import DART
37+
return DART.from_config(**self.metadata)
38+
39+
def path(self, subpath: str) -> str:
40+
"""Translate path to result/data directory."""
41+
if subpath.startswith("data/"):
42+
return os.path.join(self.datadir, subpath.replace("data/", ""))
43+
else:
44+
return os.path.join(self.resdir, subpath.replace("result/", ""))
45+
46+
def __getitem__(self, subpath: str) -> Any:
47+
"""Load npz/csv/h5/json file.
48+
49+
Use `data/` to indicate when to load from the dataset directory;
50+
otherwise, `subpath` is assumed to be in `results`. A `results/` prefix
51+
can also be passed (which is removed).
52+
"""
53+
path = self.path(subpath)
54+
if not os.path.exists(path):
55+
raise FileNotFoundError("File does not exist: {}".format(path))
56+
57+
def _err(p):
58+
raise ValueError("Unknown file extension: {}".format(p))
59+
60+
exts = {
61+
".npz": np.load, ".csv": pd.read_csv,
62+
".h5": h5py.File, ".json": _json}
63+
return exts.get(os.path.splitext(path)[1], _err)(path)

plot/_stats.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Statistical utilities."""
2+
3+
import os
4+
import numpy as np
5+
from jaxtyping import Num
6+
from beartype.typing import Union
7+
8+
9+
BASELINES = ['lidar', 'nearest', 'cfar', 'cfar_1e-2', 'cfar_1e-5', 'cfar_1e-8']
10+
11+
12+
def load_dir(path, key="ssim", baselines=BASELINES):
13+
"""Load metrics from a given dataset."""
14+
15+
def _try_load(*path):
16+
try:
17+
return np.load(os.path.join(*path))[key]
18+
except FileNotFoundError:
19+
return None
20+
21+
res = os.path.join("results", path)
22+
data = os.path.join("data", path)
23+
24+
contents = {k: _try_load(res, k, "metrics.npz") for k in os.listdir(res)}
25+
for k in baselines:
26+
contents[k] = _try_load(data, "baselines/{}.npz".format(k))
27+
28+
return {k: v for k, v in contents.items() if v is not None}
29+
30+
31+
def effective_sample_size(x: Num[np.ndarray, "t"]) -> float:
32+
"""Calculate effective sample size for time series data."""
33+
rho = np.array([
34+
np.cov(x[i:], x[:-i])[0, 1] / np.std(x[i:]) / np.std(x[:-i])
35+
for i in range(1, x.shape[0] // 2)])
36+
rho_sum = np.sum(np.maximum(0.0, rho))
37+
return x.shape[0] / (1 + 2 * rho_sum)

plot/boxes.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from matplotlib import pyplot as plt
2+
import h5py
3+
import numpy as np
4+
5+
6+
def _im(ax, x):
7+
lower, upper = np.percentile(x, [0, 99])
8+
return ax.imshow(np.clip(x, lower, upper), cmap='inferno')
9+
10+
h5 = h5py.File("results/boxes3/ngpsh/map.h5")
11+
fig, axs = plt.subplots(1, 2, figsize=(6, 4))
12+
im1 = _im(axs[0], np.mean(h5['sigma'][210:390, 230:410, 95:130], axis=2))
13+
im2 = _im(axs[1], np.mean(-h5['alpha'][210:390, 230:410, 95:130], axis=2))
14+
15+
axs[0].text(5, 5, "Reflectance", color='white', ha='left', va='top')
16+
axs[1].text(5, 5, "Transmittance", color='white', ha='left', va='top')
17+
for ax in axs:
18+
ax.set_xticks([])
19+
ax.set_yticks([])
20+
ax.text(50, 75, "(1)", color='white')
21+
ax.text(125, 75, "(2)", color='white')
22+
ax.text(20, 145, "(5)", color='white')
23+
ax.text(78, 172, "(4)", color='white')
24+
ax.text(135, 140, "(3)", color='white')
25+
fig.tight_layout(pad=1.0)
26+
27+
cbar_ax = fig.add_axes([0.021, 0.08, 0.956, 0.04])
28+
fig.colorbar(im2, cax=cbar_ax, orientation='horizontal')
29+
cbar_ax.set_xticks([])
30+
cbar_ax.set_xlabel(r"Increasing Reflectance $\longrightarrow$", loc='left')
31+
32+
cbar_ax2 = cbar_ax.twiny()
33+
cbar_ax2.xaxis.set_ticks_position('bottom')
34+
cbar_ax2.xaxis.set_label_position('bottom')
35+
cbar_ax2.set_xticks([])
36+
cbar_ax2.set_xlabel(r"$\longleftarrow$ Increasing Transmittance", loc='right')
37+
38+
fig.savefig("figures/boxes.pdf", bbox_inches='tight')

plot/cdf.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Plot SSIM CDF."""
2+
3+
import os
4+
from matplotlib import pyplot as plt
5+
import numpy as np
6+
7+
from _stats import load_dir
8+
9+
10+
fig, axs = plt.subplots(2, 6, figsize=(12, 4))
11+
datasets = {
12+
"boxes2": "Lab 1",
13+
"boxes3": "Lab 2",
14+
"wiselab4": "Office 1",
15+
"wiselab5": "Office 2",
16+
"mallesh-half": "Rowhouse 1",
17+
"mallesh-1br": "Rowhouse 2",
18+
"mallesh-full": "Rowhouse 3",
19+
"agr-ground": "House 1",
20+
"agr-full": "House 2",
21+
"agr-yard": "Yard",
22+
"tianshu-full": "Apartment 1",
23+
"tianshu-half": "Apartment 2"
24+
}
25+
methods = {
26+
"ngpsh": ("DART", 'C0', '-'),
27+
"lidar": ("Lidar", 'C1', '--'),
28+
"nearest": ("Nearest", 'C2', ':'),
29+
"cfar": ("CFAR", 'C3', '-.')
30+
}
31+
32+
for (ds, label), ax in zip(datasets.items(), axs.reshape(-1)):
33+
ssim = load_dir(ds)
34+
_ref = np.load(os.path.join("data", ds, "baselines", "reference.npz"))
35+
ref = np.mean(_ref["ssim"], axis=0)
36+
37+
ax.axvline(
38+
ref[0], color='black', linestyle='--', label='25/30/35db Reference',
39+
linewidth=1.0)
40+
ax.axvline(ref[1], color='black', linestyle='--', linewidth=1.0)
41+
ax.axvline(ref[2], color='black', linestyle='--', linewidth=1.0)
42+
43+
for k, (desc, color, ls) in methods.items():
44+
if k in ssim:
45+
v = ssim[k]
46+
ax.plot(
47+
np.sort(v), np.arange(v.shape[0]) / v.shape[0],
48+
label=desc, color=color, linestyle=ls)
49+
50+
ax.grid(visible=True)
51+
ax.set_xlim(0.35, 0.85)
52+
ax.text(0.83, 0.01, label, ha='right', va='bottom', backgroundcolor='white')
53+
54+
for ax in axs[:, 1:].reshape(-1):
55+
for tick in ax.yaxis.get_major_ticks():
56+
tick.tick1line.set_visible(False)
57+
tick.tick2line.set_visible(False)
58+
tick.label1.set_visible(False)
59+
tick.label2.set_visible(False)
60+
for ax in axs[:-1].reshape(-1):
61+
for tick in ax.xaxis.get_major_ticks():
62+
tick.tick1line.set_visible(False)
63+
tick.tick2line.set_visible(False)
64+
tick.label1.set_visible(False)
65+
tick.label2.set_visible(False)
66+
67+
68+
fig.tight_layout(h_pad=0.2, w_pad=0.5)
69+
axs[-1, -1].legend(
70+
ncols=5, loc='upper right', bbox_to_anchor=(1.05, -0.15), frameon=False)
71+
axs[1, 0].set_ylabel(r"Cumulative Probability $\longrightarrow$", loc='bottom')
72+
axs[1, 0].set_xlabel(r"SSIM (higher is better) $\longrightarrow$", loc='left')
73+
fig.savefig("figures/ssim.pdf", bbox_inches='tight')

0 commit comments

Comments
 (0)