Skip to content

Commit a76736a

Browse files
committed
Initial commit of Horovod distributed backend implementation
1 parent 4d24032 commit a76736a

17 files changed

+597
-25
lines changed

.circleci/config.yml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ references:
1010
run:
1111
name: Install Dependences
1212
command: |
13+
sudo apt-get update && sudo apt-get install -y cmake
1314
pip install "$TORCH_VERSION"
1415
pip install -r requirements.txt -q
1516
sudo pip install pytest pytest-cov pytest-flake8 -q

.drone.yml

+10-1
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,33 @@ name: torch-GPU
66

77
steps:
88
- name: testing
9-
image: pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime
9+
image: pytorch/pytorch:1.4-cuda10.1-cudnn7-devel
1010

1111
environment:
1212
SLURM_LOCALID: 0
1313
CODECOV_TOKEN:
1414
from_secret: codecov_token
15+
HOROVOD_GPU_ALLREDUCE: NCCL
16+
HOROVOD_GPU_BROADCAST: NCCL
17+
HOROVOD_WITH_PYTORCH: 1
18+
HOROVOD_WITHOUT_TENSORFLOW: 1
19+
HOROVOD_WITHOUT_MXNET: 1
20+
HOROVOD_WITH_GLOO: 1
21+
HOROVOD_WITHOUT_MPI: 1
1522

1623
#volumes:
1724
# # Mount pip cache from host
1825
# - name: pip_cache
1926
# path: /opt/conda/lib/python3.7/site-packages
2027

2128
commands:
29+
- export PATH="$PATH:/root/.local/bin"
2230
- python --version
2331
- pip install pip -U
2432
- pip --version
2533
- nvidia-smi
2634
- bash ./tests/install_AMP.sh
35+
- apt-get update && apt-get install -y cmake
2736
- pip install -r requirements.txt --user -q
2837
- pip install coverage pytest pytest-cov pytest-flake8 codecov -q
2938
- pip install -r ./tests/requirements.txt --user -q

.github/workflows/ci-testing.yml

+9-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,15 @@ jobs:
4141
if: runner.os == 'macOS'
4242
run: |
4343
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
44+
brew install openmpi # Horovod on macOS requires OpenMPI, Gloo not currently supported
4445
45-
# TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved
4646
- name: Setup Windows
47+
if: runner.os == 'windows'
48+
run: |
49+
python -c "lines = [line for line in open('requirements-extra.txt').readlines() if not line.startswith('horovod')] ; open('requirements-extra.txt', 'w').writelines(lines)"
50+
51+
# TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved
52+
- name: Setup Windows on Latest
4753
if: runner.os == 'windows' && matrix.requires == 'latest'
4854
run: |
4955
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)"
@@ -75,11 +81,12 @@ jobs:
7581
run: |
7682
# python -m pip install --upgrade --user pip
7783
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q
78-
pip install -r ./tests/requirements.txt -q
84+
HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install -r ./tests/requirements.txt -q
7985
# pip install tox coverage
8086
python --version
8187
pip --version
8288
pip list
89+
shell: bash
8390

8491
- name: Tests
8592
# env:

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
- Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158))
2626

27+
- Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529))
2728

2829
### Changed
2930

docs/source/multi_gpu.rst

+38
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ Lightning allows multiple ways of training
7676
- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine)
7777
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
7878
- DistributedDataParallel2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
79+
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
7980
- TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod)
8081

8182
Data Parallel (dp)
@@ -136,6 +137,43 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across
136137
# train on 32 GPUs (4 nodes)
137138
trainer = pl.Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4)
138139
140+
Horovod
141+
^^^^^^^
142+
`Horovod <http://horovod.ai>`_ allows the same training script to be used for single-GPU,
143+
multi-GPU, and multi-node training.
144+
145+
Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed
146+
subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass,
147+
then synchronously applied before beginning the next step.
148+
149+
The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In
150+
the training script, Horovod will detect the number of workers from the environment, and automatically
151+
scale the learning rate to compensate for the increased total batch size.
152+
153+
Horovod can be configured in the training script to run with any number of GPUs / processes as follows:
154+
155+
.. code-block:: python
156+
157+
# train Horovod on GPU (number of GPUs / machines provided on command-line)
158+
trainer = pl.Trainer(distributed_backend='horovod', gpus=1)
159+
160+
# train Horovod on CPU (number of processes / machines provided on command-line)
161+
trainer = pl.Trainer(distributed_backend='horovod')
162+
163+
When starting the training job, the driver application will then be used to specify the total
164+
number of worker processes:
165+
166+
.. code-block:: bash
167+
168+
# run training with 4 GPUs on a single machine
169+
horovodrun -np 4 python train.py
170+
171+
# run training with 8 GPUs on two machines (4 GPUs each)
172+
horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py
173+
174+
See the official `Horovod documentation <https://horovod.readthedocs.io/en/stable>`_ for details
175+
on installation and performance tuning.
176+
139177
DP/DDP2 caveats
140178
^^^^^^^^^^^^^^^
141179
In DP and DDP2 each GPU within a machine sees a portion of a batch.

pytorch_lightning/trainer/data_loading.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
else:
2727
XLA_AVAILABLE = True
2828

29+
try:
30+
import horovod.torch as hvd
31+
except ImportError:
32+
HOROVOD_AVAILABLE = False
33+
else:
34+
HOROVOD_AVAILABLE = True
35+
2936

3037
def _has_len(dataloader: DataLoader) -> bool:
3138
""" Checks if a given Dataloader has __len__ method implemented i.e. if
@@ -47,6 +54,7 @@ class TrainerDataLoadingMixin(ABC):
4754
proc_rank: int
4855
use_ddp: bool
4956
use_ddp2: bool
57+
use_horovod: bool
5058
shown_warnings: ...
5159
val_check_interval: float
5260
use_tpu: bool
@@ -89,7 +97,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
8997
# don't do anything if it's not a dataloader
9098
if not isinstance(dataloader, DataLoader):
9199
return dataloader
92-
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu)
100+
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
93101
if self.replace_sampler_ddp and need_dist_sampler:
94102

95103
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
@@ -104,6 +112,10 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
104112
num_replicas=xm.xrt_world_size(),
105113
rank=xm.get_ordinal(),
106114
)
115+
elif self.use_horovod:
116+
sampler = DistributedSampler(dataloader.dataset,
117+
num_replicas=hvd.size(),
118+
rank=hvd.rank())
107119
else:
108120
world_size = {
109121
'ddp': self.num_nodes * self.num_processes,
@@ -254,6 +266,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
254266
# all processes wait until data download has happened
255267
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
256268

269+
elif self.use_horovod:
270+
# all processes wait until data download has happened
271+
hvd.join()
272+
257273
return dataloader
258274

259275
def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,

pytorch_lightning/trainer/distrib_data_parallel.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ def train_fx(trial_hparams, cluster_manager, _):
131131
else:
132132
APEX_AVAILABLE = True
133133

134+
try:
135+
import horovod.torch as hvd
136+
except ImportError:
137+
HOROVOD_AVAILABLE = False
138+
else:
139+
HOROVOD_AVAILABLE = True
140+
134141

135142
class TrainerDDPMixin(ABC):
136143

@@ -178,10 +185,14 @@ def set_distributed_mode(self, distributed_backend):
178185
self.use_dp = False
179186
self.use_ddp = False
180187
self.use_ddp2 = False
188+
self.use_horovod = False
181189
self.single_gpu = False
182190

183191
if distributed_backend is None:
184-
if self.num_gpus == 0:
192+
if self.has_horovodrun():
193+
self.check_horovod()
194+
self.use_horovod = True
195+
elif self.num_gpus == 0:
185196
if self.num_nodes > 1 or self.num_processes > 1:
186197
self.use_ddp = True # ddp_cpu
187198
elif self.num_gpus == 1:
@@ -219,6 +230,9 @@ def set_distributed_mode(self, distributed_backend):
219230
self.use_ddp = True
220231
self.data_parallel_device_ids = None
221232
self.on_gpu = False
233+
elif distributed_backend == 'horovod':
234+
self.check_horovod()
235+
self.use_horovod = True
222236

223237
# throw error to force user ddp or ddp2 choice
224238
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
@@ -402,3 +416,23 @@ def resolve_root_node_address(self, root_node):
402416
root_node = name + number
403417

404418
return root_node
419+
420+
def check_horovod(self):
421+
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
422+
if not HOROVOD_AVAILABLE:
423+
raise MisconfigurationException(
424+
'Requested `distributed_backend="horovod"`, but Horovod is not available. See: '
425+
'https://horovod.readthedocs.io/en/stable/install_include.html for installation '
426+
'instructions.'
427+
)
428+
429+
if self.num_gpus > 1 or self.num_nodes > 1:
430+
raise MisconfigurationException(
431+
'Horovod does not support setting num_nodes / num_gpus explicitly. Use '
432+
'horovodrun / mpirun to configure the number of processes.'
433+
)
434+
435+
@staticmethod
436+
def has_horovodrun():
437+
"""Returns True if running with `horovodrun` using Gloo or OpenMPI."""
438+
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ

pytorch_lightning/trainer/distrib_parts.py

+69
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,16 @@
337337
338338
"""
339339

340+
from contextlib import ExitStack
340341
import os
341342
from abc import ABC, abstractmethod
342343
import time
343344
import random
344345
import torch
346+
from typing import Union
345347

346348
from pytorch_lightning import _logger as log
349+
from pytorch_lightning.loggers import LightningLoggerBase
347350
from pytorch_lightning.overrides.data_parallel import (
348351
LightningDistributedDataParallel,
349352
LightningDataParallel,
@@ -365,6 +368,13 @@
365368
else:
366369
XLA_AVAILABLE = True
367370

371+
try:
372+
import horovod.torch as hvd
373+
except ImportError:
374+
HOROVOD_AVAILABLE = False
375+
else:
376+
HOROVOD_AVAILABLE = True
377+
368378

369379
class TrainerDPMixin(ABC):
370380

@@ -385,6 +395,7 @@ class TrainerDPMixin(ABC):
385395
tpu_global_core_rank: int
386396
use_tpu: bool
387397
data_parallel_device_ids: ...
398+
logger: Union[LightningLoggerBase, bool]
388399

389400
@property
390401
@abstractmethod
@@ -540,6 +551,64 @@ def dp_train(self, model):
540551

541552
self.run_pretrain_routine(model)
542553

554+
def horovod_train(self, model):
555+
# Horovod: initialize library
556+
hvd.init()
557+
558+
if torch.cuda.is_available() and self.on_gpu:
559+
# Horovod: pin GPU to local rank
560+
torch.cuda.set_device(hvd.local_rank())
561+
model.cuda(hvd.local_rank())
562+
563+
# Only show progress bar from the first worker
564+
self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if hvd.rank() == 0 else 0
565+
566+
# CHOOSE OPTIMIZER
567+
# allow for lr schedulers as well
568+
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
569+
570+
# Horovod: scale the learning rate by the number of workers to account for
571+
# increased total batch size
572+
for optimizer in self.optimizers:
573+
for param_group in optimizer.param_groups:
574+
param_group['lr'] *= hvd.size()
575+
576+
if self.use_amp:
577+
# An example
578+
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
579+
self.optimizers = optimizers
580+
581+
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
582+
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
583+
for optimizer in self.optimizers:
584+
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
585+
586+
def filter_named_parameters(model, optimizer):
587+
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])])
588+
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
589+
590+
# Horovod: wrap optimizers to perform gradient aggregation via allreduce
591+
self.optimizers = [
592+
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer))
593+
for optimizer in self.optimizers
594+
]
595+
596+
# Update logger rank info from Horovod to avoid race conditions from different ranks
597+
# creating directories / writing files in the same locations.
598+
self.proc_rank = hvd.rank()
599+
set_proc_rank(self.proc_rank)
600+
if self.logger:
601+
self.logger.rank = self.proc_rank
602+
if model.logger:
603+
model.logger.rank = self.proc_rank
604+
605+
with ExitStack() as stack:
606+
for optimizer in self.optimizers:
607+
# Synchronization will be performed explicitly following backward()
608+
stack.enter_context(optimizer.skip_synchronize())
609+
610+
self.run_pretrain_routine(model)
611+
543612

544613
def normalize_parse_gpu_string_input(s):
545614
if isinstance(s, str):

pytorch_lightning/trainer/evaluation_loop.py

+14
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@
145145
else:
146146
XLA_AVAILABLE = True
147147

148+
try:
149+
import horovod.torch as hvd
150+
except ImportError:
151+
HOROVOD_AVAILABLE = False
152+
else:
153+
HOROVOD_AVAILABLE = True
154+
148155

149156
class TrainerEvaluationLoopMixin(ABC):
150157

@@ -153,9 +160,11 @@ class TrainerEvaluationLoopMixin(ABC):
153160
test_progress_bar: ...
154161
val_progress_bar: ...
155162
main_progress_bar: ...
163+
on_gpu: bool
156164
use_ddp: bool
157165
use_dp: bool
158166
use_ddp2: bool
167+
use_horovod: bool
159168
single_gpu: bool
160169
data_parallel_device_ids: ...
161170
model: LightningModule
@@ -429,6 +438,11 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:
429438
output = model(*args)
430439
return output
431440

441+
# Horovod
442+
if self.use_horovod and self.on_gpu:
443+
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
444+
args[0] = batch
445+
432446
# single GPU data transfer
433447
if self.single_gpu:
434448
# for single GPU put inputs on gpu manually

0 commit comments

Comments
 (0)