Skip to content

Commit d9b5962

Browse files
Adrian WälchliBorda
authored andcommitted
nan detection and intervention (Lightning-AI#1097)
* check for nan values * test nan detection on loss * sys.exit * whitespace * detect nan and inf values in loss and params * update * added documentation * moved detect nan to training loop, remove flag for print * blank line * test * rename * deprecate print_nan_grads * deprecated print_nan_grads * remove unused imports * update changelog * fix line too long * correct deprecated version Co-Authored-By: Jirka Borovec <[email protected]> * raise exception instead of sysexit Co-Authored-By: Jirka Borovec <[email protected]> * raise exception instead of sysexit Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec <[email protected]> * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec <[email protected]> * fix test Co-authored-by: Jirka Borovec <[email protected]>
1 parent ac588c2 commit d9b5962

File tree

5 files changed

+110
-16
lines changed

5 files changed

+110
-16
lines changed

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
1414
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
1515
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
16-
16+
- Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
1717

1818
### Changed
1919

2020
-
2121

2222
### Deprecated
2323

24-
-
24+
- Deprecated Trainer argument `print_nan_grads` ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
2525

2626
### Removed
2727

pytorch_lightning/trainer/trainer.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
distributed_backend: Optional[str] = None,
110110
use_amp=False, # backward compatible, todo: remove in v0.9.0
111111
precision: int = 32,
112-
print_nan_grads: bool = False,
112+
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
113113
weights_summary: str = 'full',
114114
weights_save_path: Optional[str] = None,
115115
amp_level: str = 'O1',
@@ -208,7 +208,10 @@ def __init__(
208208
209209
precision: Full precision (32), half precision (16).
210210
211-
print_nan_grads: Prints gradients with nan values
211+
print_nan_grads:
212+
.. warning:: .. deprecated:: 0.7.2
213+
Has no effect. When detected, NaN grads will be printed automatically.
214+
Will remove 0.9.0.
212215
213216
weights_summary: Prints a summary of the weights when training begins.
214217
@@ -296,7 +299,13 @@ def __init__(
296299
"`num_sanity_val_steps` since v0.5.0"
297300
" and this method will be removed in v0.8.0", DeprecationWarning)
298301
self.nb_sanity_val_steps = nb_sanity_val_steps
299-
self.print_nan_grads = print_nan_grads
302+
303+
# Backward compatibility, TODO: remove in v0.9.0
304+
if print_nan_grads:
305+
warnings.warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
306+
" NaN grads will be printed automatically when detected.",
307+
DeprecationWarning)
308+
300309
self.truncated_bptt_steps = truncated_bptt_steps
301310
self.resume_from_checkpoint = resume_from_checkpoint
302311
self.shown_warnings = set()

pytorch_lightning/trainer/training_loop.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,17 @@ def training_step(self, batch, batch_idx):
119119
trainer = Trainer(truncated_bptt_steps=2)
120120
121121
122+
NaN detection and intervention
123+
------------------------------
124+
In every forward pass in training, Lightning will check that
125+
126+
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
127+
2. the model parameters have finite values.
128+
129+
Lightning will terminate the training loop with an error message if NaN or infinite
130+
values are detected. If this happens, you should investigate numerically unstable operations
131+
in your model.
132+
122133
"""
123134

124135
import copy
@@ -187,7 +198,6 @@ class TrainerTrainLoopMixin(ABC):
187198
optimizers: ...
188199
accumulate_grad_batches: int
189200
use_amp: bool
190-
print_nan_grads: ...
191201
track_grad_norm: ...
192202
model: LightningModule
193203
running_loss: ...
@@ -200,7 +210,7 @@ class TrainerTrainLoopMixin(ABC):
200210
reload_dataloaders_every_epoch: bool
201211
progress_bar_refresh_rate: ...
202212
max_steps: int
203-
max_steps: int
213+
min_steps: int
204214
total_batch_idx: int
205215
checkpoint_callback: ...
206216

@@ -239,7 +249,7 @@ def clip_gradients(self):
239249
"""Warning: this is just empty shell for code implemented in other class."""
240250

241251
@abstractmethod
242-
def print_nan_gradients(self):
252+
def detect_nan_tensors(self, *args):
243253
"""Warning: this is just empty shell for code implemented in other class."""
244254

245255
@abstractmethod
@@ -556,9 +566,8 @@ def optimizer_closure():
556566
# calculate loss
557567
loss = optimizer_closure()
558568

559-
# nan grads
560-
if self.print_nan_grads:
561-
self.print_nan_gradients()
569+
# check if loss or model weights are nan
570+
self.detect_nan_tensors(loss)
562571

563572
# track total loss for logging (avoid mem leaks)
564573
self.batch_loss_value += loss.item()

pytorch_lightning/trainer/training_tricks.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import math
2+
import sys
23
from abc import ABC, abstractmethod
34

45
import torch
6+
from torch import Tensor
57

68
from pytorch_lightning import _logger as log
79
from pytorch_lightning.callbacks import GradientAccumulationScheduler
@@ -15,6 +17,7 @@ class TrainerTrainingTricksMixin(ABC):
1517
# this is just a summary on variables used in this abstract class,
1618
# the proper values/initialisation should be done in child class
1719
gradient_clip_val: ...
20+
precision: ...
1821

1922
@abstractmethod
2023
def get_model(self):
@@ -45,12 +48,29 @@ def clip_gradients(self):
4548
for p in parameters:
4649
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
4750

48-
def print_nan_gradients(self):
51+
def print_nan_gradients(self) -> None:
4952
model = self.get_model()
5053
for param in model.parameters():
5154
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
5255
log.info(param, param.grad)
5356

57+
def detect_nan_tensors(self, loss: Tensor) -> None:
58+
model = self.get_model()
59+
60+
# check if loss is nan
61+
if not torch.isfinite(loss).all():
62+
raise ValueError(
63+
'The loss returned in `training_step` is nan or inf.'
64+
)
65+
# check if a network weight is nan
66+
for name, param in model.named_parameters():
67+
if not torch.isfinite(param).all():
68+
self.print_nan_gradients()
69+
raise ValueError(
70+
f'Detected nan and/or inf values in `{name}`.'
71+
' Check your forward pass for numerically unstable operations.'
72+
)
73+
5474
def configure_accumulated_gradients(self, accumulate_grad_batches):
5575
if isinstance(accumulate_grad_batches, dict):
5676
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)

tests/test_cpu_models.py

+60-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import math
12
import warnings
23

4+
import pytest
35
import torch
46

57
import tests.models.utils as tutils
@@ -26,7 +28,6 @@ def test_early_stopping_cpu_model(tmpdir):
2628
gradient_clip_val=1.0,
2729
overfit_pct=0.20,
2830
track_grad_norm=2,
29-
print_nan_grads=True,
3031
show_progress_bar=True,
3132
logger=tutils.get_test_tube_logger(tmpdir),
3233
train_percent_check=0.1,
@@ -48,7 +49,6 @@ def test_lbfgs_cpu_model(tmpdir):
4849
trainer_options = dict(
4950
default_save_path=tmpdir,
5051
max_epochs=2,
51-
print_nan_grads=True,
5252
show_progress_bar=False,
5353
weights_summary='top',
5454
train_percent_check=1.0,
@@ -68,7 +68,6 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
6868
max_epochs=1,
6969
gradient_clip_val=1.0,
7070
overfit_pct=0.20,
71-
print_nan_grads=True,
7271
show_progress_bar=False,
7372
train_percent_check=0.01,
7473
val_percent_check=0.01,
@@ -251,7 +250,6 @@ def test_all_features_cpu_model(tmpdir):
251250
gradient_clip_val=1.0,
252251
overfit_pct=0.20,
253252
track_grad_norm=2,
254-
print_nan_grads=True,
255253
show_progress_bar=False,
256254
logger=tutils.get_test_tube_logger(tmpdir),
257255
accumulate_grad_batches=2,
@@ -359,5 +357,63 @@ def test_single_gpu_model(tmpdir):
359357
tutils.run_model_test(trainer_options, model)
360358

361359

360+
def test_nan_loss_detection(tmpdir):
361+
test_step = 8
362+
363+
class InfLossModel(LightTrainDataloader, TestModelBase):
364+
365+
def training_step(self, batch, batch_idx):
366+
output = super().training_step(batch, batch_idx)
367+
if batch_idx == test_step:
368+
if isinstance(output, dict):
369+
output['loss'] *= torch.tensor(math.inf) # make loss infinite
370+
else:
371+
output /= 0
372+
return output
373+
374+
hparams = tutils.get_hparams()
375+
model = InfLossModel(hparams)
376+
377+
# fit model
378+
trainer = Trainer(
379+
default_save_path=tmpdir,
380+
max_steps=(test_step + 1),
381+
)
382+
383+
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
384+
trainer.fit(model)
385+
assert trainer.global_step == test_step
386+
387+
for param in model.parameters():
388+
assert torch.isfinite(param).all()
389+
390+
391+
def test_nan_params_detection(tmpdir):
392+
test_step = 8
393+
394+
class NanParamModel(LightTrainDataloader, TestModelBase):
395+
396+
def on_after_backward(self):
397+
if self.global_step == test_step:
398+
# simulate parameter that became nan
399+
torch.nn.init.constant_(self.c_d1.bias, math.nan)
400+
401+
hparams = tutils.get_hparams()
402+
403+
model = NanParamModel(hparams)
404+
trainer = Trainer(
405+
default_save_path=tmpdir,
406+
max_steps=(test_step + 1),
407+
)
408+
409+
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
410+
trainer.fit(model)
411+
assert trainer.global_step == test_step
412+
413+
# after aborting the training loop, model still has nan-valued params
414+
params = torch.cat([param.view(-1) for param in model.parameters()])
415+
assert not torch.isfinite(params).all()
416+
417+
362418
# if __name__ == '__main__':
363419
# pytest.main([__file__])

0 commit comments

Comments
 (0)