Skip to content

Commit 48b4b62

Browse files
BordawilliamFalcon
authored and
akarnachev
committed
fix incomplete RunningMean (Lightning-AI#1309)
* fix RunningMean * changelog * fix none * Update supporters.py just needed to multiply by zero for init * Revert "Update supporters.py" This reverts commit 7e0da6c * fix NaN * formatting Co-authored-by: William Falcon <[email protected]>
1 parent 09e5a87 commit 48b4b62

File tree

7 files changed

+65
-44
lines changed

7 files changed

+65
-44
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
4141
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
4242
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
43+
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
4344

4445
## [0.7.1] - 2020-03-07
4546

pytorch_lightning/core/lightning.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1525,9 +1525,10 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
15251525
Dictionary with the items to be displayed in the progress bar.
15261526
"""
15271527
# call .item() only once but store elements without graphs
1528-
running_training_loss = self.trainer.running_loss.mean().cpu().item()
1528+
running_train_loss = self.trainer.running_loss.mean()
1529+
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
15291530
tqdm_dict = {
1530-
'loss': '{:.3f}'.format(running_training_loss)
1531+
'loss': '{:.3f}'.format(avg_training_loss)
15311532
}
15321533

15331534
if self.trainer.truncated_bptt_steps is not None:
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
3+
4+
class TensorRunningMean(object):
5+
"""
6+
Tracks a running mean without graph references.
7+
Round robbin for the mean
8+
9+
Examples:
10+
>>> accum = TensorRunningMean(5)
11+
>>> accum.last(), accum.mean()
12+
(None, None)
13+
>>> accum.append(torch.tensor(1.5))
14+
>>> accum.last(), accum.mean()
15+
(tensor(1.5000), tensor(1.5000))
16+
>>> accum.append(torch.tensor(2.5))
17+
>>> accum.last(), accum.mean()
18+
(tensor(2.5000), tensor(2.))
19+
>>> accum.reset()
20+
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
21+
>>> accum.last(), accum.mean()
22+
(tensor(12.), tensor(10.))
23+
"""
24+
def __init__(self, window_length: int):
25+
self.window_length = window_length
26+
self.memory = torch.Tensor(self.window_length)
27+
self.current_idx: int = 0
28+
self.last_idx: int = None
29+
self.rotated: bool = False
30+
31+
def reset(self) -> None:
32+
self = TensorRunningMean(self.window_length)
33+
34+
def last(self):
35+
if self.last_idx is not None:
36+
return self.memory[self.last_idx]
37+
38+
def append(self, x):
39+
# map proper type for memory if they don't match
40+
if self.memory.type() != x.type():
41+
self.memory.type_as(x)
42+
43+
# store without grads
44+
with torch.no_grad():
45+
self.memory[self.current_idx] = x
46+
self.last_idx = self.current_idx
47+
48+
# increase index
49+
self.current_idx += 1
50+
51+
# reset index when hit limit of tensor
52+
self.current_idx = self.current_idx % self.window_length
53+
if self.current_idx == 0:
54+
self.rotated = True
55+
56+
def mean(self):
57+
if self.last_idx is not None:
58+
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()

pytorch_lightning/trainer/supporting_classes.py

-39
This file was deleted.

pytorch_lightning/trainer/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
3535
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
3636
from pytorch_lightning.utilities.debugging import MisconfigurationException
37-
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean
37+
from pytorch_lightning.trainer.supporters import TensorRunningMean
3838

3939
try:
4040
from apex import amp

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def training_step(self, batch, batch_idx):
146146
from pytorch_lightning.core.lightning import LightningModule
147147
from pytorch_lightning.loggers import LightningLoggerBase
148148
from pytorch_lightning.utilities.debugging import MisconfigurationException
149-
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean
149+
from pytorch_lightning.trainer.supporters import TensorRunningMean
150150

151151
try:
152152
from apex import amp

tests/collect_env_details.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def info_system():
4848

4949
def info_cuda():
5050
return {
51-
'GPU': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]),
51+
'GPU': [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
5252
# 'nvidia_driver': get_nvidia_driver_version(run_lambda),
5353
'available': torch.cuda.is_available(),
5454
'version': torch.version.cuda,

0 commit comments

Comments
 (0)