Skip to content

Commit 27a3be0

Browse files
srushBorda
andauthored
TPU gradient clipping. (#963)
* clip * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec <[email protected]> * pull out epsilon * add fp16 case * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent b941845 commit 27a3be0

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

pytorch_lightning/trainer/training_tricks.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
from abc import ABC, abstractmethod
33

44
import torch
5+
import math
56

67
from pytorch_lightning.callbacks import GradientAccumulationScheduler
78

9+
EPSILON = 1e-6
10+
EPSILON_FP16 = 1e-5
11+
812

913
class TrainerTrainingTricksMixin(ABC):
1014

@@ -19,9 +23,29 @@ def get_model(self):
1923
pass
2024

2125
def clip_gradients(self):
26+
# this code is a modification of torch.nn.utils.clip_grad_norm_
27+
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
2228
if self.gradient_clip_val > 0:
2329
model = self.get_model()
24-
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
30+
parameters = model.parameters()
31+
max_norm = float(self.gradient_clip_val)
32+
norm_type = float(2.0)
33+
if isinstance(parameters, torch.Tensor):
34+
parameters = [parameters]
35+
parameters = list(filter(lambda p: p.grad is not None, parameters))
36+
if norm_type == math.inf:
37+
total_norm = max(p.grad.data.abs().max() for p in parameters)
38+
else:
39+
device = parameters[0].device
40+
total_norm = torch.zeros([], device=device if parameters else None)
41+
for p in parameters:
42+
param_norm = p.grad.data.norm(norm_type) ** norm_type
43+
total_norm.add_(param_norm)
44+
total_norm = (total_norm ** (1. / norm_type))
45+
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
46+
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
47+
for p in parameters:
48+
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
2549

2650
def print_nan_gradients(self):
2751
model = self.get_model()

0 commit comments

Comments
 (0)