@@ -70,7 +70,7 @@ def forward(self, *inputs, **kwargs):
70
70
if isinstance (outputs [0 ], Result ):
71
71
outputs = self .__gather_structured_result (outputs )
72
72
else :
73
- outputs = self .gather (outputs , self . output_device )
73
+ outputs = self .gather (outputs )
74
74
return outputs
75
75
76
76
def __gather_structured_result (self , outputs ):
@@ -83,7 +83,7 @@ def __gather_structured_result(self, outputs):
83
83
for i , output in enumerate (outputs ):
84
84
del output ['meta' ]
85
85
86
- outputs = self .gather (outputs , self . output_device )
86
+ outputs = self .gather (outputs )
87
87
88
88
# pass minimize to constructor for TrainResult
89
89
if 'minimize' in outputs :
@@ -106,16 +106,16 @@ def gather_map(outputs):
106
106
if isinstance (elem , torch .Tensor ):
107
107
return Gather .apply (self .output_device , self .dim , * outputs )
108
108
109
- elif elem is None :
109
+ if elem is None :
110
110
return None
111
111
112
- elif isinstance (elem , Mapping ):
112
+ if isinstance (elem , Mapping ):
113
113
if not all ((len (elem ) == len (d ) for d in outputs )):
114
114
raise ValueError ('All dicts must have the same number of keys' )
115
115
return elem_type (((k , gather_map ([d [k ] for d in outputs ]))
116
116
for k in elem ))
117
117
118
- elif isinstance (elem , Iterable ) and not isinstance (elem , str ):
118
+ if isinstance (elem , Iterable ) and not isinstance (elem , str ):
119
119
return elem_type (map (gather_map , zip (* outputs )))
120
120
121
121
return outputs
0 commit comments