Skip to content

Commit f35337a

Browse files
Fixes .test() for ddp (#2570)
* enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
1 parent b738126 commit f35337a

File tree

7 files changed

+87
-66
lines changed

7 files changed

+87
-66
lines changed

pytorch_lightning/trainer/distrib_data_parallel.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,15 @@ def train_fx(trial_hparams, cluster_manager, _):
122122
from time import sleep
123123
import numpy as np
124124
from os.path import abspath
125-
from torch import distributed as dist
126-
import queue
127125

128126
import torch
129127
from pytorch_lightning import _logger as log
130-
from pytorch_lightning.callbacks import ModelCheckpoint
131128
from pytorch_lightning.loggers import LightningLoggerBase
132129
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
133130
from pytorch_lightning.utilities.exceptions import MisconfigurationException
134131
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
132+
from pytorch_lightning.core.lightning import LightningModule
133+
135134

136135
try:
137136
from apex import amp
@@ -230,6 +229,10 @@ def save_checkpoint(self, *args):
230229
def setup(self, *args) -> None:
231230
"""Warning: this is just empty shell for code implemented in other class."""
232231

232+
@abstractmethod
233+
def get_model(self) -> LightningModule:
234+
"""Warning: this is just empty shell for code implemented in other class."""
235+
233236
@abstractmethod
234237
def is_function_implemented(self, *args) -> bool:
235238
"""Warning: this is just empty shell for code implemented in other class."""
@@ -556,17 +559,20 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
556559
# continue training routine
557560
results = self.run_pretrain_routine(model)
558561

562+
# get original model
563+
model = self.get_model()
564+
559565
# persist info in ddp_spawn
560-
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)
566+
self.transfer_ddp_spawn_state_on_fit_end(model, q, results)
561567

562568
# clean up memory
563569
torch.cuda.empty_cache()
564570

565571
if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
566572
return results
567573

568-
def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
569-
if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']:
574+
def transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
575+
if self.distributed_backend not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
570576
return
571577

572578
# track the best model path
@@ -581,8 +587,8 @@ def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
581587

582588
# save the last weights
583589
last_path = None
584-
if not self.testing:
585-
last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
590+
if not self.testing and best_model_path is not None and len(best_model_path) > 0:
591+
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
586592
torch.save(model.state_dict(), last_path)
587593
q.put(last_path)
588594

pytorch_lightning/trainer/distrib_parts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def tpu_train(self, tpu_core_idx, model):
222222
self.run_pretrain_routine(model)
223223

224224
# when training ends on these platforms dump weights to get out of the main process
225-
if self.on_colab_kaggle and not self.testing:
225+
if self.on_colab_kaggle:
226226
rank_zero_warn('cleaning up... please do not interrupt')
227227
self.save_spawn_weights(model)
228228

pytorch_lightning/trainer/trainer.py

+62-33
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,9 @@ def __init__(
396396
self.test_dataloaders = None
397397
self.val_dataloaders = None
398398

399+
# when .test() is called, it sets this
400+
self.tested_ckpt_path = None
401+
399402
# training state
400403
self.model = None
401404
self.testing = False
@@ -965,6 +968,10 @@ def fit(
965968

966969
self.ddp_train(process_idx=task, q=None, model=model)
967970
elif self.use_ddp:
971+
972+
# set testing if set in environ
973+
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
974+
968975
if self.is_slurm_managing_tasks:
969976
task = int(os.environ['SLURM_LOCALID'])
970977
self.ddp_train(process_idx=task, q=None, model=model)
@@ -1058,7 +1065,7 @@ def __run_ddp_spawn(self, model, nprocs):
10581065
smp = mp.get_context('spawn')
10591066
q = smp.SimpleQueue()
10601067

1061-
mp.spawn(self.ddp_train, nprocs=nprocs, args=(q, model,))
1068+
mp.spawn(self.ddp_train, nprocs=nprocs, args=(q, model, ))
10621069

10631070
# restore main state with best weights
10641071
best_path = q.get()
@@ -1070,7 +1077,8 @@ def __run_ddp_spawn(self, model, nprocs):
10701077

10711078
# load last weights
10721079
if last_path is not None and not self.testing:
1073-
torch.load(last_path, map_location=lambda storage, loc: storage)
1080+
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
1081+
model.load_state_dict(ckpt)
10741082

10751083
self.model = model
10761084
return results
@@ -1262,62 +1270,83 @@ def test(
12621270
# --------------------
12631271
# SETUP HOOK
12641272
# --------------------
1273+
if self.global_rank != 0:
1274+
return
1275+
12651276
self.setup('test')
1266-
model_ref = self.model if model is None else model
1267-
if self.is_function_implemented('setup', model_ref):
1268-
model_ref.setup('test')
1277+
1278+
if model is not None:
1279+
results = self.__test_given_model(model, test_dataloaders)
1280+
else:
1281+
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
1282+
1283+
self.teardown('test')
1284+
1285+
return results
1286+
1287+
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
1288+
model = self.get_model()
1289+
if self.is_function_implemented('setup', model):
1290+
model.setup('test')
12691291

12701292
# if user requests the best checkpoint but we don't have it, error
1271-
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
1293+
if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
12721294
raise MisconfigurationException(
12731295
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
12741296

1275-
# --------------------
1276-
# AUTO-LOAD BEST CKPT
1277-
# --------------------
1278-
# load the best checkpoint automatically unless model is given
1279-
# in which case we use that one
1280-
if model is None and ckpt_path is not None:
1297+
# load best weights
1298+
if ckpt_path is not None:
12811299
# ckpt_path is 'best' so load the best model
12821300
if ckpt_path == 'best':
12831301
ckpt_path = self.checkpoint_callback.best_model_path
1284-
model = self.get_model().load_from_checkpoint(ckpt_path)
12851302

1286-
# ----------------------------------------------------
1287-
# AUTO-LOAD BEST CKPT with the model trained in .fit()
1288-
# ----------------------------------------------------
1289-
elif model is None and ckpt_path is None:
1290-
model = model_ref
1303+
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
1304+
model.load_state_dict(ckpt['state_dict'])
12911305

1292-
# --------------------
1293-
# LOAD DATA
1294-
# --------------------
1306+
# attach dataloaders
12951307
if test_dataloaders is not None:
1296-
if model:
1297-
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
1298-
else:
1299-
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
1308+
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
13001309

1301-
# --------------------
1302-
# RUN TEST SET
1303-
# --------------------
1304-
# sets up testing so we short circuit to eval
1310+
# run tests
1311+
self.tested_ckpt_path = ckpt_path
13051312
self.set_random_port(force=True)
13061313
self.testing = True
1314+
os.environ['PL_TESTING_MODE'] = '1'
13071315
self.model = model
13081316
results = self.fit(model)
13091317
self.testing = False
1318+
del os.environ['PL_TESTING_MODE']
13101319

1311-
# --------------------
1312-
# TEAR DOWN HOOK
1313-
# --------------------
1314-
self.teardown('test')
1320+
# teardown
13151321
if self.is_function_implemented('teardown'):
13161322
model_ref = self.get_model()
13171323
model_ref.teardown('test')
13181324

13191325
return results
13201326

1327+
def __test_given_model(self, model, test_dataloaders):
1328+
# setup hook
1329+
if self.is_function_implemented('setup', model):
1330+
model.setup('test')
1331+
1332+
# attach data
1333+
if test_dataloaders is not None:
1334+
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
1335+
1336+
# run test
1337+
# sets up testing so we short circuit to eval
1338+
self.set_random_port(force=True)
1339+
self.testing = True
1340+
self.model = model
1341+
results = self.fit(model)
1342+
self.testing = False
1343+
1344+
# teardown
1345+
if self.is_function_implemented('teardown'):
1346+
model.teardown('test')
1347+
1348+
return results
1349+
13211350
def check_model_configuration(self, model: LightningModule):
13221351
r"""
13231352
Checks that the model is configured correctly before training or testing is started.

tests/loggers/test_wandb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_wandb_logger(wandb):
2525
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
2626
allow_val_change=True,
2727
)
28-
28+
2929
logger.watch('model', 'log', 10)
3030
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
3131

tests/models/data/horovod/train_default_model.py

-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def run_test_from_config(trainer_options):
8888
assert trainer.root_gpu == hvd.local_rank()
8989

9090

91-
9291
if __name__ == "__main__":
9392
args = parser.parse_args()
9493
run_test_from_config(json.loads(args.trainer_options))

tests/models/test_tpu.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,7 @@ def long_train_loader():
141141

142142

143143
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
144-
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [
145-
pytest.param([1], 'xla:1'),
146-
pytest.param([8], 'xla:8'),
147-
])
148-
def test_early_stop_checkpoints_on_tpu(tmpdir, tpu_cores, expected_device):
144+
def test_early_stop_checkpoints_on_tpu(tmpdir):
149145
"""Test if single TPU core training works"""
150146
model = EvalModelTemplate()
151147
trainer = Trainer(
@@ -155,10 +151,10 @@ def test_early_stop_checkpoints_on_tpu(tmpdir, tpu_cores, expected_device):
155151
max_epochs=50,
156152
limit_train_batches=10,
157153
limit_val_batches=10,
158-
tpu_cores=tpu_cores,
154+
tpu_cores=[1],
159155
)
160156
trainer.fit(model)
161-
assert torch_xla._XLAC._xla_get_default_device() == expected_device
157+
assert torch_xla._XLAC._xla_get_default_device() == 'xla:1'
162158

163159

164160
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@@ -172,10 +168,10 @@ def test_early_stop_checkpoints_on_tpu(tmpdir):
172168
max_epochs=50,
173169
limit_train_batches=10,
174170
limit_val_batches=10,
175-
tpu_cores=1,
171+
tpu_cores=[8],
176172
)
177173
trainer.fit(model)
178-
assert torch_xla._XLAC._xla_get_default_device() == 'xla:1'
174+
assert torch_xla._XLAC._xla_get_default_device() == 'xla:8'
179175

180176

181177
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")

tests/trainer/test_trainer.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -562,16 +562,7 @@ def test_testpass_overrides(tmpdir):
562562
def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
563563
hparams = EvalModelTemplate.get_default_hparams()
564564

565-
loaded_checkpoint_path = ''
566-
567-
class TestBestModel(EvalModelTemplate):
568-
@classmethod
569-
def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs):
570-
nonlocal loaded_checkpoint_path
571-
loaded_checkpoint_path = checkpoint_path
572-
return super().load_from_checkpoint(checkpoint_path, *args, **kwargs)
573-
574-
model = TestBestModel(**hparams)
565+
model = EvalModelTemplate(**hparams)
575566
trainer = Trainer(
576567
max_epochs=2,
577568
progress_bar_refresh_rate=0,
@@ -586,12 +577,12 @@ def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs):
586577
trainer.test(ckpt_path=ckpt_path)
587578
else:
588579
trainer.test(ckpt_path=ckpt_path)
589-
assert loaded_checkpoint_path == trainer.checkpoint_callback.best_model_path
580+
assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
590581
elif ckpt_path is None:
591582
# ckpt_path is None, meaning we don't load any checkpoints and
592583
# use the weights from the end of training
593584
trainer.test(ckpt_path=ckpt_path)
594-
assert loaded_checkpoint_path == ''
585+
assert trainer.tested_ckpt_path is None
595586
else:
596587
# specific checkpoint, pick one from saved ones
597588
if save_top_k == 0:
@@ -600,7 +591,7 @@ def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs):
600591
else:
601592
ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0].absolute())
602593
trainer.test(ckpt_path=ckpt_path)
603-
assert loaded_checkpoint_path == ckpt_path
594+
assert trainer.tested_ckpt_path == ckpt_path
604595

605596

606597
def test_disabled_validation(tmpdir):

0 commit comments

Comments
 (0)