Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to provide seed to random generators to ensure reproducibility #1572

Merged
merged 34 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
301098a
Option to provide seed to random generators to ensure reproducibility
kumuji Apr 23, 2020
e67d039
Apply recommendations from core contributors on seeding
kumuji Apr 24, 2020
b1cf263
Revert "Apply recommendations from core contributors on seeding"
kumuji Apr 24, 2020
957e0a9
Revert "Revert "Apply recommendations from core contributors on seedi…
kumuji Apr 24, 2020
f6edd60
Change in test, for correct seeding
kumuji Apr 24, 2020
4eae18f
Allow seed equal to 0
kumuji Apr 24, 2020
f09a706
Allow seed to be uint32.max
kumuji Apr 24, 2020
9720559
Added deterministic to benchmarks
kumuji Apr 24, 2020
71c8470
Cuda manual seed as in benchmark seeding
kumuji Apr 24, 2020
d40f2dd
Seeding should be done before model initialization
kumuji Apr 24, 2020
fd73061
cuda manual_seed is not necessary
kumuji Apr 24, 2020
6da087b
Fixing seed test_cpu_lbfgs
kumuji Apr 29, 2020
aaf6dfe
rebasing issue with old reproducibility.py
kumuji Apr 29, 2020
2907659
Improved documentation and ability to seed before initializing Train
kumuji Apr 30, 2020
16affce
Change in docs
kumuji May 2, 2020
2bf31c3
Removed seed from trainer, update for documentation
kumuji May 2, 2020
f1452a4
Typo in the docs
kumuji May 2, 2020
6a4b8cd
Added seed_everything to _all_
kumuji May 4, 2020
d02b0e4
Fixing old changes
kumuji May 5, 2020
374fed2
Model initialization should be earlier then Trainer
kumuji May 5, 2020
521df98
Update pytorch_lightning/trainer/__init__.py
kumuji May 5, 2020
460fc87
Fixing according to the contributors suggestions
kumuji May 7, 2020
12adf89
Moving horovod deterministic to Trainer class
kumuji May 7, 2020
a10f4dc
deterministic flag affects horovod docs update
kumuji May 7, 2020
40e22ec
Improved static typing
kumuji May 7, 2020
7cfa554
Added deterministic to test runners of horovod
kumuji May 7, 2020
bac5b63
static seeds for horovod tests
kumuji May 7, 2020
6b4d098
Change for reset_seed function in tests
kumuji May 7, 2020
d93d6f9
Seeding horovod using reset_seed from tutils
kumuji May 7, 2020
63c034a
Update pytorch_lightning/trainer/__init__.py
Borda May 9, 2020
c652311
chlog
Borda May 9, 2020
3110210
Update trainer.py
williamFalcon May 10, 2020
07c9c2f
change "testcode" to "Example" in trainer init documentation
kumuji May 10, 2020
2c97d5b
Update pytorch_lightning/trainer/seed.py, first line in comment
kumuji May 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Enable `NeptuneLogger` to work with `distributed_backend=ddp` ([#1753](https://github.com/PyTorchLightning/pytorch-lightning/pull/1753))

- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572))

### Changed

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
Expand Down
17 changes: 6 additions & 11 deletions benchmarks/test_rnn_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import Dataset, DataLoader
import tests.base.utils as tutils

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning import Trainer, LightningModule, seed_everything


class AverageDataset(Dataset):
Expand Down Expand Up @@ -68,13 +68,6 @@ def test_pytorch_parity(tmpdir):
tutils.assert_speed_parity(pl_times, pt_times, num_epochs)


def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)


def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
"""
Returns an array with the last loss from each epoch for each run
Expand All @@ -83,12 +76,13 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
errors = []
times = []

torch.backends.cudnn.deterministic = True
for i in range(num_runs):
time_start = time.perf_counter()

# set seed
seed = i
set_seed(seed)
seed_everything(seed)

# init model parts
model = MODEL()
Expand Down Expand Up @@ -134,10 +128,10 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):

# set seed
seed = i
set_seed(seed)
seed_everything(seed)
model = MODEL()

# init model parts
model = MODEL()
trainer = Trainer(
max_epochs=num_epochs,
progress_bar_refresh_rate=0,
Expand All @@ -146,6 +140,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
early_stop_callback=False,
checkpoint_callback=False,
distributed_backend='dp',
deterministic=True,
)
trainer.fit(model)

Expand Down
19 changes: 7 additions & 12 deletions benchmarks/test_trainer_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision import transforms
import tests.base.utils as tutils

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning import Trainer, LightningModule, seed_everything
from tests.base.datasets import TrialMNIST


Expand Down Expand Up @@ -69,13 +69,6 @@ def test_pytorch_parity(tmpdir):
tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs)


def _set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)


def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
"""
Returns an array with the last loss from each epoch for each run
Expand All @@ -84,12 +77,13 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
errors = []
times = []

torch.backends.cudnn.deterministic = True
for i in range(num_runs):
time_start = time.perf_counter()

# set seed
seed = i
_set_seed(seed)
seed_everything(seed)

# init model parts
model = MODEL()
Expand Down Expand Up @@ -135,17 +129,18 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):

# set seed
seed = i
_set_seed(seed)
seed_everything(seed)

# init model parts
model = MODEL()
# init model parts
trainer = Trainer(
max_epochs=num_epochs,
progress_bar_refresh_rate=0,
weights_summary=None,
gpus=1,
early_stop_callback=False,
checkpoint_callback=False
checkpoint_callback=False,
deterministic=True,
)
trainer.fit(model)

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
else:
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.seed import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core import data_loader

Expand All @@ -60,6 +61,7 @@
'LightningModule',
'Callback',
'data_loader'
'seed_everything'
]

# necessary for regular bolts imports. Skip exception since bolts is not always installed
Expand Down
36 changes: 35 additions & 1 deletion pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,24 @@ def forward(self, x):
out = pretrained_model(x)
api_write({'response': out}

------------

Reproducibility
---------------

To ensure full reproducibility from run to run you need to set seeds for pseudo-random generators,
and set ``deterministic``` flag in ``Trainer``.

.. code-block:: python

from pytorch-lightning import Trainer, seed_everything

seed_everything(42)
# sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
model = Model()
trainer = Trainer(deterministic=True)


-------

Trainer flags
Expand Down Expand Up @@ -186,6 +204,21 @@ def forward(self, x):
# default used by the Trainer
trainer = Trainer(benchmark=False)

deterministic
^^^^^^^^^^^^^

If true enables cudnn.deterministic.
Might make your system slower, but ensures reproducibility.
Also sets ``$HOROVOD_FUSION_THRESHOLD=0``.

For more info check `[pytorch docs]
<https://pytorch.org/docs/stable/notes/randomness.html>`_.

Example::

# default used by the Trainer
trainer = Trainer(deterministic=False)

callbacks
^^^^^^^^^

Expand Down Expand Up @@ -980,5 +1013,6 @@ def tbptt_split_batch(self, batch, split_size):
"""

from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.trainer.seed import seed_everything

__all__ = ['Trainer']
__all__ = ['Trainer', 'seed_everything']
42 changes: 42 additions & 0 deletions pytorch_lightning/trainer/seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Helper functions to help with reproducibility of models. """

import os
from typing import Optional

import numpy as np
import random
import torch

from pytorch_lightning import _logger as log


def seed_everything(seed: Optional[int] = None) -> int:
"""Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
"""
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min

try:
seed = int(seed)
except (TypeError, ValueError):
seed = _select_seed_randomly(min_seed_value, max_seed_value)

if (seed > max_seed_value) or (seed < min_seed_value):
log.warning(
f"{seed} is not in bounds, \
numpy accepts from {min_seed_value} to {max_seed_value}"
)
seed = _select_seed_randomly(min_seed_value, max_seed_value)

os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return seed


def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
seed = random.randint(min_seed_value, max_seed_value)
log.warning(f"No correct seed found, seed set to {seed}")
return seed
16 changes: 13 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.trainer.seed import seed_everything
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
Expand All @@ -32,8 +33,7 @@
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import parsing
from pytorch_lightning.utilities import rank_zero_warn, parsing


try:
Expand Down Expand Up @@ -126,10 +126,12 @@ def __init__(
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Optional[str] = None,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
Expand All @@ -140,7 +142,6 @@ def __init__(
use_amp=None, # backward compatible, todo: remove in v0.9.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
terminate_on_nan: bool = False,
**kwargs
):
r"""
Expand Down Expand Up @@ -293,6 +294,8 @@ def __init__(

benchmark: If true enables cudnn.benchmark.

deterministic: If true enables cudnn.deterministic

terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

Expand All @@ -303,6 +306,13 @@ def __init__(
a power search or `binsearch` that estimates the batch size through a binary search.
"""

self.deterministic = deterministic
torch.backends.cudnn.deterministic = self.deterministic
if self.deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# Init callbacks
self.callbacks = callbacks or []
self.on_init_start()
Expand Down
5 changes: 2 additions & 3 deletions tests/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

# from pl_examples import LightningTemplateModel
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS
Expand Down Expand Up @@ -188,8 +188,7 @@ def assert_ok_model_acc(trainer, key='test_acc', thr=0.5):

def reset_seed():
seed = RANDOM_SEEDS.pop()
torch.manual_seed(seed)
np.random.seed(seed)
seed_everything(seed)


def set_random_master_port():
Expand Down
8 changes: 7 additions & 1 deletion tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _nccl_available():

def _run_horovod(trainer_options, on_gpu=False):
"""Execute the training script across multiple workers in parallel."""
tutils.reset_seed()
cmdline = [
'horovodrun',
'-np', '2',
Expand All @@ -62,7 +63,8 @@ def test_horovod_cpu(tmpdir):
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
distributed_backend='horovod'
distributed_backend='horovod',
deterministic=True,
)
_run_horovod(trainer_options)

Expand All @@ -78,6 +80,7 @@ def test_horovod_cpu_implicit(tmpdir):
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
deterministic=True,
)
_run_horovod(trainer_options)

Expand All @@ -96,6 +99,7 @@ def test_horovod_multi_gpu(tmpdir):
train_percent_check=0.4,
val_percent_check=0.2,
gpus=1,
deterministic=True,
distributed_backend='horovod'
)
_run_horovod(trainer_options, on_gpu=True)
Expand Down Expand Up @@ -130,6 +134,7 @@ def validation_step(self, batch, *args, **kwargs):
train_percent_check=0.4,
val_percent_check=0.2,
gpus=1,
deterministic=True,
distributed_backend='horovod'
)
tutils.run_model_test_without_loggers(trainer_options, model)
Expand All @@ -147,6 +152,7 @@ def test_horovod_multi_optimizer(tmpdir):
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
deterministic=True,
distributed_backend='horovod'
)

Expand Down