Skip to content

Commit 15cf6a8

Browse files
Tpu logging (#2230)
* add tpu view * add tpu view * add tpu view * add tpu view * add tpu view
1 parent 9ec9404 commit 15cf6a8

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

pytorch_lightning/trainer/distrib_data_parallel.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@ def train_fx(trial_hparams, cluster_manager, _):
153153
HYDRA_AVAILABLE = True
154154

155155

156+
try:
157+
import torch_xla
158+
import torch_xla.core.xla_model as xm
159+
import torch_xla.distributed.xla_multiprocessing as xmp
160+
except ImportError:
161+
XLA_AVAILABLE = False
162+
else:
163+
XLA_AVAILABLE = True
164+
165+
156166
class TrainerDDPMixin(ABC):
157167

158168
# this is just a summary on variables used in this abstract class,
@@ -172,6 +182,7 @@ class TrainerDDPMixin(ABC):
172182
num_processes: int
173183
num_nodes: int
174184
node_rank: int
185+
tpu_cores: int
175186

176187
@property
177188
def is_global_zero(self) -> int:
@@ -277,6 +288,8 @@ def set_distributed_mode(self, distributed_backend):
277288
)
278289

279290
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
291+
num_cores = self.tpu_cores if self.tpu_cores is not None else 0
292+
rank_zero_info(f'TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores')
280293

281294
def configure_slurm_ddp(self, num_gpu_nodes):
282295
self.is_slurm_managing_tasks = False
@@ -329,7 +342,6 @@ def determine_ddp_node_rank(self):
329342
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
330343
node_ids = [(k, v) for k, v in node_ids if v is not None]
331344
if len(node_ids) == 0:
332-
log.warning("No environment variable for node rank defined. Set as 0.")
333345
return 0
334346
if len(node_ids) > 1:
335347
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "

0 commit comments

Comments
 (0)