Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrainResult/EvalResult does not log properly with on_step=True and on_epoch=True #2972

Closed
sykrn opened this issue Aug 14, 2020 · 4 comments · Fixed by #2986
Closed

TrainResult/EvalResult does not log properly with on_step=True and on_epoch=True #2972

sykrn opened this issue Aug 14, 2020 · 4 comments · Fixed by #2986
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@sykrn
Copy link

sykrn commented Aug 14, 2020

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

Here the minimal code in Colab: here

OR:

Code sample

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl



from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning import TrainResult,EvalResult

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        result = TrainResult(minimize=loss)
        result.log('tr_loss',loss,prog_bar=True,on_step=True,on_epoch=True)
        result.log('tr_acc',acc,prog_bar=True,on_step=True,on_epoch=True)      
        return result

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        result = EvalResult(checkpoint_on=loss,early_stop_on=loss)
        result.log('val_loss',loss,prog_bar=True,on_step=True,on_epoch=True)
        result.log('val_acc',acc,prog_bar=True,on_step=True,on_epoch=True)        
        return result


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.02)

train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()),shuffle=True, batch_size=32)
val_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
mnist_model = MNISTModel()
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20,max_epochs=5)    
trainer.fit(mnist_model, train_loader,val_loader) 

Expected behavior

The step_val_loss graph on Tensorboard should have $n_batch\times epochs$ items (the number of step), but it looks like the same as number of epoch (only few of them).

Environment

You can get the script and run it with this PL version 0.9.0.rc12, I used the master version here.

!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

Additional context

In another experiment, I found in the step_tr_loss also not logging properly (looks like on_epoch=True with different values)

Hope someone can help this problem. Or is there any logical error in mycode?
because, I always upgrade the PL version to master :D,

@sykrn sykrn added bug Something isn't working help wanted Open to be worked on labels Aug 14, 2020
@justusschock
Copy link
Member

cc @williamFalcon

@francoisruty
Copy link

I have a similar problem

def training_step(self, batch, batch_nb):
    loss = ...
    result = TrainResult(minimize=loss)
    result.log('loss',loss,prog_bar=False,on_step=True,on_epoch=True)
    return result

yields nothing in tensorboard log dir, except 1 data point at step 49, it's really weird

@williamFalcon
Copy link
Contributor

ummmm... that's weird. i'll check this out

@sykrn
Copy link
Author

sykrn commented Aug 15, 2020

To be more specific, here is the SS of the eval step (of my code above). either accuracy or loss of EvalResult has the same problem.

image

Compare to this, in step tr_loss, (the correct one):

image

In another case:

I also found this inconsistency, in TrainResult that should record the step values but only log a few of them, otherwise in EvalResult, it was correct to log the step values (Flipping case to the one that I post here).

Also sometimes, it did not log anything at all (another case).

williamFalcon added a commit that referenced this issue Aug 15, 2020
* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add step metrics

* add step metrics
ameliatqy pushed a commit to ameliatqy/pytorch-lightning that referenced this issue Aug 17, 2020
* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add val step arg to metrics

* add step metrics

* add step metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants