2
2
from abc import ABC , abstractmethod
3
3
4
4
import torch
5
+ import math
5
6
6
7
from pytorch_lightning .callbacks import GradientAccumulationScheduler
7
8
9
+ EPSILON = 1e-6
10
+ EPSILON_FP16 = 1e-5
11
+
8
12
9
13
class TrainerTrainingTricksMixin (ABC ):
10
14
@@ -19,9 +23,29 @@ def get_model(self):
19
23
pass
20
24
21
25
def clip_gradients (self ):
26
+ # this code is a modification of torch.nn.utils.clip_grad_norm_
27
+ # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
22
28
if self .gradient_clip_val > 0 :
23
29
model = self .get_model ()
24
- torch .nn .utils .clip_grad_norm_ (model .parameters (), self .gradient_clip_val )
30
+ parameters = model .parameters ()
31
+ max_norm = float (self .gradient_clip_val )
32
+ norm_type = float (2.0 )
33
+ if isinstance (parameters , torch .Tensor ):
34
+ parameters = [parameters ]
35
+ parameters = list (filter (lambda p : p .grad is not None , parameters ))
36
+ if norm_type == math .inf :
37
+ total_norm = max (p .grad .data .abs ().max () for p in parameters )
38
+ else :
39
+ device = parameters [0 ].device
40
+ total_norm = torch .zeros ([], device = device if parameters else None )
41
+ for p in parameters :
42
+ param_norm = p .grad .data .norm (norm_type ) ** norm_type
43
+ total_norm .add_ (param_norm )
44
+ total_norm = (total_norm ** (1. / norm_type ))
45
+ eps = EPSILON_FP16 if self .precision == 16 else EPSILON
46
+ clip_coef = torch .tensor (max_norm , device = device ) / (total_norm + eps )
47
+ for p in parameters :
48
+ p .grad .data .mul_ (torch .where (clip_coef < 1 , clip_coef , torch .tensor (1. , device = device )))
25
49
26
50
def print_nan_gradients (self ):
27
51
model = self .get_model ()
0 commit comments