Skip to content

Commit c1ed4b8

Browse files
Misc refactoring
1 parent de32b73 commit c1ed4b8

26 files changed

+122
-473
lines changed

Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ $(_SLICES): $(_MAP)
3737
$(_RAD): $(_TGT)
3838
$(DART) evaluate -p results/$(TARGET) -a -b $(BATCH)
3939
# SSIM calculation
40-
$(_SSIM): $(_TGT)
40+
$(_SSIM): $(_RAD)
4141
$(DART) ssim -p results/$(TARGET)
4242
# Camera evaluation
4343
$(_CAM): $(_TGT)
@@ -54,7 +54,7 @@ train: results/$(TARGET)
5454
map: $(_MAP)
5555
slices: $(_SLICES)
5656
radar: $(_RAD)
57-
ssim: results/$(TARGET)/ssim.npz
57+
ssim: $(_SSIM)
5858
camera: $(_CAM)
5959
video: $(_VIDEO)
6060

dart/dataset.py

+13-25
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ def load_arrays(file: str, keys: Optional[list[str]] = None) -> Any:
4646
"Unknown file type: {} (expected .npz, .h5, or .mat)".format(file))
4747

4848

49-
def gt_map(file: str) -> GroundTruth:
50-
"""Load ground truth reflectance map."""
51-
data = load_arrays(file)
52-
return GroundTruth.from_occupancy(
53-
jnp.array(data['grid']), data["lower"], data["upper"])
54-
55-
5649
def trajectory(
5750
traj: str, subset: Optional[Integer[types.ArrayLike, "nval"]] = None
5851
) -> types.Dataset:
@@ -113,8 +106,8 @@ def __doppler_decimation(
113106

114107

115108
def doppler_columns(
116-
sensor: VirtualRadar, path: str,
117-
pval: float = 0., iid_val: bool = False, doppler_decimation: int = 0,
109+
path: str, pval: float = 0., iid_val: bool = False,
110+
doppler_decimation: int = 0,
118111
key: types.PRNGSeed = 42
119112
) -> tuple[types.Dataset, Optional[types.Dataset], dict[str, PyTree]]:
120113
"""Load dataset trajectory and images.
@@ -129,55 +122,50 @@ def doppler_columns(
129122
130123
Parameters
131124
----------
132-
sensor: Sensor profile for this dataset.
133125
path: Path to file containing data.
134-
norm: Normalization factor.
135126
pval: Proportion of dataset to hold as a validation set. If `pval=0`,
136127
no validation dataset is returned.
137128
iid_val: If True, then shuffles the dataset before training so that the
138129
validation split is drawn randomly from the dataset instead of just
139130
from the end.
140-
min_speed: Minimum speed for usable samples. Images with lower
141-
velocities are rejected.
142-
repeat: Repeat dataset within each epoch to reduce overhead.
143-
threshold: Mask out values less than the provided threshold (set to 0).
144-
doppler_decimation: Simulate a lower doppler resolution by setting each
145-
block of consecutive doppler columns to their average.
146131
key: Random key to shuffle dataset frames. Does not shuffle columns.
147132
148133
Returns
149134
-------
150135
train: Train dataset.
151136
val: Val dataset.
152-
validx: Indices of original images corresponding to the validation set.
137+
meta: Metadata (exact split indices).
153138
"""
154139
file = h5py.File(path)
155140

156141
pose = types.RadarPose.from_h5file(file)
157-
idx = np.array(file["idx"])
158142
rad = np.array(file["rad"], dtype=np.float16)
159143
weight = np.array(file["weight"], dtype=np.float32)
160144
doppler = np.array(file["doppler"], dtype=np.float32)
145+
idx = np.arange(rad.shape[0])
161146

162147
meta = types.TrainingColumn(pose=pose, weight=weight, doppler=doppler)
163-
data = (meta, rad)
148+
data = (meta, rad), idx
164149

165150
print("Loaded dataset : {} valid columns".format(rad.shape))
166151

167152
if iid_val:
168153
data = utils.shuffle(data, key=key)
169154

170-
meta, rad = data
171-
nval = 0 if pval <= 0 else int(utils.get_size(data) * pval)
172-
train, val = utils.split((meta, rad), nval=nval)
155+
nval = 0 if pval <= 0 else int(rad.shape[0] * pval)
156+
(train, itrain), _val = utils.split(data, nval=nval)
173157

174158
if not iid_val:
175159
train = utils.shuffle(train, key=key)
176160

177161
print("Train split : {} columns".format(train[1].shape))
178162
train = types.Dataset.from_tensor_slices(train)
179-
if val is not None:
163+
if _val is not None:
164+
val, ival = _val
180165
print("Test split : {} columns".format(val[1].shape))
181166
val = types.Dataset.from_tensor_slices(val)
167+
else:
168+
val = None
169+
ival = np.zeros(0, dtype=bool)
182170

183-
return train, val, {"val": idx[-nval:]}
171+
return train, val, {"train": itrain, "val": ival}

dart/fields/ngp.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import jax
1414
import haiku as hk
1515

16-
from jaxtyping import Float32, Integer, Array, Float
17-
from beartype.typing import Optional, Callable, Union
16+
from jaxtyping import Float32, Integer, Array
17+
from beartype.typing import Optional, Callable
1818

1919
from dart import types
2020
from ._spatial import interpolate, spherical_harmonics
@@ -104,11 +104,15 @@ def hash_table(c):
104104

105105
def __call__(
106106
self, x: Float32[Array, "3"], dx: Optional[Float32[Array, "3"]] = None,
107-
**kwargs
107+
alpha_clip: Optional[types.FloatLike] = None, **kwargs
108108
) -> tuple[Float32[Array, ""], Float32[Array, ""]]:
109109
"""Index into learned reflectance map."""
110110
table_out = self.lookup(x)
111111
sigma, alpha = self.head(table_out.reshape(-1))
112+
113+
if alpha_clip is not None:
114+
alpha = jnp.where(sigma > alpha_clip, alpha, 0)
115+
112116
return sigma, clip(alpha) * self.alpha_scale
113117

114118
@classmethod
@@ -239,7 +243,7 @@ def __init__(self, harmonics: int = 16, **kwargs) -> None:
239243

240244
def __call__(
241245
self, x: Float32[Array, "3"], dx: Optional[Float32[Array, "3"]] = None,
242-
**kwargs
246+
alpha_clip: Optional[types.FloatLike] = None, **kwargs
243247
) -> tuple[Float32[Array, ""], Float32[Array, ""]]:
244248
"""Index into learned reflectance map."""
245249
table_out = self.lookup(x)
@@ -249,6 +253,10 @@ def __call__(
249253

250254
mlp_out = self.head(jnp.concatenate([table_out.reshape(-1), sh]))
251255
sigma, alpha = mlp_out
256+
257+
if alpha_clip is not None:
258+
alpha = jnp.where(sigma > alpha_clip, alpha, 0)
259+
252260
return sigma, clip(alpha) * self.alpha_scale
253261

254262
@staticmethod

dart/pose.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def make_pose(
5555

5656

5757
def sensor_to_world(
58-
r: Float32[Array, ""], t: Float32[Array, "3 k"],
58+
r: Float32[types.ArrayLike, ""], t: Float32[types.ArrayLike, "3 k"],
5959
pose: Union[types.CameraPose, types.RadarPose]
6060
) -> Float32[Array, "3 k"]:
6161
"""Project points to world-space.
@@ -74,7 +74,8 @@ def sensor_to_world(
7474

7575

7676
def project_angle(
77-
d: Float32[Array, ""], psi: Float32[Array, "n"], pose: types.RadarPose
77+
d: Float32[types.ArrayLike, ""], psi: Float32[types.ArrayLike, "n"],
78+
pose: types.RadarPose
7879
) -> Float32[Array, "3 n"]:
7980
"""Project angles to intersection circle on a unit sphere.
8081

dart/script.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def script_train(cfg: dict) -> None:
2424
k1, k2, k3 = jax.random.split(root, 3)
2525

2626
dart = DART.from_config(**cfg)
27-
train, val, meta = doppler_columns(dart.sensor, key=k1, **cfg["dataset"])
27+
train, val, meta = doppler_columns(key=k1, **cfg["dataset"])
2828
assert val is not None
2929
train = train.shuffle(cfg["shuffle_buffer"], reshuffle_each_iteration=True)
3030

dart/sensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def from_config(
5858
@staticmethod
5959
def get_psi_min(
6060
d: Float32[types.ArrayLike, ""], pose: types.RadarPose
61-
) -> Float32[types.ArrayLike, ""]:
61+
) -> Float32[Array, ""]:
6262
"""Get psi value representing visible region of integration circle.
6363
6464
Visible psi angles fall in the range of (-psi_min, psi_min). These

dart/utils/jaxcolors.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,23 @@
77

88

99
def hsv_to_rgb(
10-
hsv: Float[types.ArrayLike, "... 3"]
10+
hsv: Float[types.ArrayLike, "... 3"], _np=jnp
1111
) -> Float[Array, "... 3"]:
1212
"""Convert hsv values to rgb.
1313
1414
Copied here, and modified for vectorization:
1515
https://matplotlib.org/3.1.1/_modules/matplotlib/colors.html#hsv_to_rgb
1616
and converted to jax.
17+
18+
Parameters
19+
----------
20+
hsv: HSV colors.
21+
_np: numpy-like backend to use.
22+
23+
Returns
24+
-------
25+
RGB colors `float: (0, 1)`, using the array format corresponding to the
26+
provided backend.
1727
"""
1828
in_shape = hsv.shape
1929
h = hsv[..., 0]
@@ -36,10 +46,23 @@ def hsv_to_rgb(
3646

3747

3848
def colormap(
39-
colors: Num[types.ArrayLike, "n 3"],
40-
data: Float[types.ArrayLike, "..."]
49+
colors: Num[types.ArrayLike, "n d"],
50+
data: Float[types.ArrayLike, "..."],
51+
_np=jnp
4152
) -> Num[Array, "... 3"]:
42-
"""Apply a discrete colormap."""
53+
"""Apply a discrete colormap.
54+
55+
Parameters
56+
----------
57+
colors: list of discrete colors to apply (e.g. 0-255 RGB values). Can be
58+
an arbitrary number of channels, not just RGB.
59+
data: input data to index (`0 < data < 1`).
60+
_np: numpy-like backend to use.
61+
62+
Returns
63+
-------
64+
An array with the same shape as `data`, with an extra dimension appended.
65+
"""
4366
fidx = data * (colors.shape[0] - 1)
44-
left = jnp.clip(jnp.floor(fidx).astype(int), 0, colors.shape[0] - 1)
45-
return jnp.take(colors, left, axis=0)
67+
left = _np.clip(_np.floor(fidx).astype(int), 0, colors.shape[0] - 1)
68+
return _np.take(colors, left, axis=0)

dart/utils/misc.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010
from jax import numpy as jnp
1111
import numpy as np
1212

13-
1413
from jaxtyping import PyTree
1514
from beartype.typing import TypeVar, Optional, Union
1615
from dart import types
1716

1817

1918
def tf_to_jax(batch: PyTree) -> PyTree:
20-
"""Convert tensorflow array to jax array without copying."""
19+
"""Convert tensorflow array to jax array without copying.
20+
21+
NOTE: going through dlpack is much slower, so it seems jax/tf have some
22+
kind of interop going on under the hood already.
23+
"""
2124
return jax.tree_util.tree_map(jnp.array, batch)
2225

2326

dart/utils/ssim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from jax import numpy as jnp
55
import jax.scipy as jsp
66

7-
from jaxtyping import Float, Float32, Array
7+
from jaxtyping import Float, Array
88

99

1010
def ssim(

datasets.md

-83
This file was deleted.

docs/hyperparams.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Hyperparameter Tuning Notes
2+
3+
`alpha_clip` (`--clip`): 0.05 seems to work pretty well.
4+
- If the velocity is very noisy / bad, high `alpha_clip` can cause problems.
5+
- Low `alpha_clip` causes phantom "trails" behind the trajectory of high reflectance + low transmittance, which probably indicates overfitting.
6+
7+
`epochs` (`--epochs`): 3 seems to be enough.
8+
- With an appropriate `alpha_clip`, the val loss doesn't seem to move by much after the first epoch.
9+
- Too many epochs can cause overfitting which is qualitatively apparent on a visual inspection (the loss doesn't really tell the whole story). The reflectance map starts to get a lot of holes, and the transmittance map is even less continuous.

manage.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
subparsers = parser.add_subparsers()
1313
for name, command in commands.items():
1414
p = subparsers.add_parser(
15-
name, help=command._desc.split('\n')[0],
16-
description=command._desc,
15+
name, help=command.__doc__, description=command.__doc__,
1716
formatter_class=RawTextHelpFormatter)
1817
command._parse(p)
1918
p.set_defaults(_func=command._main)

0 commit comments

Comments
 (0)