Skip to content

Commit 25bbd05

Browse files
awaelchliBorda
andauthored
Also update progress_bar in training_epoch_end (#1724)
* update prog. bar metrics on train epoch end * changelog * wip test * more thorough testing * comments * update docs * move test Co-authored-by: Jirka <[email protected]>
1 parent 3a64260 commit 25bbd05

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).
1616

17+
- The progress bar metrics now also get updated in `training_epoch_end` ([#1724](https://github.com/PyTorchLightning/pytorch-lightning/pull/1724)).
18+
1719
### Changed
1820

1921
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

pytorch_lightning/core/lightning.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def training_epoch_end(
257257
May contain the following optional keys:
258258
259259
- log (metrics to be added to the logger; only tensors)
260+
- progress_bar (dict for progress bar display)
260261
- any metric used in a callback (e.g. early stopping).
261262
262263
Note:
@@ -280,7 +281,8 @@ def training_epoch_end(self, outputs):
280281
281282
# log training accuracy at the end of an epoch
282283
results = {
283-
'log': {'train_acc': train_acc_mean.item()}
284+
'log': {'train_acc': train_acc_mean.item()},
285+
'progress_bar': {'train_acc': train_acc_mean},
284286
}
285287
return results
286288
@@ -303,6 +305,7 @@ def training_epoch_end(self, outputs):
303305
# log training accuracy at the end of an epoch
304306
results = {
305307
'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
308+
'progress_bar': {'train_acc': train_acc_mean},
306309
}
307310
return results
308311
"""

pytorch_lightning/trainer/training_loop.py

+1
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def run_training_epoch(self):
491491
callback_epoch_metrics = _processed_outputs[3]
492492
self.log_metrics(log_epoch_metrics, {})
493493
self.callback_metrics.update(callback_epoch_metrics)
494+
self.add_progress_bar_metrics(_processed_outputs[1])
494495

495496
# when no val loop is present or fast-dev-run still need to call checkpoints
496497
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):

tests/models/test_module_hooks.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
3+
from pytorch_lightning import Trainer
4+
from tests.base import EvalModelTemplate
5+
6+
import tests.base.utils as tutils
7+
8+
9+
def test_training_epoch_end_metrics_collection(tmpdir):
10+
""" Test that progress bar metrics also get collected at the end of an epoch. """
11+
num_epochs = 3
12+
class CurrentModel(EvalModelTemplate):
13+
14+
def training_step(self, *args, **kwargs):
15+
output = super().training_step(*args, **kwargs)
16+
output['progress_bar'].update({'step_metric': torch.tensor(-1)})
17+
output['progress_bar'].update({'shared_metric': 100})
18+
return output
19+
20+
def training_epoch_end(self, outputs):
21+
epoch = self.current_epoch
22+
# both scalar tensors and Python numbers are accepted
23+
return {
24+
'progress_bar': {
25+
f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch
26+
'shared_metric': 111,
27+
}
28+
}
29+
30+
model = CurrentModel(tutils.get_default_hparams())
31+
trainer = Trainer(
32+
max_epochs=num_epochs,
33+
default_root_dir=tmpdir,
34+
overfit_pct=0.1,
35+
)
36+
result = trainer.fit(model)
37+
assert result == 1
38+
metrics = trainer.progress_bar_dict
39+
40+
# metrics added in training step should be unchanged by epoch end method
41+
assert metrics['step_metric'] == -1
42+
# a metric shared in both methods gets overwritten by epoch_end
43+
assert metrics['shared_metric'] == 111
44+
# metrics are kept after each epoch
45+
for i in range(num_epochs):
46+
assert metrics[f'epoch_metric_{i}'] == i

0 commit comments

Comments
 (0)