Skip to content

Commit 3aac785

Browse files
Make VBD model imports optional (#397)
* Optional vbd imports * use_vdb = False by default * Add optional dependencies pytorch-lightning and jaxlib
1 parent 90c9883 commit 3aac785

File tree

3 files changed

+29
-27
lines changed

3 files changed

+29
-27
lines changed

Diff for: README.md

+16-16
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ Finally, install the Python components of the repository using pip:
9494

9595
```bash
9696
# macOS and Linux.
97-
pip install -e .
97+
pip install -e .
9898
```
9999

100-
Optional depencies include [pufferlib], [sb3] and [tests].
100+
Optional depencies include [pufferlib], [sb3], [vbd], and [tests].
101101

102102
```bash
103103
# On Windows.
@@ -113,22 +113,22 @@ pip install -e . -Cpackages.madrona_escape_room.ext-out-dir=PATH_TO_YOUR_BUILD_D
113113
<details>
114114
<summary> 🐳 Option 2. Docker </summary>
115115

116-
To get started quickly, we provide a Dockerfile in the root directory.
116+
To get started quickly, we provide a Dockerfile in the root directory.
117117

118-
### Prerequisites
119-
Ensure you have the following installed:
120-
- [Docker](https://docs.docker.com/get-docker/)
121-
- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
118+
### Prerequisites
119+
Ensure you have the following installed:
120+
- [Docker](https://docs.docker.com/get-docker/)
121+
- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
122122

123-
### Building the Docker mage
124-
Once installed, you can build the container with:
123+
### Building the Docker mage
124+
Once installed, you can build the container with:
125125

126126
```bash
127127
DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=true --tag my_image:latest --progress=plain .
128128
```
129129

130-
### Running the Container
131-
To run the container with GPU support and shared memory:
130+
### Running the Container
131+
To run the container with GPU support and shared memory:
132132

133133
```bash
134134
docker run --gpus all -it --rm --shm-size=20G -v ${PWD}:/workspace my_image:latest /bin/bash
@@ -206,15 +206,15 @@ cd build
206206

207207
## Pre-trained policies
208208

209-
Several pre-trained policies are available via the `PyTorchModelHubMixin` class on 🤗 huggingface_hub.
209+
Several pre-trained policies are available via the `PyTorchModelHubMixin` class on 🤗 huggingface_hub.
210210

211-
- **Best Policy (10,000 Scenarios).** The best policy from [Building reliable sim driving agents by scaling self-play](https://arxiv.org/abs/2502.14706) is available here [here](https://huggingface.co/daphne-cornelisse/policy_S10_000_02_27). This policy was trained on 10,000 randomly sampled scenarios from the WOMD training dataset.
211+
- **Best Policy (10,000 Scenarios).** The best policy from [Building reliable sim driving agents by scaling self-play](https://arxiv.org/abs/2502.14706) is available here [here](https://huggingface.co/daphne-cornelisse/policy_S10_000_02_27). This policy was trained on 10,000 randomly sampled scenarios from the WOMD training dataset.
212212

213213
- **Alternative Policy (1,000 Scenarios).** A policy trained on 1,000 scenarios can be found [here](https://huggingface.co/daphne-cornelisse/policy_S1000_02_27)
214214

215215
---
216216

217-
> Note: These models were trained with the environment configurations defined in `examples/experimental/config/reliable_agents_params.yaml`, changing environment/observation configurations will affect performance.
217+
> Note: These models were trained with the environment configurations defined in `examples/experimental/config/reliable_agents_params.yaml`, changing environment/observation configurations will affect performance.
218218
219219
---
220220

@@ -240,7 +240,7 @@ See [tutorial 04](https://github.com/Emerge-Lab/gpudrive/tree/main/examples/tuto
240240
<details>
241241
<summary>Download the dataset</summary>
242242

243-
To download the dataset you need the huggingface_hub library
243+
To download the dataset you need the huggingface_hub library
244244

245245
```bash
246246
pip install huggingface_hub
@@ -356,7 +356,7 @@ and that's it!
356356
If you use GPUDrive in your research, please cite our ICLR 2025 paper
357357
```bibtex
358358
@inproceedings{kazemkhani2025gpudrive,
359-
title={GPUDrive: Data-driven, multi-agent driving simulation at 1 million FPS},
359+
title={GPUDrive: Data-driven, multi-agent driving simulation at 1 million FPS},
360360
author={Saman Kazemkhani and Aarav Pandya and Daphne Cornelisse and Brennan Shacklett and Eugene Vinitsky},
361361
booktitle={Proceedings of the International Conference on Learning Representations (ICLR)},
362362
year={2025},

Diff for: gpudrive/env/env_torch.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import mediapy as media
88
import gymnasium
99

10-
from gpudrive.integrations.vbd.data_utils import process_scenario_data
11-
1210
from gpudrive.datatypes.observation import (
1311
LocalEgoState,
1412
GlobalEgoState,
@@ -33,8 +31,7 @@
3331
from gpudrive.env.dataset import SceneDataLoader
3432
from gpudrive.utils.geometry import normalize_min_max
3533

36-
# Versatile Behavior Diffusion model
37-
from gpudrive.integrations.vbd.sim_agent.sim_actor import VBDTest
34+
from gpudrive.integrations.vbd.data_utils import process_scenario_data
3835

3936

4037
class GPUDriveTorchEnv(GPUDriveGymEnv):
@@ -132,12 +129,12 @@ def __init__(
132129

133130
def _initialize_vbd(self):
134131
"""
135-
Initialize the Versatile Behavior Diffusion (VBD) model and related components.
132+
Initialize the Versatile Behavior Diffusion (VBD) model and related
133+
components. Link: https://arxiv.org/abs/2404.02524.
136134
137135
Args:
138136
config: Configuration object containing VBD settings.
139137
"""
140-
# Set VBD configuration parameters
141138
self.use_vbd = self.config.use_vbd
142139
self.vbd_trajectory_weight = self.config.vbd_trajectory_weight
143140

@@ -149,16 +146,13 @@ def _initialize_vbd(self):
149146
else:
150147
self.init_steps = self.config.init_steps
151148

152-
# Initialize VBD model and trajectories if enabled
153149
if (
154150
self.use_vbd
155151
and hasattr(self.config, "vbd_model_path")
156152
and self.config.vbd_model_path
157153
):
158-
# Load VBD model from specified path
159154
self.vbd_model = self._load_vbd_model(self.config.vbd_model_path)
160155

161-
# Initialize trajectory tensor with zeros
162156
self.vbd_trajectories = torch.zeros(
163157
(
164158
self.num_worlds,
@@ -167,17 +161,18 @@ def _initialize_vbd(self):
167161
5,
168162
),
169163
device=self.device,
170-
dtype=torch.float32, # Explicitly specify dtype for clarity
164+
dtype=torch.float32,
171165
)
172166

173167
self._generate_vbd_trajectories()
174168
else:
175-
# Set to None if VBD is disabled or model path is not provided
176169
self.vbd_model = None
177170
self.vbd_trajectories = None
178171

179172
def _load_vbd_model(self, model_path):
180173
"""Load the Versatile Behavior Diffusion (VBD) model from checkpoint."""
174+
from gpudrive.integrations.vbd.sim_agent.sim_actor import VBDTest
175+
181176
model = VBDTest.load_from_checkpoint(
182177
model_path, torch.device(self.device)
183178
)
@@ -1169,6 +1164,7 @@ def get_obs(self, mask=None):
11691164
ego_states = self._get_ego_state(mask)
11701165
partner_observations = self._get_partner_obs(mask)
11711166
road_map_observations = self._get_road_map_obs(mask)
1167+
11721168
if (
11731169
self.use_vbd
11741170
and self.vbd_model is not None

Diff for: pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ test = [
5050
"pytest>=8.2.1",
5151
]
5252

53+
vbd = [
54+
"lightning",
55+
"jaxlib",
56+
"waymo-waymax @ git+https://github.com/waymo-research/waymax.git@main",
57+
]
58+
5359
[tool.madrona.packages.madrona_gpudrive]
5460
ext-only = true
5561
ext-out-dir = "build"

0 commit comments

Comments
 (0)