Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug misstep #2475

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,25 @@ def _reduce_agg_metrics(self):
elif len(self._metrics_to_agg) == 1:
agg_mets = self._metrics_to_agg[0]
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
# pop out 'epoch' because it is a metric automatically added in by log_metrics and will count as a
# duplicate. If you want to get rid of this, I would suggest you should get rid of `scalar_metrics[
# 'epoch'] = self.current_epoch` in TrainerLoggingMixin.log_metrics()
# check if dictionary keys are unique
agg_keys = {}
num_keys = 0
for met in self._metrics_to_agg:
met.pop("epoch")
agg_keys.update(met.keys())
num_keys += met(met)

if len(agg_keys) == num_keys:
# if dictionary keys are unique
agg_mets = self._metrics_to_agg[0]
for mets in self._metrics_to_agg[1:]:
agg_mets.update(mets)
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)

return self._prev_step, agg_mets

def _finalize_agg_metrics(self):
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,6 @@ def run_on_epoch_start_hook(self, model):
model.on_epoch_start()

def run_training_epoch(self):

# get model
model = self.get_model()

Expand Down Expand Up @@ -479,8 +478,10 @@ def run_training_epoch(self):
# -----------------------------------------
self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
# progress global step according to grads progress. If it is the last batch, we will increment the
# global_step after the loop is finished
if not is_last_batch:
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
Expand All @@ -504,6 +505,9 @@ def run_training_epoch(self):
# epoch end hook
self.run_on_epoch_end_hook(model)

# increate global step by one to progress to the next epoch
self.global_step += 1

def check_checkpoint_callback(self, should_check_val):
# when no val loop is present or fast-dev-run still need to call checkpoints
# TODO bake this logic into the checkpoint callback
Expand All @@ -527,7 +531,6 @@ def run_on_epoch_end_hook(self, model):
def run_training_epoch_end(self, epoch_output):
model = self.get_model()
if self.is_overridden('training_epoch_end', model=model):
self.global_step += 1
epoch_output = model.training_epoch_end(epoch_output)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
Expand Down