File tree 1 file changed +17
-20
lines changed
1 file changed +17
-20
lines changed Original file line number Diff line number Diff line change 3
3
"""
4
4
from typing import Dict
5
5
6
- from torch import nn
6
+ import torch
7
7
8
8
9
- class GradInformation (nn .Module ):
9
+ class GradInformation (torch . nn .Module ):
10
10
11
11
def grad_norm (self , norm_type : float ) -> Dict [str , int ]:
12
- results = {}
13
- total_norm = 0
12
+ norms , all_norms = {}, []
14
13
for name , p in self .named_parameters ():
15
- if p .requires_grad :
16
- try :
17
- param_norm = p .grad .data .norm (norm_type )
18
- total_norm += param_norm ** norm_type
19
- norm = param_norm ** (1 / norm_type )
20
-
21
- grad = round (norm .data .cpu ().numpy ().flatten ()[0 ], 3 )
22
- results ['grad_{}_norm_{}' .format (norm_type , name )] = grad
23
- except Exception :
24
- # this param had no grad
25
- pass
26
-
27
- total_norm = total_norm ** (1. / norm_type )
28
- grad = round (total_norm .data .cpu ().numpy ().flatten ()[0 ], 3 )
29
- results ['grad_{}_norm_total' .format (norm_type )] = grad
30
- return results
14
+ if p .grad is None :
15
+ continue
16
+
17
+ param_norm = float (p .grad .data .norm (norm_type ))
18
+ norms [f'grad_{ norm_type } _norm_{ name } ' ] = round (param_norm , 3 )
19
+
20
+ all_norms .append (param_norm )
21
+
22
+ total_norm = 0.
23
+ if all_norms :
24
+ total_norm = float (torch .tensor (all_norms ).norm (norm_type ))
25
+ norms [f'grad_{ norm_type } _norm_total' ] = round (total_norm , 3 )
26
+
27
+ return norms
You can’t perform that action at this time.
0 commit comments