Skip to content

Commit f0d129c

Browse files
committed
[BugFix] Fix slow and flaky non-tensor parallel env test
ghstack-source-id: fcb5caa56e05176958b3468a7d6f69e363cfe558 Pull-Request-resolved: #2926
1 parent 0da9044 commit f0d129c

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

test/test_env.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3716,14 +3716,15 @@ def test_serial(self, bwad, use_buffers):
37163716

37173717
@pytest.mark.parametrize("bwad", [True, False])
37183718
@pytest.mark.parametrize("use_buffers", [False, True])
3719-
def test_parallel(self, bwad, use_buffers):
3719+
def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv):
37203720
N = 50
3721-
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
3721+
env = maybe_fork_ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
37223722
try:
37233723
r = env.rollout(N, break_when_any_done=bwad)
37243724
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
37253725
finally:
37263726
env.close(raise_if_closed=False)
3727+
del env
37273728

37283729
class AddString(Transform):
37293730
def __init__(self):

torchrl/envs/batched_envs.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -2232,8 +2232,14 @@ def _shutdown_workers(self) -> None:
22322232

22332233
for channel in self.parent_channels:
22342234
channel.close()
2235+
start_time = time.time()
2236+
while (
2237+
any(proc.is_alive() for proc in self._workers)
2238+
and (time.time() - start_time) < self._timeout
2239+
):
2240+
time.sleep(0.01)
22352241
for proc in self._workers:
2236-
proc.join(timeout=self._timeout)
2242+
proc.join()
22372243
finally:
22382244
for proc in self._workers:
22392245
if proc.is_alive():
@@ -2731,16 +2737,13 @@ def _run_worker_pipe_direct(
27312737
if not initialized:
27322738
raise RuntimeError("call 'init' before closing")
27332739
env.close()
2734-
del (
2735-
env,
2736-
data,
2737-
)
27382740
mp_event.set()
27392741
child_pipe.close()
27402742
if verbose:
27412743
torchrl_logger.info(f"{pid} closed")
2744+
del (env, data, child_pipe, mp_event)
27422745
gc.collect()
2743-
break
2746+
return
27442747

27452748
elif cmd == "load_state_dict":
27462749
env.load_state_dict(data)

0 commit comments

Comments
 (0)