Skip to content

Commit ad0f119

Browse files
justusschockBorda
andauthored
Support Mean in DDP Sync (#2568)
* Update converters.py * Update test_converters.py * pep8 * pep8 tests * Update test_datamodules.py * Update test_converters.py * Update converters.py * Update test_datamodules.py * Update test_converters.py * Update test_converters.py * fix tests * fix ddp tests on windows * chlog Co-authored-by: Jirka Borovec <[email protected]>
1 parent 38d6b25 commit ad0f119

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727

2828
- Added call DataModule hooks implicitly in trainer ([#2755](https://github.com/PyTorchLightning/pytorch-lightning/pull/2755))
2929

30+
- Added support for Mean in DDP Sync ([#2568](https://github.com/PyTorchLightning/pytorch-lightning/pull/2568))
31+
3032
### Changed
3133

3234
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

pytorch_lightning/metrics/converters.py

+9
Original file line numberDiff line numberDiff line change
@@ -234,23 +234,32 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
234234
result: the value to sync and reduce (typically tensor or number)
235235
group: the process group to gather results from. Defaults to all processes (world)
236236
reduce_op: the reduction operation. Defaults to sum.
237+
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
237238
238239
Return:
239240
reduced value
240241
"""
241242

242243
if torch.distributed.is_available() and torch.distributed.is_initialized():
244+
divide_by_world_size = False
245+
243246
if group is None:
244247
group = torch.distributed.group.WORLD
245248

246249
if reduce_op is None:
247250
reduce_op = torch.distributed.ReduceOp.SUM
251+
elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'):
252+
reduce_op = torch.distributed.ReduceOp.SUM
253+
divide_by_world_size = True
248254

249255
# sync all processes before reduction
250256
torch.distributed.barrier(group=group)
251257
torch.distributed.all_reduce(result, op=reduce_op, group=group,
252258
async_op=False)
253259

260+
if divide_by_world_size:
261+
result = result / torch.distributed.get_world_size(group)
262+
254263
return result
255264

256265

tests/metrics/test_converters.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
import sys
34
import torch
45
import torch.distributed as dist
56
import torch.multiprocessing as mp
@@ -114,26 +115,44 @@ def _setup_ddp(rank, worldsize):
114115
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
115116

116117

117-
def _ddp_test_fn(rank, worldsize):
118+
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
118119
_setup_ddp(rank, worldsize)
119-
tensor = torch.tensor([1.], device='cuda:0')
120-
121-
reduced_tensor = _sync_ddp_if_available(tensor)
120+
if add_offset:
121+
tensor = torch.tensor([float(rank)])
122+
else:
123+
tensor = torch.tensor([1.], )
124+
if reduction_mean:
125+
reduced_tensor = _sync_ddp_if_available(tensor, reduce_op='avg')
126+
127+
manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size()
128+
print(reduced_tensor)
129+
print(manual_reduction)
130+
assert reduced_tensor.item() == manual_reduction
131+
else:
132+
reduced_tensor = _sync_ddp_if_available(tensor)
122133

123-
assert reduced_tensor.item() == dist.get_world_size(), \
124-
'Sync-Reduce does not work properly with DDP and Tensors'
134+
assert reduced_tensor.item() == dist.get_world_size(), \
135+
'Sync-Reduce does not work properly with DDP and Tensors'
125136

126137

127-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
138+
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
128139
def test_sync_reduce_ddp():
129140
"""Make sure sync-reduce works with DDP"""
130141
tutils.reset_seed()
131142
tutils.set_random_master_port()
132143

133144
worldsize = 2
134-
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
145+
mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize)
135146

136-
# dist.destroy_process_group()
147+
148+
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
149+
def test_sync_reduce_ddp_mean():
150+
"""Make sure sync-reduce works with DDP"""
151+
tutils.reset_seed()
152+
tutils.set_random_master_port()
153+
154+
worldsize = 2
155+
mp.spawn(_ddp_test_fn, args=(worldsize, True, True), nprocs=worldsize)
137156

138157

139158
def test_sync_reduce_simple():
@@ -172,7 +191,7 @@ def _ddp_test_tensor_metric(rank, worldsize):
172191
_test_tensor_metric(True)
173192

174193

175-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
194+
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
176195
def test_tensor_metric_ddp():
177196
tutils.reset_seed()
178197
tutils.set_random_master_port()
@@ -212,7 +231,7 @@ def _ddp_test_numpy_metric(rank, worldsize):
212231
_test_numpy_metric(True)
213232

214233

215-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
234+
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
216235
def test_numpy_metric_ddp():
217236
tutils.reset_seed()
218237
tutils.set_random_master_port()

0 commit comments

Comments
 (0)