Skip to content

Commit 1fb6e58

Browse files
author
lezwon
committed
condition if tpu_id is None
1 parent 1741fc3 commit 1fb6e58

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
250250

251251
# on TPU we have to wrap it under the ParallelLoader
252252
if self.use_tpu:
253-
device = xm.xla_device(self.tpu_id)
253+
device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
254254
dataloader = xla_pl.ParallelLoader(dataloader, [device])
255255
dataloader = dataloader.per_device_loader(device)
256256

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def run_training_epoch(self):
413413

414414
# on TPU we have to wrap it under the ParallelLoader
415415
if self.use_tpu:
416-
device = xm.xla_device(self.tpu_id)
416+
device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
417417
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
418418
train_dataloader = train_dataloader.per_device_loader(device)
419419

0 commit comments

Comments
 (0)