diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f23ab04523766..88067ae07031c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -231,6 +231,78 @@ def training_end(self, *args, **kwargs): Deprecated in v0.7.0. use training_step_end instead """ + def training_epoch_end( + self, + outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + ) -> Dict[str, Dict[str, Tensor]]: + """Called at the end of training epoch with the outputs of all training_steps + + .. code-block:: python + + # the pseudocode for these calls + + train_outs = [] + for train_batch in train_data: + out = training_step(train_batch) + train_outs.append(out) + training_epoch_end(val_outs) + + Args: + outputs: List of outputs you defined in training_step, or if there are multiple + dataloaders, a list containing a list of outputs for each dataloader + + Return: + Dict or OrderedDict (dict): Dict has the following optional keys: + progress_bar -> Dict for progress bar display. Must have only tensors + log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + + .. note:: If this method is not overridden, this won't be called. + + - The outputs here are strictly for logging or progress bar. + - If you don't need to display anything, don't return anything. + - If you want to manually set current step, you can specify the 'step' key in the 'log' Dict + + Examples: + With a single dataloader + + .. code-block:: python + + def training_epoch_end(self, outputs): + train_acc_mean = 0 + for output in outputs: + train_acc_mean += output['train_acc'] + + train_acc_mean /= len(outputs) + + # log training accuracy at the end of an epoch + results = { + 'log': {'train_acc': train_acc_mean.item()} + } + return results + + With multiple dataloaders, `outputs` will be a list of lists. The outer list contains + one entry per dataloader, while the inner list contains the individual outputs of + each validation step for that dataloader. + + .. code-block:: python + + def training_epoch_end(self, outputs): + train_acc_mean = 0 + i = 0 + for dataloader_outputs in outputs: + for output in dataloader_outputs: + train_acc_mean += output['train_acc'] + i += 1 + + train_acc_mean /= i + + # log training accuracy at the end of an epoch + results = { + 'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch} + } + return results + """ + def training_step_end(self, *args, **kwargs) -> Dict[ str, Union[Tensor, Dict[str, Tensor]] ]: @@ -453,7 +525,7 @@ def validation_epoch_end( outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: """ - Called at end of validation epoch with the output of all validation_steps + Called at end of validation epoch with the outputs of all validation_steps .. code-block:: python @@ -462,7 +534,7 @@ def validation_epoch_end( val_outs = [] for val_batch in val_data: out = validation_step(train_batch) - train_outs.append(out) + val_outs.append(out) validation_epoch_end(val_outs) Args: @@ -493,7 +565,7 @@ def validation_epoch_end(self, outputs): val_acc_mean /= len(outputs) tqdm_dict = {'val_acc': val_acc_mean.item()} - # show val_loss and val_acc in progress bar but only log val_loss + # show val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, 'log': {'val_acc': val_acc_mean.item()} diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e3a8bb16a7b14..8a7e2816219ec 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -145,6 +145,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningMean @@ -390,6 +391,9 @@ def train(self): def run_training_epoch(self): + # get model + model = self.get_model() + # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks @@ -397,7 +401,7 @@ def run_training_epoch(self): # model hooks if self.is_function_implemented('on_epoch_start'): - self.get_model().on_epoch_start() + model.on_epoch_start() # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader @@ -408,6 +412,9 @@ def run_training_epoch(self): train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) + # bookkeeping + outputs = [] + # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch" @@ -418,14 +425,15 @@ def run_training_epoch(self): self.batch_idx = batch_idx - model = self.get_model() model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- - output = self.run_training_batch(batch, batch_idx) - batch_result, grad_norm_dic, batch_step_metrics = output + _outputs = self.run_training_batch(batch, batch_idx) + batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs + # detach tensors in batch_output before appending to outputs + outputs.append(_recursive_detach(batch_output)) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 @@ -484,6 +492,18 @@ def run_training_epoch(self): if early_stop_epoch or self.fast_dev_run: break + # process epoch outputs + if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): + model = model.module + + if self.is_overriden('training_epoch_end', model=model): + epoch_output = model.training_epoch_end(outputs) + _processed_outputs = self.process_output(epoch_output) + log_epoch_metrics = _processed_outputs[2] + callback_epoch_metrics = _processed_outputs[3] + self.log_metrics(log_epoch_metrics, {}) + self.callback_metrics.update(callback_epoch_metrics) + # in case validation step is missing and you are not running fast-dev to duplicate last batch if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val): self.call_checkpoint_callback() @@ -497,7 +517,7 @@ def run_training_epoch(self): self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): - self.get_model().on_epoch_end() + model.on_epoch_end() def run_training_batch(self, batch, batch_idx): # track grad norms @@ -546,14 +566,13 @@ def run_training_batch(self, batch, batch_idx): def optimizer_closure(): # forward pass with self.profiler.profile('model_forward'): - output = self.training_forward( + output_dict = self.training_forward( split_batch, batch_idx, opt_idx, self.hiddens) - closure_loss = output[0] - progress_bar_metrics = output[1] - log_metrics = output[2] - callback_metrics = output[3] - self.hiddens = output[4] + # format and reduce outputs accordingly + processed_output = self.process_output(output_dict, train=True) + + closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output # accumulate loss # (if accumulate_grad_batches = 1 no effect) @@ -577,10 +596,10 @@ def optimizer_closure(): with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() - return closure_loss + return closure_loss, output_dict # calculate loss - loss = optimizer_closure() + loss, batch_output = optimizer_closure() # check if loss or model weights are nan self.detect_nan_tensors(loss) @@ -606,7 +625,8 @@ def optimizer_closure(): model = self.get_model() with self.profiler.profile('optimizer_step'): model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, optimizer_closure) + optimizer, opt_idx, + lambda: optimizer_closure()[0]) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean()) @@ -633,7 +653,7 @@ def optimizer_closure(): # track all metrics for callbacks self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) - return 0, grad_norm_dic, all_log_metrics + return 0, grad_norm_dic, all_log_metrics, batch_output def _get_optimizers_iterable(self): if not self.optimizer_frequencies: @@ -732,9 +752,6 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): warnings.warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.' ' Use training_epoch_end instead', DeprecationWarning) - # format and reduce outputs accordingly - output = self.process_output(output, train=True) - return output def update_learning_rates(self, interval: str): @@ -784,3 +801,29 @@ def _with_is_last(iterable): last = val # yield last, no longer has next yield last, True + + +def _recursive_detach(in_dict): + """Detach all tensors in `in_dict`. + + May operate recursively if some of the values in `in_dict` are dictionaries + which contain instances of `torch.Tensor`. Other types in `in_dict` are + not affected by this utility function. + + Parameters + ---------- + in_dict : dict + + Returns + ------- + out_dict : dict + """ + out_dict = {} + for k, v in in_dict.items(): + if isinstance(v, dict): + out_dict.update({k: _recursive_detach(v)}) + elif callable(getattr(v, 'detach', None)): + out_dict.update({k: v.detach()}) + else: + out_dict.update({k: v}) + return out_dict diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 9217e1c27de9c..d7f3503cf8193 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -124,11 +124,16 @@ def test_multiple_loggers_pickle(tmpdir): def test_adding_step_key(tmpdir): logged_step = 0 - def _validation_end(outputs): + def _validation_epoch_end(outputs): nonlocal logged_step logged_step += 1 return {"log": {"step": logged_step, "val_acc": logged_step / 10}} + def _training_epoch_end(outputs): + nonlocal logged_step + logged_step += 1 + return {"log": {"step": logged_step, "train_acc": logged_step / 10}} + def _log_metrics_decorator(log_metrics_fn): def decorated(metrics, step): if "val_acc" in metrics: @@ -138,7 +143,8 @@ def decorated(metrics, step): return decorated model, hparams = tutils.get_default_model() - model.validation_epoch_end = _validation_end + model.validation_epoch_end = _validation_epoch_end + model.training_epoch_end = _training_epoch_end trainer_options = dict( max_epochs=4, default_save_path=tmpdir,