Skip to content

Commit d0c9472

Browse files
areshytkoAlexander ReshytkoBordawilliamFalcon
authored
Add SLURM check in ddp_train() and init_ddp_connection() (#1387)
* slurm check in ddp_train and init_ddp_connection * Remove code example in init_ddp_connection Co-Authored-By: Jirka Borovec <[email protected]> * remove blank line Co-Authored-By: Jirka Borovec <[email protected]> * improve for test coverage Co-Authored-By: Jirka Borovec <[email protected]> * update changelog * Default values and warnings for DDP env variables * fix merge artifacts * update localhost value * change to NODE_RANK Co-authored-by: Alexander Reshytko <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent 7131685 commit d0c9472

File tree

3 files changed

+50
-52
lines changed

3 files changed

+50
-52
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Added flag `replace_sampler_ddp` to manually disaple sampler replacement in ddp ([#1513](https://github.com/PyTorchLightning/pytorch-lightning/pull/1513))
1212
- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
13+
14+
- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
15+
16+
- Added support for ddp mode in clusters without SLURM ([#1345](https://github.com/PyTorchLightning/pytorch-lightning/issues/1345))
17+
1318
- Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
1419

1520
- Added `terminate_on_nan` flag to trainer that performs a NaN check with each training iteration when set to `True`. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
1621

22+
1723
### Changed
1824

1925
- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))

pytorch_lightning/core/lightning.py

+38-47
Original file line numberDiff line numberDiff line change
@@ -873,53 +873,10 @@ def configure_ddp(self, model, device_ids):
873873
)
874874
return model
875875

876-
def init_ddp_connection(self, proc_rank: int, world_size: int) -> None:
877-
r"""
878-
Override to define your custom way of setting up a distributed environment.
879-
880-
Lightning's implementation uses ``env://`` init by default and sets the first node as root.
881-
882-
Args:
883-
proc_rank: The current process rank within the node.
884-
world_size: Number of GPUs being use across all nodes (num_nodes * num_gpus).
885-
886-
Examples:
887-
.. code-block:: python
888-
889-
def init_ddp_connection(self):
890-
# use slurm job id for the port number
891-
# guarantees unique ports across jobs from same grid search
892-
try:
893-
# use the last 4 numbers in the job id as the id
894-
default_port = os.environ['SLURM_JOB_ID']
895-
default_port = default_port[-4:]
896-
897-
# all ports should be in the 10k+ range
898-
default_port = int(default_port) + 15000
899-
900-
except Exception as e:
901-
default_port = 12910
902-
903-
# if user gave a port number, use that one instead
904-
try:
905-
default_port = os.environ['MASTER_PORT']
906-
except Exception:
907-
os.environ['MASTER_PORT'] = str(default_port)
908-
909-
# figure out the root node addr
910-
try:
911-
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
912-
except Exception:
913-
root_node = '127.0.0.2'
914-
915-
root_node = self.trainer.resolve_root_node_address(root_node)
916-
os.environ['MASTER_ADDR'] = root_node
917-
dist.init_process_group(
918-
'nccl',
919-
rank=self.proc_rank,
920-
world_size=self.world_size
921-
)
922-
876+
def _init_slurm_connection(self) -> None:
877+
"""
878+
Sets up environemnt variables necessary for pytorch distributed communications
879+
based on slurm environment.
923880
"""
924881
# use slurm job id for the port number
925882
# guarantees unique ports across jobs from same grid search
@@ -948,6 +905,40 @@ def init_ddp_connection(self):
948905

949906
root_node = self.trainer.resolve_root_node_address(root_node)
950907
os.environ['MASTER_ADDR'] = root_node
908+
909+
def init_ddp_connection(
910+
self,
911+
proc_rank: int,
912+
world_size: int,
913+
is_slurm_managing_tasks: bool = True
914+
) -> None:
915+
"""
916+
Override to define your custom way of setting up a distributed environment.
917+
918+
Lightning's implementation uses env:// init by default and sets the first node as root
919+
for SLURM managed cluster.
920+
921+
Args:
922+
proc_rank: The current process rank within the node.
923+
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
924+
is_slurm_managing_tasks: is cluster managed by SLURM.
925+
926+
"""
927+
if is_slurm_managing_tasks:
928+
self._init_slurm_connection()
929+
930+
if 'MASTER_ADDR' not in os.environ:
931+
log.warning("MASTER_ADDR environment variable is not defined. Set as localhost")
932+
os.environ['MASTER_ADDR'] = '127.0.0.1'
933+
934+
if 'MASTER_PORT' not in os.environ:
935+
log.warning("MASTER_PORT environment variable is not defined. Set as 12910")
936+
os.environ['MASTER_PORT'] = '12910'
937+
938+
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
939+
log.warning("WORLD_SIZE environment variable is not equal to the computed "
940+
"world size. Ignored.")
941+
951942
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
952943
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)
953944

pytorch_lightning/trainer/distrib_data_parallel.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,13 @@ def ddp_train(self, process_idx, model):
285285
:param cluster_obj:
286286
:return:
287287
"""
288-
# node rank using relative slurm id
289-
# otherwise default to node rank 0
288+
# node rank using relative slurm id if under slurm management
289+
# otherwise use given node rank or default to node rank 0
290290
try:
291-
node_id = os.environ['SLURM_NODEID']
291+
node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK']
292292
self.node_rank = int(node_id)
293-
except Exception:
293+
except KeyError:
294+
log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.")
294295
self.node_rank = 0
295296

296297
# show progressbar only on progress_rank 0
@@ -317,7 +318,7 @@ def ddp_train(self, process_idx, model):
317318
# try to init for 20 times at max in case ports are taken
318319
# where to store ip_table
319320
model.trainer = self
320-
model.init_ddp_connection(self.proc_rank, self.world_size)
321+
model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks)
321322

322323
# CHOOSE OPTIMIZER
323324
# allow for lr schedulers as well

0 commit comments

Comments
 (0)