Skip to content

Commit 9031dc3

Browse files
authored
Fix result gathering with varying tensor shapes (#3020)
* test for gethering results * fix gather * document tests * changelog * assert dtype * default to concat * additional test
1 parent 9445c80 commit 9031dc3

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
146146

147147
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
148148

149+
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
150+
149151
## [0.8.5] - 2020-07-09
150152

151153
### Added

pytorch_lightning/core/step_result.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numbers
22
from copy import copy
3-
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
3+
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple
44

55
import torch
66
from torch import Tensor
@@ -417,19 +417,23 @@ def recursive_stack(result: MutableMapping):
417417
if isinstance(v, dict):
418418
recursive_stack(v)
419419

420-
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
421-
v = torch.stack(v)
422-
result[k] = v
420+
result[k] = collate_tensors(v)
423421

424422

425-
def recursive_padded_stack(result: MutableMapping):
426-
for k, v in result.items():
427-
if isinstance(v, dict):
428-
recursive_stack(v)
423+
def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]:
424+
if not items or not isinstance(items, (list, tuple)) or any(not isinstance(item, Tensor) for item in items):
425+
# items is not a sequence, empty, or contains non-tensors
426+
return items
427+
428+
if all(item.ndim == 0 for item in items):
429+
# all tensors are scalars, we need to stack
430+
return torch.stack(items)
431+
432+
if all(item.ndim >= 1 and item.shape[1:] == items[0].shape[1:] for item in items):
433+
# we can concatenate along the first dimension
434+
return torch.cat(items)
429435

430-
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
431-
v = torch.stack(v)
432-
result[k] = v
436+
return items
433437

434438

435439
class TrainResult(Result):

tests/core/test_results.py

+62
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,65 @@ def test_result_obj_predictions_ddp_spawn(tmpdir):
174174
predictions = torch.load(prediction_file)
175175
size += len(predictions)
176176
assert size == len(dm.mnist_test)
177+
178+
179+
def test_result_gather_stack():
180+
""" Test that tensors get concatenated when they all have the same shape. """
181+
outputs = [
182+
{"foo": torch.zeros(4, 5)},
183+
{"foo": torch.zeros(4, 5)},
184+
{"foo": torch.zeros(4, 5)},
185+
]
186+
result = Result.gather(outputs)
187+
assert isinstance(result["foo"], torch.Tensor)
188+
assert list(result["foo"].shape) == [12, 5]
189+
190+
191+
def test_result_gather_concatenate():
192+
""" Test that tensors get concatenated when they have varying size in first dimension. """
193+
outputs = [
194+
{"foo": torch.zeros(4, 5)},
195+
{"foo": torch.zeros(8, 5)},
196+
{"foo": torch.zeros(3, 5)},
197+
]
198+
result = Result.gather(outputs)
199+
assert isinstance(result["foo"], torch.Tensor)
200+
assert list(result["foo"].shape) == [15, 5]
201+
202+
203+
def test_result_gather_scalar():
204+
""" Test that 0-dim tensors get gathered and stacked correctly. """
205+
outputs = [
206+
{"foo": torch.tensor(1)},
207+
{"foo": torch.tensor(2)},
208+
{"foo": torch.tensor(3)},
209+
]
210+
result = Result.gather(outputs)
211+
assert isinstance(result["foo"], torch.Tensor)
212+
assert list(result["foo"].shape) == [3]
213+
214+
215+
def test_result_gather_different_shapes():
216+
""" Test that tensors of varying shape get gathered into a list. """
217+
outputs = [
218+
{"foo": torch.tensor(1)},
219+
{"foo": torch.zeros(2, 3)},
220+
{"foo": torch.zeros(1, 2, 3)},
221+
]
222+
result = Result.gather(outputs)
223+
expected = [torch.tensor(1), torch.zeros(2, 3), torch.zeros(1, 2, 3)]
224+
assert isinstance(result["foo"], list)
225+
assert all(torch.eq(r, e).all() for r, e in zip(result["foo"], expected))
226+
227+
228+
def test_result_gather_mixed_types():
229+
""" Test that a collection of mixed types gets gathered into a list. """
230+
outputs = [
231+
{"foo": 1.2},
232+
{"foo": ["bar", None]},
233+
{"foo": torch.tensor(1)},
234+
]
235+
result = Result.gather(outputs)
236+
expected = [1.2, ["bar", None], torch.tensor(1)]
237+
assert isinstance(result["foo"], list)
238+
assert result["foo"] == expected

0 commit comments

Comments
 (0)