From d8ea3c68fa4a4d676eae44bfe6c4404080d64e53 Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Sat, 30 May 2020 12:59:01 +0300
Subject: [PATCH 1/9] fix grad norm formula

---
 pytorch_lightning/core/grads.py | 35 ++++++++++++++-------------------
 1 file changed, 15 insertions(+), 20 deletions(-)

diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py
index b5d2d5616a60f..e0ad46e06f213 100644
--- a/pytorch_lightning/core/grads.py
+++ b/pytorch_lightning/core/grads.py
@@ -3,28 +3,23 @@
 """
 from typing import Dict
 
-from torch import nn
+import torch
 
 
-class GradInformation(nn.Module):
+class GradInformation(torch.nn.Module):
 
     def grad_norm(self, norm_type: float) -> Dict[str, int]:
-        results = {}
-        total_norm = 0
+        norms, all_norms = {}, []
         for name, p in self.named_parameters():
-            if p.requires_grad:
-                try:
-                    param_norm = p.grad.data.norm(norm_type)
-                    total_norm += param_norm ** norm_type
-                    norm = param_norm ** (1 / norm_type)
-
-                    grad = round(norm.data.cpu().numpy().flatten()[0], 3)
-                    results['grad_{}_norm_{}'.format(norm_type, name)] = grad
-                except Exception:
-                    # this param had no grad
-                    pass
-
-        total_norm = total_norm ** (1. / norm_type)
-        grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
-        results['grad_{}_norm_total'.format(norm_type)] = grad
-        return results
+            if p.grad is None:
+                continue
+
+            param_norm = float(p.grad.data.norm(norm_type))
+            norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 3)
+
+            all_norms.append(param_norm)
+
+        total_norm = float(torch.tensor(all_norms).norm(norm_type))
+        norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 3)
+
+        return norms

From e684666abc91abb3c85fffde58fdcd780b7c0e04 Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Sun, 31 May 2020 19:29:29 +0300
Subject: [PATCH 2/9] grad-norm tracker test

---
 tests/models/test_grad_norm.py | 102 +++++++++++++++++++++++++++++++++
 1 file changed, 102 insertions(+)
 create mode 100644 tests/models/test_grad_norm.py

diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py
new file mode 100644
index 0000000000000..f3f25f58e8caf
--- /dev/null
+++ b/tests/models/test_grad_norm.py
@@ -0,0 +1,102 @@
+import torch
+import pytest
+import numpy as np
+
+from pytorch_lightning import Trainer
+
+from pytorch_lightning.loggers import LightningLoggerBase
+from pytorch_lightning.utilities import rank_zero_only
+
+from tests.base import EvalModelTemplate
+
+
+class OnlyMetricsListLogger(LightningLoggerBase):
+    def __init__(self):
+        super().__init__()
+        self.metrics = []
+
+    @rank_zero_only
+    def log_metrics(self, metrics, step):
+        self.metrics.append(metrics)
+
+    @property
+    def experiment(self):
+        return 'test'
+
+    @rank_zero_only
+    def log_hyperparams(self, params):
+        pass
+
+    @rank_zero_only
+    def finalize(self, status):
+        pass
+
+    @property
+    def name(self):
+        return 'name'
+
+    @property
+    def version(self):
+        return '1'
+
+
+class ModelWithManualGradTracker(EvalModelTemplate):
+    def __init__(self, norm_type, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.stored_grad_norms, self.norm_type = [], norm_type
+
+    # validation spoils logger's metrics with `val_loss` records
+    validation_step = None
+    val_dataloader = None
+
+    def training_step(self, batch, batch_idx, optimizer_idx=None):
+        # just return a loss, no log or progress bar meta
+        x, y = batch
+        loss_val = self.loss(y, self(x.flatten(1, -1)))
+        return {'loss': loss_val}
+
+    def on_after_backward(self):
+        out, norms = {}, []
+        prefix = f'grad_{self.norm_type}_norm_'
+        for name, p in self.named_parameters():
+            if p.grad is None:
+                continue
+
+            # `np.linalg.norm` implementation likely uses fp64 intermediates
+            flat = p.grad.data.cpu().numpy().ravel()
+            norm = np.linalg.norm(flat, self.norm_type)
+            norms.append(norm)
+
+            out[prefix + name] = round(norm, 3)
+
+        # handle total norm
+        norm = np.linalg.norm(norms, self.norm_type)
+        out[prefix + 'total'] = round(norm, 3)
+        self.stored_grad_norms.append(out)
+
+
+# @pytest.mark.skip(reason="might fail for small `norm_type` due to round-off")
+@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, float('inf')])
+def test_custom_logger(tmpdir, norm_type):
+    # use a custom grad tracking module and a list logger
+    model = ModelWithManualGradTracker(norm_type)
+    logger = OnlyMetricsListLogger()
+
+    result = Trainer(
+        max_epochs=3,
+        logger=logger,
+        track_grad_norm=norm_type,
+        row_log_interval=1,  # request grad_norms every batch
+    ).fit(model)
+
+    assert result == 1, "Training failed"
+    assert logger.metrics
+
+    # compare the logged metrics gainst tracked by the model on `.backward`
+    for mod, log in zip(model.stored_grad_norms, logger.metrics):
+        common = mod.keys() & log.keys()
+
+        log, mod = [log[k] for k in common], [mod[k] for k in common]
+
+        # 1e-3 respects the round-off in grad_norms and above
+        assert np.allclose(log, mod, rtol=5e-3)

From 555c7b55e40d43692eb36729a8e6718919a2f357 Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Mon, 1 Jun 2020 13:15:15 +0300
Subject: [PATCH 3/9] fixed seed and explicit rtol in grad norm tracking test

---
 tests/models/test_grad_norm.py | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py
index f3f25f58e8caf..39116af2d6ef5 100644
--- a/tests/models/test_grad_norm.py
+++ b/tests/models/test_grad_norm.py
@@ -2,7 +2,7 @@
 import pytest
 import numpy as np
 
-from pytorch_lightning import Trainer
+from pytorch_lightning import Trainer, seed_everything
 
 from pytorch_lightning.loggers import LightningLoggerBase
 from pytorch_lightning.utilities import rank_zero_only
@@ -75,9 +75,13 @@ def on_after_backward(self):
         self.stored_grad_norms.append(out)
 
 
-# @pytest.mark.skip(reason="might fail for small `norm_type` due to round-off")
+@pytest.mark.parametrize("seed", [479_158_593])  # a vetted random number
 @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, float('inf')])
-def test_custom_logger(tmpdir, norm_type):
+def test_custom_logger(tmpdir, norm_type, seed, rtol=5e-3):
+    # rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
+
+    seed_everything(seed)
+
     # use a custom grad tracking module and a list logger
     model = ModelWithManualGradTracker(norm_type)
     logger = OnlyMetricsListLogger()
@@ -90,13 +94,12 @@ def test_custom_logger(tmpdir, norm_type):
     ).fit(model)
 
     assert result == 1, "Training failed"
-    assert logger.metrics
+    assert len(logger.metrics) == len(model.stored_grad_norms)
 
-    # compare the logged metrics gainst tracked by the model on `.backward`
+    # compare the logged metrics against tracked norms on `.backward`
     for mod, log in zip(model.stored_grad_norms, logger.metrics):
         common = mod.keys() & log.keys()
 
         log, mod = [log[k] for k in common], [mod[k] for k in common]
 
-        # 1e-3 respects the round-off in grad_norms and above
-        assert np.allclose(log, mod, rtol=5e-3)
+        assert np.allclose(log, mod, rtol=rtol)

From 3c7068630e3839d01a2843a459dfb3f06a070791 Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Mon, 1 Jun 2020 13:25:16 +0300
Subject: [PATCH 4/9] a docstring for grad-norms and forced cast to float of
 norm_type

---
 pytorch_lightning/core/grads.py | 26 ++++++++++++++++++++++++--
 1 file changed, 24 insertions(+), 2 deletions(-)

diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py
index e0ad46e06f213..d6bce643a7ff3 100644
--- a/pytorch_lightning/core/grads.py
+++ b/pytorch_lightning/core/grads.py
@@ -1,14 +1,36 @@
 """
 Module to describe gradients
 """
-from typing import Dict
+from typing import Dict, Union
 
 import torch
 
 
 class GradInformation(torch.nn.Module):
 
-    def grad_norm(self, norm_type: float) -> Dict[str, int]:
+    def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
+        r"""Compute individual parameter's gradient norms and the overall norm.
+
+        Parameters
+        ----------
+        norm_type: float, int, str:
+            The type of the used p-norm, cast to float if necessary. Can be
+            ``'inf'`` for infinity norm.
+
+        Returns
+        -------
+        norms: dict
+            The dictionary of p-norms each individual gradient and the a
+            special entry for the total p-norm of the parameters' gradients
+            viewed as a single vector.
+
+        Details
+        -------
+        The overall norm is computed over all gradients together, as if they
+        were concatenated into a single vector.
+        """
+        norm_type = float(norm_type)
+
         norms, all_norms = {}, []
         for name, p in self.named_parameters():
             if p.grad is None:

From 5bebf23497cc2ba092bb3c97263cab4871366a27 Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Mon, 1 Jun 2020 13:45:56 +0300
Subject: [PATCH 5/9] support for inf-norm

---
 pytorch_lightning/trainer/trainer.py       | 12 +++++++++---
 pytorch_lightning/trainer/training_loop.py |  2 +-
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 6239e66cd541f..32ef7f49d7051 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -101,7 +101,7 @@ def __init__(
             log_gpu_memory: Optional[str] = None,
             progress_bar_refresh_rate: int = 1,
             overfit_pct: float = 0.0,
-            track_grad_norm: int = -1,
+            track_grad_norm: Union[int, float, str] = -1,
             check_val_every_n_epoch: int = 1,
             fast_dev_run: bool = False,
             accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
@@ -205,7 +205,7 @@ def __init__(
 
             overfit_pct: How much of training-, validation-, and test dataset to check.
 
-            track_grad_norm: -1 no tracking. Otherwise tracks that norm
+            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
 
             check_val_every_n_epoch: Check val every n train epochs.
 
@@ -341,7 +341,13 @@ def __init__(
             self.gradient_clip = gradient_clip
 
         self.check_val_every_n_epoch = check_val_every_n_epoch
-        self.track_grad_norm = track_grad_norm
+
+        if not isinstance(track_grad_norm, (int, float)) \
+           and track_grad_norm != 'inf':
+            raise MisconfigurationException("track_grad_norm can be an int, a "
+                                            "float or 'inf' (infinity norm).")
+        self.track_grad_norm = float(track_grad_norm)
+
         self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
 
         # tpu config
diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py
index ff3ed0e4fec6a..1a829f37bbf42 100644
--- a/pytorch_lightning/trainer/training_loop.py
+++ b/pytorch_lightning/trainer/training_loop.py
@@ -625,7 +625,7 @@ def optimizer_closure():
 
                     # track gradient norms when requested
                     if batch_idx % self.row_log_interval == 0:
-                        if self.track_grad_norm > 0:
+                        if float(self.track_grad_norm) > 0:
                             model = self.get_model()
                             grad_norm_dic = model.grad_norm(
                                 self.track_grad_norm)

From 96a8dd0f2e840ae6e90c85c8e97b9e930abf7fcc Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Mon, 1 Jun 2020 13:48:37 +0300
Subject: [PATCH 6/9] renamed the grad norm test

---
 tests/models/test_grad_norm.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py
index 39116af2d6ef5..2db5603283609 100644
--- a/tests/models/test_grad_norm.py
+++ b/tests/models/test_grad_norm.py
@@ -43,7 +43,7 @@ def version(self):
 class ModelWithManualGradTracker(EvalModelTemplate):
     def __init__(self, norm_type, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.stored_grad_norms, self.norm_type = [], norm_type
+        self.stored_grad_norms, self.norm_type = [], float(norm_type)
 
     # validation spoils logger's metrics with `val_loss` records
     validation_step = None
@@ -76,8 +76,8 @@ def on_after_backward(self):
 
 
 @pytest.mark.parametrize("seed", [479_158_593])  # a vetted random number
-@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, float('inf')])
-def test_custom_logger(tmpdir, norm_type, seed, rtol=5e-3):
+@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
+def test_grad_tracking(tmpdir, norm_type, seed, rtol=5e-3):
     # rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
 
     seed_everything(seed)

From 1dd2ef359a8bfac2b271a2b3df58bcb2a32497a7 Mon Sep 17 00:00:00 2001
From: Jirka <jirka@pytorchlightning.ai>
Date: Mon, 1 Jun 2020 14:11:09 +0200
Subject: [PATCH 7/9] docs

---
 pytorch_lightning/core/grads.py | 24 +++++++++---------------
 1 file changed, 9 insertions(+), 15 deletions(-)

diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py
index d6bce643a7ff3..9c45af8a855e4 100644
--- a/pytorch_lightning/core/grads.py
+++ b/pytorch_lightning/core/grads.py
@@ -11,23 +11,17 @@ class GradInformation(torch.nn.Module):
     def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
         r"""Compute individual parameter's gradient norms and the overall norm.
 
-        Parameters
-        ----------
-        norm_type: float, int, str:
-            The type of the used p-norm, cast to float if necessary. Can be
-            ``'inf'`` for infinity norm.
-
-        Returns
-        -------
-        norms: dict
-            The dictionary of p-norms each individual gradient and the a
-            special entry for the total p-norm of the parameters' gradients
-            viewed as a single vector.
-
-        Details
-        -------
         The overall norm is computed over all gradients together, as if they
         were concatenated into a single vector.
+
+        Args:
+            norm_type: The type of the used p-norm, cast to float if necessary.
+                Can be ``'inf'`` for infinity norm.
+
+        Return:
+            norms: The dictionary of p-norms each individual gradient and the a
+                special entry for the total p-norm of the parameters' gradients
+                viewed as a single vector.
         """
         norm_type = float(norm_type)
 

From bfb074eb4e4a844e96f23a64e56caf6b15c7d587 Mon Sep 17 00:00:00 2001
From: ivannz <ivannz@yandex.ru>
Date: Mon, 1 Jun 2020 15:31:25 +0300
Subject: [PATCH 8/9] fixed language in docstring

---
 pytorch_lightning/core/grads.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py
index 9c45af8a855e4..cb2215002c7d8 100644
--- a/pytorch_lightning/core/grads.py
+++ b/pytorch_lightning/core/grads.py
@@ -9,7 +9,7 @@
 class GradInformation(torch.nn.Module):
 
     def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
-        r"""Compute individual parameter's gradient norms and the overall norm.
+        """Compute each parameter's gradient's norm and their overall norm.
 
         The overall norm is computed over all gradients together, as if they
         were concatenated into a single vector.
@@ -19,9 +19,9 @@ def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
                 Can be ``'inf'`` for infinity norm.
 
         Return:
-            norms: The dictionary of p-norms each individual gradient and the a
-                special entry for the total p-norm of the parameters' gradients
-                viewed as a single vector.
+            norms: The dictionary of p-norms of each parameter's gradient and
+                a special entry for the total p-norm of the gradients viewed
+                as a single vector.
         """
         norm_type = float(norm_type)
 

From ea374a3fa0e175cdcb0084efe60b2d6800f57037 Mon Sep 17 00:00:00 2001
From: Jirka Borovec <Borda@users.noreply.github.com>
Date: Mon, 1 Jun 2020 14:42:01 +0200
Subject: [PATCH 9/9] Apply suggestions from code review

---
 pytorch_lightning/trainer/trainer.py |  7 +++----
 tests/models/test_grad_norm.py       | 11 ++++++-----
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 32ef7f49d7051..4a42c3d7e1fa6 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -342,10 +342,9 @@ def __init__(
 
         self.check_val_every_n_epoch = check_val_every_n_epoch
 
-        if not isinstance(track_grad_norm, (int, float)) \
-           and track_grad_norm != 'inf':
-            raise MisconfigurationException("track_grad_norm can be an int, a "
-                                            "float or 'inf' (infinity norm).")
+        if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
+            raise MisconfigurationException(
+                "track_grad_norm can be an int, a float or 'inf' (infinity norm).")
         self.track_grad_norm = float(track_grad_norm)
 
         self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py
index 2db5603283609..9140eef1624b0 100644
--- a/tests/models/test_grad_norm.py
+++ b/tests/models/test_grad_norm.py
@@ -8,6 +8,7 @@
 from pytorch_lightning.utilities import rank_zero_only
 
 from tests.base import EvalModelTemplate
+from tests.base.utils import reset_seed
 
 
 class OnlyMetricsListLogger(LightningLoggerBase):
@@ -75,23 +76,23 @@ def on_after_backward(self):
         self.stored_grad_norms.append(out)
 
 
-@pytest.mark.parametrize("seed", [479_158_593])  # a vetted random number
 @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
-def test_grad_tracking(tmpdir, norm_type, seed, rtol=5e-3):
+def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
     # rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
 
-    seed_everything(seed)
+    reset_seed()
 
     # use a custom grad tracking module and a list logger
     model = ModelWithManualGradTracker(norm_type)
     logger = OnlyMetricsListLogger()
 
-    result = Trainer(
+    trainer = Trainer(
         max_epochs=3,
         logger=logger,
         track_grad_norm=norm_type,
         row_log_interval=1,  # request grad_norms every batch
-    ).fit(model)
+    )
+    result = trainer.fit(model)
 
     assert result == 1, "Training failed"
     assert len(logger.metrics) == len(model.stored_grad_norms)