diff --git a/pytorch_lightning/trainer/dp_mixin.py b/pytorch_lightning/trainer/dp_mixin.py index 684ff15c6989b..82ecdc1658905 100644 --- a/pytorch_lightning/trainer/dp_mixin.py +++ b/pytorch_lightning/trainer/dp_mixin.py @@ -63,12 +63,12 @@ def transfer_batch_to_gpu(self, batch, gpu_id): return batch def single_gpu_train(self, model): + model.cuda(self.root_gpu) + # CHOOSE OPTIMIZER # allow for lr schedulers as well self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) - model.cuda(self.root_gpu) - if self.use_amp: # An example model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)