Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tracks all outputs including TBPTT and multiple optimizers #2890

Merged
merged 21 commits into from
Aug 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 182 additions & 28 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def log(
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
Expand All @@ -113,15 +115,22 @@ def log(
if on_step and on_epoch:
# set step version
step_name = f'step_{name}'
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
self.__set_meta(step_name, value, prog_bar, logger,
on_step=True, on_epoch=False,
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'epoch_{name}'
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True,
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
self.__set_meta(name, value,
prog_bar, logger,
on_step, on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)

# set the value
self.__setitem__(name, value)
Expand All @@ -135,6 +144,8 @@ def __set_meta(
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable
):
# set the meta for the item
meta_value = value
Expand All @@ -144,7 +155,9 @@ def __set_meta(
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
value=meta_value
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token
)

self['meta'][name] = meta
Expand Down Expand Up @@ -253,6 +266,39 @@ def gather(cls, outputs):
result['meta'] = meta
return result

@classmethod
def padded_gather(cls, outputs):
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)

# find the padding used for other values
default_padding_idx = 0
for name, value in result.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
default_padding_idx = meta[name]['tbptt_pad_token']
break

# pad across each key individually
for name, value in result.items():
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):

if is_reserved:
padding_key = default_padding_idx
else:
padding_key = meta[name]['tbptt_pad_token']
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
result[name] = padded

# also update the result
if meta and not is_reserved:
meta[name]['value'] = padded
if meta:
result['meta'] = meta
return result

@classmethod
def reduce_on_epoch_end(cls, outputs):
meta = outputs[0]['meta']
Expand All @@ -271,10 +317,36 @@ def reduce_on_epoch_end(cls, outputs):
result['meta'] = meta
return result

@classmethod
def reduce_across_time(cls, time_outputs):
# auto-reduce across time for tbptt
meta = time_outputs[0]['meta']
result = cls()
result = recursive_gather(time_outputs, result)
recursive_stack(result)

for k, value in result.items():
if k == 'meta':
continue

# pick the reduce fx
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
tbptt_reduce_fx = torch.mean
else:
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
result[k] = tbptt_reduce_fx(value)

result['meta'] = meta
return result

@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']

def drop_hiddens(self):
if 'hiddens' in self:
del self['hiddens']


def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
Expand Down Expand Up @@ -303,6 +375,16 @@ def recursive_stack(result: MutableMapping):
result[k] = v


def recursive_padded_stack(result: MutableMapping):
for k, v in result.items():
if isinstance(v, dict):
recursive_stack(v)

if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
v = torch.stack(v)
result[k] = v


class TrainResult(Result):

def __init__(
Expand Down Expand Up @@ -348,6 +430,8 @@ def log(
on_step: bool = True,
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
Expand Down Expand Up @@ -381,10 +465,26 @@ def log(
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
super().log(name=name,
value=value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)

def log_dict(
self,
Expand All @@ -394,6 +494,8 @@ def log_dict(
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
Expand All @@ -408,17 +510,33 @@ def log_dict(
result.log_dict(values)

Args:
dictionary:
prog_bar:
logger:
on_step:
on_epoch:
reduce_fx:
enable_graph:
dictionary: key value pairs (str, tensors)
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
self.log(name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)


class EvalResult(Result):
Expand Down Expand Up @@ -464,6 +582,8 @@ def log(
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
Expand Down Expand Up @@ -494,12 +614,28 @@ def log(
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the validation loop or test loop aggregated
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph :
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
super().log(name=name,
value=value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)

def log_dict(
self,
Expand All @@ -509,6 +645,8 @@ def log_dict(
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
Expand All @@ -523,17 +661,33 @@ def log_dict(
result.log_dict(values)

Args:
dictionary:
prog_bar:
logger:
on_step:
on_epoch:
reduce_fx:
enable_graph:
dictionary: key value pairs (str, tensors)
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_ddp: if True, reduces the metric across GPUs/TPUs
sync_ddp_op: the op to sync across
sync_ddp_group: the ddp group
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
self.log(name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_ddp=sync_ddp,
sync_ddp_group=sync_ddp_group,
sync_ddp_op=sync_ddp_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx)

def get_callback_metrics(self) -> dict:
result = {
Expand Down
Loading