Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

made load from checkpoint flexible #1307

Merged
merged 3 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))

### Deprecated
Expand Down
43 changes: 38 additions & 5 deletions docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,50 @@ To save your own checkpoint call:
Checkpoint Loading
------------------

To load a model along with its weights, biases and hyperparameters use following method:
To load a model along with its weights, biases and hyperparameters use following method.

.. code-block:: python

model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)

A LightningModule is no different than a nn.Module. This means you can load it and use it for
predictions as you would a nn.Module.
The above only works if you used `hparams` in your model definition

.. code-block:: python

class MyModel(pl.LightningModule):

def __init__(self, hparams):
self.hparams = hparams
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)

But if you don't and instead pass individual parameters

.. code-block:: python

class MyModel(pl.LightningModule):

def __init__(self, in_dim, out_dim):
self.l1 = nn.Linear(in_dim, out_dim)

you can restore the model like this

.. code-block:: python

model = MyModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)


Restoring Training State
------------------------

If you don't just want to load weights, but instead restore the full training,
do the following:

.. code-block:: python

model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')

.. note:: To restore the trainer state as well use
:meth:`pytorch_lightning.trainer.trainer.Trainer.resume_from_checkpoint`.
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)
19 changes: 16 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,7 @@ def load_from_checkpoint(
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
*args, **kwargs
) -> 'LightningModule':
r"""

Expand All @@ -1346,6 +1347,7 @@ def __init__(self, hparams):

Args:
checkpoint_path: Path to checkpoint.
model_args: Any keyword args needed to init the model.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
Expand Down Expand Up @@ -1387,6 +1389,14 @@ def __init__(self, hparams):
tags_csv='/path/to/hparams_file.csv'
)

# or load passing whatever args the model takes to load
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
learning_rate=0.1,
layers=2,
pretrained_model=some_model
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
Expand All @@ -1403,11 +1413,11 @@ def __init__(self, hparams):
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
model = cls._load_model_state(checkpoint, *args, **kwargs)
return model

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule':
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'LightningModule':
cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters
ckpt_hparams = checkpoint.get('hparams')

Expand All @@ -1433,7 +1443,10 @@ def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule':

# load the state_dict on the model automatically
model_args = [hparams] if hparams else []
model = cls(*model_args)
if len(model_args) > 0:
model = cls(*model_args)
else:
model = cls(*args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])

# give model a chance to load something
Expand Down