diff --git a/CHANGELOG.md b/CHANGELOG.md index a6afe75a1e821..57143d63a7b9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enable `NeptuneLogger` to work with `distributed_backend=ddp` ([#1753](https://github.com/PyTorchLightning/pytorch-lightning/pull/1753)) +- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572)) + ### Changed - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) diff --git a/benchmarks/test_rnn_parity.py b/benchmarks/test_rnn_parity.py index b1aa96088a896..b535a9ddb5f65 100644 --- a/benchmarks/test_rnn_parity.py +++ b/benchmarks/test_rnn_parity.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset, DataLoader import tests.base.utils as tutils -from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning import Trainer, LightningModule, seed_everything class AverageDataset(Dataset): @@ -68,13 +68,6 @@ def test_pytorch_parity(tmpdir): tutils.assert_speed_parity(pl_times, pt_times, num_epochs) -def set_seed(seed): - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - - def vanilla_loop(MODEL, num_runs=10, num_epochs=10): """ Returns an array with the last loss from each epoch for each run @@ -83,12 +76,13 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10): errors = [] times = [] + torch.backends.cudnn.deterministic = True for i in range(num_runs): time_start = time.perf_counter() # set seed seed = i - set_seed(seed) + seed_everything(seed) # init model parts model = MODEL() @@ -134,10 +128,10 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10): # set seed seed = i - set_seed(seed) + seed_everything(seed) + model = MODEL() # init model parts - model = MODEL() trainer = Trainer( max_epochs=num_epochs, progress_bar_refresh_rate=0, @@ -146,6 +140,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10): early_stop_callback=False, checkpoint_callback=False, distributed_backend='dp', + deterministic=True, ) trainer.fit(model) diff --git a/benchmarks/test_trainer_parity.py b/benchmarks/test_trainer_parity.py index 4c0e89d107b7d..d97eb07c5c621 100644 --- a/benchmarks/test_trainer_parity.py +++ b/benchmarks/test_trainer_parity.py @@ -10,7 +10,7 @@ from torchvision import transforms import tests.base.utils as tutils -from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning import Trainer, LightningModule, seed_everything from tests.base.datasets import TrialMNIST @@ -69,13 +69,6 @@ def test_pytorch_parity(tmpdir): tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs) -def _set_seed(seed): - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - - def vanilla_loop(MODEL, num_runs=10, num_epochs=10): """ Returns an array with the last loss from each epoch for each run @@ -84,12 +77,13 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10): errors = [] times = [] + torch.backends.cudnn.deterministic = True for i in range(num_runs): time_start = time.perf_counter() # set seed seed = i - _set_seed(seed) + seed_everything(seed) # init model parts model = MODEL() @@ -135,17 +129,18 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10): # set seed seed = i - _set_seed(seed) + seed_everything(seed) - # init model parts model = MODEL() + # init model parts trainer = Trainer( max_epochs=num_epochs, progress_bar_refresh_rate=0, weights_summary=None, gpus=1, early_stop_callback=False, - checkpoint_callback=False + checkpoint_callback=False, + deterministic=True, ) trainer.fit(model) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index f305c06b59d9b..01ace195ddb41 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -52,6 +52,7 @@ else: from pytorch_lightning.core import LightningModule from pytorch_lightning.trainer import Trainer + from pytorch_lightning.trainer.seed import seed_everything from pytorch_lightning.callbacks import Callback from pytorch_lightning.core import data_loader @@ -60,6 +61,7 @@ 'LightningModule', 'Callback', 'data_loader' + 'seed_everything' ] # necessary for regular bolts imports. Skip exception since bolts is not always installed diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 80f82d917396f..98d2d8a7e8d7e 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -101,6 +101,24 @@ def forward(self, x): out = pretrained_model(x) api_write({'response': out} +------------ + +Reproducibility +--------------- + +To ensure full reproducibility from run to run you need to set seeds for pseudo-random generators, +and set ``deterministic``` flag in ``Trainer``. + +.. code-block:: python + + from pytorch-lightning import Trainer, seed_everything + + seed_everything(42) + # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. + model = Model() + trainer = Trainer(deterministic=True) + + ------- Trainer flags @@ -186,6 +204,21 @@ def forward(self, x): # default used by the Trainer trainer = Trainer(benchmark=False) +deterministic +^^^^^^^^^^^^^ + +If true enables cudnn.deterministic. +Might make your system slower, but ensures reproducibility. +Also sets ``$HOROVOD_FUSION_THRESHOLD=0``. + +For more info check `[pytorch docs] +`_. + +Example:: + + # default used by the Trainer + trainer = Trainer(deterministic=False) + callbacks ^^^^^^^^^ @@ -980,5 +1013,6 @@ def tbptt_split_batch(self, batch, split_size): """ from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.trainer.seed import seed_everything -__all__ = ['Trainer'] +__all__ = ['Trainer', 'seed_everything'] diff --git a/pytorch_lightning/trainer/seed.py b/pytorch_lightning/trainer/seed.py new file mode 100644 index 0000000000000..604f2f8984b90 --- /dev/null +++ b/pytorch_lightning/trainer/seed.py @@ -0,0 +1,42 @@ +"""Helper functions to help with reproducibility of models. """ + +import os +from typing import Optional + +import numpy as np +import random +import torch + +from pytorch_lightning import _logger as log + + +def seed_everything(seed: Optional[int] = None) -> int: + """Function that sets seed for pseudo-random number generators in: + pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable. + """ + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + + try: + seed = int(seed) + except (TypeError, ValueError): + seed = _select_seed_randomly(min_seed_value, max_seed_value) + + if (seed > max_seed_value) or (seed < min_seed_value): + log.warning( + f"{seed} is not in bounds, \ + numpy accepts from {min_seed_value} to {max_seed_value}" + ) + seed = _select_seed_randomly(min_seed_value, max_seed_value) + + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + return seed + + +def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: + seed = random.randint(min_seed_value, max_seed_value) + log.warning(f"No correct seed found, seed set to {seed}") + return seed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e01f1dbb497d1..d401a94645ca1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -14,6 +14,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler +from pytorch_lightning.trainer.seed import seed_everything from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -32,8 +33,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities import parsing +from pytorch_lightning.utilities import rank_zero_warn, parsing try: @@ -126,10 +126,12 @@ def __init__( resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, benchmark: bool = False, + deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, + terminate_on_nan: bool = False, auto_scale_batch_size: Optional[str] = None, amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 @@ -140,7 +142,6 @@ def __init__( use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 - terminate_on_nan: bool = False, **kwargs ): r""" @@ -293,6 +294,8 @@ def __init__( benchmark: If true enables cudnn.benchmark. + deterministic: If true enables cudnn.deterministic + terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. @@ -303,6 +306,13 @@ def __init__( a power search or `binsearch` that estimates the batch size through a binary search. """ + self.deterministic = deterministic + torch.backends.cudnn.deterministic = self.deterministic + if self.deterministic: + # fixing non-deterministic part of horovod + # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 + os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) + # Init callbacks self.callbacks = callbacks or [] self.on_init_start() diff --git a/tests/base/utils.py b/tests/base/utils.py index a193a92ff024f..973972d60bf57 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -5,7 +5,7 @@ import torch # from pl_examples import LightningTemplateModel -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS @@ -188,8 +188,7 @@ def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): def reset_seed(): seed = RANDOM_SEEDS.pop() - torch.manual_seed(seed) - np.random.seed(seed) + seed_everything(seed) def set_random_master_port(): diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 2f305b41e7281..6117bc8a0e264 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -39,6 +39,7 @@ def _nccl_available(): def _run_horovod(trainer_options, on_gpu=False): """Execute the training script across multiple workers in parallel.""" + tutils.reset_seed() cmdline = [ 'horovodrun', '-np', '2', @@ -62,7 +63,8 @@ def test_horovod_cpu(tmpdir): max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, - distributed_backend='horovod' + distributed_backend='horovod', + deterministic=True, ) _run_horovod(trainer_options) @@ -78,6 +80,7 @@ def test_horovod_cpu_implicit(tmpdir): max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, + deterministic=True, ) _run_horovod(trainer_options) @@ -96,6 +99,7 @@ def test_horovod_multi_gpu(tmpdir): train_percent_check=0.4, val_percent_check=0.2, gpus=1, + deterministic=True, distributed_backend='horovod' ) _run_horovod(trainer_options, on_gpu=True) @@ -130,6 +134,7 @@ def validation_step(self, batch, *args, **kwargs): train_percent_check=0.4, val_percent_check=0.2, gpus=1, + deterministic=True, distributed_backend='horovod' ) tutils.run_model_test_without_loggers(trainer_options, model) @@ -147,6 +152,7 @@ def test_horovod_multi_optimizer(tmpdir): max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, + deterministic=True, distributed_backend='horovod' )