File tree 1 file changed +5
-2
lines changed
pytorch_lightning/trainer
1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -510,8 +510,7 @@ def __init__(
510
510
self .autocast_original_forward = None
511
511
self .use_native_amp = hasattr (torch .cuda , "amp" ) and hasattr (torch .cuda .amp , "autocast" )
512
512
self .precision = precision
513
- if self .use_native_amp and self .precision == 16 :
514
- self .scaler = torch .cuda .amp .GradScaler ()
513
+ self .scaler = None
515
514
516
515
# TODO: remove for v0.8.0
517
516
self .amp_level = amp_level
@@ -858,6 +857,10 @@ def run_pretrain_routine(self, model: LightningModule):
858
857
# set local properties on the model
859
858
self .copy_trainer_model_properties (ref_model )
860
859
860
+ # init amp. Must be done here instead of __init__ to allow ddp to work
861
+ if self .use_native_amp and self .precision == 16 :
862
+ self .scaler = torch .cuda .amp .GradScaler ()
863
+
861
864
# log hyper-parameters
862
865
if self .logger is not None :
863
866
# save exp to get started
You can’t perform that action at this time.
0 commit comments