9
9
"""
10
10
from __future__ import annotations
11
11
12
+ import warnings
13
+
12
14
import hydra
13
- from torchrl . _utils import logger as torchrl_logger
14
- from torchrl .record import VideoRecorder
15
+
16
+ from torchrl ._utils import compile_with_warmup
15
17
16
18
17
19
@hydra .main (config_path = "" , config_name = "config_atari" , version_base = "1.1" )
18
20
def main (cfg : "DictConfig" ): # noqa: F821
19
21
20
- import time
21
-
22
22
import torch .optim
23
23
import tqdm
24
24
25
25
from tensordict import TensorDict
26
+ from tensordict .nn import CudaGraphModule
27
+
28
+ from torchrl ._utils import timeit
26
29
from torchrl .collectors import SyncDataCollector
27
- from torchrl .data import LazyMemmapStorage , TensorDictReplayBuffer
30
+ from torchrl .data import LazyTensorStorage , TensorDictReplayBuffer
28
31
from torchrl .data .replay_buffers .samplers import SamplerWithoutReplacement
29
32
from torchrl .envs import ExplorationType , set_exploration_type
30
33
from torchrl .objectives import ClipPPOLoss
31
34
from torchrl .objectives .value .advantages import GAE
35
+ from torchrl .record import VideoRecorder
32
36
from torchrl .record .loggers import generate_exp_name , get_logger
33
37
from utils_atari import eval_model , make_parallel_env , make_ppo_models
34
38
35
- device = "cpu" if not torch .cuda .device_count () else "cuda"
39
+ torch .set_float32_matmul_precision ("high" )
40
+
41
+ device = cfg .optim .device
42
+ if device in ("" , None ):
43
+ if torch .cuda .is_available ():
44
+ device = "cuda:0"
45
+ else :
46
+ device = "cpu"
47
+ device = torch .device (device )
36
48
37
49
# Correct for frame_skip
38
50
frame_skip = 4
@@ -41,27 +53,40 @@ def main(cfg: "DictConfig"): # noqa: F821
41
53
mini_batch_size = cfg .loss .mini_batch_size // frame_skip
42
54
test_interval = cfg .logger .test_interval // frame_skip
43
55
56
+ compile_mode = None
57
+ if cfg .compile .compile :
58
+ compile_mode = cfg .compile .compile_mode
59
+ if compile_mode in ("" , None ):
60
+ if cfg .compile .cudagraphs :
61
+ compile_mode = "default"
62
+ else :
63
+ compile_mode = "reduce-overhead"
64
+
44
65
# Create models (check utils_atari.py)
45
- actor , critic = make_ppo_models (cfg .env .env_name )
46
- actor , critic = actor .to (device ), critic .to (device )
66
+ actor , critic = make_ppo_models (cfg .env .env_name , device = device )
47
67
48
68
# Create collector
49
69
collector = SyncDataCollector (
50
70
create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , "cpu" ),
51
71
policy = actor ,
52
72
frames_per_batch = frames_per_batch ,
53
73
total_frames = total_frames ,
54
- device = "cpu" ,
55
- storing_device = "cpu" ,
74
+ device = device ,
75
+ storing_device = device ,
56
76
max_frames_per_traj = - 1 ,
77
+ compile_policy = {"mode" : compile_mode , "warmup" : 1 } if compile_mode else False ,
78
+ cudagraph_policy = cfg .compile .cudagraphs ,
57
79
)
58
80
59
81
# Create data buffer
60
82
sampler = SamplerWithoutReplacement ()
61
83
data_buffer = TensorDictReplayBuffer (
62
- storage = LazyMemmapStorage (frames_per_batch ),
84
+ storage = LazyTensorStorage (
85
+ frames_per_batch , compilable = cfg .compile .compile , device = device
86
+ ),
63
87
sampler = sampler ,
64
88
batch_size = mini_batch_size ,
89
+ compilable = cfg .compile .compile ,
65
90
)
66
91
67
92
# Create loss and adv modules
@@ -70,6 +95,8 @@ def main(cfg: "DictConfig"): # noqa: F821
70
95
lmbda = cfg .loss .gae_lambda ,
71
96
value_network = critic ,
72
97
average_gae = False ,
98
+ device = device ,
99
+ vectorized = not cfg .compile .compile ,
73
100
)
74
101
loss_module = ClipPPOLoss (
75
102
actor_network = actor ,
@@ -121,15 +148,52 @@ def main(cfg: "DictConfig"): # noqa: F821
121
148
122
149
# Main loop
123
150
collected_frames = 0
124
- num_network_updates = 0
125
- start_time = time .time ()
151
+ num_network_updates = torch .zeros ((), dtype = torch .int64 , device = device )
126
152
pbar = tqdm .tqdm (total = total_frames )
127
153
num_mini_batches = frames_per_batch // mini_batch_size
128
154
total_network_updates = (
129
155
(total_frames // frames_per_batch ) * cfg .loss .ppo_epochs * num_mini_batches
130
156
)
131
157
132
- sampling_start = time .time ()
158
+ def update (batch , num_network_updates ):
159
+ optim .zero_grad (set_to_none = True )
160
+
161
+ # Linearly decrease the learning rate and clip epsilon
162
+ alpha = torch .ones ((), device = device )
163
+ if cfg_optim_anneal_lr :
164
+ alpha = 1 - (num_network_updates / total_network_updates )
165
+ for group in optim .param_groups :
166
+ group ["lr" ] = cfg_optim_lr * alpha
167
+ if cfg_loss_anneal_clip_eps :
168
+ loss_module .clip_epsilon .copy_ (cfg_loss_clip_epsilon * alpha )
169
+ num_network_updates = num_network_updates + 1
170
+ # Get a data batch
171
+ batch = batch .to (device , non_blocking = True )
172
+
173
+ # Forward pass PPO loss
174
+ loss = loss_module (batch )
175
+ loss_sum = loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
176
+ # Backward pass
177
+ loss_sum .backward ()
178
+ torch .nn .utils .clip_grad_norm_ (
179
+ loss_module .parameters (), max_norm = cfg_optim_max_grad_norm
180
+ )
181
+
182
+ # Update the networks
183
+ optim .step ()
184
+ return loss .detach ().set ("alpha" , alpha ), num_network_updates
185
+
186
+ if cfg .compile .compile :
187
+ update = compile_with_warmup (update , mode = compile_mode , warmup = 1 )
188
+ adv_module = compile_with_warmup (adv_module , mode = compile_mode , warmup = 1 )
189
+
190
+ if cfg .compile .cudagraphs :
191
+ warnings .warn (
192
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
193
+ category = UserWarning ,
194
+ )
195
+ update = CudaGraphModule (update , in_keys = [], out_keys = [], warmup = 5 )
196
+ adv_module = CudaGraphModule (adv_module )
133
197
134
198
# extract cfg variables
135
199
cfg_loss_ppo_epochs = cfg .loss .ppo_epochs
@@ -142,13 +206,16 @@ def main(cfg: "DictConfig"): # noqa: F821
142
206
cfg .loss .clip_epsilon = cfg_loss_clip_epsilon
143
207
losses = TensorDict (batch_size = [cfg_loss_ppo_epochs , num_mini_batches ])
144
208
145
- for i , data in enumerate (collector ):
209
+ collector_iter = iter (collector )
210
+
211
+ for i in range (len (collector )):
212
+ with timeit ("collecting" ):
213
+ data = next (collector_iter )
146
214
147
215
log_info = {}
148
- sampling_time = time .time () - sampling_start
149
216
frames_in_batch = data .numel ()
150
217
collected_frames += frames_in_batch * frame_skip
151
- pbar .update (data . numel () )
218
+ pbar .update (frames_in_batch )
152
219
153
220
# Get training rewards and episode lengths
154
221
episode_rewards = data ["next" , "episode_reward" ][data ["next" , "terminated" ]]
@@ -162,96 +229,70 @@ def main(cfg: "DictConfig"): # noqa: F821
162
229
}
163
230
)
164
231
165
- training_start = time .time ()
166
- for j in range (cfg_loss_ppo_epochs ):
167
-
168
- # Compute GAE
169
- with torch .no_grad ():
170
- data = adv_module (data .to (device , non_blocking = True ))
171
- data_reshape = data .reshape (- 1 )
172
- # Update the data buffer
173
- data_buffer .extend (data_reshape )
174
-
175
- for k , batch in enumerate (data_buffer ):
176
-
177
- # Linearly decrease the learning rate and clip epsilon
178
- alpha = 1.0
179
- if cfg_optim_anneal_lr :
180
- alpha = 1 - (num_network_updates / total_network_updates )
181
- for group in optim .param_groups :
182
- group ["lr" ] = cfg_optim_lr * alpha
183
- if cfg_loss_anneal_clip_eps :
184
- loss_module .clip_epsilon .copy_ (cfg_loss_clip_epsilon * alpha )
185
- num_network_updates += 1
186
- # Get a data batch
187
- batch = batch .to (device , non_blocking = True )
188
-
189
- # Forward pass PPO loss
190
- loss = loss_module (batch )
191
- losses [j , k ] = loss .select (
192
- "loss_critic" , "loss_entropy" , "loss_objective"
193
- ).detach ()
194
- loss_sum = (
195
- loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
196
- )
197
- # Backward pass
198
- loss_sum .backward ()
199
- torch .nn .utils .clip_grad_norm_ (
200
- list (loss_module .parameters ()), max_norm = cfg_optim_max_grad_norm
201
- )
202
-
203
- # Update the networks
204
- optim .step ()
205
- optim .zero_grad ()
232
+ with timeit ("training" ):
233
+ for j in range (cfg_loss_ppo_epochs ):
234
+
235
+ # Compute GAE
236
+ with torch .no_grad (), timeit ("adv" ):
237
+ torch .compiler .cudagraph_mark_step_begin ()
238
+ data = adv_module (data )
239
+ if compile_mode :
240
+ data = data .clone ()
241
+ with timeit ("rb - extend" ):
242
+ # Update the data buffer
243
+ data_reshape = data .reshape (- 1 )
244
+ data_buffer .extend (data_reshape )
245
+
246
+ for k , batch in enumerate (data_buffer ):
247
+ torch .compiler .cudagraph_mark_step_begin ()
248
+ loss , num_network_updates = update (
249
+ batch , num_network_updates = num_network_updates
250
+ )
251
+ loss = loss .clone ()
252
+ num_network_updates = num_network_updates .clone ()
253
+ losses [j , k ] = loss .select (
254
+ "loss_critic" , "loss_entropy" , "loss_objective"
255
+ )
206
256
207
257
# Get training losses and times
208
- training_time = time .time () - training_start
209
258
losses_mean = losses .apply (lambda x : x .float ().mean (), batch_size = [])
210
259
for key , value in losses_mean .items ():
211
260
log_info .update ({f"train/{ key } " : value .item ()})
212
261
log_info .update (
213
262
{
214
- "train/lr" : alpha * cfg_optim_lr ,
215
- "train/sampling_time" : sampling_time ,
216
- "train/training_time" : training_time ,
217
- "train/clip_epsilon" : alpha * cfg_loss_clip_epsilon ,
263
+ "train/lr" : loss ["alpha" ] * cfg_optim_lr ,
264
+ "train/clip_epsilon" : loss ["alpha" ] * cfg_loss_clip_epsilon ,
218
265
}
219
266
)
220
267
221
268
# Get test rewards
222
- with torch .no_grad (), set_exploration_type (ExplorationType .DETERMINISTIC ):
269
+ with torch .no_grad (), set_exploration_type (
270
+ ExplorationType .DETERMINISTIC
271
+ ), timeit ("eval" ):
223
272
if ((i - 1 ) * frames_in_batch * frame_skip ) // test_interval < (
224
273
i * frames_in_batch * frame_skip
225
274
) // test_interval :
226
275
actor .eval ()
227
- eval_start = time .time ()
228
276
test_rewards = eval_model (
229
277
actor , test_env , num_episodes = cfg_logger_num_test_episodes
230
278
)
231
- eval_time = time .time () - eval_start
232
279
log_info .update (
233
280
{
234
281
"eval/reward" : test_rewards .mean (),
235
- "eval/time" : eval_time ,
236
282
}
237
283
)
238
284
actor .train ()
239
-
240
285
if logger :
286
+ log_info .update (timeit .todict (prefix = "time" ))
241
287
for key , value in log_info .items ():
242
288
logger .log_scalar (key , value , collected_frames )
243
289
244
290
collector .update_policy_weights_ ()
245
- sampling_start = time .time ()
246
291
247
292
collector .shutdown ()
248
293
if not test_env .is_closed :
249
294
test_env .close ()
250
295
251
- end_time = time .time ()
252
- execution_time = end_time - start_time
253
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
254
-
255
296
256
297
if __name__ == "__main__" :
257
298
main ()
0 commit comments