Skip to content

Commit 0d316c4

Browse files
daphne-cornelissenadarenatoraaravpandya
authored
Integrate Versatile Behavior Diffusion (VBD) model (#283)
* NYU HPC Greene setup (#316) * Add .env for private variables and update reqs * Add sbatch generation script * Add env templplate * Combine into single sbatch script * Add logs folde * Version changes to make pufferlib compatible * Add jax version * Remove old files * rmv old files * Minor: Fix path * Use typer to override default args using CLI * Add versions * WIP * Add optional warmup period * Add scene selector return info * Add VBD model and util functions * Linting * Add VBD model and utils * Cleanup notebook * Use StateDynamics model * Refactor expert actions function * Add minimal reproducable example * WIP * Rebase over main * Warmup period steps vehicles with log playback trajectory for the first init_steps of a scene. * Fix nans in diffusion model pred trajs * First pass working vbd * Small fixes * Align road types * Get id from waymax scenarios * wip * Align sims with same scenarios * Complete align sim notebook * wip * Bug: Polyline dist collapsed to only 0s * Init setup * Add plots roadgraph * Add more dist plots * Update * Rebase over main * WIP * Revert changes related to road graph type conversion * Revert type additions * Remove conversion functions * Big resructure * Integrate new polyline construction code * New debug file with map_element_id and road_ids * Merge branch 'main' of https://github.com/Emerge-Lab/gpudrive into integrations/vbd * Use newly integrated maptypes * Minor * integrate new road types * clean test notebook * Delete old debugging scene * Cleanup gitignore * Feature comparison debug * Feature comparison debug * First pass feature analysis * Convert GPUDrive yaw to be between [-pi, pi] * Cleanup * Add VBD predicted outputs to notebooks for inspection * Fix typo in get_obs() * revert * refactor * path fix * add example * new data * Add debugging notebook * Wrap yaws * Update data structure * Cleanup * Update * Decrease img sizes * Update * update * regenerated debug scene * Remove corrupted data * Applying APs changes (todo: rebase over main after PR is merged) * Minor * polyline feature alignment * minor * Cleanup * minor * minor * Integrate new datatypes * rg * roadgraph, deepcopy, refactoring * roadgraph and vbd_data * Add debugging scenes * 50 debug scenes * rebase over main * rebase over main * Visualization updates and support to train on fixed dataset size (#317) * Add support for resample_limit * Add option to draw the log replay trajectories in visualizer * WIP * minimal set maps script * minimal set maps script * Data processing utils * Rendering improvements and ppo * Make dataset to resample from configurable * Bug fix: resampling * Agent observation figures * Cleanup * Prep experiments * Prep experiments * vbd_data restructuring * Fix breaking visualizer changes * Update vbd_types * err_val handling in vbd_data * clone road graph tensor * output gifs, debug scenes update * Make collision behavior transient (#318) * Make collision behavior transient * Simpler way to reset collision state for collision behavior ignore * Reset info on collision ignore --------- Co-authored-by: Aarav Pandya <[email protected]> * Improved visualizer and gym environment (cloning all tensors) (#320) * Missing function * vbd trajectory in agent obs plot * misc fixes * partial training integration * Complete first 2 todo's * Reward penalty pseudocode * Add vbd_in_obs argument to model * partial integration * more partial integration changes * make VBD work with 'good' agents * global->local coord transform * norm vbd obs * discard artifacts * move vbd to gpudrive/integrations * moving stuff * more deletions * more deletions * full cleanup * Revert * Improve get_obs() function * rmv unused function * Partial vbd updates * minor fix * Ensure that init_steps > 0 when using vbd * Integrate VBD with PPO script * Remove tutorial 9 for now * misc * typos --------- Co-authored-by: Daphne Cornelisse <[email protected]> Co-authored-by: kevin <[email protected]> Co-authored-by: Aarav Pandya <[email protected]>
1 parent d9e745b commit 0d316c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+9199
-1814
lines changed

Diff for: .env.template

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# .env template
2+
3+
# Path for logs
4+
LOG_FOLDER=
5+
6+
# Your HPC account code
7+
NYU_HPC_ACCOUNT=
8+
9+
# NYU ID
10+
USERNAME=
11+
12+
SINGULARITY_IMAGE=
13+
OVERLAY_FILE=

Diff for: .gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
.vscode/launch.json
99
.vscode/settings.json
1010
.vscode/tasks.json
11+
1112
/examples/benchmarks/results/
1213
/baselines/ppo/logs/*
1314
*.sif
@@ -24,8 +25,8 @@
2425
hpc/overlay*
2526
data/raw/*
2627
data/processed/validation/*
27-
data/processed/testing/*
2828
data/processed/training/*
29+
data/processed/testing/*
2930
data/processed/sampled/*
3031
data/processed/hand_designed/*
3132
analyze/figures/*

Diff for: baselines/imitation_data_generation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,4 +276,4 @@ def generate_state_action_pairs(
276276

277277
# Uncommment to save the expert actions and observations
278278
# torch.save(expert_actions, "expert_actions.pt")
279-
# torch.save(expert_obs, "expert_obs.pt")
279+
# torch.save(expert_obs, "expert_obs.pt")

Diff for: baselines/ppo/config/ppo_base_puffer.yaml

+11-4
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ model_cpt: null
88

99
environment: # Overrides default environment configs (see pygpudrive/env/config.py)
1010
name: "gpudrive"
11-
num_worlds: 100 # Number of parallel environments
12-
k_unique_scenes: 100 # Number of unique scenes to sample from
11+
num_worlds: 75 # Number of parallel environments
12+
k_unique_scenes: 75 # Number of unique scenes to sample from
1313
max_controlled_agents: 64 # Maximum number of agents controlled by the model. Make sure this aligns with the variable kMaxAgentCount in src/consts.hpp
1414
ego_state: true
1515
road_map_obs: true
1616
partner_obs: true
1717
norm_obs: true
1818
remove_non_vehicles: true # If false, all agents are included (vehicles, pedestrians, cyclists)
1919
lidar_obs: false # NOTE: Setting this to true currently turns of the other observation types
20-
reward_type: "weighted_combination"
20+
reward_type: "weighted_combination"
2121
collision_weight: -0.75
2222
off_road_weight: -0.75
2323
goal_achieved_weight: 1.0
@@ -29,6 +29,13 @@ environment: # Overrides default environment configs (see pygpudrive/env/config.
2929
obs_radius: 50.0 # Visibility radius of the agents
3030
action_space_steer_disc: 13
3131
action_space_accel_disc: 7
32+
# Versatile Behavior Diffusion (VBD): This will slow down training
33+
use_vbd: false
34+
vbd_model_path: "gpudrive/integrations/vbd/weights/epoch=18.ckpt"
35+
init_steps: 11
36+
vbd_trajectory_weight: 0.1 # Importance of distance to the vbd trajectories in the reward function
37+
vbd_in_obs: false
38+
3239
wandb:
3340
entity: ""
3441
project: "gpudrive"
@@ -55,7 +62,7 @@ train:
5562
# # # PPO # # #
5663
torch_deterministic: false
5764
total_timesteps: 1_000_000_000
58-
batch_size: 262_144
65+
batch_size: 131_072
5966
minibatch_size: 8192
6067
learning_rate: 3e-4
6168
anneal_lr: false

Diff for: baselines/ppo/ppo_pufferlib.py

+11
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def run(
170170
obs_radius: Annotated[Optional[float], typer.Option(help="The radius for the observation")] = None,
171171
collision_behavior: Annotated[Optional[str], typer.Option(help="The collision behavior; 'ignore' or 'remove'")] = None,
172172
remove_non_vehicles: Annotated[Optional[int], typer.Option(help="Remove non-vehicles from the scene; 0 or 1")] = None,
173+
use_vbd: Annotated[Optional[bool], typer.Option(help="Use VBD model for trajectory predictions")] = False,
174+
vbd_model_path: Annotated[Optional[str], typer.Option(help="Path to VBD model checkpoint")] = None,
175+
vbd_trajectory_weight: Annotated[Optional[float], typer.Option(help="Weight for VBD trajectory deviation penalty")] = 0.1,
176+
vbd_in_obs: Annotated[Optional[bool], typer.Option(help="Include VBD predictions in the observation")] = False,
177+
init_steps: Annotated[Optional[int], typer.Option(help="Environment warmup steps")] = 0,
173178
# Train options
174179
seed: Annotated[Optional[int], typer.Option(help="The seed for training")] = None,
175180
learning_rate: Annotated[Optional[float], typer.Option(help="The learning rate for training")] = None,
@@ -210,10 +215,16 @@ def run(
210215
"remove_non_vehicles": None
211216
if remove_non_vehicles is None
212217
else bool(remove_non_vehicles),
218+
"use_vbd": use_vbd,
219+
"vbd_model_path": vbd_model_path,
220+
"vbd_trajectory_weight": vbd_trajectory_weight,
221+
"vbd_in_obs": vbd_in_obs,
222+
"init_steps": init_steps,
213223
}
214224
config.environment.update(
215225
{k: v for k, v in env_config.items() if v is not None}
216226
)
227+
217228
train_config = {
218229
"seed": seed,
219230
"learning_rate": learning_rate,

Diff for: data_utils/process_waymo_files.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import psutil
1414
from pathlib import Path
1515
import warnings
16-
from typing import Any, Dict, Optional
16+
from typing import Any, Dict, Optional, List
17+
from pdb import set_trace as T
1718
from tqdm import tqdm
1819
from waymo_open_dataset.protos import scenario_pb2, map_pb2
1920
from datatypes import MapElementIds
@@ -132,10 +133,10 @@ def _parse_object_state(
132133
"width": final_state.width,
133134
"length": final_state.length,
134135
"height": final_state.height,
135-
"heading": [
136-
wrap_yaws(state.heading) if state.valid else ERR_VAL
136+
"heading": [ # In radians between [-pi, pi]
137+
(state.heading + np.pi) % (2 * np.pi) - np.pi if state.valid else ERR_VAL
137138
for state in states
138-
],
139+
],
139140
"velocity": [
140141
{"x": state.velocity_x, "y": state.velocity_y}
141142
if state.valid
@@ -668,7 +669,7 @@ def process_data(args):
668669

669670
parser = argparse.ArgumentParser(
670671
description="Convert TFRecord files to JSON. \
671-
Note: This takes about 45 seconds per tfrecord file (=50 traffic scenes)."
672+
Note: This takes about 45 seconds per tfrecord file (=500 traffic scenes)."
672673
)
673674
parser.add_argument(
674675
"tfrecord_dir", help="Path to the directory containing TFRecord files"

0 commit comments

Comments
 (0)