Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mean in DDP Sync #2568

Merged
merged 13 commits into from
Aug 4, 2020
9 changes: 9 additions & 0 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,23 +234,32 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""

if torch.distributed.is_available() and torch.distributed.is_initialized():
divide_by_world_size = False

if group is None:
group = torch.distributed.group.WORLD

if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'):
reduce_op = torch.distributed.ReduceOp.SUM
divide_by_process_number = True

# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)

return result


Expand Down
29 changes: 23 additions & 6 deletions tests/metrics/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,23 @@ def _setup_ddp(rank, worldsize):
dist.init_process_group("gloo", rank=rank, world_size=worldsize)


def _ddp_test_fn(rank, worldsize):
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.], device='cuda:0')

reduced_tensor = _sync_ddp_if_available(tensor)
if add_offset:
tensor = tensor + rank

if reduction_mean:
reduced_tensor = _sync_ddp_if_available(tensor, 'avg')

manual_reduction = sum([tensor.item() + i for i in range(dist.get_world_size())]) / dist.get_world_size()
assert reduced_tensor.item() == manual_reduction
else:
reduced_tensor = _sync_ddp_if_available(tensor)

assert reduced_tensor.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Tensors'
assert reduced_tensor.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Tensors'


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand All @@ -131,9 +140,17 @@ def test_sync_reduce_ddp():
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize)

# dist.destroy_process_group()

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_reduce_ddp_mean():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize)


def test_sync_reduce_simple():
Expand Down