Skip to content

Commit 7b60d49

Browse files
fixed native amp + ddp (#1788)
* fixed native amp + ddp * fixed native amp + ddp
1 parent 1df0d2d commit 7b60d49

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytorch_lightning/trainer/trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,7 @@ def __init__(
510510
self.autocast_original_forward = None
511511
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
512512
self.precision = precision
513-
if self.use_native_amp and self.precision == 16:
514-
self.scaler = torch.cuda.amp.GradScaler()
513+
self.scaler = None
515514

516515
# TODO: remove for v0.8.0
517516
self.amp_level = amp_level
@@ -858,6 +857,10 @@ def run_pretrain_routine(self, model: LightningModule):
858857
# set local properties on the model
859858
self.copy_trainer_model_properties(ref_model)
860859

860+
# init amp. Must be done here instead of __init__ to allow ddp to work
861+
if self.use_native_amp and self.precision == 16:
862+
self.scaler = torch.cuda.amp.GradScaler()
863+
861864
# log hyper-parameters
862865
if self.logger is not None:
863866
# save exp to get started

0 commit comments

Comments
 (0)