Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect default cuda device when using single gpu other than cuda:0 #3030

Closed
manipopopo opened this issue Aug 18, 2020 · 3 comments · Fixed by #3042
Closed

Incorrect default cuda device when using single gpu other than cuda:0 #3030

manipopopo opened this issue Aug 18, 2020 · 3 comments · Fixed by #3042
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@manipopopo
Copy link
Contributor

🐛 Bug

The default cuda is not set properly to the trainer.root_gpu in single-GPU mode. The tensors created with device='cuda' will be placed on the incorrect gpu, and the dataloader will acquire memory on the incorrect gpu when pin_memory=True.

Maybe we'll need to add
torch.cuda.set_device(self.trainer.root_gpu) to https://github.com/PyTorchLightning/pytorch-lightning/blob/5dfc7b157e7febab692036b7392dac8b52f41b87/pytorch_lightning/accelerators/gpu_backend.py#L24
as DDPBackend did:

https://github.com/PyTorchLightning/pytorch-lightning/blob/5dfc7b157e7febab692036b7392dac8b52f41b87/pytorch_lightning/accelerators/ddp_backend.py#L195

To Reproduce

Running the following code will get

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Code sample

import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils import data


class Dataset(data.Dataset):

  def __getitem__(self, item):
    return torch.zeros(1)

  def __len__(self):
    return 5


class Model(pl.LightningModule):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.x = nn.Parameter(torch.zeros(1))

  def forward(self, *args, **kwargs):
    return self.x

  def training_step(self, *args, **kwargs):
    return self.x + torch.zeros(1, device='cuda')  # RuntimeError.

  def train_dataloader(self):
    return data.DataLoader(Dataset(), num_workers=1, pin_memory=True)

  def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), 1.0)


if __name__ == '__main__':
  trainer = pl.Trainer(gpus=[1], num_sanity_val_steps=0, max_epochs=1)
  model = Model()
  trainer.fit(model)

Expected behavior

No RuntimeError occurs.

Environment

  • CUDA:
    • GPU:
    • available:
    • version:
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0
    • pytorch-lightning: 0.9.0rc16
    • tensorboard: 2.3.0
    • tqdm: 4.48.2
  • System:
    • OS: Windows
    • architecture:
      • 64bit
      • WindowsPE
    • processor:
    • python: 3.7.3
    • version: 10.0.18362

Additional context

@manipopopo manipopopo added bug Something isn't working help wanted Open to be worked on labels Aug 18, 2020
@nateraw
Copy link
Contributor

nateraw commented Aug 18, 2020

Directly related to #3016 I believe, perhaps a duplicate? Thanks for bringing this up 😄

@ananyahjha93 ananyahjha93 self-assigned this Aug 18, 2020
@edenlightning edenlightning added the priority: 0 High priority task label Aug 18, 2020
@edenlightning edenlightning added this to the 0.9.0 milestone Aug 18, 2020
@ananyahjha93
Copy link
Contributor

@nateraw looked at #3016, that is slightly different. Limiting my PR #3042 to fixing this particular bug.

@kencyshaka
Copy link

I am still getting the same error. Is there an update on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants