diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 68c34cf4f3793..7e9fde01485e3 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -155,7 +155,7 @@ def run_training_batch(self, batch, batch_nb): all_log_metrics = [] if batch is None: - return 0, grad_norm_dic + return 0, grad_norm_dic, {} # hook if self.is_function_implemented('on_batch_start'): @@ -163,7 +163,7 @@ def run_training_batch(self, batch, batch_nb): response = model_ref.on_batch_start(batch) if response == -1: - return -1, grad_norm_dic + return -1, grad_norm_dic, {} splits = [batch] if self.truncated_bptt_steps is not None: