Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make training_epoch_end behave like validation_epoch_end #1357

Merged
merged 4 commits into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,78 @@ def training_end(self, *args, **kwargs):
Deprecated in v0.7.0. use training_step_end instead
"""

def training_epoch_end(
self,
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
"""Called at the end of training epoch with the outputs of all training_steps

.. code-block:: python

# the pseudocode for these calls

train_outs = []
for train_batch in train_data:
out = training_step(train_batch)
train_outs.append(out)
training_epoch_end(val_outs)

Args:
outputs: List of outputs you defined in training_step, or if there are multiple
dataloaders, a list containing a list of outputs for each dataloader

Return:
Dict or OrderedDict (dict): Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)

.. note:: If this method is not overridden, this won't be called.

- The outputs here are strictly for logging or progress bar.
- If you don't need to display anything, don't return anything.
- If you want to manually set current step, you can specify the 'step' key in the 'log' Dict

Examples:
With a single dataloader

.. code-block:: python

def training_epoch_end(self, outputs):
train_acc_mean = 0
for output in outputs:
train_acc_mean += output['train_acc']

train_acc_mean /= len(outputs)

# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item()}
}
return results

With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.

.. code-block:: python

def training_epoch_end(self, outputs):
train_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
train_acc_mean += output['train_acc']
i += 1

train_acc_mean /= i

# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
}
return results
"""

def training_step_end(self, *args, **kwargs) -> Dict[
str, Union[Tensor, Dict[str, Tensor]]
]:
Expand Down Expand Up @@ -453,7 +525,7 @@ def validation_epoch_end(
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
"""
Called at end of validation epoch with the output of all validation_steps
Called at end of validation epoch with the outputs of all validation_steps

.. code-block:: python

Expand All @@ -462,7 +534,7 @@ def validation_epoch_end(
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
train_outs.append(out)
val_outs.append(out)
validation_epoch_end(val_outs)

Args:
Expand Down Expand Up @@ -493,7 +565,7 @@ def validation_epoch_end(self, outputs):
val_acc_mean /= len(outputs)
tqdm_dict = {'val_acc': val_acc_mean.item()}

# show val_loss and val_acc in progress bar but only log val_loss
# show val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_acc': val_acc_mean.item()}
Expand Down
79 changes: 61 additions & 18 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

Expand Down Expand Up @@ -390,14 +391,17 @@ def train(self):

def run_training_epoch(self):

# get model
model = self.get_model()

# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()

# model hooks
if self.is_function_implemented('on_epoch_start'):
self.get_model().on_epoch_start()
model.on_epoch_start()

# track local dataloader so TPU can wrap each epoch
train_dataloader = self.train_dataloader
Expand All @@ -408,6 +412,9 @@ def run_training_epoch(self):
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

# bookkeeping
outputs = []

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
Expand All @@ -418,14 +425,15 @@ def run_training_epoch(self):

self.batch_idx = batch_idx

model = self.get_model()
model.global_step = self.global_step

# ---------------
# RUN TRAIN STEP
# ---------------
output = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics = output
_outputs = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
# detach tensors in batch_output before appending to outputs
outputs.append(_recursive_detach(batch_output))

# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1
Expand Down Expand Up @@ -484,6 +492,18 @@ def run_training_epoch(self):
if early_stop_epoch or self.fast_dev_run:
break

# process epoch outputs
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
model = model.module

if self.is_overriden('training_epoch_end', model=model):
epoch_output = model.training_epoch_end(outputs)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
self.log_metrics(log_epoch_metrics, {})
self.callback_metrics.update(callback_epoch_metrics)

# in case validation step is missing and you are not running fast-dev to duplicate last batch
if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()
Expand All @@ -497,7 +517,7 @@ def run_training_epoch(self):
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
self.get_model().on_epoch_end()
model.on_epoch_end()

def run_training_batch(self, batch, batch_idx):
# track grad norms
Expand Down Expand Up @@ -546,14 +566,13 @@ def run_training_batch(self, batch, batch_idx):
def optimizer_closure():
# forward pass
with self.profiler.profile('model_forward'):
output = self.training_forward(
output_dict = self.training_forward(
split_batch, batch_idx, opt_idx, self.hiddens)

closure_loss = output[0]
progress_bar_metrics = output[1]
log_metrics = output[2]
callback_metrics = output[3]
self.hiddens = output[4]
# format and reduce outputs accordingly
processed_output = self.process_output(output_dict, train=True)

closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
Expand All @@ -577,10 +596,10 @@ def optimizer_closure():
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()

return closure_loss
return closure_loss, output_dict

# calculate loss
loss = optimizer_closure()
loss, batch_output = optimizer_closure()

# check if loss or model weights are nan
self.detect_nan_tensors(loss)
Expand All @@ -606,7 +625,8 @@ def optimizer_closure():
model = self.get_model()
with self.profiler.profile('optimizer_step'):
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, optimizer_closure)
optimizer, opt_idx,
lambda: optimizer_closure()[0])

# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean())
Expand All @@ -633,7 +653,7 @@ def optimizer_closure():
# track all metrics for callbacks
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, all_log_metrics
return 0, grad_norm_dic, all_log_metrics, batch_output

def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
Expand Down Expand Up @@ -732,9 +752,6 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
warnings.warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use training_epoch_end instead', DeprecationWarning)

# format and reduce outputs accordingly
output = self.process_output(output, train=True)

return output

def update_learning_rates(self, interval: str):
Expand Down Expand Up @@ -784,3 +801,29 @@ def _with_is_last(iterable):
last = val
# yield last, no longer has next
yield last, True


def _recursive_detach(in_dict):
"""Detach all tensors in `in_dict`.

May operate recursively if some of the values in `in_dict` are dictionaries
which contain instances of `torch.Tensor`. Other types in `in_dict` are
not affected by this utility function.

Parameters
----------
in_dict : dict

Returns
-------
out_dict : dict
"""
out_dict = {}
for k, v in in_dict.items():
if isinstance(v, dict):
out_dict.update({k: _recursive_detach(v)})
elif callable(getattr(v, 'detach', None)):
out_dict.update({k: v.detach()})
else:
out_dict.update({k: v})
return out_dict
10 changes: 8 additions & 2 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,16 @@ def test_multiple_loggers_pickle(tmpdir):
def test_adding_step_key(tmpdir):
logged_step = 0

def _validation_end(outputs):
def _validation_epoch_end(outputs):
nonlocal logged_step
logged_step += 1
return {"log": {"step": logged_step, "val_acc": logged_step / 10}}

def _training_epoch_end(outputs):
nonlocal logged_step
logged_step += 1
return {"log": {"step": logged_step, "train_acc": logged_step / 10}}

def _log_metrics_decorator(log_metrics_fn):
def decorated(metrics, step):
if "val_acc" in metrics:
Expand All @@ -138,7 +143,8 @@ def decorated(metrics, step):
return decorated

model, hparams = tutils.get_default_model()
model.validation_epoch_end = _validation_end
model.validation_epoch_end = _validation_epoch_end
model.training_epoch_end = _training_epoch_end
trainer_options = dict(
max_epochs=4,
default_save_path=tmpdir,
Expand Down