@@ -155,6 +155,11 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
155
155
return move_data_to_device (batch , device )
156
156
157
157
def single_gpu_train (self , model ):
158
+ # call setup
159
+ self .setup ('fit' )
160
+ if self .is_function_implemented ('setup' , model ):
161
+ model .setup ('fit' )
162
+
158
163
model .cuda (self .root_gpu )
159
164
160
165
# CHOOSE OPTIMIZER
@@ -171,6 +176,11 @@ def single_gpu_train(self, model):
171
176
self .run_pretrain_routine (model )
172
177
173
178
def tpu_train (self , tpu_core_idx , model ):
179
+ # call setup after the ddp process has connected
180
+ self .setup ('fit' )
181
+ if self .is_function_implemented ('setup' , model ):
182
+ model .setup ('fit' )
183
+
174
184
# put model on tpu
175
185
self ._device = xm .xla_device (self .tpu_id ) if self .tpu_id is not None else xm .xla_device ()
176
186
model .to (self ._device )
@@ -205,6 +215,10 @@ def tpu_train(self, tpu_core_idx, model):
205
215
self .save_spawn_weights (model )
206
216
207
217
def dp_train (self , model ):
218
+ # call setup after the ddp process has connected
219
+ self .setup ('fit' )
220
+ if self .is_function_implemented ('setup' , model ):
221
+ model .setup ('fit' )
208
222
209
223
# CHOOSE OPTIMIZER
210
224
# allow for lr schedulers as well
@@ -246,6 +260,11 @@ def dp_train(self, model):
246
260
model .forward = model_autocast_original_forward
247
261
248
262
def horovod_train (self , model ):
263
+ # call setup after the ddp process has connected
264
+ self .setup ('fit' )
265
+ if self .is_function_implemented ('setup' , model ):
266
+ model .setup ('fit' )
267
+
249
268
if torch .cuda .is_available () and self .on_gpu :
250
269
# Horovod: pin GPU to local rank
251
270
assert self .root_gpu == hvd .local_rank ()
0 commit comments