From 2324f6ff8114f441841fef7eb0ce84a2042677b0 Mon Sep 17 00:00:00 2001 From: Jacob Zhong Date: Mon, 27 Apr 2020 11:41:35 -0400 Subject: [PATCH 1/3] change module params to dict --- pytorch_lightning/core/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a1f3eb4e9252c..3c9f2c7e77c4f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1548,8 +1548,9 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh ) # load the state_dict on the model automatically - model_args = [hparams] if hparams else [] - model = cls(*model_args, *args, **kwargs) + if hparams: + kwargs.update(hparams=hparams) + model = cls(*args, **kwargs) model.load_state_dict(checkpoint['state_dict']) # give model a chance to load something From 211f49516c83fd9f28bb8775b4da807fbba05603 Mon Sep 17 00:00:00 2001 From: Jacob Zhong Date: Wed, 29 Apr 2020 12:08:54 -0400 Subject: [PATCH 2/3] tiny change --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3c9f2c7e77c4f..3e504b1202418 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1548,7 +1548,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh ) # load the state_dict on the model automatically - if hparams: + if hparams: # only when the model has `hparams` argument kwargs.update(hparams=hparams) model = cls(*args, **kwargs) model.load_state_dict(checkpoint['state_dict']) From 0b04a858bb7204235f8ddf50b311a3a31404740a Mon Sep 17 00:00:00 2001 From: Jacob Zhong Date: Wed, 29 Apr 2020 12:35:55 -0400 Subject: [PATCH 3/3] reverse --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3e504b1202418..3c9f2c7e77c4f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1548,7 +1548,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh ) # load the state_dict on the model automatically - if hparams: # only when the model has `hparams` argument + if hparams: kwargs.update(hparams=hparams) model = cls(*args, **kwargs) model.load_state_dict(checkpoint['state_dict'])