Skip to content

Commit ddbf7de

Browse files
alexeykarnachevJoe DavisonBordawilliamFalcon
authored
Added accumulation of loggers' metrics for the same steps (Lightning-AI#1278)
* `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <[email protected]> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <[email protected]> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <[email protected]> * remove .item which causes sync issues (Lightning-AI#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: J. Borovec <[email protected]>
1 parent 471499c commit ddbf7de

File tree

8 files changed

+230
-38
lines changed

8 files changed

+230
-38
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
1112
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
1213
- Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351))
1314
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
@@ -30,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3031

3132
### Changed
3233

34+
- Changed (renamed and refatored) `TensorRunningMean` -> `TensorRunningAccum`: running accumulations were generalized. ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
3335
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
3436
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
3537
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))

pytorch_lightning/loggers/base.py

+152-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import argparse
2+
import functools
3+
import operator
24
from abc import ABC, abstractmethod
35
from argparse import Namespace
46
from functools import wraps
5-
from typing import Union, Optional, Dict, Iterable, Any, Callable, List
7+
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple
68

9+
import numpy as np
710
import torch
811

912

@@ -25,22 +28,119 @@ def wrapped_fn(self, *args, **kwargs):
2528
class LightningLoggerBase(ABC):
2629
"""Base class for experiment loggers."""
2730

28-
def __init__(self):
31+
def __init__(
32+
self,
33+
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
34+
agg_default_func: Callable[[Sequence[float]], float] = np.mean
35+
):
36+
"""
37+
Args:
38+
agg_key_funcs:
39+
Dictionary which maps a metric name to a function, which will
40+
aggregate the metric values for the same steps.
41+
agg_default_func:
42+
Default function to aggregate metric values. If some metric name
43+
is not presented in the `agg_key_funcs` dictionary, then the
44+
`agg_default_func` will be used for aggregation.
45+
46+
Notes:
47+
`agg_key_funcs` and `agg_default_func` are used only when one logs metrics with
48+
`LightningLoggerBase.agg_and_log_metrics` method.
49+
"""
2950
self._rank = 0
51+
self._prev_step = -1
52+
self._metrics_to_agg: List[Dict[str, float]] = []
53+
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
54+
self._agg_default_func = agg_default_func
55+
56+
def update_agg_funcs(
57+
self,
58+
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
59+
agg_default_func: Callable[[Sequence[float]], float] = np.mean
60+
):
61+
"""Update aggregation methods.
62+
63+
Args:
64+
agg_key_funcs:
65+
Dictionary which maps a metric name to a function, which will
66+
aggregate the metric values for the same steps.
67+
agg_default_func:
68+
Default function to aggregate metric values. If some metric name
69+
is not presented in the `agg_key_funcs` dictionary, then the
70+
`agg_default_func` will be used for aggregation.
71+
"""
72+
if agg_key_funcs:
73+
self._agg_key_funcs.update(agg_key_funcs)
74+
if agg_default_func:
75+
self._agg_default_func = agg_default_func
3076

3177
@property
3278
@abstractmethod
3379
def experiment(self) -> Any:
3480
"""Return the experiment object associated with this logger"""
3581

82+
def _aggregate_metrics(
83+
self, metrics: Dict[str, float], step: Optional[int] = None
84+
) -> Tuple[int, Optional[Dict[str, float]]]:
85+
"""Aggregates metrics.
86+
87+
Args:
88+
metrics: Dictionary with metric names as keys and measured quantities as values
89+
step: Step number at which the metrics should be recorded
90+
91+
Returns:
92+
sStep and aggregated metrics. The return value could be None. In such case, metrics
93+
are added to the aggregation list, but not aggregated yet.
94+
"""
95+
# if you still receiving metric from the same step, just accumulate it
96+
if step == self._prev_step:
97+
self._metrics_to_agg.append(metrics)
98+
return step, None
99+
100+
# compute the metrics
101+
agg_step, agg_mets = self._finalize_agg_metrics()
102+
103+
# as new step received reset accumulator
104+
self._metrics_to_agg = [metrics]
105+
self._prev_step = step
106+
return agg_step, agg_mets
107+
108+
def _finalize_agg_metrics(self):
109+
"""Aggregate accumulated metrics. This shall be called in close."""
110+
# compute the metrics
111+
if not self._metrics_to_agg:
112+
agg_mets = None
113+
elif len(self._metrics_to_agg) == 1:
114+
agg_mets = self._metrics_to_agg[0]
115+
else:
116+
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
117+
return self._prev_step, agg_mets
118+
119+
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
120+
"""Aggregates and records metrics.
121+
This method doesn't log the passed metrics instantaneously, but instead
122+
it aggregates them and logs only if metrics are ready to be logged.
123+
124+
Args:
125+
metrics: Dictionary with metric names as keys and measured quantities as values
126+
step: Step number at which the metrics should be recorded
127+
"""
128+
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)
129+
130+
if metrics_to_log is not None:
131+
self.log_metrics(metrics=metrics_to_log, step=agg_step)
132+
36133
@abstractmethod
37134
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
38-
"""Record metrics.
135+
"""Records metrics.
136+
This method logs metrics as as soon as it received them. If you want to aggregate
137+
metrics for one specific `step`, use the `agg_and_log_metrics` method.
39138
40139
Args:
41140
metrics: Dictionary with metric names as keys and measured quantities as values
42141
step: Step number at which the metrics should be recorded
43142
"""
143+
pass
44144

45145
@staticmethod
46146
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
@@ -131,7 +231,10 @@ def finalize(self, status: str) -> None:
131231

132232
def close(self) -> None:
133233
"""Do any cleanup that is necessary to close an experiment."""
134-
pass
234+
agg_step, metrics_to_log = self._finalize_agg_metrics()
235+
236+
if metrics_to_log is not None:
237+
self.log_metrics(metrics=metrics_to_log, step=agg_step)
135238

136239
@property
137240
def rank(self) -> int:
@@ -200,3 +303,48 @@ def name(self) -> str:
200303
@property
201304
def version(self) -> str:
202305
return '_'.join([str(logger.version) for logger in self._logger_iterable])
306+
307+
308+
def merge_dicts(
309+
dicts: Sequence[Mapping],
310+
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
311+
default_func: Callable[[Sequence[float]], float] = np.mean
312+
) -> Dict:
313+
"""Merge a sequence with dictionaries into one dictionary by aggregating the
314+
same keys with some given function.
315+
316+
Args:
317+
dicts:
318+
Sequence of dictionaries to be merged.
319+
agg_key_funcs:
320+
Mapping from key name to function. This function will aggregate a
321+
list of values, obtained from the same key of all dictionaries.
322+
If some key has no specified aggregation function, the default one
323+
will be used. Default is: None (all keys will be aggregated by the
324+
default function).
325+
default_func:
326+
Default function to aggregate keys, which are not presented in the
327+
`agg_key_funcs` map.
328+
329+
Returns:
330+
Dictionary with merged values.
331+
332+
Examples:
333+
>>> import pprint
334+
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
335+
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
336+
>>> d3 = {'a': 1.1, 'v': 2.3}
337+
>>> dflt_func = min
338+
>>> agg_funcs = {'a': np.mean, 'v': max}
339+
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
340+
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
341+
"""
342+
343+
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
344+
d_out = {}
345+
for k in keys:
346+
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
347+
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
348+
d_out[k] = agg_val
349+
350+
return d_out

pytorch_lightning/trainer/logging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
7171
step = step if step is not None else self.global_step
7272
# log actual metrics
7373
if self.proc_rank == 0 and self.logger is not None:
74-
self.logger.log_metrics(scalar_metrics, step=step)
74+
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
7575
self.logger.save()
7676

7777
def add_tqdm_metrics(self, metrics):

pytorch_lightning/trainer/supporters.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
22

33

4-
class TensorRunningMean(object):
5-
"""
6-
Tracks a running mean without graph references.
7-
Round robbin for the mean
4+
class TensorRunningAccum(object):
5+
"""Tracks a running accumulation values (min, max, mean) without graph
6+
references.
87
98
Examples:
10-
>>> accum = TensorRunningMean(5)
9+
>>> accum = TensorRunningAccum(5)
1110
>>> accum.last(), accum.mean()
1211
(None, None)
1312
>>> accum.append(torch.tensor(1.5))
@@ -18,8 +17,8 @@ class TensorRunningMean(object):
1817
(tensor(2.5000), tensor(2.))
1918
>>> accum.reset()
2019
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
21-
>>> accum.last(), accum.mean()
22-
(tensor(12.), tensor(10.))
20+
>>> accum.last(), accum.mean(), accum.min(), accum.max()
21+
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
2322
"""
2423
def __init__(self, window_length: int):
2524
self.window_length = window_length
@@ -29,13 +28,16 @@ def __init__(self, window_length: int):
2928
self.rotated: bool = False
3029

3130
def reset(self) -> None:
32-
self = TensorRunningMean(self.window_length)
31+
"""Empty the accumulator."""
32+
self = TensorRunningAccum(self.window_length)
3333

3434
def last(self):
35+
"""Get the last added element."""
3536
if self.last_idx is not None:
3637
return self.memory[self.last_idx]
3738

3839
def append(self, x):
40+
"""Add an element to the accumulator."""
3941
# ensure same device and type
4042
if self.memory.device != x.device or self.memory.type() != x.type():
4143
x = x.to(self.memory)
@@ -54,5 +56,20 @@ def append(self, x):
5456
self.rotated = True
5557

5658
def mean(self):
59+
"""Get mean value from stored elements."""
60+
return self._agg_memory('mean')
61+
62+
def max(self):
63+
"""Get maximal value from stored elements."""
64+
return self._agg_memory('max')
65+
66+
def min(self):
67+
"""Get minimal value from stored elements."""
68+
return self._agg_memory('min')
69+
70+
def _agg_memory(self, how: str):
5771
if self.last_idx is not None:
58-
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
72+
if self.rotated:
73+
return getattr(self.memory, how)()
74+
else:
75+
return getattr(self.memory[:self.current_idx], how)()

pytorch_lightning/trainer/trainer.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
2929
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
3030
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
31-
from pytorch_lightning.trainer.supporters import TensorRunningMean
31+
from pytorch_lightning.trainer.supporters import TensorRunningAccum
3232
from pytorch_lightning.trainer.training_io import TrainerIOMixin
3333
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
3434
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
@@ -342,7 +342,7 @@ def __init__(
342342

343343
# training bookeeping
344344
self.total_batch_idx = 0
345-
self.running_loss = TensorRunningMean(window_length=20)
345+
self.running_loss = TensorRunningAccum(window_length=20)
346346
self.batch_idx = 0
347347
self.tqdm_metrics = {}
348348
self.callback_metrics = {}
@@ -551,20 +551,19 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
551551

552552
allowed_types = (str, float, int, bool)
553553
# TODO: get "help" from docstring :)
554-
for arg, arg_types, arg_default in cls.get_init_arguments_and_types():
555-
if arg not in depr_arg_names:
556-
for allowed_type in allowed_types:
557-
if allowed_type in arg_types:
558-
if allowed_type is bool:
559-
allowed_type = lambda x: bool(distutils.util.strtobool(x))
560-
parser.add_argument(
561-
f'--{arg}',
562-
default=arg_default,
563-
type=allowed_type,
564-
dest=arg,
565-
help='autogenerated by pl.Trainer'
566-
)
567-
break
554+
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
555+
if at[0] not in depr_arg_names):
556+
for allowed_type in (at for at in allowed_types if at in arg_types):
557+
if isinstance(allowed_type, bool):
558+
allowed_type = lambda x: bool(distutils.util.strtobool(x))
559+
parser.add_argument(
560+
f'--{arg}',
561+
default=arg_default,
562+
type=allowed_type,
563+
dest=arg,
564+
help='autogenerated by pl.Trainer'
565+
)
566+
break
568567

569568
return parser
570569

pytorch_lightning/trainer/training_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def training_step(self, batch, batch_idx):
147147
from pytorch_lightning.loggers import LightningLoggerBase
148148
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
149149
from pytorch_lightning.utilities.exceptions import MisconfigurationException
150-
from pytorch_lightning.trainer.supporters import TensorRunningMean
150+
from pytorch_lightning.trainer.supporters import TensorRunningAccum
151151

152152
try:
153153
from apex import amp
@@ -337,7 +337,7 @@ def train(self):
337337
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
338338

339339
# stores accumulated grad fractions per batch
340-
self.batch_loss_value = TensorRunningMean(
340+
self.batch_loss_value = TensorRunningAccum(
341341
window_length=self.accumulate_grad_batches
342342
)
343343

0 commit comments

Comments
 (0)