Skip to content

Commit e5a68d9

Browse files
More preprocess changes + video batching
1 parent 5c130e0 commit e5a68d9

File tree

8 files changed

+143
-260
lines changed

8 files changed

+143
-260
lines changed

dart/preprocess/dataset.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from beartype.typing import NamedTuple
6-
from jaxtyping import Complex64, Int16
6+
from jaxtyping import Complex64, Int16, Float64
77

88
from .radar import AWR1843Boost
99

@@ -52,7 +52,7 @@ def _get_frames(self, packets, valid):
5252
res = rad[self.start:self.end].reshape(-1, self.frame_size)[valid]
5353
return res
5454

55-
def _get_times(self, packets, valid):
55+
def _get_times(self, packets, valid) -> Float64[np.ndarray, "frames"]:
5656
"""Get timestamps for each frame.
5757
5858
Timestamps are denoted by the first packet corresponding to data
@@ -75,7 +75,12 @@ def _to_iq(
7575
iq[:, 1::2] = frames[:, 1::4] + 1j * frames[:, 3::4]
7676
return iq.reshape((-1, *radar.frame_shape))
7777

78-
def as_frames(self, radar, packets):
78+
def as_frames(
79+
self, radar: AWR1843Boost, packets
80+
) -> tuple[
81+
Complex64[np.ndarray, "frames tx rx chirp"],
82+
Float64[np.ndarray, "frames"]
83+
]:
7984
"""Convert packets to frames."""
8085
valid = self._get_valid(packets)
8186
data = self._get_frames(packets, valid)

dart/preprocess/radar.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414

1515
def to_float16(
16-
arr: Float32[Array, "..."], max_normal: float = 65504.0
17-
) -> Float16[Array, "..."]:
16+
arr: Float32[types.ArrayLike, "..."], max_normal: float = 65504.0
17+
) -> Float16[types.ArrayLike, "..."]:
1818
"""Convert array to float16.
1919
2020
The largest observed value is scaled to the largest normal normal number
@@ -79,14 +79,18 @@ def estimate_speed(
7979
"""Estimate speed for this frame using doppler spectrum."""
8080
nd_half = int(self.framelen / 2)
8181

82-
# Spectrum across doppler bins; folded in half
83-
_spectrum = np.sum(rda, axis=(1, 3))
84-
spectrum = _spectrum[:, :nd_half] + _spectrum[:, nd_half:][:, ::-1]
82+
percentile = 99.75
83+
min_count = 10
8584

86-
threshold = jnp.min(spectrum, axis=1).reshape(-1, 1)
87-
speed_nd = jnp.argmax(
88-
(jnp.diff(spectrum, axis=1) > threshold)
89-
& (spectrum[:, :-1] > (threshold * 2)), axis=1)
85+
# Threshold for each frame, antenna (over=1,2)
86+
threshold = jnp.percentile(
87+
rda, percentile, axis=(1, 2))[:, None, None, ...]
88+
# Count across range, antenna (over=2,3)
89+
valid = jnp.sum(rda > threshold, axis=(1, 3)) > min_count
90+
91+
left = jnp.argmax(valid, axis=1)
92+
right = jnp.argmax(valid[:, ::-1], axis=1)
93+
speed_nd = jnp.minimum(left, right)
9094

9195
return (nd_half - speed_nd) / nd_half * self.dmax / 2
9296

@@ -126,8 +130,8 @@ def range_doppler_azimuth(
126130
return rda[:, :self.max_range], self.estimate_speed(rda)
127131

128132
def remove_artifact(
129-
self, images: Float32[types.ArrayLike, "frame range doppler antenna"],
130-
) -> Float32[types.ArrayLike, "frame range doppler antenna"]:
133+
self, images: Float32[Array, "frame range doppler antenna"],
134+
) -> Float32[Array, "frame range doppler antenna"]:
131135
"""Remove zero-doppler artifact from images.
132136
133137
Collected range-doppler radar data will have an artifact at close
@@ -138,7 +142,7 @@ def remove_artifact(
138142
zero = int(images.shape[2] / 2)
139143
artifact = jnp.percentile(
140144
images[..., zero - 1:zero + 2, :], self.artifact_threshold, axis=0)
141-
removed = np.maximum(
145+
removed = jnp.maximum(
142146
images[..., zero - 1:zero + 2, :]
143147
- artifact.reshape(1, *artifact.shape), 0)
144148
return images.at[..., zero - 1:zero + 2, :].set(removed)
@@ -197,6 +201,7 @@ def process_data(
197201
chirps_flat, window_shape=self.framelen, axis=0)[::stride]
198202

199203
process_func = jax.jit(partial(self.range_doppler_azimuth))
204+
# process_func = partial(self.range_doppler_azimuth)
200205
res, speed = [], []
201206
for _ in range(int(np.ceil(frames.shape[0] / batch_size))):
202207
r, s = process_func(jnp.array(frames[:batch_size]))

dart/preprocess/trajectory.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from beartype.typing import NamedTuple
1111
from jaxtyping import Float32, Float64, Bool
1212

13+
from dart import types
14+
1315

1416
class Trajectory(NamedTuple):
1517
"""Sensor trajectory.
@@ -36,10 +38,8 @@ def from_csv(cls, path: str) -> "Trajectory":
3638
"""
3739
df = pd.read_csv(os.path.join(path, "trajectory.csv"))
3840

39-
t_slam = np.array(df["field.header.stamp"]) / 1e9
40-
41-
# tmp
42-
t_slam = t_slam + 0.5
41+
# Manual 0.5s offset, likely due to buffering the DCA1000
42+
t_slam = np.array(df["field.header.stamp"]) / 1e9 + 0.5
4343

4444
xyz = np.array(
4545
[df["field.transform.translation." + char] for char in "xyz"])
@@ -51,14 +51,18 @@ def from_csv(cls, path: str) -> "Trajectory":
5151

5252
return cls(spline=spline, slerp=slerp)
5353

54-
def valid_mask(self, t: Float64[np.ndarray, "N"]) -> Bool[np.ndarray, "N"]:
54+
def valid_mask(
55+
self, t: Float64[types.ArrayLike, "N"], window: float = 0.1
56+
) -> Bool[types.ArrayLike, "N"]:
5557
"""Get mask of valid timestamps (within the trajectory timestamps)."""
56-
return (t >= self.spline.x[0]) & (t <= self.spline.x[-1])
58+
return (
59+
(t - window >= self.spline.x[0])
60+
& (t + window <= self.spline.x[-1]))
5761

5862
def interpolate(
59-
self, t: Float64[np.ndarray, "N"], window: float = 0.1,
63+
self, t: Float64[types.ArrayLike, "N"], window: float = 0.1,
6064
samples: int = 25
61-
) -> dict[str, Float32[np.ndarray, "N ..."]]:
65+
) -> dict[str, Float32[types.ArrayLike, "N ..."]]:
6266
"""Calculate poses, averaging along a window.
6367
6468
Parameters
@@ -72,13 +76,13 @@ def interpolate(
7276
Dictionary with pos, vel, and rot entries.
7377
"""
7478
window_offsets = np.linspace(-window / 2, window / 2, samples)
75-
samples = t[..., None] + window_offsets[None, ...]
79+
tt = t[..., None] + window_offsets[None, ...]
7680

7781
# Rotation unfortunately does not allow vectorization at this time.
78-
rot = Rotation.concatenate([self.slerp(row).mean() for row in samples])
82+
rot = Rotation.concatenate([self.slerp(row).mean() for row in tt])
7983

8084
return {
81-
"pos": np.mean(self.spline(samples), axis=1),
82-
"vel": np.mean(self.spline.derivative()(samples), axis=1),
85+
"pos": np.mean(self.spline(tt), axis=1),
86+
"vel": np.mean(self.spline.derivative()(tt), axis=1),
8387
"rot": rot.as_matrix()
8488
}

dart/result.py

+40
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import numpy as np
66
import h5py
77
import matplotlib as mpl
8+
from jax import numpy as jnp
89

910
from beartype.typing import Optional, Any
1011
from jaxtyping import Integer, Array, Float, UInt8
1112

1213
from .dart import DART
14+
from dart.jaxcolors import colormap
1315
from .dataset import load_arrays, trajectory
1416
from . import types
1517

@@ -39,6 +41,8 @@ def __init__(self, path: str) -> None:
3941
with open(_meta) as f:
4042
self.metadata = json.load(f)
4143

44+
self.DATASET = self.metadata["dataset"]["path"]
45+
4246
def dart(self) -> DART:
4347
"""Get DART object."""
4448
return DART.from_config(**self.metadata)
@@ -64,6 +68,10 @@ def load(self, path: str, keys: Optional[list[str]] = None) -> Any:
6468
"""Load file inside this result."""
6569
return load_arrays(os.path.join(self.path, path), keys=keys)
6670

71+
def open(self, path: str) -> Any:
72+
"""Load h5py file in this scope."""
73+
return h5py.File(os.path.join(self.path, path))
74+
6775
@staticmethod
6876
def colorize_map(
6977
arr: Float[types.ArrayLike, "..."], sigma: bool = True,
@@ -99,3 +107,35 @@ def colorize_map(
99107
arr = np.exp(arr) # type: ignore
100108

101109
return (mpl.colormaps['viridis'](arr)[..., :3] * 255).astype(np.uint8)
110+
111+
@staticmethod
112+
def colorize_radar(
113+
cmap: Float[types.ArrayLike, "..."],
114+
rad: Float[types.ArrayLike, "..."],
115+
clip: tuple[float, float] = (5.0, 99.9)
116+
) -> UInt8[types.ArrayLike, "... 3"]:
117+
"""Colorize a radar intensity map in a jax-friendly way.
118+
119+
Parameters
120+
----------
121+
cmap: color map to apply.
122+
rad: input array. If range-doppler, colorize directly; if
123+
range-doppler-azimuth, tiles into 2 columns x 4 rows.
124+
clip: percentile clipping range.
125+
"""
126+
def _tile(arr):
127+
unpack = [arr[:, :, :, i] for i in range(arr.shape[3])]
128+
left = jnp.concatenate(unpack[:4], axis=1)
129+
right = jnp.concatenate(unpack[4:], axis=1)
130+
return jnp.concatenate([left, right], axis=2)
131+
132+
p5, p95 = jnp.nanpercentile(rad, jnp.array(clip))
133+
rad = (rad - p5) / (p95 - p5)
134+
colors = (colormap(cmap, rad) * 255).astype(jnp.uint8)
135+
if len(rad.shape) == 4:
136+
if rad.shape[-1] > 1:
137+
return _tile(colors)
138+
else:
139+
return colors[..., 0, :]
140+
else:
141+
return colors

notebooks/map.ipynb

+8-164
Large diffs are not rendered by default.

tools/preprocess.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
from tqdm import tqdm
77
import numpy as np
8+
from scipy.ndimage import gaussian_filter1d
89

910
from dart.preprocess import AWR1843Boost, AWR1843BoostDataset, Trajectory
1011

@@ -27,7 +28,7 @@ def _process_batch(radar: AWR1843Boost, file_batch, traj: Trajectory):
2728
range_doppler, speed_est = radar.process_data(chirps)
2829
t_image = radar.process_timestamps(t_chirp)
2930

30-
t_valid = traj.valid_mask(t_image)
31+
t_valid = traj.valid_mask(t_image, window=radar.frame_time * 0.5)
3132
pose = traj.interpolate(t_image[t_valid], window=radar.frame_time * 0.5)
3233
pose["t"] = t_image[t_valid]
3334
pose["speed"] = speed_est[t_valid]
@@ -37,7 +38,7 @@ def _process_batch(radar: AWR1843Boost, file_batch, traj: Trajectory):
3738

3839
def _process(
3940
path: str, radar: AWR1843Boost, batch_size: int = 1000000,
40-
overwrite: bool = False):
41+
overwrite: bool = False, sigma: float = 2.0):
4142

4243
traj = Trajectory.from_csv(path)
4344
packetfile = h5py.File(os.path.join(path, "radarpackets.h5"), 'r')
@@ -53,22 +54,25 @@ def _process(
5354
"rad", (1, *rs), dtype='f2', chunks=(1, *rs), maxshape=(None, *rs))
5455

5556
total_size = 0
56-
poses = []
57+
_poses = []
5758
for _ in tqdm(range(int(np.ceil(packet_dataset.shape[0] / batch_size)))):
5859
rda, pose = _process_batch(radar, packet_dataset[:batch_size], traj)
59-
poses.append(pose)
60+
_poses.append(pose)
6061

6162
total_size += rda.shape[0]
6263
range_doppler_azimuth.resize((total_size, *radar.image_shape))
6364
range_doppler_azimuth[-rda.shape[0]:] = rda
6465

6566
packet_dataset = packet_dataset[batch_size:]
6667

67-
poses_cat = {}
68-
for k in poses[0]:
69-
poses_cat[k] = np.concatenate([p[k] for p in poses])
68+
poses = {}
69+
for k in _poses[0]:
70+
poses[k] = np.concatenate([p[k] for p in _poses])
7071

71-
for k, v in poses_cat.items():
72+
if sigma > 0:
73+
poses['vel'] = gaussian_filter1d(poses['vel'], sigma=sigma, axis=0)
74+
75+
for k, v in poses.items():
7276
outfile.create_dataset(k, data=v)
7377
packetfile.close()
7478
outfile.close()

0 commit comments

Comments
 (0)