Skip to content

Commit 31a658e

Browse files
made load from checkpoint flexible (#1307)
* made load from checkpoint flexible * made load from checkpoint flexible * made load from checkpoint flexible
1 parent 3101712 commit 31a658e

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Changed
2323

24+
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
2425
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))
2526

2627
### Deprecated

docs/source/weights_loading.rst

+38-5
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,50 @@ To save your own checkpoint call:
8484
Checkpoint Loading
8585
------------------
8686

87-
To load a model along with its weights, biases and hyperparameters use following method:
87+
To load a model along with its weights, biases and hyperparameters use following method.
8888

8989
.. code-block:: python
9090
9191
model = MyLightingModule.load_from_checkpoint(PATH)
9292
model.eval()
9393
y_hat = model(x)
9494
95-
A LightningModule is no different than a nn.Module. This means you can load it and use it for
96-
predictions as you would a nn.Module.
95+
The above only works if you used `hparams` in your model definition
9796

97+
.. code-block:: python
98+
99+
class MyModel(pl.LightningModule):
100+
101+
def __init__(self, hparams):
102+
self.hparams = hparams
103+
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
104+
105+
But if you don't and instead pass individual parameters
106+
107+
.. code-block:: python
108+
109+
class MyModel(pl.LightningModule):
110+
111+
def __init__(self, in_dim, out_dim):
112+
self.l1 = nn.Linear(in_dim, out_dim)
113+
114+
you can restore the model like this
115+
116+
.. code-block:: python
117+
118+
model = MyModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
119+
120+
121+
Restoring Training State
122+
------------------------
123+
124+
If you don't just want to load weights, but instead restore the full training,
125+
do the following:
126+
127+
.. code-block:: python
128+
129+
model = LitModel()
130+
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
98131
99-
.. note:: To restore the trainer state as well use
100-
:meth:`pytorch_lightning.trainer.trainer.Trainer.resume_from_checkpoint`.
132+
# automatically restores model, epoch, step, LR schedulers, apex, etc...
133+
trainer.fit(model)

pytorch_lightning/core/lightning.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,7 @@ def load_from_checkpoint(
13241324
checkpoint_path: str,
13251325
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
13261326
tags_csv: Optional[str] = None,
1327+
*args, **kwargs
13271328
) -> 'LightningModule':
13281329
r"""
13291330
@@ -1346,6 +1347,7 @@ def __init__(self, hparams):
13461347
13471348
Args:
13481349
checkpoint_path: Path to checkpoint.
1350+
model_args: Any keyword args needed to init the model.
13491351
map_location:
13501352
If your checkpoint saved a GPU model and you now load on CPUs
13511353
or a different number of GPUs, use this to map to the new setup.
@@ -1387,6 +1389,14 @@ def __init__(self, hparams):
13871389
tags_csv='/path/to/hparams_file.csv'
13881390
)
13891391
1392+
# or load passing whatever args the model takes to load
1393+
MyLightningModule.load_from_checkpoint(
1394+
'path/to/checkpoint.ckpt',
1395+
learning_rate=0.1,
1396+
layers=2,
1397+
pretrained_model=some_model
1398+
)
1399+
13901400
# predict
13911401
pretrained_model.eval()
13921402
pretrained_model.freeze()
@@ -1403,11 +1413,11 @@ def __init__(self, hparams):
14031413
hparams.__setattr__('on_gpu', False)
14041414
checkpoint['hparams'] = vars(hparams)
14051415

1406-
model = cls._load_model_state(checkpoint)
1416+
model = cls._load_model_state(checkpoint, *args, **kwargs)
14071417
return model
14081418

14091419
@classmethod
1410-
def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule':
1420+
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'LightningModule':
14111421
cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters
14121422
ckpt_hparams = checkpoint.get('hparams')
14131423

@@ -1433,7 +1443,10 @@ def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule':
14331443

14341444
# load the state_dict on the model automatically
14351445
model_args = [hparams] if hparams else []
1436-
model = cls(*model_args)
1446+
if len(model_args) > 0:
1447+
model = cls(*model_args)
1448+
else:
1449+
model = cls(*args, **kwargs)
14371450
model.load_state_dict(checkpoint['state_dict'])
14381451

14391452
# give model a chance to load something

0 commit comments

Comments
 (0)