Skip to content

Commit 06ca642

Browse files
authored
Allow user to specify 'step' key while logging metrics (#808)
* allow to specify 'step' key * add test * docs to log_metrics * fix test * rename * also rename
1 parent 62e9963 commit 06ca642

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

pytorch_lightning/core/lightning.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ def validation_end(self, outputs):
418418
The outputs here are strictly for the progress bar.
419419
If you don't need to display anything, don't return anything.
420420
Any keys present in 'log', 'progress_bar' or the rest of the dictionary
421-
are available for callbacks to access.
421+
are available for callbacks to access. If you want to manually set current step, you can specify it with
422+
'step' key in the 'log' Dict.
422423
423424
Example
424425
-------
@@ -468,7 +469,7 @@ def validation_end(self, outputs):
468469
# show val_loss and val_acc in progress bar but only log val_loss
469470
results = {
470471
'progress_bar': tqdm_dict,
471-
'log': {'val_loss': val_loss_mean.item()}
472+
'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch}
472473
}
473474
return results
474475

pytorch_lightning/trainer/logging.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,12 @@ def configure_logger(self, logger):
3939

4040
def log_metrics(self, metrics, grad_norm_dic, step=None):
4141
"""Logs the metric dict passed in.
42-
43-
:param metrics:
44-
:param grad_norm_dic:
42+
If `step` parameter is None and `step` key is presented is metrics,
43+
uses metrics["step"] as a step
44+
:param metrics (dict): Metric values
45+
:param grad_norm_dic (dict): Gradient norms
46+
:param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
4547
"""
46-
# added metrics by Lightning for convenience
47-
metrics['epoch'] = self.current_epoch
48-
4948
# add gpu memory
5049
if self.on_gpu and self.log_gpu_memory:
5150
mem_map = memory.get_memory_profile(self.log_gpu_memory)
@@ -57,7 +56,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
5756
# turn all tensors to scalars
5857
scalar_metrics = self.metrics_to_scalars(metrics)
5958

60-
step = step if step is not None else self.global_step
59+
if "step" in scalar_metrics and step is None:
60+
step = scalar_metrics.pop("step")
61+
else:
62+
# added metrics by Lightning for convenience
63+
metrics['epoch'] = self.current_epoch
64+
step = step if step is not None else self.global_step
6165
# log actual metrics
6266
if self.proc_rank == 0 and self.logger is not None:
6367
self.logger.log_metrics(scalar_metrics, step=step)

tests/test_logging.py

+30
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,33 @@ def version(self):
376376
assert logger.hparams_logged == hparams
377377
assert logger.metrics_logged != {}
378378
assert logger.finalized_status == "success"
379+
380+
381+
def test_adding_step_key(tmpdir):
382+
logged_step = 0
383+
384+
def _validation_end(outputs):
385+
nonlocal logged_step
386+
logged_step += 1
387+
return {"log": {"step": logged_step, "val_acc": logged_step / 10}}
388+
389+
def _log_metrics_decorator(log_metrics_fn):
390+
def decorated(metrics, step):
391+
if "val_acc" in metrics:
392+
assert step == logged_step
393+
return log_metrics_fn(metrics, step)
394+
395+
return decorated
396+
397+
model, hparams = tutils.get_model()
398+
model.validation_end = _validation_end
399+
trainer_options = dict(
400+
max_epochs=4,
401+
default_save_path=tmpdir,
402+
train_percent_check=0.001,
403+
val_percent_check=0.01,
404+
num_sanity_val_steps=0
405+
)
406+
trainer = Trainer(**trainer_options)
407+
trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics)
408+
trainer.fit(model)

0 commit comments

Comments
 (0)