Skip to content

Commit ed6e758

Browse files
author
lezwon
committed
use parallel loader
1 parent fdbbe96 commit ed6e758

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pytorch_lightning/trainer/evaluation_loop.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
from pytorch_lightning.utilities import rank_zero_warn
136136

137137
try:
138+
import torch_xla
138139
import torch_xla.distributed.parallel_loader as xla_pl
139140
import torch_xla.core.xla_model as xm
140141
except ImportError:
@@ -249,8 +250,8 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
249250
dl_outputs = []
250251

251252
# on TPU we have to wrap it under the ParallelLoader
252-
if self.use_tpu and self.tpu_id is None:
253-
device = xm.xla_device()
253+
if self.use_tpu:
254+
device = torch_xla._XLAC._xla_get_default_device()
254255
dataloader = xla_pl.ParallelLoader(dataloader, [device])
255256
dataloader = dataloader.per_device_loader(device)
256257

pytorch_lightning/trainer/training_loop.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def training_step(self, batch, batch_idx):
167167
APEX_AVAILABLE = True
168168

169169
try:
170+
import torch_xla
170171
import torch_xla.distributed.parallel_loader as xla_pl
171172
import torch_xla.core.xla_model as xm
172173
except ImportError:
@@ -412,8 +413,8 @@ def run_training_epoch(self):
412413
train_dataloader = self.train_dataloader
413414

414415
# on TPU we have to wrap it under the ParallelLoader
415-
if self.use_tpu and self.tpu_id is None:
416-
device = xm.xla_device()
416+
if self.use_tpu:
417+
device = torch_xla._XLAC._xla_get_default_device()
417418
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
418419
train_dataloader = train_dataloader.per_device_loader(device)
419420

0 commit comments

Comments
 (0)