Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistake in parameters' grad norm tracking #2012

Merged
merged 9 commits into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
"""
Module to describe gradients
"""
from typing import Dict
from typing import Dict, Union

from torch import nn
import torch


class GradInformation(nn.Module):
class GradInformation(torch.nn.Module):

def grad_norm(self, norm_type: float) -> Dict[str, int]:
results = {}
total_norm = 0
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
r"""Compute individual parameter's gradient norms and the overall norm.

The overall norm is computed over all gradients together, as if they
were concatenated into a single vector.

Args:
norm_type: The type of the used p-norm, cast to float if necessary.
Can be ``'inf'`` for infinity norm.

Return:
norms: The dictionary of p-norms each individual gradient and the a
special entry for the total p-norm of the parameters' gradients
viewed as a single vector.
"""
norm_type = float(norm_type)

norms, all_norms = {}, []
for name, p in self.named_parameters():
if p.requires_grad:
try:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
norm = param_norm ** (1 / norm_type)

grad = round(norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
except Exception:
# this param had no grad
pass

total_norm = total_norm ** (1. / norm_type)
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_total'.format(norm_type)] = grad
return results
if p.grad is None:
continue

param_norm = float(p.grad.data.norm(norm_type))
norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 3)

all_norms.append(param_norm)

total_norm = float(torch.tensor(all_norms).norm(norm_type))
norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 3)

return norms
12 changes: 9 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
overfit_pct: float = 0.0,
track_grad_norm: int = -1,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(

overfit_pct: How much of training-, validation-, and test dataset to check.

track_grad_norm: -1 no tracking. Otherwise tracks that norm
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

check_val_every_n_epoch: Check val every n train epochs.

Expand Down Expand Up @@ -341,7 +341,13 @@ def __init__(
self.gradient_clip = gradient_clip

self.check_val_every_n_epoch = check_val_every_n_epoch
self.track_grad_norm = track_grad_norm

if not isinstance(track_grad_norm, (int, float)) \
and track_grad_norm != 'inf':
raise MisconfigurationException("track_grad_norm can be an int, a "
"float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)

self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

# tpu config
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def optimizer_closure():

# track gradient norms when requested
if batch_idx % self.row_log_interval == 0:
if self.track_grad_norm > 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)
Expand Down
105 changes: 105 additions & 0 deletions tests/models/test_grad_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import pytest
import numpy as np

from pytorch_lightning import Trainer, seed_everything

from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only

from tests.base import EvalModelTemplate


class OnlyMetricsListLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.metrics = []

@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics.append(metrics)

@property
def experiment(self):
return 'test'

@rank_zero_only
def log_hyperparams(self, params):
pass

@rank_zero_only
def finalize(self, status):
pass

@property
def name(self):
return 'name'

@property
def version(self):
return '1'


class ModelWithManualGradTracker(EvalModelTemplate):
def __init__(self, norm_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stored_grad_norms, self.norm_type = [], float(norm_type)

# validation spoils logger's metrics with `val_loss` records
validation_step = None
val_dataloader = None

def training_step(self, batch, batch_idx, optimizer_idx=None):
# just return a loss, no log or progress bar meta
x, y = batch
loss_val = self.loss(y, self(x.flatten(1, -1)))
return {'loss': loss_val}

def on_after_backward(self):
out, norms = {}, []
prefix = f'grad_{self.norm_type}_norm_'
for name, p in self.named_parameters():
if p.grad is None:
continue

# `np.linalg.norm` implementation likely uses fp64 intermediates
flat = p.grad.data.cpu().numpy().ravel()
norm = np.linalg.norm(flat, self.norm_type)
norms.append(norm)

out[prefix + name] = round(norm, 3)

# handle total norm
norm = np.linalg.norm(norms, self.norm_type)
out[prefix + 'total'] = round(norm, 3)
self.stored_grad_norms.append(out)


@pytest.mark.parametrize("seed", [479_158_593]) # a vetted random number
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
def test_grad_tracking(tmpdir, norm_type, seed, rtol=5e-3):
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above

seed_everything(seed)

# use a custom grad tracking module and a list logger
model = ModelWithManualGradTracker(norm_type)
logger = OnlyMetricsListLogger()

result = Trainer(
max_epochs=3,
logger=logger,
track_grad_norm=norm_type,
row_log_interval=1, # request grad_norms every batch
).fit(model)

assert result == 1, "Training failed"
assert len(logger.metrics) == len(model.stored_grad_norms)

# compare the logged metrics against tracked norms on `.backward`
for mod, log in zip(model.stored_grad_norms, logger.metrics):
common = mod.keys() & log.keys()

log, mod = [log[k] for k in common], [mod[k] for k in common]

assert np.allclose(log, mod, rtol=rtol)