Skip to content

Commit c7e4493

Browse files
committed
tpu id
1 parent f62660d commit c7e4493

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

pytorch_lightning/trainer/distrib_parts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ def copy_trainer_model_properties(self, model):
432432
m.tpu_local_core_rank = self.tpu_local_core_rank
433433
m.tpu_global_core_rank = self.tpu_global_core_rank
434434

435-
def transfer_batch_to_tpu(self, batch: Any):
436-
device = xm.xla_device() if XLA_AVAILABLE else torch.device('cpu')
435+
def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None):
436+
device = xm.xla_device(tpu_id) if XLA_AVAILABLE else torch.device('cpu')
437437
return self.__transfer_data_to_device(batch, device)
438438

439439
def transfer_batch_to_gpu(self, batch: Any, gpu_id: int):

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:
434434

435435
# TPU data transfer
436436
if self.use_tpu:
437-
batch = self.transfer_batch_to_tpu(batch)
437+
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
438438
args[0] = batch
439439

440440
# CPU, TPU or gpu step

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
729729

730730
# TPU support
731731
elif self.use_tpu:
732-
batch = self.transfer_batch_to_tpu(batch)
732+
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
733733
args[0] = batch
734734
output = self.model.training_step(*args)
735735

0 commit comments

Comments
 (0)