Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add auto_restore option to disable auto loading #972

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
auto_restore: bool = True,
):
r"""

Expand Down Expand Up @@ -610,6 +611,8 @@ def on_train_end(self):
algorithm for the hardware `[see discussion here]
<https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_.

auto_restore (bool): If true will restore training and model.

.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:

- `nb_sanity_val_steps`
Expand All @@ -620,6 +623,9 @@ def on_train_end(self):
self.callbacks = callbacks
self.on_init_start(self)

# auto restore option
self.auto_restore = auto_restore

# benchmarking
self.benchmark = benchmark
if benchmark:
Expand Down Expand Up @@ -1104,7 +1110,8 @@ def run_pretrain_routine(self, model: LightningModule):
self.model = model

# restore training and model before hpc call
self.restore_weights(model)
if self.auto_restore:
self.restore_weights(model)

# download the data and do whatever transforms we need
self.call_prepare_data(ref_model)
Expand Down
55 changes: 55 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,5 +713,60 @@ def on_test_end(self, trainer, pl_module):
assert test_callback.on_test_start_called
assert test_callback.on_test_end_called


def test_auto_restore(tmpdir):
"""Verify auto restore option."""
import types
tutils.reset_seed()
hparams = tutils.get_hparams()

def new_model():
# Create a model that tracks epochs and batches seen
model = LightningTestModel(hparams)
model.num_epochs_seen = 0
model.num_batches_seen = 0

def increment_epoch(self):
self.num_epochs_seen += 1

def increment_batch(self, _):
self.num_batches_seen += 1

# Bind the increment_epoch function on_epoch_end so that the
# model keeps track of the number of epochs it has seen.
model.on_epoch_end = types.MethodType(increment_epoch, model)
model.on_batch_start = types.MethodType(increment_batch, model)
return model

model = new_model()

trainer_options = dict(
show_progress_bar=False,
max_epochs=2,
train_percent_check=0.65,
val_percent_check=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
logger=False,
default_save_path=tmpdir,
early_stop_callback=False,
val_check_interval=0.5,
)

# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model)

training_batches = trainer.num_training_batches

assert model.num_epochs_seen == 2
assert model.num_batches_seen == training_batches * 2

next_model = new_model()
# Resume training with auto_restore=False
trainer_options['max_epochs'] = 1
new_trainer = Trainer(**trainer_options, auto_restore=False)
new_trainer.fit(next_model)
assert next_model.num_epochs_seen == 1

# if __name__ == '__main__':
# pytest.main([__file__])