Skip to content

Commit 42b6fd2

Browse files
williamFalconjustusschock
authored andcommitted
Adds back the slow spawn ddp implementation that people want (#2115)
* training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * adding spawn * adding spawn * adding spawn * adding spawn * adding spawn * adding spawn * adding spawn * adding spawn
1 parent ee05ee1 commit 42b6fd2

File tree

4 files changed

+127
-5
lines changed

4 files changed

+127
-5
lines changed

docs/source/multi_gpu.rst

+111-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ Distributed modes
200200
Lightning allows multiple ways of training
201201

202202
- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine)
203-
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
203+
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines (python script based)).
204+
- DistributedDataParallel (`distributed_backend='ddp_spawn'`) (multiple-gpus across many machines (spawn based)).
204205
- DistributedDataParallel 2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
205206
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
206207
- TPUs (`tpu_cores=8|x`) (tpu or TPU pod)
@@ -253,6 +254,26 @@ Distributed Data Parallel
253254
# train on 32 GPUs (4 nodes)
254255
trainer = Trainer(gpus=8, distributed_backend='ddp', num_nodes=4)
255256
257+
This Lightning implementation of ddp calls your script under the hood multiple times with the correct environment
258+
variables. If your code does not support this (ie: jupyter notebook, colab, or a nested script without a root package),
259+
use `dp` or `ddp_spawn`
260+
261+
.. code-block:: bash
262+
263+
# example for 3 GPUs ddp
264+
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
265+
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
266+
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
267+
268+
The reason we use ddp this way is because `ddp_spawn` has a few limitations (because of Python and PyTorch):
269+
270+
1. Since `.spawn()` trains the model in subprocesses, the model on the main process does not get updated.
271+
2. Dataloader(num_workers=N) where N is large bottlenecks training with ddp...
272+
ie: it will be VERY slow or not work at all. This is a PyTorch limitation.
273+
3. Forces everything to be picklable.
274+
275+
However, if you don't mind these limitations, please use `ddp_spawn`.
276+
256277
Distributed Data Parallel 2
257278
^^^^^^^^^^^^^^^^^^^^^^^^^^^
258279
In certain cases, it's advantageous to use all batches on the same machine instead of a subset.
@@ -275,6 +296,75 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across
275296
# train on 32 GPUs (4 nodes)
276297
trainer = Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4)
277298
299+
Distributed Data Parallel Spawn
300+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
301+
`ddp_spawn` is exactly like `ddp` except that it uses .spawn to start the training processes.
302+
303+
.. warning:: It is STRONGLY recommended to use `ddp` for speed and performance.
304+
305+
.. code-block:: python
306+
307+
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
308+
309+
Here's how to call this.
310+
311+
.. code-block:: python
312+
313+
# train on 8 GPUs (same machine (ie: node))
314+
trainer = Trainer(gpus=8, distributed_backend='ddp')
315+
316+
Use this method if your script does not support being called from the command line (ie: it is nested without a root
317+
project module). However, we STRONGLY discourage this use because it has limitations (because of Python and PyTorch):
318+
319+
1. The model you pass in will not update. Please save a checkpoint and restore from there.
320+
2. Set Dataloader(num_workers=0) or it will bottleneck training.
321+
322+
`ddp` is MUCH faster than `ddp_spawn`. We recommend you install a top-level module for your project using setup.py
323+
324+
.. code-block:: python
325+
326+
# setup.py
327+
#!/usr/bin/env python
328+
329+
from setuptools import setup, find_packages
330+
331+
setup(name='src',
332+
version='0.0.1',
333+
description='Describe Your Cool Project',
334+
author='',
335+
author_email='',
336+
url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
337+
install_requires=[
338+
'pytorch-lightning'
339+
],
340+
packages=find_packages()
341+
)
342+
343+
Then setup your project like so:
344+
345+
.. code-block:: bash
346+
347+
/project
348+
/src
349+
some_file.py
350+
/or_a_folder
351+
setup.py
352+
353+
Then install as a root-level package
354+
355+
.. code-block:: bash
356+
357+
cd /project
358+
pip install -e .
359+
360+
Now you can call your scripts anywhere
361+
362+
.. code-block:: bash
363+
364+
cd /project/src
365+
python some_file.py --distributed_backend 'ddp' --gpus 8
366+
367+
278368
Horovod
279369
^^^^^^^
280370
`Horovod <http://horovod.ai>`_ allows the same training script to be used for single-GPU,
@@ -516,3 +606,23 @@ And then launch the elastic job with:
516606
517607
See the official `PytorchElastic documentation <https://pytorch.org/elastic>`_ for details
518608
on installation and more use cases.
609+
610+
Jupyter Notebooks
611+
-----------------
612+
Unfortunately any `ddp_` is not supported in jupyter notebooks. Please use `dp` for multiple GPUs. This is a known
613+
Jupyter issue. If you feel like taking a stab at adding this support, feel free to submit a PR!
614+
615+
Pickle Errors
616+
--------------
617+
Multi-GPU training sometimes requires your model to be pickled. If you run into an issue with pickling
618+
try the following to figure out the issue
619+
620+
.. code-block:: python
621+
622+
import pickle
623+
624+
model = YourModel()
625+
pickle.dumps(model)
626+
627+
However, if you use `ddp` the pickling requirement is not there and you should be fine. If you use `ddp_spawn` the
628+
pickling requirement remains. This is a limitation of Python.

pytorch_lightning/trainer/data_loading.py

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _get_distributed_sampler(self, dataloader):
140140
else:
141141
world_size = {
142142
'ddp': self.num_nodes * self.num_processes,
143+
'ddp_spawn': self.num_nodes * self.num_processes,
143144
'ddp2': self.num_nodes,
144145
'ddp_cpu': self.num_processes * self.num_nodes
145146
}

pytorch_lightning/trainer/distrib_data_parallel.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def set_distributed_mode(self, distributed_backend):
221221
elif self.num_gpus > 1:
222222
self.use_dp = True
223223

224-
elif distributed_backend == "ddp":
224+
elif distributed_backend in ['ddp', 'ddp_spawn']:
225225
if self.num_gpus == 0:
226226
if self.num_nodes > 1 or self.num_processes > 1:
227227
self.use_ddp = True # ddp_cpu
@@ -378,6 +378,7 @@ def spawn_ddp_children(self, model):
378378

379379
self.interactive_ddp_procs = []
380380
for local_rank in range(1, self.num_processes):
381+
print('launching local_rank', local_rank)
381382
env_copy = os.environ.copy()
382383
env_copy['LOCAL_RANK'] = f'{local_rank}'
383384

@@ -394,14 +395,17 @@ def spawn_ddp_children(self, model):
394395
local_rank = 0
395396
self.ddp_train(local_rank, model, is_master=True)
396397

397-
def ddp_train(self, process_idx, model, is_master=False):
398+
def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
398399
"""
399400
Entry point into a DP thread
400401
:param gpu_idx:
401402
:param model:
402403
:param cluster_obj:
403404
:return:
404405
"""
406+
# offset the process id if requested
407+
process_idx = process_idx + proc_offset
408+
405409
# show progressbar only on progress_rank 0
406410
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
407411
self.progress_bar_callback.disable()
@@ -454,7 +458,7 @@ def ddp_train(self, process_idx, model, is_master=False):
454458
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
455459

456460
# DDP2 uses all GPUs on the machine
457-
if self.distributed_backend == 'ddp':
461+
if self.distributed_backend == 'ddp' or self.distributed_backend == 'ddp_spawn':
458462
device_ids = [self.root_gpu]
459463
elif self.use_ddp2:
460464
device_ids = self.data_parallel_device_ids

pytorch_lightning/trainer/trainer.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
247247
Use `row_log_interval` instead. Will remove 0.9.0.
248248
249-
distributed_backend: The distributed backend to use.
249+
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn)
250250
251251
use_amp:
252252
.. warning:: .. deprecated:: 0.7.0
@@ -876,9 +876,16 @@ def fit(
876876
self.ddp_train(task, model)
877877

878878
elif self.distributed_backend == 'cpu_ddp':
879+
self.__set_random_port()
879880
self.model = model
880881
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
881882

883+
elif self.distributed_backend == 'ddp_spawn':
884+
model.share_memory()
885+
886+
# spin up peers
887+
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
888+
882889
elif self.distributed_backend == 'ddp':
883890
self.spawn_ddp_children(model)
884891

0 commit comments

Comments
 (0)