Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional dataloader created and discarded when training with reload_dataloaders_every_epoch #1181

Closed
TevenLeScao opened this issue Mar 18, 2020 · 1 comment · Fixed by #1196
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@TevenLeScao
Copy link
Contributor

TevenLeScao commented Mar 18, 2020

🐛 Bug

I am training with reload_dataloaders_every_epoch and I've noticed it instantiates an extra DataLoader before training for which nothing is run. This is an issue for me as I am training with chunks that get loaded every epoch and it is messing with the order I load them in especially if I reload a checkpoint; it would be an issue for people that generate a new dataset every epoch as they waste computation. The tqdm bar also keeps the information of the first, discarded DataLoader (in the screenshot, the number of iterations is the same for both whereas they should be different sizes)

image

To Reproduce

Run the code sample below, which runs for one epoch and displays a message every time a DataLoader is created.

A DataLoader gets instantiated a first time line 286 in training_loop.py outside of the epoch loop (that's the usual time it gets instantiated when not reloading every epoch. Then when using reload_dataloaders_every_epoch another one is created at the start of every epoch line 386, inside the loop, so for the first epoch there's an extra one.

Code sample

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from time import sleep

class MinimalDataset(Dataset):

    def __init__(self, index):
        self.data = torch.Tensor(64 * index, 1024)

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return len(self.data)

class MinimalModule(pl.LightningModule):

    def __init__(self):
        super(MinimalModule, self).__init__()
        self.nn = torch.nn.Linear(1024, 1)
        self.current_index = 0

    def forward(self, batch):
        return self.nn(batch)

    def training_step(self, batch, batch_idx):
        sleep(0.1)
        loss = self.nn(batch)[0]
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        loss = self.nn(batch)[0]
        return {'val_loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)

    def train_dataloader(self):
        # REQUIRED
        self.current_index += 1
        print(f"initializing DataLoader n{self.current_index}")
        data_loader = DataLoader(MinimalDataset(self.current_index))
        return data_loader
    
model = MinimalModule()
trainer = pl.Trainer(reload_dataloaders_every_epoch=True, num_sanity_val_steps=0, val_check_interval=8, max_epochs=1)

trainer.fit(model)

Expected behavior

Only one dataloader should be created; two are. The tqdm bar should show 128 iterations as that is the dataset size the second time; but it shows 64 instead (I added the sleep(0.1) to leave time to observe that)

Environment

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: Could not collect

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collectepoch_end
GPU models and configuration: GPU 0: GeForce RTX 2070 with Max-Q Design
Nvidia driver version: 435.21
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] pytorch-lightning==0.7.1
[pip3] torch==1.4.0
[pip3] torchvision==0.4.2
[conda] Could not collect

Additional context

@TevenLeScao TevenLeScao added bug Something isn't working help wanted Open to be worked on labels Mar 18, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant