Skip to content

Commit 793036d

Browse files
nsarangBorda
andauthored
Support returning python scalars in DP (#1935)
* Override the default gather method to support scalars * add computing average of a list * bug: change if to elif * add some tests * change style * change documentation * use apply_to_collection in DP gather * use apply_to_collection in DP gather * fix warning msg * override gather method in DP * add tests for python scalars * add python scalars to docstring * Update message * override gather method in DP * formatting * chlog Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent ea658e3 commit 793036d

File tree

7 files changed

+90
-42
lines changed

7 files changed

+90
-42
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535

3636
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
3737

38+
- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935))
39+
3840
### Changed
3941

4042
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

pytorch_lightning/core/lightning.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def forward(self, batch):
168168
169169
"""
170170

171-
def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]:
171+
def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Union[float, Tensor]]]]]:
172172
r"""
173173
Here you compute and return the training loss and some additional metrics for e.g.
174174
the progress bar or logger.
@@ -186,8 +186,8 @@ def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, D
186186
When implementing :meth:`training_step`, return whatever you need in that step:
187187
188188
- loss -> tensor scalar **REQUIRED**
189-
- progress_bar -> Dict for progress bar display. Must have only tensors
190-
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
189+
- progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars
190+
- log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc)
191191
192192
In this step you'd normally do the forward pass and calculate the loss for a batch.
193193
You can also do fancier things like multiple forward passes or something model specific.
@@ -202,14 +202,14 @@ def training_step(self, batch, batch_idx):
202202
out = self(x)
203203
loss = self.loss(out, x)
204204
205-
logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS)
205+
logger_logs = {'training_loss': loss} # optional
206206
207207
# if using TestTubeLogger or TensorBoardLogger you can nest scalars
208-
logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS)
208+
logger_logs = {'losses': logger_logs} # optional
209209
210210
output = {
211211
'loss': loss, # required
212-
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
212+
'progress_bar': {'training_loss': loss}, # optional
213213
'log': logger_logs
214214
}
215215
@@ -259,8 +259,8 @@ def training_end(self, *args, **kwargs):
259259
"""
260260

261261
def training_epoch_end(
262-
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
263-
) -> Dict[str, Dict[str, Tensor]]:
262+
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Union[float, Tensor]]]]]
263+
) -> Dict[str, Dict[str, Union[float, Tensor]]]:
264264
"""Called at the end of the training epoch with the outputs of all training steps.
265265
266266
.. code-block:: python
@@ -334,7 +334,7 @@ def training_epoch_end(self, outputs):
334334
return results
335335
"""
336336

337-
def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
337+
def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Union[float, Tensor]]]]:
338338
"""
339339
Use this when training with dp or ddp2 because :meth:`training_step`
340340
will operate on only part of the batch. However, this is still optional
@@ -358,8 +358,8 @@ def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str
358358
Dict with loss key and optional log or progress bar keys.
359359
360360
- loss -> tensor scalar **REQUIRED**
361-
- progress_bar -> Dict for progress bar display. Must have only tensors
362-
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
361+
- progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars
362+
- log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc)
363363
364364
Examples:
365365
.. code-block:: python
@@ -396,7 +396,7 @@ def training_step_end(self, outputs):
396396
See the :ref:`multi-gpu-training` guide for more details.
397397
"""
398398

399-
def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]:
399+
def validation_step(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
400400
r"""
401401
Operates on a single batch of data from the validation set.
402402
In this step you'd might generate examples or calculate anything of interest like accuracy.
@@ -486,7 +486,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx):
486486
the model goes back to training mode and gradients are enabled.
487487
"""
488488

489-
def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
489+
def validation_step_end(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
490490
"""
491491
Use this when validating with dp or ddp2 because :meth:`validation_step`
492492
will operate on only part of the batch. However, this is still optional
@@ -553,8 +553,8 @@ def validation_end(self, outputs):
553553
"""
554554

555555
def validation_epoch_end(
556-
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
557-
) -> Dict[str, Dict[str, Tensor]]:
556+
self, outputs: Union[List[Dict[str, Union[float, Tensor]]], List[List[Dict[str, Union[float, Tensor]]]]]
557+
) -> Dict[str, Dict[str, Union[float, Tensor]]]:
558558
"""
559559
Called at the end of the validation epoch with the outputs of all validation steps.
560560
@@ -575,8 +575,8 @@ def validation_epoch_end(
575575
Dict or OrderedDict.
576576
May have the following optional keys:
577577
578-
- progress_bar (dict for progress bar display; only tensors)
579-
- log (dict of metrics to add to logger; only tensors).
578+
- progress_bar (dict for progress bar display; either scalar tensors or Python scalars)
579+
- log (dict of metrics to add to logger; either scalar tensors or Python scalars).
580580
581581
Note:
582582
If you didn't define a :meth:`validation_step`, this won't be called.
@@ -630,7 +630,7 @@ def validation_epoch_end(self, outputs):
630630
return results
631631
"""
632632

633-
def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
633+
def test_step(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
634634
r"""
635635
Operates on a single batch of data from the test set.
636636
In this step you'd normally generate examples or calculate anything of interest
@@ -713,7 +713,7 @@ def test_step(self, batch, batch_idx, dataloader_idx):
713713
to training mode and gradients are enabled.
714714
"""
715715

716-
def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
716+
def test_step_end(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
717717
"""
718718
Use this when testing with dp or ddp2 because :meth:`test_step` will operate
719719
on only part of the batch. However, this is still optional
@@ -779,8 +779,8 @@ def test_end(self, outputs):
779779
"""
780780

781781
def test_epoch_end(
782-
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
783-
) -> Dict[str, Dict[str, Tensor]]:
782+
self, outputs: Union[List[Dict[str, Union[float, Tensor]]], List[List[Dict[str, Union[float, Tensor]]]]]
783+
) -> Dict[str, Dict[str, Union[float, Tensor]]]:
784784
"""
785785
Called at the end of a test epoch with the output of all test steps.
786786
@@ -800,8 +800,8 @@ def test_epoch_end(
800800
Return:
801801
Dict or OrderedDict: Dict has the following optional keys:
802802
803-
- progress_bar -> Dict for progress bar display. Must have only tensors.
804-
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc).
803+
- progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars.
804+
- log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc).
805805
806806
Note:
807807
If you didn't define a :meth:`test_step`, this won't be called.

pytorch_lightning/overrides/data_parallel.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import itertools
22
import threading
33
from itertools import chain
4+
from collections import Mapping, Iterable
45

56
import torch
67
from torch.cuda._utils import _get_device_index
78
from torch.nn import DataParallel
89
from torch.nn.parallel import DistributedDataParallel
10+
from torch.nn.parallel._functions import Gather
911
from pytorch_lightning.core.step_result import Result
1012

1113

@@ -68,7 +70,7 @@ def forward(self, *inputs, **kwargs):
6870
if isinstance(outputs[0], Result):
6971
outputs = self.__gather_structured_result(outputs)
7072
else:
71-
outputs = self.gather(outputs, self.output_device)
73+
outputs = self.gather(outputs)
7274
return outputs
7375

7476
def __gather_structured_result(self, outputs):
@@ -81,7 +83,7 @@ def __gather_structured_result(self, outputs):
8183
for i, output in enumerate(outputs):
8284
del output['meta']
8385

84-
outputs = self.gather(outputs, self.output_device)
86+
outputs = self.gather(outputs)
8587

8688
# pass minimize to constructor for TrainResult
8789
if 'minimize' in outputs:
@@ -93,6 +95,39 @@ def __gather_structured_result(self, outputs):
9395
result['meta'] = meta
9496
return result
9597

98+
def gather(self, outputs):
99+
r"""
100+
Override the gather method to support python scalars as well.
101+
"""
102+
def gather_map(outputs):
103+
elem = outputs[0]
104+
elem_type = type(elem)
105+
106+
if isinstance(elem, torch.Tensor):
107+
return Gather.apply(self.output_device, self.dim, *outputs)
108+
109+
if elem is None:
110+
return None
111+
112+
if isinstance(elem, Mapping):
113+
if not all((len(elem) == len(d) for d in outputs)):
114+
raise ValueError('All dicts must have the same number of keys')
115+
return elem_type(((k, gather_map([d[k] for d in outputs]))
116+
for k in elem))
117+
118+
if isinstance(elem, Iterable) and not isinstance(elem, str):
119+
return elem_type(map(gather_map, zip(*outputs)))
120+
121+
return outputs
122+
123+
# Recursive function calls like this create reference cycles.
124+
# Setting the function to None clears the refcycle.
125+
try:
126+
res = gather_map(outputs)
127+
finally:
128+
gather_map = None
129+
return res
130+
96131
def parallel_apply(self, replicas, inputs, kwargs):
97132
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
98133

@@ -126,9 +161,8 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
126161
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
127162
output = self.gather(outputs, self.output_device)
128163
else:
129-
# normal
130164
# output = self.module(*inputs, **kwargs)
131-
# lightning (ddp_cpu)
165+
# normal lightning (ddp_cpu)
132166
if self.module.training:
133167
output = self.module.training_step(*inputs, **kwargs)
134168
elif self.module.testing:

pytorch_lightning/trainer/ignored_warnings.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33

44
def ignore_scalar_return_in_dp():
55
# Users get confused by this warning so we silence it
6-
m_1 = """
7-
Was asked to gather along dimension 0, but all
8-
input tensors were scalars; will instead unsqueeze
9-
and return a vector.
10-
"""
11-
warnings.filterwarnings('ignore', message=m_1)
6+
warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all'
7+
' input tensors were scalars; will instead unsqueeze'
8+
' and return a vector.')
129

1310

1411
ignore_scalar_return_in_dp()

pytorch_lightning/trainer/logging.py

+4
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def reduce_distributed_output(self, output, num_gpus):
208208
if isinstance(output[k], dict):
209209
output[k] = self.reduce_distributed_output(output[k], num_gpus)
210210

211+
# compute the average of scalars
212+
elif isinstance(output[k], list):
213+
output[k] = sum(output[k]) / len(output[k])
214+
211215
# do nothing when there's a scalar
212216
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
213217
pass

tests/base/model_train_steps.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,31 @@ class TrainingStepVariations(ABC):
1111
"""
1212
Houses all variations of training steps
1313
"""
14+
1415
test_step_inf_loss = float('inf')
1516

1617
def training_step(self, batch, batch_idx, optimizer_idx=None):
1718
"""Lightning calls this inside the training loop"""
1819
# forward pass
1920
x, y = batch
2021
x = x.view(x.size(0), -1)
21-
2222
y_hat = self(x)
2323

2424
# calculate loss
2525
loss_val = self.loss(y, y_hat)
26-
27-
# alternate possible outputs to test
28-
output = OrderedDict({
29-
'loss': loss_val,
30-
'progress_bar': {'some_val': loss_val * loss_val},
31-
'log': {'train_some_val': loss_val * loss_val},
32-
})
26+
log_val = loss_val
27+
28+
# alternate between tensors and scalars for "log" and "progress_bar"
29+
if batch_idx % 2 == 0:
30+
log_val = log_val.item()
31+
32+
output = OrderedDict(
33+
{
34+
'loss': loss_val,
35+
'progress_bar': {'some_val': log_val * log_val},
36+
'log': {'train_some_val': log_val * log_val},
37+
}
38+
)
3339
return output
3440

3541
def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None):

tests/base/model_valid_epoch_ends.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def _mean(res, key):
2525
val_loss_mean = _mean(outputs, 'val_loss')
2626
val_acc_mean = _mean(outputs, 'val_acc')
2727

28+
# alternate between tensor and scalar
29+
if self.current_epoch % 2 == 0:
30+
val_loss_mean = val_loss_mean.item()
31+
val_acc_mean = val_acc_mean.item()
32+
2833
metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
2934
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
3035
return results
@@ -54,6 +59,6 @@ def _mean(res, key):
5459
results = {
5560
'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(),
5661
'progress_bar': pbar,
57-
'log': logs
62+
'log': logs,
5863
}
5964
return results

0 commit comments

Comments
 (0)