-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmap.py
76 lines (62 loc) · 2.49 KB
/
map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Evaluate DART model in a grid."""
import os
import h5py
from functools import partial
import numpy as np
from jax import numpy as jnp
import jax
from dart import DartResult, DART, utils
def _parse(p):
p.add_argument("-p", "--path", help="File path to output base name.")
p.add_argument("-c", "--checkpoint", help="Load specific checkpoint.")
p.add_argument(
"-l", "--lower", nargs='+', type=float, default=None,
help="Lower coordinate in x y z form.")
p.add_argument(
"-u", "--upper", nargs='+', type=float, default=None,
help="Upper coordinate in x y z form.")
p.add_argument(
"--padding", type=float, nargs='+', default=[4.0, 4.0, 2.0],
help="Region padding relative to trajectory min/max.")
p.add_argument(
"-r", "--resolution", type=int, default=25,
help="Map resolution, in units per meter.")
p.add_argument(
"-b", "--batch", type=int, default=4, help="Batch size along the z "
"axis for breaking up high resolution grids.")
return p
def _set_bounds(args, res):
if args.lower is None or args.upper is None:
args.padding = np.array((args.padding * 3)[:3])
x = np.array(h5py.File(
os.path.join(res.DATADIR, "trajectory.h5"))["pos"])
args.lower = np.min(x, axis=0) - args.padding
args.upper = np.max(x, axis=0) + args.padding
else:
assert len(args.lower) == 3
assert len(args.upper) == 3
args.lower = np.array(args.lower)
args.upper = np.array(args.upper)
def _main(args):
result = DartResult(args.path)
_set_bounds(args, result)
resolution = (args.resolution * (args.upper - args.lower)).astype(int)
print("Bounds: {:.1f}x{:.1f}x{:.1f}m ({}x{}x{}px)".format(
*(args.upper - args.lower), *resolution))
if args.checkpoint is None:
outfile = "map.h5"
args.checkpoint = "model"
else:
outfile = "map.{}.h5".format(args.checkpoint)
args.checkpoint = os.path.join("checkpoints", args.checkpoint)
dart = DART.from_file(args.path)
params = dart.load(os.path.join(args.path, args.checkpoint))
x, y, z = [
jnp.linspace(lower, upper, res) for lower, upper, res in
zip(args.lower, args.upper, resolution)]
render = jax.jit(partial(dart.grid, params, x, y))
sigma, alpha = utils.vmap_batch(render, z, batch=args.batch, axis=2)
result.save(outfile, {
"sigma": sigma, "alpha": alpha,
"lower": args.lower, "upper": args.upper
})