@@ -153,6 +153,16 @@ def train_fx(trial_hparams, cluster_manager, _):
153
153
HYDRA_AVAILABLE = True
154
154
155
155
156
+ try :
157
+ import torch_xla
158
+ import torch_xla .core .xla_model as xm
159
+ import torch_xla .distributed .xla_multiprocessing as xmp
160
+ except ImportError :
161
+ XLA_AVAILABLE = False
162
+ else :
163
+ XLA_AVAILABLE = True
164
+
165
+
156
166
class TrainerDDPMixin (ABC ):
157
167
158
168
# this is just a summary on variables used in this abstract class,
@@ -172,6 +182,7 @@ class TrainerDDPMixin(ABC):
172
182
num_processes : int
173
183
num_nodes : int
174
184
node_rank : int
185
+ tpu_cores : int
175
186
176
187
@property
177
188
def is_global_zero (self ) -> int :
@@ -277,6 +288,8 @@ def set_distributed_mode(self, distributed_backend):
277
288
)
278
289
279
290
rank_zero_info (f'GPU available: { torch .cuda .is_available ()} , used: { self .on_gpu } ' )
291
+ num_cores = self .tpu_cores if self .tpu_cores is not None else 0
292
+ rank_zero_info (f'TPU available: { XLA_AVAILABLE } , using: { num_cores } TPU cores' )
280
293
281
294
def configure_slurm_ddp (self , num_gpu_nodes ):
282
295
self .is_slurm_managing_tasks = False
@@ -329,7 +342,6 @@ def determine_ddp_node_rank(self):
329
342
node_ids = [(k , os .environ .get (k , None )) for k in env_vars ]
330
343
node_ids = [(k , v ) for k , v in node_ids if v is not None ]
331
344
if len (node_ids ) == 0 :
332
- log .warning ("No environment variable for node rank defined. Set as 0." )
333
345
return 0
334
346
if len (node_ids ) > 1 :
335
347
log .warning (f"Multiple environment variables ({ node_ids } ) defined for node rank. "
0 commit comments