Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax hparams in model saving/loading #907

Closed
polars05 opened this issue Feb 21, 2020 · 8 comments · Fixed by #919
Closed

Relax hparams in model saving/loading #907

polars05 opened this issue Feb 21, 2020 · 8 comments · Fixed by #919
Labels
question Further information is requested

Comments

@polars05
Copy link

I've managed to train a model using pl.fit(model) and have the .ckpt file. Now, I'm trying to load the .ckpt file so that I can do inference on a single image:

model = CoolSystem()
to_infer = torch.load('checkpoints/try_ckpt_epoch_1_v0.ckpt')
model.load_from_checkpoint(to_infer) # ------------- error is thrown at this line

However, upon loading the .ckpt file, the following error is thrown:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

Am I doing something wrong when using PyTorch Lightning for inference?

For reference, this is my system:

import pytorch_lightning as pl

import os
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()

        # self.hparams = hparams
        self.data_dir = '/content/hymenoptera_data'

        self.model = torchvision.models.resnet18(pretrained=True) # final layer is of size [bs, 1000]
        num_ftrs = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(num_ftrs, 2) # change final layer to be of size [bs, 2]
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def configure_optimizers(self):
        # Observe that all parameters are being optimized
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

        # Decay LR by a factor of 0.1 every 7 epochs
        exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
 
        return [optimizer], [exp_lr_scheduler]

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED

        transform = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ])

        train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)

        return train_loader
    
    @pl.data_loader
    def val_dataloader(self):
      transform = transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ])
                              
      val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)
      val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)

      return val_loader

And I'm training it this way:

model = CoolSystem() 

import os

checkpoint_callback = pl.callbacks.ModelCheckpoint(
          filepath=os.path.join(os.getcwd(), 'checkpoints'),
          verbose=True,
          monitor='val_loss', 
          mode='min', 
          prefix='try',
          save_top_k=-1,
          period=1 # check val_loss every n periods, and saves the checkpoint if it is better than the val_loss at the previous period
      )

trainer = pl.Trainer(
      max_epochs=2,
      checkpoint_callback=checkpoint_callback)  

trainer.fit(model)
@polars05 polars05 added the question Further information is requested label Feb 21, 2020
@github-actions
Copy link
Contributor

Hey, thanks for your contribution! Great first issue!

@awaelchli
Copy link
Contributor

Have not tested it, but I think it should be
model.load_from_checkpoint('checkpoints/try_ckpt_epoch_1_v0.ckpt')
(the method takes a string).
See docs:
https://pytorch-lightning.readthedocs.io/en/0.6.0/pytorch_lightning.core.html#pytorch_lightning.core.LightningModule.load_from_checkpoint

@polars05
Copy link
Author

polars05 commented Feb 22, 2020

After trying model.load_from_checkpoint('checkpoints/try_ckpt_epoch_1_v0.ckpt'), the following error is now thrown:

OSError: Checkpoint does not contain hyperparameters. Are your model hyperparameters storedin self.hparams?

I built CoolSystem() without self.hparams, as per the example Colab notebook (https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg)

Any advice on this?

@awaelchli
Copy link
Contributor

awaelchli commented Feb 22, 2020

I guess it should be added to that example. The GAN example has it.
Add the hparams to the __init__, train it, and then try to load again.
Looks like it is always needed, even if you don't pass any hparams in.

@polars05
Copy link
Author

Got it! Will take note to always add hparams to __init__ then

@williamFalcon
Copy link
Contributor

@awaelchli find submitting a PR to fix?

i think the point was for hparams to be optional? or should we make it more flexible? @neggert

@awaelchli
Copy link
Contributor

I can look at it.
To make it optional, I guess we could simply change the loading behaviour depending on whether the user has defined hparams or not.

@williamFalcon williamFalcon reopened this Feb 22, 2020
@williamFalcon williamFalcon changed the title Loading model trained with pl.fit(model) for inference Relax hparams in model saving/loading Feb 22, 2020
@awaelchli
Copy link
Contributor

I will hold back until #849 is finalized because it affects ModelCheckpoint callback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants