Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tpu_cores=8 not working #2106

Closed
rohitgr7 opened this issue Jun 7, 2020 · 15 comments
Closed

tpu_cores=8 not working #2106

rohitgr7 opened this issue Jun 7, 2020 · 15 comments
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Milestone

Comments

@rohitgr7
Copy link
Contributor

rohitgr7 commented Jun 7, 2020

🐛 Bug

After #2016 was fixed with PR #2033 the code is running perfectly on single tpu core and a specific tpu core but now not working with 8 tpu cores. After the training is complete getting RuntimeError: Cannot replicate if number of devices (1) is different from 8.

To Reproduce

Colab notebook

Expected behavior

Should train with 8 tpu cores with no error just like it works in case of a single core.

Environment

  • pytorch/xla: nightly
  • pytorch-lightning: master
  • PyTorch Version (e.g., 1.0): 1.5
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.7
@rohitgr7 rohitgr7 added the help wanted Open to be worked on label Jun 7, 2020
@Borda Borda added bug Something isn't working priority: 0 High priority task labels Jun 10, 2020
@Borda Borda added this to the 0.8.0 milestone Jun 10, 2020
@Borda
Copy link
Member

Borda commented Jun 16, 2020

To Reproduce
Colab notebook

I am sorry, I can't run it as it is read-only

@rohitgr7 mind check tests in #2094 as all the test there are passing, see: #2094 (comment)

@Borda Borda added waiting on author Waiting on user action, correction, or update accelerator: tpu Tensor Processing Unit labels Jun 16, 2020
@lezwon
Copy link
Contributor

lezwon commented Jun 16, 2020

I'm seeing this issue when I run the notebook. Unsure of why it's happening. It's not consistent either but occurs the majority of times. I ran the code with an old release of lightning too, it still breaks. There is some change somewhere which is breaking it. If I run just the model with plain xmp.spawn , it seems to work fine.

@rohitgr7
Copy link
Contributor Author

I have checked this on colab only. Will try it on kaggle-kernels and check if it's happening there too.

@Borda Borda removed the priority: 0 High priority task label Jun 17, 2020
@Borda Borda modified the milestones: 0.8.0, 0.8.x Jun 17, 2020
@williamFalcon
Copy link
Contributor

@lezwon is this still an issue? Will reopen if it's still an issue since this is now tested

@lezwon
Copy link
Contributor

lezwon commented Jun 26, 2020

I just ran the notebook. The training seems to be working fine, but .test() fails. Also calling .fit() the second time seems to throw the same error @rohitgr7 mentioned.

@edenlightning edenlightning reopened this Jun 26, 2020
@Borda
Copy link
Member

Borda commented Jun 26, 2020

may we add a test for it so we can later fix it?

@lezwon
Copy link
Contributor

lezwon commented Jun 27, 2020

On it. :]

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Jun 27, 2020

Checked on Kaggle with master.

  • When specifying the tpu_id training and testing both are working fine👌.
  • When training on 1 core both training and testing completes with SystemExit: 0.
  • When training on 8 cores training gets stuck and nothing happens unless I restart the kernel and in case of testing it throws RuntimeError: tensorflow/compiler/xla/xla_client/computation_client.cc:272 : Missing XLA configuration whenever I test without training since training never ends in my case.

@lezwon

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Jul 5, 2020

I tried training a model on 8 tpu cores and checked the model device in forward. The global_rank is varying between 0-7 but the model and batch tensors are always on device 0 or 1.

Can anyone please check if something is wrong in the code??

I have disabled the optimizer step just to check what's happening and returned an arbitrary tensor in forward to maintain the flow of training procedure. Here is the code:

class JigsawBert(pl.LightningModule):
    def __init__(self, hparams, train_df, valid_df, train_tfms, valid_tfms, tokenizer):
        super().__init__()
    
        self.hparams = hparams
        self.train_df, self.valid_df = train_df, valid_df
        self.train_tfms, self.valid_tfms = train_tfms, valid_tfms
        self.tokenizer = tokenizer
        
        self.model = JigsawBertModel(self.hparams.model_path)
    
        self.am_tloss = AverageMeter()
        self.am_vloss = AverageMeter()
        self.am_vauroc = AverageMeter()
    
    def prepare_data(self):
        self.train_ds = JigsawBertDataset(
            df=self.train_df,
            tokenizer=self.tokenizer,
            max_length=self.hparams.max_length,
            tfms=self.train_tfms,
            is_testing=False
        )
        
        self.valid_ds = JigsawBertDataset(
            df=self.valid_df,
            tokenizer=self.tokenizer,
            max_length=self.hparams.max_length,
            tfms=self.valid_tfms,
            is_testing=False
        )
        
    def setup(self, stage):
        pass

    def train_dataloader(self):
        sampler = DistributedSampler(
            self.train_ds,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )
        
        dl = DataLoader(
            self.train_ds,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            sampler=sampler,
            drop_last=True
        )
        
        return dl
    
    def val_dataloader(self):
        sampler = DistributedSampler(
            self.valid_ds,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False
        )
        
        dl = DataLoader(
            self.valid_ds,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            sampler=sampler,
            drop_last=False
        )
        
        return dl
    
    def _compute_loss(self, logits, targets):
        loss_fn = nn.BCEWithLogitsLoss()
        loss = loss_fn(logits, targets)
        return loss
    
    def forward(self, batch):
#         return self.model(batch['input_ids'], batch['token_type_ids'], batch['attn_mask'])
        print(f"Model device: {next(self.model.parameters()).device}",
              f"Input device: {batch['input_ids'].device}",
              f"Input shape: {batch['input_ids'].shape}",
              f"Global rank: {self.trainer.global_rank}")

        return nn.Linear(2, 1)(torch.randn(len(batch['input_ids']), 2)).to(batch['input_ids'].device)
    
    def training_step(self, batch, batch_nb):
        logits = self.forward(batch).view(-1)
        train_loss = self._compute_loss(logits, batch['label'])
        self.am_tloss.update(train_loss.item(), len(batch['input_ids']))
        return {'loss': train_loss}
    
    def training_epoch_end(self, outputs):
        logs = {'avg_train_loss': self.am_tloss.avg}
        self.am_tloss.reset()
        return {'progress_bar': logs, 'log': logs}
    
    def validation_step(self, batch, batch_nb):
        logits = self.forward(batch).view(-1)
        valid_loss = self._compute_loss(logits, batch['label'])
        logits = torch.sigmoid(logits)
        self.am_vloss.update(valid_loss.item(), len(batch['input_ids']))
    
    def validation_epoch_end(self, outputs):
        logs = {
            'avg_valid_loss': self.am_vloss.avg,
            'avg_valid_auroc': self.am_vauroc.avg
        }
        return {'progress_bar': logs, 'log': logs}
    
    def optimizer_step(self, *args, **kwargs):
        pass
    
    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=self.hparams.lr*xm.xrt_world_size())
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 2
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 6
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 4
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:1 Input device: xla:1 Input shape: torch.Size([16, 32]) Global rank: 0
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 7
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 1
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 5
Model device: xla:0 Input device: xla:0 Input shape: torch.Size([16, 32]) Global rank: 3

@williamFalcon
Copy link
Contributor

fixed on 0.8.5 but .test() doesn't work since we can't share weights back to main process. In the meantime, use .test(ckpt_path=PATH)

@rohitgr7
Copy link
Contributor Author

@williamFalcon Still facing #2419 on the master with 8 tpu cores. With checkpoint_callback=True, it's stuck at the beginning of 2nd epoch and with checkpoint_callback=False it stuck at the end of the last epoch (waited for more than 10 minutes for clean up but nothing happens). Tried on both kaggle and colab.

@williamFalcon
Copy link
Contributor

you need a val loop for now. can you share a colab?

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Jul 11, 2020

with a val loop, it's working fine :) but without a val loop it's not. Don't know how to share an editable colab notebook link, but here is the same notebook on kaggle: https://www.kaggle.com/rohitgr/lightning-tpu

@ghost
Copy link

ghost commented Jun 30, 2021

I have recently been attempting to run pytorch on GoogleColab with "tpu_cores=8" on torch version (1.9.0+cu102), but I was getting the following error: ProcessExitedException: process 0 terminated with exit code 17 site:stackoverflow.com. However, using a single TPU cores works (extremely slow). Any idea what's going on?

@EricPHassey
Copy link

I also am getting an issue using TPUs on Google Collab. Not sure what to do or how to fix it. Assuming it's part of the Lightning package creating these issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

6 participants