Skip to content

Commit cb0c6ad

Browse files
fix setup call while testing (#2624)
* fix setup call while testing * changelog * drop if condition * add test to check setup call * flake8 * update test to check model stage Co-authored-by: William Falcon <[email protected]>
1 parent 8599b67 commit cb0c6ad

File tree

5 files changed

+47
-16
lines changed

5 files changed

+47
-16
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28+
- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624))
29+
2830
- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626))
2931

3032
- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

pytorch_lightning/trainer/distrib_data_parallel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
509509
model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks)
510510

511511
# call setup after the ddp process has connected
512-
self.setup('fit')
513-
if self.is_function_implemented('setup', model):
512+
if not self.testing:
513+
self.setup('fit')
514514
model.setup('fit')
515515

516516
# on world_size=0 let everyone know training is starting

pytorch_lightning/trainer/distrib_parts.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
167167

168168
def single_gpu_train(self, model):
169169
# call setup
170-
self.setup('fit')
171-
if self.is_function_implemented('setup', model):
170+
if not self.testing:
171+
self.setup('fit')
172172
model.setup('fit')
173173

174174
model.cuda(self.root_gpu)
@@ -189,8 +189,8 @@ def single_gpu_train(self, model):
189189

190190
def tpu_train(self, tpu_core_idx, model):
191191
# call setup after the ddp process has connected
192-
self.setup('fit')
193-
if self.is_function_implemented('setup', model):
192+
if not self.testing:
193+
self.setup('fit')
194194
model.setup('fit')
195195

196196
# put model on tpu
@@ -229,8 +229,8 @@ def tpu_train(self, tpu_core_idx, model):
229229

230230
def dp_train(self, model):
231231
# call setup after the ddp process has connected
232-
self.setup('fit')
233-
if self.is_function_implemented('setup', model):
232+
if not self.testing:
233+
self.setup('fit')
234234
model.setup('fit')
235235

236236
model.cuda(self.root_gpu)
@@ -275,8 +275,8 @@ def dp_train(self, model):
275275

276276
def horovod_train(self, model):
277277
# call setup after the ddp process has connected
278-
self.setup('fit')
279-
if self.is_function_implemented('setup', model):
278+
if not self.testing:
279+
self.setup('fit')
280280
model.setup('fit')
281281

282282
if torch.cuda.is_available() and self.on_gpu:

pytorch_lightning/trainer/trainer.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1087,8 +1087,8 @@ def fit(
10871087
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
10881088

10891089
# call setup after the ddp process has connected
1090-
self.setup('fit')
1091-
if self.is_function_implemented('setup', model):
1090+
if not self.testing:
1091+
self.setup('fit')
10921092
model.setup('fit')
10931093

10941094
# CHOOSE OPTIMIZER
@@ -1381,8 +1381,7 @@ def test(
13811381

13821382
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
13831383
model = self.get_model()
1384-
if self.is_function_implemented('setup', model):
1385-
model.setup('test')
1384+
model.setup('test')
13861385

13871386
# if user requests the best checkpoint but we don't have it, error
13881387
if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
@@ -1429,8 +1428,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
14291428

14301429
def __test_given_model(self, model, test_dataloaders):
14311430
# setup hook
1432-
if self.is_function_implemented('setup', model):
1433-
model.setup('test')
1431+
model.setup('test')
14341432

14351433
# attach data
14361434
if test_dataloaders is not None:

tests/trainer/test_trainer.py

+31
Original file line numberDiff line numberDiff line change
@@ -980,3 +980,34 @@ def test_trainer_pickle(tmpdir):
980980
)
981981
pickle.dumps(trainer)
982982
cloudpickle.dumps(trainer)
983+
984+
985+
def test_trainer_setup_call(tmpdir):
986+
"""Test setup call with fit and test call."""
987+
988+
class CurrentModel(EvalModelTemplate):
989+
990+
def setup(self, stage):
991+
self.stage = stage
992+
993+
class TrainerSubclass(Trainer):
994+
995+
def setup(self, stage):
996+
self.stage = stage
997+
998+
model = CurrentModel()
999+
1000+
# fit model
1001+
trainer = TrainerSubclass(
1002+
default_root_dir=tmpdir,
1003+
max_epochs=1,
1004+
checkpoint_callback=False
1005+
)
1006+
1007+
trainer.fit(model)
1008+
assert trainer.stage == 'fit'
1009+
assert trainer.get_model().stage == 'fit'
1010+
1011+
trainer.test(ckpt_path=None)
1012+
assert trainer.stage == 'test'
1013+
assert trainer.get_model().stage == 'test'

0 commit comments

Comments
 (0)