Skip to content

Commit 6034d5e

Browse files
authored
fix apex gradient clipping (#2829)
1 parent 5bbcb8d commit 6034d5e

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pytorch_lightning/trainer/training_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def transfer_batch_to_tpu(self, *args):
291291
"""Warning: this is just empty shell for code implemented in other class."""
292292

293293
@abstractmethod
294-
def clip_gradients(self):
294+
def clip_gradients(self, *args):
295295
"""Warning: this is just empty shell for code implemented in other class."""
296296

297297
@abstractmethod
@@ -817,7 +817,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
817817
# ------------------
818818
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
819819
self.scaler.unscale_(optimizer)
820-
self.clip_gradients()
820+
self.clip_gradients(optimizer)
821821

822822
# ------------------
823823
# .STEP + ZERO_GRAD

pytorch_lightning/trainer/training_tricks.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,17 @@
2727
from pytorch_lightning.core.lightning import LightningModule
2828
from pytorch_lightning.callbacks import GradientAccumulationScheduler
2929
from pytorch_lightning.loggers.base import DummyLogger
30+
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
3031
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3132
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
3233

34+
try:
35+
from apex import amp
36+
except ImportError:
37+
APEX_AVAILABLE = False
38+
else:
39+
APEX_AVAILABLE = True
40+
3341
EPSILON = 1e-6
3442
EPSILON_FP16 = 1e-5
3543

@@ -60,14 +68,17 @@ def restore(self, *args):
6068
def fit(self, *args):
6169
"""Warning: this is just empty shell for code implemented in other class."""
6270

63-
def clip_gradients(self):
71+
def clip_gradients(self, optimizer):
6472

6573
# this code is a modification of torch.nn.utils.clip_grad_norm_
6674
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
6775
if self.gradient_clip_val <= 0:
6876
return
6977
model = self.get_model()
70-
parameters = model.parameters()
78+
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
79+
parameters = amp.master_params(optimizer)
80+
else:
81+
parameters = model.parameters()
7182
max_norm = float(self.gradient_clip_val)
7283
norm_type = float(2.0)
7384
if isinstance(parameters, torch.Tensor):

0 commit comments

Comments
 (0)