|
2 | 2 | import logging as log
|
3 | 3 | import os
|
4 | 4 | import pickle
|
| 5 | +import functools |
5 | 6 |
|
6 | 7 | import cloudpickle
|
7 | 8 | import pytest
|
@@ -319,6 +320,105 @@ def test_model_saving_loading(tmpdir):
|
319 | 320 | assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
|
320 | 321 |
|
321 | 322 |
|
| 323 | +@pytest.mark.parametrize('url_ckpt', [True, False]) |
| 324 | +def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): |
| 325 | + """Tests use case where trainer saves the model, and user loads it from tags independently.""" |
| 326 | + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir |
| 327 | + monkeypatch.setenv('TORCH_HOME', tmpdir) |
| 328 | + |
| 329 | + model = EvalModelTemplate() |
| 330 | + # Extra layer |
| 331 | + model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim) |
| 332 | + |
| 333 | + # logger file to get meta |
| 334 | + logger = tutils.get_default_logger(tmpdir) |
| 335 | + |
| 336 | + # fit model |
| 337 | + trainer = Trainer( |
| 338 | + default_root_dir=tmpdir, |
| 339 | + max_epochs=1, |
| 340 | + logger=logger, |
| 341 | + checkpoint_callback=ModelCheckpoint(tmpdir), |
| 342 | + ) |
| 343 | + result = trainer.fit(model) |
| 344 | + |
| 345 | + # traning complete |
| 346 | + assert result == 1 |
| 347 | + |
| 348 | + # save model |
| 349 | + new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') |
| 350 | + trainer.save_checkpoint(new_weights_path) |
| 351 | + |
| 352 | + # load new model |
| 353 | + hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml') |
| 354 | + hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' |
| 355 | + ckpt_path = hparams_url if url_ckpt else new_weights_path |
| 356 | + |
| 357 | + EvalModelTemplate.load_from_checkpoint( |
| 358 | + checkpoint_path=ckpt_path, |
| 359 | + hparams_file=hparams_path, |
| 360 | + strict=False, |
| 361 | + ) |
| 362 | + |
| 363 | + with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): |
| 364 | + EvalModelTemplate.load_from_checkpoint( |
| 365 | + checkpoint_path=ckpt_path, |
| 366 | + hparams_file=hparams_path, |
| 367 | + strict=True, |
| 368 | + ) |
| 369 | + |
| 370 | + |
| 371 | +@pytest.mark.parametrize('url_ckpt', [True, False]) |
| 372 | +def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): |
| 373 | + """Tests use case where trainer saves the model, and user loads it from tags independently.""" |
| 374 | + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir |
| 375 | + monkeypatch.setenv('TORCH_HOME', tmpdir) |
| 376 | + |
| 377 | + model = EvalModelTemplate() |
| 378 | + |
| 379 | + # logger file to get meta |
| 380 | + logger = tutils.get_default_logger(tmpdir) |
| 381 | + |
| 382 | + # fit model |
| 383 | + trainer = Trainer( |
| 384 | + default_root_dir=tmpdir, |
| 385 | + max_epochs=1, |
| 386 | + logger=logger, |
| 387 | + checkpoint_callback=ModelCheckpoint(tmpdir), |
| 388 | + ) |
| 389 | + result = trainer.fit(model) |
| 390 | + |
| 391 | + # traning complete |
| 392 | + assert result == 1 |
| 393 | + |
| 394 | + # save model |
| 395 | + new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') |
| 396 | + trainer.save_checkpoint(new_weights_path) |
| 397 | + |
| 398 | + # load new model |
| 399 | + hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml') |
| 400 | + hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' |
| 401 | + ckpt_path = hparams_url if url_ckpt else new_weights_path |
| 402 | + |
| 403 | + class CurrentModel(EvalModelTemplate): |
| 404 | + def __init__(self): |
| 405 | + super().__init__() |
| 406 | + self.c_d3 = torch.nn.Linear(7, 7) |
| 407 | + |
| 408 | + CurrentModel.load_from_checkpoint( |
| 409 | + checkpoint_path=ckpt_path, |
| 410 | + hparams_file=hparams_path, |
| 411 | + strict=False, |
| 412 | + ) |
| 413 | + |
| 414 | + with pytest.raises(RuntimeError, match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): |
| 415 | + CurrentModel.load_from_checkpoint( |
| 416 | + checkpoint_path=ckpt_path, |
| 417 | + hparams_file=hparams_path, |
| 418 | + strict=True, |
| 419 | + ) |
| 420 | + |
| 421 | + |
322 | 422 | def test_model_pickle(tmpdir):
|
323 | 423 | model = EvalModelTemplate()
|
324 | 424 | pickle.dumps(model)
|
|
0 commit comments