@@ -92,6 +92,8 @@ def log(
92
92
on_step : bool = False ,
93
93
on_epoch : bool = True ,
94
94
reduce_fx : Callable = torch .mean ,
95
+ tbptt_reduce_fx : Callable = torch .mean ,
96
+ tbptt_pad_token : int = 0 ,
95
97
enable_graph : bool = False ,
96
98
sync_ddp : bool = False ,
97
99
sync_ddp_op : Union [Any , str ] = 'mean' ,
@@ -113,15 +115,22 @@ def log(
113
115
if on_step and on_epoch :
114
116
# set step version
115
117
step_name = f'step_{ name } '
116
- self .__set_meta (step_name , value , prog_bar , logger , on_step = True , on_epoch = False , reduce_fx = reduce_fx )
118
+ self .__set_meta (step_name , value , prog_bar , logger ,
119
+ on_step = True , on_epoch = False ,
120
+ reduce_fx = reduce_fx , tbptt_reduce_fx = tbptt_reduce_fx , tbptt_pad_token = tbptt_pad_token )
117
121
self .__setitem__ (step_name , value )
118
122
119
123
# set epoch version
120
124
epoch_name = f'epoch_{ name } '
121
- self .__set_meta (epoch_name , value , prog_bar , logger , on_step = False , on_epoch = True , reduce_fx = reduce_fx )
125
+ self .__set_meta (epoch_name , value , prog_bar , logger , on_step = False , on_epoch = True ,
126
+ reduce_fx = reduce_fx , tbptt_reduce_fx = tbptt_reduce_fx , tbptt_pad_token = tbptt_pad_token )
122
127
self .__setitem__ (epoch_name , value )
123
128
else :
124
- self .__set_meta (name , value , prog_bar , logger , on_step , on_epoch , reduce_fx )
129
+ self .__set_meta (name , value ,
130
+ prog_bar , logger ,
131
+ on_step , on_epoch ,
132
+ reduce_fx ,
133
+ tbptt_reduce_fx = tbptt_reduce_fx , tbptt_pad_token = tbptt_pad_token )
125
134
126
135
# set the value
127
136
self .__setitem__ (name , value )
@@ -135,6 +144,8 @@ def __set_meta(
135
144
on_step : bool ,
136
145
on_epoch : bool ,
137
146
reduce_fx : Callable ,
147
+ tbptt_pad_token : int ,
148
+ tbptt_reduce_fx : Callable
138
149
):
139
150
# set the meta for the item
140
151
meta_value = value
@@ -144,7 +155,9 @@ def __set_meta(
144
155
on_step = on_step ,
145
156
on_epoch = on_epoch ,
146
157
reduce_fx = reduce_fx ,
147
- value = meta_value
158
+ value = meta_value ,
159
+ tbptt_reduce_fx = tbptt_reduce_fx ,
160
+ tbptt_pad_token = tbptt_pad_token
148
161
)
149
162
150
163
self ['meta' ][name ] = meta
@@ -253,6 +266,39 @@ def gather(cls, outputs):
253
266
result ['meta' ] = meta
254
267
return result
255
268
269
+ @classmethod
270
+ def padded_gather (cls , outputs ):
271
+ meta = outputs [0 ].get ('meta' )
272
+ result = cls ()
273
+ result = recursive_gather (outputs , result )
274
+
275
+ # find the padding used for other values
276
+ default_padding_idx = 0
277
+ for name , value in result .items ():
278
+ if isinstance (value , list ) and len (value ) > 0 and isinstance (value [0 ], torch .Tensor ):
279
+ if name not in {'checkpoint_on' , 'early_stop_on' , 'minimize' }:
280
+ default_padding_idx = meta [name ]['tbptt_pad_token' ]
281
+ break
282
+
283
+ # pad across each key individually
284
+ for name , value in result .items ():
285
+ is_reserved = name in {'checkpoint_on' , 'early_stop_on' , 'minimize' }
286
+ if isinstance (value , list ) and len (value ) > 0 and isinstance (value [0 ], torch .Tensor ):
287
+
288
+ if is_reserved :
289
+ padding_key = default_padding_idx
290
+ else :
291
+ padding_key = meta [name ]['tbptt_pad_token' ]
292
+ padded = torch .nn .utils .rnn .pad_sequence (value , batch_first = True , padding_value = padding_key )
293
+ result [name ] = padded
294
+
295
+ # also update the result
296
+ if meta and not is_reserved :
297
+ meta [name ]['value' ] = padded
298
+ if meta :
299
+ result ['meta' ] = meta
300
+ return result
301
+
256
302
@classmethod
257
303
def reduce_on_epoch_end (cls , outputs ):
258
304
meta = outputs [0 ]['meta' ]
@@ -271,10 +317,36 @@ def reduce_on_epoch_end(cls, outputs):
271
317
result ['meta' ] = meta
272
318
return result
273
319
320
+ @classmethod
321
+ def reduce_across_time (cls , time_outputs ):
322
+ # auto-reduce across time for tbptt
323
+ meta = time_outputs [0 ]['meta' ]
324
+ result = cls ()
325
+ result = recursive_gather (time_outputs , result )
326
+ recursive_stack (result )
327
+
328
+ for k , value in result .items ():
329
+ if k == 'meta' :
330
+ continue
331
+
332
+ # pick the reduce fx
333
+ if k in ['checkpoint_on' , 'early_stop_on' , 'minimize' ]:
334
+ tbptt_reduce_fx = torch .mean
335
+ else :
336
+ tbptt_reduce_fx = meta [k ]['tbptt_reduce_fx' ]
337
+ result [k ] = tbptt_reduce_fx (value )
338
+
339
+ result ['meta' ] = meta
340
+ return result
341
+
274
342
@property
275
343
def should_reduce_on_epoch_end (self ) -> bool :
276
344
return self ['meta' ]['_internal' ]['_reduce_on_epoch' ]
277
345
346
+ def drop_hiddens (self ):
347
+ if 'hiddens' in self :
348
+ del self ['hiddens' ]
349
+
278
350
279
351
def recursive_gather (outputs : Sequence [dict ], result : Optional [MutableMapping ] = None ) -> Optional [MutableMapping ]:
280
352
for out in outputs :
@@ -303,6 +375,16 @@ def recursive_stack(result: MutableMapping):
303
375
result [k ] = v
304
376
305
377
378
+ def recursive_padded_stack (result : MutableMapping ):
379
+ for k , v in result .items ():
380
+ if isinstance (v , dict ):
381
+ recursive_stack (v )
382
+
383
+ if isinstance (v , list ) and len (v ) > 0 and isinstance (v [0 ], torch .Tensor ):
384
+ v = torch .stack (v )
385
+ result [k ] = v
386
+
387
+
306
388
class TrainResult (Result ):
307
389
308
390
def __init__ (
@@ -348,6 +430,8 @@ def log(
348
430
on_step : bool = True ,
349
431
on_epoch : bool = False ,
350
432
reduce_fx : Callable = torch .mean ,
433
+ tbptt_reduce_fx : Callable = torch .mean ,
434
+ tbptt_pad_token : int = 0 ,
351
435
enable_graph : bool = False ,
352
436
sync_ddp : bool = False ,
353
437
sync_ddp_op : Union [Any , str ] = 'mean' ,
@@ -381,10 +465,26 @@ def log(
381
465
on_step: if True logs the output of validation_step or test_step
382
466
on_epoch: if True, logs the output of the training loop aggregated
383
467
reduce_fx: Torch.mean by default
468
+ tbptt_reduce_fx: function to reduce on truncated back prop
469
+ tbptt_pad_token: token to use for padding
384
470
enable_graph: if True, will not auto detach the graph
471
+ sync_ddp: if True, reduces the metric across GPUs/TPUs
472
+ sync_ddp_op: the op to sync across
473
+ sync_ddp_group: the ddp group
385
474
"""
386
- super ().log (name , value , prog_bar , logger , on_step , on_epoch , reduce_fx , enable_graph ,
387
- sync_ddp = sync_ddp , sync_ddp_group = sync_ddp_group , sync_ddp_op = sync_ddp_op )
475
+ super ().log (name = name ,
476
+ value = value ,
477
+ prog_bar = prog_bar ,
478
+ logger = logger ,
479
+ on_step = on_step ,
480
+ on_epoch = on_epoch ,
481
+ reduce_fx = reduce_fx ,
482
+ enable_graph = enable_graph ,
483
+ sync_ddp = sync_ddp ,
484
+ sync_ddp_group = sync_ddp_group ,
485
+ sync_ddp_op = sync_ddp_op ,
486
+ tbptt_pad_token = tbptt_pad_token ,
487
+ tbptt_reduce_fx = tbptt_reduce_fx )
388
488
389
489
def log_dict (
390
490
self ,
@@ -394,6 +494,8 @@ def log_dict(
394
494
on_step : bool = False ,
395
495
on_epoch : bool = True ,
396
496
reduce_fx : Callable = torch .mean ,
497
+ tbptt_reduce_fx : Callable = torch .mean ,
498
+ tbptt_pad_token : int = 0 ,
397
499
enable_graph : bool = False ,
398
500
sync_ddp : bool = False ,
399
501
sync_ddp_op : Union [Any , str ] = 'mean' ,
@@ -408,17 +510,33 @@ def log_dict(
408
510
result.log_dict(values)
409
511
410
512
Args:
411
- dictionary:
412
- prog_bar:
413
- logger:
414
- on_step:
415
- on_epoch:
416
- reduce_fx:
417
- enable_graph:
513
+ dictionary: key value pairs (str, tensors)
514
+ prog_bar: if True logs to the progress base
515
+ logger: if True logs to the logger
516
+ on_step: if True logs the output of validation_step or test_step
517
+ on_epoch: if True, logs the output of the training loop aggregated
518
+ reduce_fx: Torch.mean by default
519
+ tbptt_reduce_fx: function to reduce on truncated back prop
520
+ tbptt_pad_token: token to use for padding
521
+ enable_graph: if True, will not auto detach the graph
522
+ sync_ddp: if True, reduces the metric across GPUs/TPUs
523
+ sync_ddp_op: the op to sync across
524
+ sync_ddp_group: the ddp group:
418
525
"""
419
526
for k , v in dictionary .items ():
420
- self .log (k , v , prog_bar , logger , on_step , on_epoch , reduce_fx , enable_graph ,
421
- sync_ddp = sync_ddp , sync_ddp_group = sync_ddp_group , sync_ddp_op = sync_ddp_op )
527
+ self .log (name = k ,
528
+ value = v ,
529
+ prog_bar = prog_bar ,
530
+ logger = logger ,
531
+ on_step = on_step ,
532
+ on_epoch = on_epoch ,
533
+ reduce_fx = reduce_fx ,
534
+ enable_graph = enable_graph ,
535
+ sync_ddp = sync_ddp ,
536
+ sync_ddp_group = sync_ddp_group ,
537
+ sync_ddp_op = sync_ddp_op ,
538
+ tbptt_pad_token = tbptt_pad_token ,
539
+ tbptt_reduce_fx = tbptt_reduce_fx )
422
540
423
541
424
542
class EvalResult (Result ):
@@ -464,6 +582,8 @@ def log(
464
582
on_step : bool = False ,
465
583
on_epoch : bool = True ,
466
584
reduce_fx : Callable = torch .mean ,
585
+ tbptt_reduce_fx : Callable = torch .mean ,
586
+ tbptt_pad_token : int = 0 ,
467
587
enable_graph : bool = False ,
468
588
sync_ddp : bool = False ,
469
589
sync_ddp_op : Union [Any , str ] = 'mean' ,
@@ -494,12 +614,28 @@ def log(
494
614
prog_bar: if True logs to the progress base
495
615
logger: if True logs to the logger
496
616
on_step: if True logs the output of validation_step or test_step
497
- on_epoch: if True, logs the output of the validation loop or test loop aggregated
617
+ on_epoch: if True, logs the output of the training loop aggregated
498
618
reduce_fx: Torch.mean by default
499
- enable_graph: if True, will not auto detach the graph :
619
+ tbptt_reduce_fx: function to reduce on truncated back prop
620
+ tbptt_pad_token: token to use for padding
621
+ enable_graph: if True, will not auto detach the graph
622
+ sync_ddp: if True, reduces the metric across GPUs/TPUs
623
+ sync_ddp_op: the op to sync across
624
+ sync_ddp_group: the ddp group
500
625
"""
501
- super ().log (name , value , prog_bar , logger , on_step , on_epoch , reduce_fx , enable_graph ,
502
- sync_ddp = sync_ddp , sync_ddp_group = sync_ddp_group , sync_ddp_op = sync_ddp_op )
626
+ super ().log (name = name ,
627
+ value = value ,
628
+ prog_bar = prog_bar ,
629
+ logger = logger ,
630
+ on_step = on_step ,
631
+ on_epoch = on_epoch ,
632
+ reduce_fx = reduce_fx ,
633
+ enable_graph = enable_graph ,
634
+ sync_ddp = sync_ddp ,
635
+ sync_ddp_group = sync_ddp_group ,
636
+ sync_ddp_op = sync_ddp_op ,
637
+ tbptt_pad_token = tbptt_pad_token ,
638
+ tbptt_reduce_fx = tbptt_reduce_fx )
503
639
504
640
def log_dict (
505
641
self ,
@@ -509,6 +645,8 @@ def log_dict(
509
645
on_step : bool = False ,
510
646
on_epoch : bool = True ,
511
647
reduce_fx : Callable = torch .mean ,
648
+ tbptt_reduce_fx : Callable = torch .mean ,
649
+ tbptt_pad_token : int = 0 ,
512
650
enable_graph : bool = False ,
513
651
sync_ddp : bool = False ,
514
652
sync_ddp_op : Union [Any , str ] = 'mean' ,
@@ -523,17 +661,33 @@ def log_dict(
523
661
result.log_dict(values)
524
662
525
663
Args:
526
- dictionary:
527
- prog_bar:
528
- logger:
529
- on_step:
530
- on_epoch:
531
- reduce_fx:
532
- enable_graph:
664
+ dictionary: key value pairs (str, tensors)
665
+ prog_bar: if True logs to the progress base
666
+ logger: if True logs to the logger
667
+ on_step: if True logs the output of validation_step or test_step
668
+ on_epoch: if True, logs the output of the training loop aggregated
669
+ reduce_fx: Torch.mean by default
670
+ tbptt_reduce_fx: function to reduce on truncated back prop
671
+ tbptt_pad_token: token to use for padding
672
+ enable_graph: if True, will not auto detach the graph
673
+ sync_ddp: if True, reduces the metric across GPUs/TPUs
674
+ sync_ddp_op: the op to sync across
675
+ sync_ddp_group: the ddp group
533
676
"""
534
677
for k , v in dictionary .items ():
535
- self .log (k , v , prog_bar , logger , on_step , on_epoch , reduce_fx , enable_graph ,
536
- sync_ddp = sync_ddp , sync_ddp_group = sync_ddp_group , sync_ddp_op = sync_ddp_op )
678
+ self .log (name = k ,
679
+ value = v ,
680
+ prog_bar = prog_bar ,
681
+ logger = logger ,
682
+ on_step = on_step ,
683
+ on_epoch = on_epoch ,
684
+ reduce_fx = reduce_fx ,
685
+ enable_graph = enable_graph ,
686
+ sync_ddp = sync_ddp ,
687
+ sync_ddp_group = sync_ddp_group ,
688
+ sync_ddp_op = sync_ddp_op ,
689
+ tbptt_pad_token = tbptt_pad_token ,
690
+ tbptt_reduce_fx = tbptt_reduce_fx )
537
691
538
692
def get_callback_metrics (self ) -> dict :
539
693
result = {
0 commit comments