Skip to content

Commit 670d221

Browse files
Faster resets
1 parent 8bac16f commit 670d221

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

Diff for: baselines/ppo/ppo_waypoint.py

+4
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def run(
117117
off_road_weight: Annotated[Optional[float], typer.Option(help="The weight for off-road penalty")] = None,
118118
goal_achieved_weight: Annotated[Optional[float], typer.Option(help="The weight for goal-achieved reward")] = None,
119119
waypoint_distance_scale: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None,
120+
speed_distance_scale: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None,
121+
jerk_smoothness_scale: Annotated[Optional[float], typer.Option(help="Scale for realism rewards")] = None,
120122
dist_to_goal_threshold: Annotated[Optional[float], typer.Option(help="The distance threshold for goal-achieved")] = None,
121123
randomize_rewards: Annotated[Optional[int], typer.Option(help="If reward_type == reward_conditioned, choose the condition_mode; 0 or 1")] = 0,
122124
sampling_seed: Annotated[Optional[int], typer.Option(help="The seed for sampling scenes")] = None,
@@ -174,6 +176,8 @@ def run(
174176
"off_road_weight": off_road_weight,
175177
"goal_achieved_weight": goal_achieved_weight,
176178
"waypoint_distance_scale": waypoint_distance_scale,
179+
"jerk_smoothness_scale": jerk_smoothness_scale,
180+
"speed_distance_scale": speed_distance_scale,
177181
"dist_to_goal_threshold": dist_to_goal_threshold,
178182
"sampling_seed": sampling_seed,
179183
"obs_radius": obs_radius,

Diff for: gpudrive/env/env_puffer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,10 @@ def step(self, action):
581581
self.last_obs = self.env.get_obs(self.controlled_agent_mask)
582582

583583
# Asynchronously reset the done worlds and empty storage
584-
self.env.reset(env_idx_list=done_worlds_cpu)
584+
self.env.reset(
585+
env_idx_list=done_worlds_cpu,
586+
mask=self.controlled_agent_mask
587+
)
585588
self.episode_returns[done_worlds] = 0
586589
self.agent_episode_returns[done_worlds, :] = 0
587590
self.episode_lengths[done_worlds, :] = 0

0 commit comments

Comments
 (0)