Skip to content

Commit f5a187d

Browse files
committed
[Feature] PPO compatibility with compile
ghstack-source-id: 0ed29f352fcd85f0dc0683d90e95bdbecf6c14f9 Pull Request resolved: #2652
1 parent 2cfc2ab commit f5a187d

10 files changed

+288
-176
lines changed

sota-implementations/dqn/dqn_atari.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from torchrl.record.loggers import generate_exp_name, get_logger
2929
from utils_atari import eval_model, make_dqn_model, make_env
3030

31+
torch.set_float32_matmul_precision("high")
32+
3133

3234
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
3335
def main(cfg: "DictConfig"): # noqa: F821

sota-implementations/dqn/dqn_cartpole.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from torchrl.record.loggers import generate_exp_name, get_logger
2424
from utils_cartpole import eval_model, make_dqn_model, make_env
2525

26+
torch.set_float32_matmul_precision("high")
27+
2628

2729
@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
2830
def main(cfg: "DictConfig"): # noqa: F821

sota-implementations/ppo/config_atari.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ optim:
2525
weight_decay: 0.0
2626
max_grad_norm: 0.5
2727
anneal_lr: True
28+
device:
2829

2930
# loss
3031
loss:
@@ -37,3 +38,8 @@ loss:
3738
critic_coef: 1.0
3839
entropy_coef: 0.01
3940
loss_critic_type: l2
41+
42+
compile:
43+
compile: False
44+
compile_mode:
45+
cudagraphs: False

sota-implementations/ppo/config_mujoco.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ optim:
2222
lr: 3e-4
2323
weight_decay: 0.0
2424
anneal_lr: True
25+
device:
2526

2627
# loss
2728
loss:
@@ -34,3 +35,8 @@ loss:
3435
critic_coef: 0.25
3536
entropy_coef: 0.0
3637
loss_critic_type: l2
38+
39+
compile:
40+
compile: False
41+
compile_mode:
42+
cudagraphs: False

sota-implementations/ppo/ppo_atari.py

+114-73
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,42 @@
99
"""
1010
from __future__ import annotations
1111

12+
import warnings
13+
1214
import hydra
13-
from torchrl._utils import logger as torchrl_logger
14-
from torchrl.record import VideoRecorder
15+
16+
from torchrl._utils import compile_with_warmup
1517

1618

1719
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
1820
def main(cfg: "DictConfig"): # noqa: F821
1921

20-
import time
21-
2222
import torch.optim
2323
import tqdm
2424

2525
from tensordict import TensorDict
26+
from tensordict.nn import CudaGraphModule
27+
28+
from torchrl._utils import timeit
2629
from torchrl.collectors import SyncDataCollector
27-
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
30+
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2831
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
2932
from torchrl.envs import ExplorationType, set_exploration_type
3033
from torchrl.objectives import ClipPPOLoss
3134
from torchrl.objectives.value.advantages import GAE
35+
from torchrl.record import VideoRecorder
3236
from torchrl.record.loggers import generate_exp_name, get_logger
3337
from utils_atari import eval_model, make_parallel_env, make_ppo_models
3438

35-
device = "cpu" if not torch.cuda.device_count() else "cuda"
39+
torch.set_float32_matmul_precision("high")
40+
41+
device = cfg.optim.device
42+
if device in ("", None):
43+
if torch.cuda.is_available():
44+
device = "cuda:0"
45+
else:
46+
device = "cpu"
47+
device = torch.device(device)
3648

3749
# Correct for frame_skip
3850
frame_skip = 4
@@ -41,27 +53,40 @@ def main(cfg: "DictConfig"): # noqa: F821
4153
mini_batch_size = cfg.loss.mini_batch_size // frame_skip
4254
test_interval = cfg.logger.test_interval // frame_skip
4355

56+
compile_mode = None
57+
if cfg.compile.compile:
58+
compile_mode = cfg.compile.compile_mode
59+
if compile_mode in ("", None):
60+
if cfg.compile.cudagraphs:
61+
compile_mode = "default"
62+
else:
63+
compile_mode = "reduce-overhead"
64+
4465
# Create models (check utils_atari.py)
45-
actor, critic = make_ppo_models(cfg.env.env_name)
46-
actor, critic = actor.to(device), critic.to(device)
66+
actor, critic = make_ppo_models(cfg.env.env_name, device=device)
4767

4868
# Create collector
4969
collector = SyncDataCollector(
5070
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"),
5171
policy=actor,
5272
frames_per_batch=frames_per_batch,
5373
total_frames=total_frames,
54-
device="cpu",
55-
storing_device="cpu",
74+
device=device,
75+
storing_device=device,
5676
max_frames_per_traj=-1,
77+
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
78+
cudagraph_policy=cfg.compile.cudagraphs,
5779
)
5880

5981
# Create data buffer
6082
sampler = SamplerWithoutReplacement()
6183
data_buffer = TensorDictReplayBuffer(
62-
storage=LazyMemmapStorage(frames_per_batch),
84+
storage=LazyTensorStorage(
85+
frames_per_batch, compilable=cfg.compile.compile, device=device
86+
),
6387
sampler=sampler,
6488
batch_size=mini_batch_size,
89+
compilable=cfg.compile.compile,
6590
)
6691

6792
# Create loss and adv modules
@@ -70,6 +95,8 @@ def main(cfg: "DictConfig"): # noqa: F821
7095
lmbda=cfg.loss.gae_lambda,
7196
value_network=critic,
7297
average_gae=False,
98+
device=device,
99+
vectorized=not cfg.compile.compile,
73100
)
74101
loss_module = ClipPPOLoss(
75102
actor_network=actor,
@@ -121,15 +148,52 @@ def main(cfg: "DictConfig"): # noqa: F821
121148

122149
# Main loop
123150
collected_frames = 0
124-
num_network_updates = 0
125-
start_time = time.time()
151+
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
126152
pbar = tqdm.tqdm(total=total_frames)
127153
num_mini_batches = frames_per_batch // mini_batch_size
128154
total_network_updates = (
129155
(total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches
130156
)
131157

132-
sampling_start = time.time()
158+
def update(batch, num_network_updates):
159+
optim.zero_grad(set_to_none=True)
160+
161+
# Linearly decrease the learning rate and clip epsilon
162+
alpha = torch.ones((), device=device)
163+
if cfg_optim_anneal_lr:
164+
alpha = 1 - (num_network_updates / total_network_updates)
165+
for group in optim.param_groups:
166+
group["lr"] = cfg_optim_lr * alpha
167+
if cfg_loss_anneal_clip_eps:
168+
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
169+
num_network_updates = num_network_updates + 1
170+
# Get a data batch
171+
batch = batch.to(device, non_blocking=True)
172+
173+
# Forward pass PPO loss
174+
loss = loss_module(batch)
175+
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
176+
# Backward pass
177+
loss_sum.backward()
178+
torch.nn.utils.clip_grad_norm_(
179+
loss_module.parameters(), max_norm=cfg_optim_max_grad_norm
180+
)
181+
182+
# Update the networks
183+
optim.step()
184+
return loss.detach().set("alpha", alpha), num_network_updates
185+
186+
if cfg.compile.compile:
187+
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
188+
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
189+
190+
if cfg.compile.cudagraphs:
191+
warnings.warn(
192+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
193+
category=UserWarning,
194+
)
195+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
196+
adv_module = CudaGraphModule(adv_module)
133197

134198
# extract cfg variables
135199
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
@@ -142,13 +206,16 @@ def main(cfg: "DictConfig"): # noqa: F821
142206
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
143207
losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
144208

145-
for i, data in enumerate(collector):
209+
collector_iter = iter(collector)
210+
211+
for i in range(len(collector)):
212+
with timeit("collecting"):
213+
data = next(collector_iter)
146214

147215
log_info = {}
148-
sampling_time = time.time() - sampling_start
149216
frames_in_batch = data.numel()
150217
collected_frames += frames_in_batch * frame_skip
151-
pbar.update(data.numel())
218+
pbar.update(frames_in_batch)
152219

153220
# Get training rewards and episode lengths
154221
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
@@ -162,96 +229,70 @@ def main(cfg: "DictConfig"): # noqa: F821
162229
}
163230
)
164231

165-
training_start = time.time()
166-
for j in range(cfg_loss_ppo_epochs):
167-
168-
# Compute GAE
169-
with torch.no_grad():
170-
data = adv_module(data.to(device, non_blocking=True))
171-
data_reshape = data.reshape(-1)
172-
# Update the data buffer
173-
data_buffer.extend(data_reshape)
174-
175-
for k, batch in enumerate(data_buffer):
176-
177-
# Linearly decrease the learning rate and clip epsilon
178-
alpha = 1.0
179-
if cfg_optim_anneal_lr:
180-
alpha = 1 - (num_network_updates / total_network_updates)
181-
for group in optim.param_groups:
182-
group["lr"] = cfg_optim_lr * alpha
183-
if cfg_loss_anneal_clip_eps:
184-
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
185-
num_network_updates += 1
186-
# Get a data batch
187-
batch = batch.to(device, non_blocking=True)
188-
189-
# Forward pass PPO loss
190-
loss = loss_module(batch)
191-
losses[j, k] = loss.select(
192-
"loss_critic", "loss_entropy", "loss_objective"
193-
).detach()
194-
loss_sum = (
195-
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
196-
)
197-
# Backward pass
198-
loss_sum.backward()
199-
torch.nn.utils.clip_grad_norm_(
200-
list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm
201-
)
202-
203-
# Update the networks
204-
optim.step()
205-
optim.zero_grad()
232+
with timeit("training"):
233+
for j in range(cfg_loss_ppo_epochs):
234+
235+
# Compute GAE
236+
with torch.no_grad(), timeit("adv"):
237+
torch.compiler.cudagraph_mark_step_begin()
238+
data = adv_module(data)
239+
if compile_mode:
240+
data = data.clone()
241+
with timeit("rb - extend"):
242+
# Update the data buffer
243+
data_reshape = data.reshape(-1)
244+
data_buffer.extend(data_reshape)
245+
246+
for k, batch in enumerate(data_buffer):
247+
torch.compiler.cudagraph_mark_step_begin()
248+
loss, num_network_updates = update(
249+
batch, num_network_updates=num_network_updates
250+
)
251+
loss = loss.clone()
252+
num_network_updates = num_network_updates.clone()
253+
losses[j, k] = loss.select(
254+
"loss_critic", "loss_entropy", "loss_objective"
255+
)
206256

207257
# Get training losses and times
208-
training_time = time.time() - training_start
209258
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
210259
for key, value in losses_mean.items():
211260
log_info.update({f"train/{key}": value.item()})
212261
log_info.update(
213262
{
214-
"train/lr": alpha * cfg_optim_lr,
215-
"train/sampling_time": sampling_time,
216-
"train/training_time": training_time,
217-
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
263+
"train/lr": loss["alpha"] * cfg_optim_lr,
264+
"train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon,
218265
}
219266
)
220267

221268
# Get test rewards
222-
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
269+
with torch.no_grad(), set_exploration_type(
270+
ExplorationType.DETERMINISTIC
271+
), timeit("eval"):
223272
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
224273
i * frames_in_batch * frame_skip
225274
) // test_interval:
226275
actor.eval()
227-
eval_start = time.time()
228276
test_rewards = eval_model(
229277
actor, test_env, num_episodes=cfg_logger_num_test_episodes
230278
)
231-
eval_time = time.time() - eval_start
232279
log_info.update(
233280
{
234281
"eval/reward": test_rewards.mean(),
235-
"eval/time": eval_time,
236282
}
237283
)
238284
actor.train()
239-
240285
if logger:
286+
log_info.update(timeit.todict(prefix="time"))
241287
for key, value in log_info.items():
242288
logger.log_scalar(key, value, collected_frames)
243289

244290
collector.update_policy_weights_()
245-
sampling_start = time.time()
246291

247292
collector.shutdown()
248293
if not test_env.is_closed:
249294
test_env.close()
250295

251-
end_time = time.time()
252-
execution_time = end_time - start_time
253-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
254-
255296

256297
if __name__ == "__main__":
257298
main()

0 commit comments

Comments
 (0)