Skip to content

Commit 256059a

Browse files
tracks all outputs including TBPTT and multiple optimizers (#2890)
* pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update * pl 0.9 update
1 parent 4d0406e commit 256059a

7 files changed

+486
-50
lines changed

pytorch_lightning/core/step_result.py

+182-28
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def log(
9292
on_step: bool = False,
9393
on_epoch: bool = True,
9494
reduce_fx: Callable = torch.mean,
95+
tbptt_reduce_fx: Callable = torch.mean,
96+
tbptt_pad_token: int = 0,
9597
enable_graph: bool = False,
9698
sync_ddp: bool = False,
9799
sync_ddp_op: Union[Any, str] = 'mean',
@@ -113,15 +115,22 @@ def log(
113115
if on_step and on_epoch:
114116
# set step version
115117
step_name = f'step_{name}'
116-
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
118+
self.__set_meta(step_name, value, prog_bar, logger,
119+
on_step=True, on_epoch=False,
120+
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
117121
self.__setitem__(step_name, value)
118122

119123
# set epoch version
120124
epoch_name = f'epoch_{name}'
121-
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
125+
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True,
126+
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
122127
self.__setitem__(epoch_name, value)
123128
else:
124-
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
129+
self.__set_meta(name, value,
130+
prog_bar, logger,
131+
on_step, on_epoch,
132+
reduce_fx,
133+
tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
125134

126135
# set the value
127136
self.__setitem__(name, value)
@@ -135,6 +144,8 @@ def __set_meta(
135144
on_step: bool,
136145
on_epoch: bool,
137146
reduce_fx: Callable,
147+
tbptt_pad_token: int,
148+
tbptt_reduce_fx: Callable
138149
):
139150
# set the meta for the item
140151
meta_value = value
@@ -144,7 +155,9 @@ def __set_meta(
144155
on_step=on_step,
145156
on_epoch=on_epoch,
146157
reduce_fx=reduce_fx,
147-
value=meta_value
158+
value=meta_value,
159+
tbptt_reduce_fx=tbptt_reduce_fx,
160+
tbptt_pad_token=tbptt_pad_token
148161
)
149162

150163
self['meta'][name] = meta
@@ -253,6 +266,39 @@ def gather(cls, outputs):
253266
result['meta'] = meta
254267
return result
255268

269+
@classmethod
270+
def padded_gather(cls, outputs):
271+
meta = outputs[0].get('meta')
272+
result = cls()
273+
result = recursive_gather(outputs, result)
274+
275+
# find the padding used for other values
276+
default_padding_idx = 0
277+
for name, value in result.items():
278+
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
279+
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
280+
default_padding_idx = meta[name]['tbptt_pad_token']
281+
break
282+
283+
# pad across each key individually
284+
for name, value in result.items():
285+
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
286+
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
287+
288+
if is_reserved:
289+
padding_key = default_padding_idx
290+
else:
291+
padding_key = meta[name]['tbptt_pad_token']
292+
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
293+
result[name] = padded
294+
295+
# also update the result
296+
if meta and not is_reserved:
297+
meta[name]['value'] = padded
298+
if meta:
299+
result['meta'] = meta
300+
return result
301+
256302
@classmethod
257303
def reduce_on_epoch_end(cls, outputs):
258304
meta = outputs[0]['meta']
@@ -271,10 +317,36 @@ def reduce_on_epoch_end(cls, outputs):
271317
result['meta'] = meta
272318
return result
273319

320+
@classmethod
321+
def reduce_across_time(cls, time_outputs):
322+
# auto-reduce across time for tbptt
323+
meta = time_outputs[0]['meta']
324+
result = cls()
325+
result = recursive_gather(time_outputs, result)
326+
recursive_stack(result)
327+
328+
for k, value in result.items():
329+
if k == 'meta':
330+
continue
331+
332+
# pick the reduce fx
333+
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
334+
tbptt_reduce_fx = torch.mean
335+
else:
336+
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
337+
result[k] = tbptt_reduce_fx(value)
338+
339+
result['meta'] = meta
340+
return result
341+
274342
@property
275343
def should_reduce_on_epoch_end(self) -> bool:
276344
return self['meta']['_internal']['_reduce_on_epoch']
277345

346+
def drop_hiddens(self):
347+
if 'hiddens' in self:
348+
del self['hiddens']
349+
278350

279351
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
280352
for out in outputs:
@@ -303,6 +375,16 @@ def recursive_stack(result: MutableMapping):
303375
result[k] = v
304376

305377

378+
def recursive_padded_stack(result: MutableMapping):
379+
for k, v in result.items():
380+
if isinstance(v, dict):
381+
recursive_stack(v)
382+
383+
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
384+
v = torch.stack(v)
385+
result[k] = v
386+
387+
306388
class TrainResult(Result):
307389

308390
def __init__(
@@ -348,6 +430,8 @@ def log(
348430
on_step: bool = True,
349431
on_epoch: bool = False,
350432
reduce_fx: Callable = torch.mean,
433+
tbptt_reduce_fx: Callable = torch.mean,
434+
tbptt_pad_token: int = 0,
351435
enable_graph: bool = False,
352436
sync_ddp: bool = False,
353437
sync_ddp_op: Union[Any, str] = 'mean',
@@ -381,10 +465,26 @@ def log(
381465
on_step: if True logs the output of validation_step or test_step
382466
on_epoch: if True, logs the output of the training loop aggregated
383467
reduce_fx: Torch.mean by default
468+
tbptt_reduce_fx: function to reduce on truncated back prop
469+
tbptt_pad_token: token to use for padding
384470
enable_graph: if True, will not auto detach the graph
471+
sync_ddp: if True, reduces the metric across GPUs/TPUs
472+
sync_ddp_op: the op to sync across
473+
sync_ddp_group: the ddp group
385474
"""
386-
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
387-
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
475+
super().log(name=name,
476+
value=value,
477+
prog_bar=prog_bar,
478+
logger=logger,
479+
on_step=on_step,
480+
on_epoch=on_epoch,
481+
reduce_fx=reduce_fx,
482+
enable_graph=enable_graph,
483+
sync_ddp=sync_ddp,
484+
sync_ddp_group=sync_ddp_group,
485+
sync_ddp_op=sync_ddp_op,
486+
tbptt_pad_token=tbptt_pad_token,
487+
tbptt_reduce_fx=tbptt_reduce_fx)
388488

389489
def log_dict(
390490
self,
@@ -394,6 +494,8 @@ def log_dict(
394494
on_step: bool = False,
395495
on_epoch: bool = True,
396496
reduce_fx: Callable = torch.mean,
497+
tbptt_reduce_fx: Callable = torch.mean,
498+
tbptt_pad_token: int = 0,
397499
enable_graph: bool = False,
398500
sync_ddp: bool = False,
399501
sync_ddp_op: Union[Any, str] = 'mean',
@@ -408,17 +510,33 @@ def log_dict(
408510
result.log_dict(values)
409511
410512
Args:
411-
dictionary:
412-
prog_bar:
413-
logger:
414-
on_step:
415-
on_epoch:
416-
reduce_fx:
417-
enable_graph:
513+
dictionary: key value pairs (str, tensors)
514+
prog_bar: if True logs to the progress base
515+
logger: if True logs to the logger
516+
on_step: if True logs the output of validation_step or test_step
517+
on_epoch: if True, logs the output of the training loop aggregated
518+
reduce_fx: Torch.mean by default
519+
tbptt_reduce_fx: function to reduce on truncated back prop
520+
tbptt_pad_token: token to use for padding
521+
enable_graph: if True, will not auto detach the graph
522+
sync_ddp: if True, reduces the metric across GPUs/TPUs
523+
sync_ddp_op: the op to sync across
524+
sync_ddp_group: the ddp group:
418525
"""
419526
for k, v in dictionary.items():
420-
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
421-
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
527+
self.log(name=k,
528+
value=v,
529+
prog_bar=prog_bar,
530+
logger=logger,
531+
on_step=on_step,
532+
on_epoch=on_epoch,
533+
reduce_fx=reduce_fx,
534+
enable_graph=enable_graph,
535+
sync_ddp=sync_ddp,
536+
sync_ddp_group=sync_ddp_group,
537+
sync_ddp_op=sync_ddp_op,
538+
tbptt_pad_token=tbptt_pad_token,
539+
tbptt_reduce_fx=tbptt_reduce_fx)
422540

423541

424542
class EvalResult(Result):
@@ -464,6 +582,8 @@ def log(
464582
on_step: bool = False,
465583
on_epoch: bool = True,
466584
reduce_fx: Callable = torch.mean,
585+
tbptt_reduce_fx: Callable = torch.mean,
586+
tbptt_pad_token: int = 0,
467587
enable_graph: bool = False,
468588
sync_ddp: bool = False,
469589
sync_ddp_op: Union[Any, str] = 'mean',
@@ -494,12 +614,28 @@ def log(
494614
prog_bar: if True logs to the progress base
495615
logger: if True logs to the logger
496616
on_step: if True logs the output of validation_step or test_step
497-
on_epoch: if True, logs the output of the validation loop or test loop aggregated
617+
on_epoch: if True, logs the output of the training loop aggregated
498618
reduce_fx: Torch.mean by default
499-
enable_graph: if True, will not auto detach the graph :
619+
tbptt_reduce_fx: function to reduce on truncated back prop
620+
tbptt_pad_token: token to use for padding
621+
enable_graph: if True, will not auto detach the graph
622+
sync_ddp: if True, reduces the metric across GPUs/TPUs
623+
sync_ddp_op: the op to sync across
624+
sync_ddp_group: the ddp group
500625
"""
501-
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
502-
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
626+
super().log(name=name,
627+
value=value,
628+
prog_bar=prog_bar,
629+
logger=logger,
630+
on_step=on_step,
631+
on_epoch=on_epoch,
632+
reduce_fx=reduce_fx,
633+
enable_graph=enable_graph,
634+
sync_ddp=sync_ddp,
635+
sync_ddp_group=sync_ddp_group,
636+
sync_ddp_op=sync_ddp_op,
637+
tbptt_pad_token=tbptt_pad_token,
638+
tbptt_reduce_fx=tbptt_reduce_fx)
503639

504640
def log_dict(
505641
self,
@@ -509,6 +645,8 @@ def log_dict(
509645
on_step: bool = False,
510646
on_epoch: bool = True,
511647
reduce_fx: Callable = torch.mean,
648+
tbptt_reduce_fx: Callable = torch.mean,
649+
tbptt_pad_token: int = 0,
512650
enable_graph: bool = False,
513651
sync_ddp: bool = False,
514652
sync_ddp_op: Union[Any, str] = 'mean',
@@ -523,17 +661,33 @@ def log_dict(
523661
result.log_dict(values)
524662
525663
Args:
526-
dictionary:
527-
prog_bar:
528-
logger:
529-
on_step:
530-
on_epoch:
531-
reduce_fx:
532-
enable_graph:
664+
dictionary: key value pairs (str, tensors)
665+
prog_bar: if True logs to the progress base
666+
logger: if True logs to the logger
667+
on_step: if True logs the output of validation_step or test_step
668+
on_epoch: if True, logs the output of the training loop aggregated
669+
reduce_fx: Torch.mean by default
670+
tbptt_reduce_fx: function to reduce on truncated back prop
671+
tbptt_pad_token: token to use for padding
672+
enable_graph: if True, will not auto detach the graph
673+
sync_ddp: if True, reduces the metric across GPUs/TPUs
674+
sync_ddp_op: the op to sync across
675+
sync_ddp_group: the ddp group
533676
"""
534677
for k, v in dictionary.items():
535-
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
536-
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
678+
self.log(name=k,
679+
value=v,
680+
prog_bar=prog_bar,
681+
logger=logger,
682+
on_step=on_step,
683+
on_epoch=on_epoch,
684+
reduce_fx=reduce_fx,
685+
enable_graph=enable_graph,
686+
sync_ddp=sync_ddp,
687+
sync_ddp_group=sync_ddp_group,
688+
sync_ddp_op=sync_ddp_op,
689+
tbptt_pad_token=tbptt_pad_token,
690+
tbptt_reduce_fx=tbptt_reduce_fx)
537691

538692
def get_callback_metrics(self) -> dict:
539693
result = {

0 commit comments

Comments
 (0)