Skip to content

Commit a18067b

Browse files
ivannzBorda
authored andcommitted
Mistake in parameters' grad norm tracking (#2012)
* fix grad norm formula * grad-norm tracker test * fixed seed and explicit rtol in grad norm tracking test * a docstring for grad-norms and forced cast to float of norm_type * support for inf-norm * renamed the grad norm test * docs * fixed language in docstring * Apply suggestions from code review Co-authored-by: Jirka <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 93dbca9 commit a18067b

File tree

4 files changed

+148
-26
lines changed

4 files changed

+148
-26
lines changed

pytorch_lightning/core/grads.py

+33-22
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,41 @@
11
"""
22
Module to describe gradients
33
"""
4-
from typing import Dict
4+
from typing import Dict, Union
55

6-
from torch import nn
6+
import torch
77

88

9-
class GradInformation(nn.Module):
9+
class GradInformation(torch.nn.Module):
1010

11-
def grad_norm(self, norm_type: float) -> Dict[str, int]:
12-
results = {}
13-
total_norm = 0
11+
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
12+
"""Compute each parameter's gradient's norm and their overall norm.
13+
14+
The overall norm is computed over all gradients together, as if they
15+
were concatenated into a single vector.
16+
17+
Args:
18+
norm_type: The type of the used p-norm, cast to float if necessary.
19+
Can be ``'inf'`` for infinity norm.
20+
21+
Return:
22+
norms: The dictionary of p-norms of each parameter's gradient and
23+
a special entry for the total p-norm of the gradients viewed
24+
as a single vector.
25+
"""
26+
norm_type = float(norm_type)
27+
28+
norms, all_norms = {}, []
1429
for name, p in self.named_parameters():
15-
if p.requires_grad:
16-
try:
17-
param_norm = p.grad.data.norm(norm_type)
18-
total_norm += param_norm ** norm_type
19-
norm = param_norm ** (1 / norm_type)
20-
21-
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
22-
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
23-
except Exception:
24-
# this param had no grad
25-
pass
26-
27-
total_norm = total_norm ** (1. / norm_type)
28-
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
29-
results['grad_{}_norm_total'.format(norm_type)] = grad
30-
return results
30+
if p.grad is None:
31+
continue
32+
33+
param_norm = float(p.grad.data.norm(norm_type))
34+
norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 3)
35+
36+
all_norms.append(param_norm)
37+
38+
total_norm = float(torch.tensor(all_norms).norm(norm_type))
39+
norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 3)
40+
41+
return norms

pytorch_lightning/trainer/trainer.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
log_gpu_memory: Optional[str] = None,
101101
progress_bar_refresh_rate: int = 1,
102102
overfit_pct: float = 0.0,
103-
track_grad_norm: int = -1,
103+
track_grad_norm: Union[int, float, str] = -1,
104104
check_val_every_n_epoch: int = 1,
105105
fast_dev_run: bool = False,
106106
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
@@ -204,7 +204,7 @@ def __init__(
204204
205205
overfit_pct: How much of training-, validation-, and test dataset to check.
206206
207-
track_grad_norm: -1 no tracking. Otherwise tracks that norm
207+
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
208208
209209
check_val_every_n_epoch: Check val every n train epochs.
210210
@@ -340,7 +340,12 @@ def __init__(
340340
self.gradient_clip = gradient_clip
341341

342342
self.check_val_every_n_epoch = check_val_every_n_epoch
343-
self.track_grad_norm = track_grad_norm
343+
344+
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
345+
raise MisconfigurationException(
346+
"track_grad_norm can be an int, a float or 'inf' (infinity norm).")
347+
self.track_grad_norm = float(track_grad_norm)
348+
344349
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
345350

346351
# tpu config

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def optimizer_closure():
631631

632632
# track gradient norms when requested
633633
if batch_idx % self.row_log_interval == 0:
634-
if self.track_grad_norm > 0:
634+
if float(self.track_grad_norm) > 0:
635635
model = self.get_model()
636636
grad_norm_dic = model.grad_norm(
637637
self.track_grad_norm)

tests/models/test_grad_norm.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
import pytest
3+
import numpy as np
4+
5+
from pytorch_lightning import Trainer, seed_everything
6+
7+
from pytorch_lightning.loggers import LightningLoggerBase
8+
from pytorch_lightning.utilities import rank_zero_only
9+
10+
from tests.base import EvalModelTemplate
11+
from tests.base.utils import reset_seed
12+
13+
14+
class OnlyMetricsListLogger(LightningLoggerBase):
15+
def __init__(self):
16+
super().__init__()
17+
self.metrics = []
18+
19+
@rank_zero_only
20+
def log_metrics(self, metrics, step):
21+
self.metrics.append(metrics)
22+
23+
@property
24+
def experiment(self):
25+
return 'test'
26+
27+
@rank_zero_only
28+
def log_hyperparams(self, params):
29+
pass
30+
31+
@rank_zero_only
32+
def finalize(self, status):
33+
pass
34+
35+
@property
36+
def name(self):
37+
return 'name'
38+
39+
@property
40+
def version(self):
41+
return '1'
42+
43+
44+
class ModelWithManualGradTracker(EvalModelTemplate):
45+
def __init__(self, norm_type, *args, **kwargs):
46+
super().__init__(*args, **kwargs)
47+
self.stored_grad_norms, self.norm_type = [], float(norm_type)
48+
49+
# validation spoils logger's metrics with `val_loss` records
50+
validation_step = None
51+
val_dataloader = None
52+
53+
def training_step(self, batch, batch_idx, optimizer_idx=None):
54+
# just return a loss, no log or progress bar meta
55+
x, y = batch
56+
loss_val = self.loss(y, self(x.flatten(1, -1)))
57+
return {'loss': loss_val}
58+
59+
def on_after_backward(self):
60+
out, norms = {}, []
61+
prefix = f'grad_{self.norm_type}_norm_'
62+
for name, p in self.named_parameters():
63+
if p.grad is None:
64+
continue
65+
66+
# `np.linalg.norm` implementation likely uses fp64 intermediates
67+
flat = p.grad.data.cpu().numpy().ravel()
68+
norm = np.linalg.norm(flat, self.norm_type)
69+
norms.append(norm)
70+
71+
out[prefix + name] = round(norm, 3)
72+
73+
# handle total norm
74+
norm = np.linalg.norm(norms, self.norm_type)
75+
out[prefix + 'total'] = round(norm, 3)
76+
self.stored_grad_norms.append(out)
77+
78+
79+
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
80+
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
81+
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
82+
83+
reset_seed()
84+
85+
# use a custom grad tracking module and a list logger
86+
model = ModelWithManualGradTracker(norm_type)
87+
logger = OnlyMetricsListLogger()
88+
89+
trainer = Trainer(
90+
max_epochs=3,
91+
logger=logger,
92+
track_grad_norm=norm_type,
93+
row_log_interval=1, # request grad_norms every batch
94+
)
95+
result = trainer.fit(model)
96+
97+
assert result == 1, "Training failed"
98+
assert len(logger.metrics) == len(model.stored_grad_norms)
99+
100+
# compare the logged metrics against tracked norms on `.backward`
101+
for mod, log in zip(model.stored_grad_norms, logger.metrics):
102+
common = mod.keys() & log.keys()
103+
104+
log, mod = [log[k] for k in common], [mod[k] for k in common]
105+
106+
assert np.allclose(log, mod, rtol=rtol)

0 commit comments

Comments
 (0)