You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I was facing few issues while using the trainer.test() function, on debugging I found out that the problem was with the _load_model_state class method which is called by load_from_checkpoint.
Code For reference
@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
# add some back compatibility, the actual one shall be last
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
if hparam_key in checkpoint:
model_args.update(checkpoint[hparam_key])
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
init_args_name = inspect.signature(cls).parameters.keys()
if args_name == 'kwargs':
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
kwargs.update(**cls_kwargs)
elif args_name:
if args_name in init_args_name:
kwargs.update({args_name: model_args})
else:
args = (model_args, ) + args
# load the state_dict on the model automatically
model = cls(*args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])
# give model a chance to load something
model.on_load_checkpoint(checkpoint)
return model
Consider the case where the model has no arguments, which corresponds to LightModel.load_from_checkpoint('path'). Here, the else clause of the if-elif is being executed where the agrs variable is updated from an empty tuple to a tuple with an empty dictionary args = (model_args, ) + args (as model_args={}). Therefore, while unpacking the args and kwargs (model = cls(*args, **kwargs)), There is an extra argument being passed which raises a TypeError: __init__() takes 1 positional arguments but 2 were given. #2364
In some cases if the model has an argument and the user has forgotten to add it in the load_from_checkpoint, then an empty dictionary will be passed instead and it raises other errors depending on the code. For example, in the issue #2359 an empty dict is passed while loading the model and hence raises RuntimeError: Error(s) in loading state_dict for Model:.
I do not fully understand what is happening in the function. It would be great if someone can suggest changes to make in the comments so that I can start working after updating the changes in my forked repo.
Steps to reproduce
!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def test_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
mnist_model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=3)
trainer.fit(mnist_model, train_loader)
test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
trainer.test(test_dataloaders=test_loader)
Which returns:
TypeError Traceback (most recent call last)
<ipython-input-5-50449ee4f6cc> in <module>()
1 test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
----> 2 trainer.test(test_dataloaders=test_loader)
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path)
1168 if ckpt_path == 'best':
1169 ckpt_path = self.checkpoint_callback.best_model_path
-> 1170 model = self.get_model().load_from_checkpoint(ckpt_path)
1171
1172 self.testing = True
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
167 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
168
--> 169 model = cls._load_model_state(checkpoint, *args, **kwargs)
170 return model
171
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs)
201
202 # load the state_dict on the model automatically
--> 203 model = cls(*cls_args, **cls_kwargs)
204 model.load_state_dict(checkpoint['state_dict'])
205
TypeError: __init__() takes 1 positional argument but 2 were given
Expected behavior
Start testing
The text was updated successfully, but these errors were encountered:
Hi! thanks for your contribution!, great first issue!
nischal-sanil
changed the title
An Extra argument passed to the class loaded from load_from_checkpoint.
An Extra argument passed to the class, loaded from load_from_checkpoint.
Jun 27, 2020
🐛 Bug
Hello,
I was facing few issues while using the
trainer.test()
function, on debugging I found out that the problem was with the_load_model_state
class method which is called byload_from_checkpoint
.Code For reference
Consider the case where the model has no arguments, which corresponds to
LightModel.load_from_checkpoint('path')
. Here, the else clause of the if-elif is being executed where theagrs
variable is updated from an empty tuple to a tuple with an empty dictionaryargs = (model_args, ) + args
(asmodel_args={}
). Therefore, while unpacking the args and kwargs (model = cls(*args, **kwargs)
), There is an extra argument being passed which raises aTypeError: __init__() takes 1 positional arguments but 2 were given
. #2364In some cases if the model has an argument and the user has forgotten to add it in the load_from_checkpoint, then an empty dictionary will be passed instead and it raises other errors depending on the code. For example, in the issue #2359 an empty dict is passed while loading the model and hence raises
RuntimeError: Error(s) in loading state_dict for Model:
.I do not fully understand what is happening in the function. It would be great if someone can suggest changes to make in the comments so that I can start working after updating the changes in my forked repo.
Steps to reproduce
Which returns:
Expected behavior
Start testing
The text was updated successfully, but these errors were encountered: