Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers' .fit() mimics .test() after first call to .test() + .test() doesn't print metrics #633

Closed
lmartak opened this issue Dec 17, 2019 · 1 comment · Fixed by #1017
Closed
Labels
bug Something isn't working help wanted Open to be worked on
Milestone

Comments

@lmartak
Copy link

lmartak commented Dec 17, 2019

🐛 Bug

  1. After first call to Trainer.test() all subsequent calls to Trainer.fit() exhibit output behavior of Trainer.test()
  2. Trainer.test() doesn't print metrics (and returns None) returned by LightningModule.test_end()

To Reproduce

Run following code in a Python 3.6.8 env with torch=1.3.1 and pytorch_lightning=0.5.3.2 installed:

Code sample

Click to view the code sample.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import pytorch_lightning as pl

print(torch.__version__, pl.__version__)


class TestModule(pl.LightningModule):

    def __init__(self, bs):
        super(TestModule, self).__init__()
        self.fc = nn.Linear(2, 2)
        self.bs = bs
        self.criterion = nn.MSELoss()

    def forward(self, x):
        x = self.fc(x)
        return x
    
    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': self.criterion(y_hat, y)}

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': self.criterion(y_hat, y)}
    
    def test_end(self, outputs):
        test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        test_metrics = {'test_loss': test_loss}
        return {'progress_bar': test_metrics, 'log': test_metrics}

    def configure_optimizers(self):
        self.optimizer = optim.Adam(self.parameters())
        return self.optimizer

    @pl.data_loader
    def train_dataloader(self):
        x = torch.rand(1000, 2) - 0.5
        y = torch.sign(x)
        ds = TensorDataset(x, y)
        dl = DataLoader(ds, batch_size=self.bs, shuffle=True)
        return dl

    @pl.data_loader
    def test_dataloader(self):
        x = torch.rand(100, 2) - 0.5
        y = torch.sign(x)
        ds = TensorDataset(x, y)
        dl = DataLoader(ds, batch_size=self.bs * 2)
        return dl


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net = TestModule(bs=32).to(device)

epochs = 10
trainer = pl.Trainer(gpus=-1, max_nb_epochs=epochs, min_nb_epochs=epochs)

trainer.fit(net)
trainer.test()
trainer.fit(net)
trainer.fit(net)

Code output

Output of the sequence of calls fit -> test -> fit -> fit

1.3.1 0.5.3.2
Epoch 10: 100%|██████████| 32/32 [00:00<00:00, 272.26batch/s,
batch_nb=31, gpu=0, loss=1.009, v_nb=65]
Testing: 100%|██████████| 2/2 [00:00<00:00, 222.23batch/s]
Testing: 100%|██████████| 2/2 [00:00<00:00, 357.40batch/s]
Testing: 100%|██████████| 2/2 [00:00<00:00, 380.52batch/s]

Expected behavior

  1. Trainer.fit() should always run model training, even after it was already tested once via Trainer.test().
  2. Trainer.test() should return metrics produced by LightningModule.test_end(). What is processing the whole test dataset through the model good for if not collecting performance metrics?

Environment

Click to view the environment.
PyTorch version: 1.3.1 [0/47800]
Is debug build: No
CUDA used to build PyTorch: 10.1.243

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: TITAN V
Nvidia driver version: 418.87.00
cuDNN version: Probably one of the following:
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip3] botorch==0.1.4
[pip3] gpytorch==0.3.6
[pip3] numpy==1.17.4
[pip3] pytorch-lightning==0.5.3.2
[pip3] torch==1.3.1
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.4.2
[conda] Could not collect
@lmartak lmartak added the bug Something isn't working label Dec 17, 2019
@williamFalcon
Copy link
Contributor

@lmartak good catch. looks like we need to reset a few flags after test?
submit a PR?

@Borda Borda added the help wanted Open to be worked on label Jan 24, 2020
abdullah-alnahas added a commit to abdullah-alnahas/pytorch-lightning that referenced this issue Jan 27, 2020
After first call to Trainer.test() all subsequent calls to Trainer.fit() will NOT exhibit output behavior of Trainer.test(), Trainer.test() WILL return the metrics returned by LightningModule.test_end(). However, Trainer.test(model) will not return the metrics returned by LightningModule.test_end()
@williamFalcon williamFalcon added this to the 0.6.1 milestone Feb 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
3 participants