Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EarlyStopping reinitializes to .wait=0 even with Trainer resume_from_checkpoint #1463

Closed
lizhitwo opened this issue Apr 12, 2020 · 5 comments · Fixed by #2391
Closed

EarlyStopping reinitializes to .wait=0 even with Trainer resume_from_checkpoint #1463

lizhitwo opened this issue Apr 12, 2020 · 5 comments · Fixed by #2391
Assignees
Labels
feature Is an improvement or enhancement

Comments

@lizhitwo
Copy link

lizhitwo commented Apr 12, 2020

🐛 Bug

When using Trainer's resume_from_checkpoint with EarlyStopping callback, the callback's patience progress (i.e. self.wait) is loaded according to the checkpoint, but is getting reset by its on_train_start method, making the checkpoint restoration moot.

Also, the EarlyStopping's .best is not saved or restored at all, making its restoration further unusable.

To Reproduce

Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

model = CoolSystem()

checkpoint_callback = ModelCheckpoint(
    filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
    save_top_k=-1,
    verbose=True,
    monitor='val_loss',
    mode='auto'
)

class EarlyStoppingPrinting(EarlyStopping):

    def on_train_start(self, trainer, pl_module):
        print('EarlyStoppingPrinting before on_train_start')
        print('self.wait = ', self.wait)
        super().on_train_start(trainer, pl_module)
        print('EarlyStoppingPrinting after on_train_start')
        print('self.wait = ', self.wait)

    def on_epoch_end(self, trainer, pl_module):
        ret = super().on_epoch_end(trainer, pl_module)
        if self.wait:
            print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
        return ret


early_stopping = EarlyStoppingPrinting(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='auto'
)

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=checkpoint_callback, 
                  early_stop_callback=early_stopping)

trainer.fit(model)

And then use KeyboardInterrupt on the training when early_stopping.wait>0. Load the corresponding checkpoint (let's say it's model_ckpt/_ckpt_epoch_5.ckpt) and resume with

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=None, 
                  resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_5.ckpt',
                  early_stop_callback=early_stopping)
trainer.fit(model)

The early_stopping callback would print:

EarlyStoppingPrinting before on_train_start
self.wait =  2
EarlyStoppingPrinting after on_train_start
self.wait =  0

And for self.best, I mean it's not even saved; do I need to write the code?

Expected behavior

Checkpoint value of self.wait should be preserved rather than reset:

EarlyStoppingPrinting before on_train_start
self.wait =  2
EarlyStoppingPrinting after on_train_start
self.wait =  2

And self.best should be saved and loaded from the checkpoint.

Environment

This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD

Additional context

It is confusing what member variables of the model Lightning saves into the checkpoints from reading the tutorials -- it's implied it saves a wide range of things, but what is being saved is actually very specific.

Also confusingly there are many ways to restore a checkpoint (model's load_from_checkpoint method, trainer's resume_from_checkpoint parameter, and using test_tube). These are not well documented (at least I didn't find this page before searching github) and I have no idea if I used the right one.

@lizhitwo lizhitwo added bug Something isn't working help wanted Open to be worked on labels Apr 12, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@shijie-wu
Copy link

@stale
Copy link

stale bot commented Jun 15, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 15, 2020
@jeremyjordan jeremyjordan removed won't fix This will not be worked on help wanted Open to be worked on labels Jun 15, 2020
@williamFalcon williamFalcon added feature Is an improvement or enhancement and removed bug Something isn't working labels Jun 26, 2020
@williamFalcon
Copy link
Contributor

@jeremyjordan is this being added to #1504?

@jeremyjordan
Copy link
Contributor

@williamFalcon yes this is fixed and there is a test to prevent regressions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants