-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsimulate.py
72 lines (57 loc) · 2.42 KB
/
simulate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Generate a simulated dataset from a ground truth reflectance grid."""
import os
from tqdm import tqdm
import h5py
from functools import partial
import numpy as np
from jax import numpy as jnp
import jax
from dart import VirtualRadar, fields, pose, types
def _parse(p):
p.add_argument("-p", "--path", help="Path to data directory.")
p.add_argument(
"-r", "--key", default=42, type=int, help="Random seed.")
p.add_argument("-o", "--out", default=None, help="Save path.")
p.add_argument("-b", "--batch", default=16, type=int, help="Batch size")
p.add_argument(
"-m", "--mode", default="lidar",
help="Simulation source (lidar/cfar).")
return p
def _load_poses(path):
data = h5py.File(path)
vel = jnp.array(data["vel"])
pos = jnp.array(data["pos"])
rot = jnp.array(data["rot"])
poses = jax.vmap(pose.make_pose)(vel, pos, rot, jnp.arange(vel.shape[0]))
return types.Dataset.from_tensor_slices(poses)
def _main(args):
if args.out is None:
args.out = os.path.join(args.path, "baselines/{}.h5".format(args.mode))
sensor = VirtualRadar.from_file(args.path)
if args.mode == "lidar":
gt_data = np.load(os.path.join(args.path, "map.npz"))
gt = fields.GroundTruth.from_occupancy(
jnp.array(gt_data['grid']), gt_data['lower'], gt_data['upper'],
alpha_scale=100.0)
elif args.mode.startswith("cfar"):
gt_data = np.load(os.path.join(args.path, "{}.npz".format(args.mode)))
gt = fields.GroundTruth(
jnp.array(gt_data['grid'] / np.max(gt_data['grid'])),
gt_data['lower'], jnp.array(
gt_data['grid'].shape) / (gt_data['upper'] - gt_data['lower']),
alpha_scale=0.0)
else:
raise ValueError("Unknown mode: {}.".format(args.mode))
traj = _load_poses(os.path.join(args.path, "trajectory.h5"))
render = partial(sensor.render, sigma=gt)
render = jax.jit(jax.vmap(render))
root_key = jax.random.PRNGKey(args.key)
frames = []
for batch in tqdm(traj.batch(args.batch)):
root_key, key = jax.random.split(root_key, 2)
keys = jnp.array(jax.random.split(key, batch.x.shape[0]))
pose = jax.tree_util.tree_map(jnp.array, batch)
frames.append(
np.asarray(render(pose=pose, key=keys), dtype=np.float16))
with h5py.File(args.out, 'w') as hf:
hf.create_dataset("rad", data=np.concatenate(frames, axis=0))