Skip to content

Commit c0332c6

Browse files
committed
fix grad norm formula
1 parent 9893681 commit c0332c6

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

pytorch_lightning/core/grads.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,25 @@
33
"""
44
from typing import Dict
55

6-
from torch import nn
6+
import torch
77

88

9-
class GradInformation(nn.Module):
9+
class GradInformation(torch.nn.Module):
1010

1111
def grad_norm(self, norm_type: float) -> Dict[str, int]:
12-
results = {}
13-
total_norm = 0
12+
norms, all_norms = {}, []
1413
for name, p in self.named_parameters():
15-
if p.requires_grad:
16-
try:
17-
param_norm = p.grad.data.norm(norm_type)
18-
total_norm += param_norm ** norm_type
19-
norm = param_norm ** (1 / norm_type)
20-
21-
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
22-
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
23-
except Exception:
24-
# this param had no grad
25-
pass
26-
27-
total_norm = total_norm ** (1. / norm_type)
28-
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
29-
results['grad_{}_norm_total'.format(norm_type)] = grad
30-
return results
14+
if p.grad is None:
15+
continue
16+
17+
param_norm = float(p.grad.data.norm(norm_type))
18+
norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 3)
19+
20+
all_norms.append(param_norm)
21+
22+
total_norm = 0.
23+
if all_norms:
24+
total_norm = float(torch.tensor(total_norm).norm(norm_type))
25+
norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 3)
26+
27+
return norms

0 commit comments

Comments
 (0)