From 89aa2085ec507df83cccc5c3882dc89e892178e2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 9 Jul 2020 07:21:08 -0400 Subject: [PATCH 1/2] enable none checkpoint --- pytorch_lightning/trainer/distrib_parts.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 78bc22d21589d..8d19655175f8e 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -222,7 +222,7 @@ def tpu_train(self, tpu_core_idx, model): self.run_pretrain_routine(model) # when training ends on these platforms dump weights to get out of the main process - if self.on_colab_kaggle: + if self.on_colab_kaggle and not self.testing: rank_zero_warn('cleaning up... please do not interrupt') self.save_spawn_weights(model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index eec21752912b0..023a7d4a57557 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1014,7 +1014,7 @@ def fit( xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method) # load weights if not interrupted - if self.on_colab_kaggle: + if self.on_colab_kaggle and not self.testing: self.load_spawn_weights(model) self.model = model From d117fe18e09bc63b0c57980c771f52f52dde4f49 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 9 Jul 2020 11:31:50 -0400 Subject: [PATCH 2/2] enable none checkpoint --- tests/callbacks/test_model_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index b5cb7ca0c756e..1091a4cf3a8dd 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -15,7 +15,9 @@ @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ + """ + Test that None in checkpoint callback is valid and that chkp_path is set correctly + """ tutils.reset_seed() model = EvalModelTemplate()