-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tpu_cores=8 not working #2106
Comments
I am sorry, I can't run it as it is read-only @rohitgr7 mind check tests in #2094 as all the test there are passing, see: #2094 (comment) |
I'm seeing this issue when I run the notebook. Unsure of why it's happening. It's not consistent either but occurs the majority of times. I ran the code with an old release of lightning too, it still breaks. There is some change somewhere which is breaking it. If I run just the model with plain |
I have checked this on colab only. Will try it on kaggle-kernels and check if it's happening there too. |
@lezwon is this still an issue? Will reopen if it's still an issue since this is now tested |
I just ran the notebook. The training seems to be working fine, but |
may we add a test for it so we can later fix it? |
On it. :] |
Checked on Kaggle with master.
|
I tried training a model on 8 tpu cores and checked the model device in forward. The global_rank is varying between 0-7 but the model and batch tensors are always on device 0 or 1. Can anyone please check if something is wrong in the code?? I have disabled the optimizer step just to check what's happening and returned an arbitrary tensor in forward to maintain the flow of training procedure. Here is the code: class JigsawBert(pl.LightningModule):
def __init__(self, hparams, train_df, valid_df, train_tfms, valid_tfms, tokenizer):
super().__init__()
self.hparams = hparams
self.train_df, self.valid_df = train_df, valid_df
self.train_tfms, self.valid_tfms = train_tfms, valid_tfms
self.tokenizer = tokenizer
self.model = JigsawBertModel(self.hparams.model_path)
self.am_tloss = AverageMeter()
self.am_vloss = AverageMeter()
self.am_vauroc = AverageMeter()
def prepare_data(self):
self.train_ds = JigsawBertDataset(
df=self.train_df,
tokenizer=self.tokenizer,
max_length=self.hparams.max_length,
tfms=self.train_tfms,
is_testing=False
)
self.valid_ds = JigsawBertDataset(
df=self.valid_df,
tokenizer=self.tokenizer,
max_length=self.hparams.max_length,
tfms=self.valid_tfms,
is_testing=False
)
def setup(self, stage):
pass
def train_dataloader(self):
sampler = DistributedSampler(
self.train_ds,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)
dl = DataLoader(
self.train_ds,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
sampler=sampler,
drop_last=True
)
return dl
def val_dataloader(self):
sampler = DistributedSampler(
self.valid_ds,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False
)
dl = DataLoader(
self.valid_ds,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
sampler=sampler,
drop_last=False
)
return dl
def _compute_loss(self, logits, targets):
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, targets)
return loss
def forward(self, batch):
# return self.model(batch['input_ids'], batch['token_type_ids'], batch['attn_mask'])
print(f"Model device: {next(self.model.parameters()).device}",
f"Input device: {batch['input_ids'].device}",
f"Input shape: {batch['input_ids'].shape}",
f"Global rank: {self.trainer.global_rank}")
return nn.Linear(2, 1)(torch.randn(len(batch['input_ids']), 2)).to(batch['input_ids'].device)
def training_step(self, batch, batch_nb):
logits = self.forward(batch).view(-1)
train_loss = self._compute_loss(logits, batch['label'])
self.am_tloss.update(train_loss.item(), len(batch['input_ids']))
return {'loss': train_loss}
def training_epoch_end(self, outputs):
logs = {'avg_train_loss': self.am_tloss.avg}
self.am_tloss.reset()
return {'progress_bar': logs, 'log': logs}
def validation_step(self, batch, batch_nb):
logits = self.forward(batch).view(-1)
valid_loss = self._compute_loss(logits, batch['label'])
logits = torch.sigmoid(logits)
self.am_vloss.update(valid_loss.item(), len(batch['input_ids']))
def validation_epoch_end(self, outputs):
logs = {
'avg_valid_loss': self.am_vloss.avg,
'avg_valid_auroc': self.am_vauroc.avg
}
return {'progress_bar': logs, 'log': logs}
def optimizer_step(self, *args, **kwargs):
pass
def configure_optimizers(self):
return optim.AdamW(self.parameters(), lr=self.hparams.lr*xm.xrt_world_size())
|
fixed on 0.8.5 but .test() doesn't work since we can't share weights back to main process. In the meantime, use .test(ckpt_path=PATH) |
@williamFalcon Still facing #2419 on the master with 8 tpu cores. With |
you need a val loop for now. can you share a colab? |
with a val loop, it's working fine :) but without a val loop it's not. Don't know how to share an editable colab notebook link, but here is the same notebook on kaggle: https://www.kaggle.com/rohitgr/lightning-tpu |
I have recently been attempting to run pytorch on GoogleColab with "tpu_cores=8" on torch version (1.9.0+cu102), but I was getting the following error: |
I also am getting an issue using TPUs on Google Collab. Not sure what to do or how to fix it. Assuming it's part of the Lightning package creating these issues. |
🐛 Bug
After #2016 was fixed with PR #2033 the code is running perfectly on single tpu core and a specific tpu core but now not working with 8 tpu cores. After the training is complete getting
RuntimeError: Cannot replicate if number of devices (1) is different from 8
.To Reproduce
Colab notebook
Expected behavior
Should train with 8 tpu cores with no error just like it works in case of a single core.
Environment
conda
,pip
, source): pipThe text was updated successfully, but these errors were encountered: