Skip to content

Commit 9b83715

Browse files
committed
Revert "Added accumulation of loggers' metrics for the same steps (Lightning-AI#1278)"
This reverts commit ddbf7de.
1 parent 3f1e4b9 commit 9b83715

File tree

8 files changed

+24
-217
lines changed

8 files changed

+24
-217
lines changed

CHANGELOG.md

-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6363

6464
### Added
6565

66-
- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
6766
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
6867
- Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351))
6968
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
@@ -85,7 +84,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8584

8685
### Changed
8786

88-
- Changed (renamed and refatored) `TensorRunningMean` -> `TensorRunningAccum`: running accumulations were generalized. ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
8987
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
9088
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
9189
- Updated references to `self.forward()` to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))

pytorch_lightning/loggers/base.py

+4-152
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import argparse
2-
import functools
3-
import operator
42
from abc import ABC, abstractmethod
53
from argparse import Namespace
64
from functools import wraps
7-
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple
5+
from typing import Union, Optional, Dict, Iterable, Any, Callable, List
86

9-
import numpy as np
107
import torch
118

129

@@ -28,119 +25,22 @@ def wrapped_fn(self, *args, **kwargs):
2825
class LightningLoggerBase(ABC):
2926
"""Base class for experiment loggers."""
3027

31-
def __init__(
32-
self,
33-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
34-
agg_default_func: Callable[[Sequence[float]], float] = np.mean
35-
):
36-
"""
37-
Args:
38-
agg_key_funcs:
39-
Dictionary which maps a metric name to a function, which will
40-
aggregate the metric values for the same steps.
41-
agg_default_func:
42-
Default function to aggregate metric values. If some metric name
43-
is not presented in the `agg_key_funcs` dictionary, then the
44-
`agg_default_func` will be used for aggregation.
45-
46-
Notes:
47-
`agg_key_funcs` and `agg_default_func` are used only when one logs metrics with
48-
`LightningLoggerBase.agg_and_log_metrics` method.
49-
"""
28+
def __init__(self):
5029
self._rank = 0
51-
self._prev_step = -1
52-
self._metrics_to_agg: List[Dict[str, float]] = []
53-
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
54-
self._agg_default_func = agg_default_func
55-
56-
def update_agg_funcs(
57-
self,
58-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
59-
agg_default_func: Callable[[Sequence[float]], float] = np.mean
60-
):
61-
"""Update aggregation methods.
62-
63-
Args:
64-
agg_key_funcs:
65-
Dictionary which maps a metric name to a function, which will
66-
aggregate the metric values for the same steps.
67-
agg_default_func:
68-
Default function to aggregate metric values. If some metric name
69-
is not presented in the `agg_key_funcs` dictionary, then the
70-
`agg_default_func` will be used for aggregation.
71-
"""
72-
if agg_key_funcs:
73-
self._agg_key_funcs.update(agg_key_funcs)
74-
if agg_default_func:
75-
self._agg_default_func = agg_default_func
7630

7731
@property
7832
@abstractmethod
7933
def experiment(self) -> Any:
8034
"""Return the experiment object associated with this logger"""
8135

82-
def _aggregate_metrics(
83-
self, metrics: Dict[str, float], step: Optional[int] = None
84-
) -> Tuple[int, Optional[Dict[str, float]]]:
85-
"""Aggregates metrics.
86-
87-
Args:
88-
metrics: Dictionary with metric names as keys and measured quantities as values
89-
step: Step number at which the metrics should be recorded
90-
91-
Returns:
92-
sStep and aggregated metrics. The return value could be None. In such case, metrics
93-
are added to the aggregation list, but not aggregated yet.
94-
"""
95-
# if you still receiving metric from the same step, just accumulate it
96-
if step == self._prev_step:
97-
self._metrics_to_agg.append(metrics)
98-
return step, None
99-
100-
# compute the metrics
101-
agg_step, agg_mets = self._finalize_agg_metrics()
102-
103-
# as new step received reset accumulator
104-
self._metrics_to_agg = [metrics]
105-
self._prev_step = step
106-
return agg_step, agg_mets
107-
108-
def _finalize_agg_metrics(self):
109-
"""Aggregate accumulated metrics. This shall be called in close."""
110-
# compute the metrics
111-
if not self._metrics_to_agg:
112-
agg_mets = None
113-
elif len(self._metrics_to_agg) == 1:
114-
agg_mets = self._metrics_to_agg[0]
115-
else:
116-
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
117-
return self._prev_step, agg_mets
118-
119-
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
120-
"""Aggregates and records metrics.
121-
This method doesn't log the passed metrics instantaneously, but instead
122-
it aggregates them and logs only if metrics are ready to be logged.
123-
124-
Args:
125-
metrics: Dictionary with metric names as keys and measured quantities as values
126-
step: Step number at which the metrics should be recorded
127-
"""
128-
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)
129-
130-
if metrics_to_log is not None:
131-
self.log_metrics(metrics=metrics_to_log, step=agg_step)
132-
13336
@abstractmethod
13437
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
135-
"""Records metrics.
136-
This method logs metrics as as soon as it received them. If you want to aggregate
137-
metrics for one specific `step`, use the `agg_and_log_metrics` method.
38+
"""Record metrics.
13839
13940
Args:
14041
metrics: Dictionary with metric names as keys and measured quantities as values
14142
step: Step number at which the metrics should be recorded
14243
"""
143-
pass
14444

14545
@staticmethod
14646
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
@@ -231,10 +131,7 @@ def finalize(self, status: str) -> None:
231131

232132
def close(self) -> None:
233133
"""Do any cleanup that is necessary to close an experiment."""
234-
agg_step, metrics_to_log = self._finalize_agg_metrics()
235-
236-
if metrics_to_log is not None:
237-
self.log_metrics(metrics=metrics_to_log, step=agg_step)
134+
pass
238135

239136
@property
240137
def rank(self) -> int:
@@ -303,48 +200,3 @@ def name(self) -> str:
303200
@property
304201
def version(self) -> str:
305202
return '_'.join([str(logger.version) for logger in self._logger_iterable])
306-
307-
308-
def merge_dicts(
309-
dicts: Sequence[Mapping],
310-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
311-
default_func: Callable[[Sequence[float]], float] = np.mean
312-
) -> Dict:
313-
"""Merge a sequence with dictionaries into one dictionary by aggregating the
314-
same keys with some given function.
315-
316-
Args:
317-
dicts:
318-
Sequence of dictionaries to be merged.
319-
agg_key_funcs:
320-
Mapping from key name to function. This function will aggregate a
321-
list of values, obtained from the same key of all dictionaries.
322-
If some key has no specified aggregation function, the default one
323-
will be used. Default is: None (all keys will be aggregated by the
324-
default function).
325-
default_func:
326-
Default function to aggregate keys, which are not presented in the
327-
`agg_key_funcs` map.
328-
329-
Returns:
330-
Dictionary with merged values.
331-
332-
Examples:
333-
>>> import pprint
334-
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
335-
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
336-
>>> d3 = {'a': 1.1, 'v': 2.3}
337-
>>> dflt_func = min
338-
>>> agg_funcs = {'a': np.mean, 'v': max}
339-
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
340-
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
341-
"""
342-
343-
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
344-
d_out = {}
345-
for k in keys:
346-
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
347-
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
348-
d_out[k] = agg_val
349-
350-
return d_out

pytorch_lightning/trainer/logging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
7171
step = step if step is not None else self.global_step
7272
# log actual metrics
7373
if self.proc_rank == 0 and self.logger is not None:
74-
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
74+
self.logger.log_metrics(scalar_metrics, step=step)
7575
self.logger.save()
7676

7777
def add_tqdm_metrics(self, metrics):
+9-26
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import torch
22

33

4-
class TensorRunningAccum(object):
5-
"""Tracks a running accumulation values (min, max, mean) without graph
6-
references.
4+
class TensorRunningMean(object):
5+
"""
6+
Tracks a running mean without graph references.
7+
Round robbin for the mean
78
89
Examples:
9-
>>> accum = TensorRunningAccum(5)
10+
>>> accum = TensorRunningMean(5)
1011
>>> accum.last(), accum.mean()
1112
(None, None)
1213
>>> accum.append(torch.tensor(1.5))
@@ -17,8 +18,8 @@ class TensorRunningAccum(object):
1718
(tensor(2.5000), tensor(2.))
1819
>>> accum.reset()
1920
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
20-
>>> accum.last(), accum.mean(), accum.min(), accum.max()
21-
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
21+
>>> accum.last(), accum.mean()
22+
(tensor(12.), tensor(10.))
2223
"""
2324

2425
def __init__(self, window_length: int):
@@ -29,16 +30,13 @@ def __init__(self, window_length: int):
2930
self.rotated: bool = False
3031

3132
def reset(self) -> None:
32-
"""Empty the accumulator."""
33-
self = TensorRunningAccum(self.window_length)
33+
self = TensorRunningMean(self.window_length)
3434

3535
def last(self):
36-
"""Get the last added element."""
3736
if self.last_idx is not None:
3837
return self.memory[self.last_idx]
3938

4039
def append(self, x):
41-
"""Add an element to the accumulator."""
4240
# ensure same device and type
4341
if self.memory.device != x.device or self.memory.type() != x.type():
4442
x = x.to(self.memory)
@@ -57,20 +55,5 @@ def append(self, x):
5755
self.rotated = True
5856

5957
def mean(self):
60-
"""Get mean value from stored elements."""
61-
return self._agg_memory('mean')
62-
63-
def max(self):
64-
"""Get maximal value from stored elements."""
65-
return self._agg_memory('max')
66-
67-
def min(self):
68-
"""Get minimal value from stored elements."""
69-
return self._agg_memory('min')
70-
71-
def _agg_memory(self, how: str):
7258
if self.last_idx is not None:
73-
if self.rotated:
74-
return getattr(self.memory, how)()
75-
else:
76-
return getattr(self.memory[:self.current_idx], how)()
59+
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()

pytorch_lightning/trainer/trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3333
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
3434
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
35-
from pytorch_lightning.trainer.supporters import TensorRunningAccum
35+
from pytorch_lightning.trainer.supporters import TensorRunningMean
3636
from pytorch_lightning.trainer.training_io import TrainerIOMixin
3737
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
3838
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
@@ -378,7 +378,7 @@ def __init__(
378378

379379
# training bookeeping
380380
self.total_batch_idx = 0
381-
self.running_loss = TensorRunningAccum(window_length=20)
381+
self.running_loss = TensorRunningMean(window_length=20)
382382
self.batch_idx = 0
383383
self.tqdm_metrics = {}
384384
self.callback_metrics = {}

pytorch_lightning/trainer/training_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def training_step(self, batch, batch_idx):
146146
from pytorch_lightning.loggers import LightningLoggerBase
147147
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
148148
from pytorch_lightning.utilities.exceptions import MisconfigurationException
149-
from pytorch_lightning.trainer.supporters import TensorRunningAccum
150149
from pytorch_lightning.utilities import rank_zero_warn
150+
from pytorch_lightning.trainer.supporters import TensorRunningMean
151151

152152
try:
153153
from apex import amp
@@ -337,7 +337,7 @@ def train(self):
337337
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
338338

339339
# stores accumulated grad fractions per batch
340-
self.batch_loss_value = TensorRunningAccum(
340+
self.batch_loss_value = TensorRunningMean(
341341
window_length=self.accumulate_grad_batches
342342
)
343343

tests/loggers/test_base.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import pickle
2-
from collections import OrderedDict
32
from unittest.mock import MagicMock
43

5-
import numpy as np
6-
74
import tests.base.utils as tutils
85
from pytorch_lightning import Trainer
96
from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only, LoggerCollection
@@ -59,18 +56,6 @@ def version(self):
5956
return "1"
6057

6158

62-
class StoreHistoryLogger(CustomLogger):
63-
def __init__(self):
64-
super().__init__()
65-
self.history = {}
66-
67-
@rank_zero_only
68-
def log_metrics(self, metrics, step):
69-
if step not in self.history:
70-
self.history[step] = {}
71-
self.history[step].update(metrics)
72-
73-
7459
def test_custom_logger(tmpdir):
7560
hparams = tutils.get_default_hparams()
7661
model = LightningTestModel(hparams)
@@ -168,19 +153,5 @@ def decorated(metrics, step):
168153
num_sanity_val_steps=0,
169154
)
170155
trainer = Trainer(**trainer_options)
171-
trainer.logger.log_metrics = _log_metrics_decorator(
172-
trainer.logger.log_metrics)
156+
trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics)
173157
trainer.fit(model)
174-
175-
176-
def test_with_accumulate_grad_batches():
177-
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""
178-
logger = StoreHistoryLogger()
179-
180-
np.random.seed(42)
181-
for i, loss in enumerate(np.random.random(10)):
182-
logger.agg_and_log_metrics({'loss': loss}, step=int(i / 5))
183-
184-
assert logger.history == {0: {'loss': 0.5623850983416314}}
185-
logger.close()
186-
assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}}

tests/trainer/test_trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import glob
22
import math
33
import os
4-
from argparse import Namespace, ArgumentParser
4+
from argparse import Namespace
55

66
import pytest
77
import torch
88

99
import tests.base.utils as tutils
1010
from pytorch_lightning import Trainer
11-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
11+
from pytorch_lightning.callbacks import (
12+
EarlyStopping,
13+
ModelCheckpoint,
14+
)
1215
from pytorch_lightning import Callback
1316
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
1417
from pytorch_lightning.trainer.logging import TrainerLoggingMixin

0 commit comments

Comments
 (0)