Skip to content

Commit e350a7d

Browse files
Enable apex O2 + dp (#493)
* remove O2 crash * remove O2 crash * bananas
1 parent 8ea7473 commit e350a7d

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pytorch_lightning/trainer/dp_mixin.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,15 @@ def dp_train(self, model):
8787
# check for this bug (amp + dp + !01 doesn't work)
8888
# https://github.com/NVIDIA/apex/issues/227
8989
if self.use_dp and self.use_amp:
90-
m = f"""
91-
Amp level {self.amp_level} with DataParallel is not supported.
92-
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.
93-
We recommend you switch to ddp if you want to use amp
94-
"""
95-
raise MisconfigurationException(m)
90+
if self.amp_level == 'O2':
91+
m = f"""
92+
Amp level {self.amp_level} with DataParallel is not supported.
93+
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.
94+
We recommend you switch to ddp if you want to use amp
95+
"""
96+
raise MisconfigurationException(m)
97+
else:
98+
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
9699

97100
# create list of device ids
98101
device_ids = self.data_parallel_device_ids

0 commit comments

Comments
 (0)