Skip to content

Commit 560b9ae

Browse files
nsarangBorda
authored andcommitted
override gather method in DP
1 parent 1f82457 commit 560b9ae

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pytorch_lightning/overrides/data_parallel.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forward(self, *inputs, **kwargs):
7070
if isinstance(outputs[0], Result):
7171
outputs = self.__gather_structured_result(outputs)
7272
else:
73-
outputs = self.gather(outputs, self.output_device)
73+
outputs = self.gather(outputs)
7474
return outputs
7575

7676
def __gather_structured_result(self, outputs):
@@ -83,7 +83,7 @@ def __gather_structured_result(self, outputs):
8383
for i, output in enumerate(outputs):
8484
del output['meta']
8585

86-
outputs = self.gather(outputs, self.output_device)
86+
outputs = self.gather(outputs)
8787

8888
# pass minimize to constructor for TrainResult
8989
if 'minimize' in outputs:
@@ -106,16 +106,16 @@ def gather_map(outputs):
106106
if isinstance(elem, torch.Tensor):
107107
return Gather.apply(self.output_device, self.dim, *outputs)
108108

109-
elif elem is None:
109+
if elem is None:
110110
return None
111111

112-
elif isinstance(elem, Mapping):
112+
if isinstance(elem, Mapping):
113113
if not all((len(elem) == len(d) for d in outputs)):
114114
raise ValueError('All dicts must have the same number of keys')
115115
return elem_type(((k, gather_map([d[k] for d in outputs]))
116116
for k in elem))
117117

118-
elif isinstance(elem, Iterable) and not isinstance(elem, str):
118+
if isinstance(elem, Iterable) and not isinstance(elem, str):
119119
return elem_type(map(gather_map, zip(*outputs)))
120120

121121
return outputs

0 commit comments

Comments
 (0)