Skip to content

Commit 001e81e

Browse files
Add _rank_not_in_group to idist (#3339)
* Add _rank_not_in_group * Fix the function * Fix typehint * Fix the function * Improve the function * Fix tests * Fix tests * Fix the function * Add _rank_not_in_group method to CompModel * Add _rank_not_in_group method to CompModel * Add group option to hvd allgather * Add group option to hvd allgather * Fix a bug * Fix a bug * Remove pytest assertion on hvd's not having allgather with group * More fixes for hvd * few more hvd fixes * Fixed _test_distrib_group for native dist config --------- Co-authored-by: vfdev-5 <[email protected]>
1 parent 678405d commit 001e81e

File tree

10 files changed

+114
-82
lines changed

10 files changed

+114
-82
lines changed

ignite/distributed/comp_models/base.py

+7
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ def barrier(self) -> None:
298298
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
299299
pass
300300

301+
@abstractmethod
302+
def _rank_not_in_group(self, group: Any) -> bool:
303+
pass
304+
301305

302306
class _SerialModel(ComputationModel):
303307
"""Private class defines non-distributed computation model for code compatibility with other distributed models."""
@@ -396,3 +400,6 @@ def new_group(self, ranks: List[int], **kwargs: Any) -> Any:
396400
return self._do_new_group(ranks, **kwargs)
397401
else:
398402
raise ValueError("Argument ranks should be list of int")
403+
404+
def _rank_not_in_group(self, group: Any) -> bool:
405+
return False

ignite/distributed/comp_models/horovod.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import warnings
23
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
34

@@ -23,6 +24,9 @@
2324
if has_hvd_support:
2425
HOROVOD = "horovod"
2526

27+
# Enables dynamic process sets: new_group methods and passing group into collective ops
28+
os.environ["HOROVOD_DYNAMIC_PROCESS_SETS"] = "1"
29+
2630
class _HorovodDistModel(ComputationModel):
2731
"""Private class for `Horovod <https://horovod.readthedocs.io/en/stable/>`_ distributed computation model."""
2832

@@ -155,6 +159,15 @@ def spawn(
155159
**kwargs,
156160
)
157161

162+
def _setup_group(self, group: Any) -> hvd.ProcessSet:
163+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
164+
group = self._do_new_group(group)
165+
if not isinstance(group, hvd.ProcessSet):
166+
raise ValueError(
167+
f"Argument group should be list of int or hvd.ProcessSet, got {type(group)}, group={group}"
168+
)
169+
return group
170+
158171
_reduce_op_map = {
159172
"SUM": hvd.mpi_ops.Sum,
160173
"AVERAGE": hvd.mpi_ops.Average,
@@ -187,19 +200,24 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
187200

188201
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
189202
if group is not None:
190-
raise NotImplementedError("all_gather with group for horovod is not implemented")
203+
group = self._setup_group(group)
204+
if self._rank_not_in_group(group):
205+
return tensor
191206
if tensor.ndimension() == 0:
192207
tensor = tensor.unsqueeze(0)
193-
return hvd.allgather(tensor)
208+
if group is not None:
209+
return hvd.allgather(tensor, process_set=group)
210+
else:
211+
return hvd.allgather(tensor)
194212

195213
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
196214
if group is not None:
197215
raise NotImplementedError("all_gather with group for horovod is not implemented")
198216

199217
return hvd.allgather_object(tensor)
200218

201-
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
202-
return hvd.ProcessSet(ranks)
219+
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> hvd.ProcessSet:
220+
return hvd.add_process_set(ranks)
203221

204222
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
205223
return hvd.broadcast(tensor, root_rank=src)
@@ -208,3 +226,8 @@ def barrier(self) -> None:
208226
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
209227
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
210228
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
229+
230+
def _rank_not_in_group(self, group: Optional[Any]) -> bool:
231+
if group is None:
232+
return False
233+
return not group.included()

ignite/distributed/comp_models/native.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def spawn(
408408
**spawn_kwargs,
409409
)
410410

411-
def _setup_group(self, group: Optional[Any]) -> dist.ProcessGroup:
411+
def _setup_group(self, group: Any) -> dist.ProcessGroup:
412412
if isinstance(group, list) and all(isinstance(item, int) for item in group):
413413
group = self._do_new_group(group)
414414
if not (isinstance(group, dist.ProcessGroup) or group == dist.GroupMember.NON_GROUP_MEMBER):
@@ -442,7 +442,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
442442
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
443443
if group is not None:
444444
group = self._setup_group(group)
445-
if group == dist.GroupMember.NON_GROUP_MEMBER:
445+
if self._rank_not_in_group(group):
446446
return tensor
447447
if group is None:
448448
group_size = self.get_world_size()
@@ -466,7 +466,7 @@ def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Lis
466466
)
467467
if group is not None:
468468
group = self._setup_group(group)
469-
if group == dist.GroupMember.NON_GROUP_MEMBER:
469+
if self._rank_not_in_group(group):
470470
return tensor
471471
if group is None:
472472
group_size = self.get_world_size()
@@ -491,6 +491,9 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
491491
def barrier(self) -> None:
492492
dist.barrier()
493493

494+
def _rank_not_in_group(self, group: Optional[Any]) -> bool:
495+
return dist._rank_not_in_group(group)
496+
494497
def _expand_hostlist(nodelist: str) -> List[str]:
495498
"""Expand a compressed hostlist string and returns all hosts listed.
496499

ignite/distributed/comp_models/xla.py

+3
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,6 @@ def _check_group_type(self, group: Optional[Any]) -> bool:
175175
if isinstance(group, list) and all(isinstance(item, int) for item in group):
176176
return True
177177
return False
178+
179+
def _rank_not_in_group(self, group: Any) -> bool:
180+
return self.get_rank() not in group

ignite/distributed/utils.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import socket
33
from contextlib import contextmanager
44
from functools import wraps
5-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Union
66

77
import torch
8-
from torch import distributed as dist
98

109
from ignite.distributed.comp_models import (
1110
_SerialModel,
@@ -384,15 +383,15 @@ def all_gather_tensors_with_shapes(
384383
if isinstance(group, list) and all(isinstance(item, int) for item in group):
385384
group = _model.new_group(group)
386385

387-
if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
386+
if _rank_not_in_group(group):
388387
return [tensor]
389388

390389
max_shape = torch.tensor(shapes).amax(dim=0)
391390
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
392391
padded_tensor = torch.nn.functional.pad(
393392
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
394393
)
395-
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group)
394+
all_padded_tensors: torch.Tensor = cast(torch.Tensor, _model.all_gather(padded_tensor, group=group))
396395
return [
397396
all_padded_tensors[
398397
[
@@ -731,3 +730,12 @@ def download_dataset():
731730

732731
if current_rank == rank:
733732
barrier()
733+
734+
735+
def _rank_not_in_group(group: Optional[Union[Any, List[int]]]) -> bool:
736+
"""Check if the current process's rank is not in a given group."""
737+
if group is None:
738+
return False
739+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
740+
group = new_group(group)
741+
return _model._rank_not_in_group(group)

tests/ignite/distributed/utils/__init__.py

+48-60
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.distributed as dist
44

55
import ignite.distributed as idist
6-
from ignite.distributed.utils import all_gather_tensors_with_shapes, sync
6+
from ignite.distributed.utils import _rank_not_in_group, all_gather_tensors_with_shapes, sync
77
from ignite.engine import Engine, Events
88

99

@@ -122,7 +122,7 @@ def _test_distrib_all_reduce_group(device):
122122
assert idist.get_world_size() > 1, idist.get_world_size()
123123
assert idist.backend() is not None, idist.backend()
124124

125-
ranks = [0, 1]
125+
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
126126
rank = idist.get_rank()
127127
t = torch.tensor([rank], device=device)
128128
bnd = idist.backend()
@@ -225,32 +225,27 @@ def _test_distrib_all_gather(device):
225225
def _test_distrib_all_gather_group(device):
226226
assert idist.get_world_size() > 1, idist.get_world_size()
227227

228-
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
228+
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
229229
rank = idist.get_rank()
230230
bnd = idist.backend()
231231

232232
t = torch.tensor([rank], device=device)
233233
group = idist.new_group(ranks)
234-
if bnd in ("horovod"):
235-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
236-
res = idist.all_gather(t, group=group)
234+
res = idist.all_gather(t, group=group)
235+
if rank in ranks:
236+
assert torch.equal(res, torch.tensor(ranks, device=device))
237237
else:
238-
res = idist.all_gather(t, group=group)
239-
if rank in ranks:
240-
assert torch.equal(res, torch.tensor(sorted(ranks), device=device)), res
241-
else:
242-
assert res == t
238+
assert res == t
243239

244240
t = torch.tensor([rank], device=device)
245-
if bnd in ("horovod"):
246-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
247-
res = idist.all_gather(t, group=ranks)
241+
if bnd == "horovod":
242+
res = idist.all_gather(t, group=group)
248243
else:
249244
res = idist.all_gather(t, group=ranks)
250-
if rank in ranks:
251-
assert torch.equal(res, torch.tensor(sorted(ranks), device=device))
252-
else:
253-
assert res == t
245+
if rank in ranks:
246+
assert torch.equal(res, torch.tensor(ranks, device=device))
247+
else:
248+
assert res == t
254249

255250
t = {
256251
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
@@ -262,12 +257,12 @@ def _test_distrib_all_gather_group(device):
262257
res = idist.all_gather(t, group=ranks)
263258
elif bnd in ("horovod"):
264259
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
265-
res = idist.all_gather(t, group=ranks)
260+
res = idist.all_gather(t, group=group)
266261
else:
267262
res = idist.all_gather(t, group=ranks)
268263
if rank in ranks:
269264
assert isinstance(res, list) and len(res) == len(ranks)
270-
for i, obj in zip(sorted(ranks), res):
265+
for i, obj in zip(ranks, res):
271266
assert isinstance(obj, dict)
272267
assert list(obj.keys()) == ["a", "b", "c"], obj
273268
expected_device = (
@@ -284,22 +279,20 @@ def _test_distrib_all_gather_group(device):
284279
else:
285280
assert res == t
286281

287-
if bnd in ("nccl", "gloo", "mpi"):
288-
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
282+
t = torch.tensor([rank], device=device)
283+
if bnd in ("nccl", "gloo", "mpi", "horovod"):
284+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
289285
res = idist.all_gather(t, group="abc")
290286
elif bnd in ("xla-tpu"):
291287
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
292288
res = idist.all_gather(t, group="abc")
293-
elif bnd in ("horovod"):
294-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
295-
res = idist.all_gather(t, group="abc")
296289

297290

298291
def _test_idist_all_gather_tensors_with_shapes(device):
299292
torch.manual_seed(41)
300293
rank = idist.get_rank()
301294
ws = idist.get_world_size()
302-
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
295+
reference = torch.randn(ws * 5, ws * 5, ws * 5, device=device)
303296
rank_tensor = reference[
304297
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
305298
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
@@ -312,41 +305,37 @@ def _test_idist_all_gather_tensors_with_shapes(device):
312305
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
313306
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
314307
]
315-
assert (r_tensor == tensors[r]).all()
308+
assert r_tensor.allclose(tensors[r])
316309

317310

318311
def _test_idist_all_gather_tensors_with_shapes_group(device):
319312
assert idist.get_world_size(), idist.get_world_size()
320313
torch.manual_seed(41)
321314

322315
rank = idist.get_rank()
323-
ranks = list(range(1, idist.get_world_size()))
316+
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
324317
ws = idist.get_world_size()
325-
bnd = idist.backend()
326318
if rank in ranks:
327-
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
319+
reference = torch.randn(ws * 5, ws * 5, ws * 5, device=device)
328320
rank_tensor = reference[
329321
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
330322
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
331323
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
332324
]
333325
else:
334326
rank_tensor = torch.tensor([rank], device=device)
335-
if bnd in ("horovod"):
336-
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
337-
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
327+
328+
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
329+
if rank in ranks:
330+
for i, r in enumerate(ranks):
331+
r_tensor = reference[
332+
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
333+
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
334+
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
335+
]
336+
assert r_tensor.allclose(tensors[i])
338337
else:
339-
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
340-
if rank in ranks:
341-
for r in ranks:
342-
r_tensor = reference[
343-
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
344-
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
345-
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
346-
]
347-
assert (r_tensor == tensors[r - 1]).all()
348-
else:
349-
assert [rank_tensor] == tensors
338+
assert [rank_tensor] == tensors
350339

351340

352341
def _test_distrib_broadcast(device):
@@ -413,31 +402,30 @@ def _test_distrib_barrier(device):
413402
assert tt.item() == true_res + 10.0
414403

415404

416-
def _test_distrib_new_group(device):
405+
def _test_distrib_group(device):
406+
ranks = sorted(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [1, 2, 3]
417407
if idist.get_world_size() > 1 and idist.backend() is not None:
418408
bnd = idist.backend()
419-
ranks = [0, 1]
409+
rank = idist.get_rank()
410+
g = idist.new_group(ranks)
420411
if idist.has_native_dist_support and bnd in ("nccl", "gloo", "mpi"):
421-
g1 = idist.new_group(ranks)
422-
g2 = dist.new_group(ranks)
423-
424-
rank = idist.get_rank()
425412
if rank in ranks:
426-
assert g1.rank() == g2.rank()
413+
# mapping between group ranks and global ranks
414+
global_to_group = {r: i for i, r in enumerate(ranks)}
415+
assert g.rank() == global_to_group[rank], (g.rank(), global_to_group, rank)
416+
427417
elif idist.has_xla_support and bnd in ("xla-tpu"):
428-
assert idist.new_group(ranks) == [ranks]
418+
assert g == [ranks]
429419
elif idist.has_hvd_support and bnd in ("horovod"):
430-
from horovod.common.process_sets import ProcessSet
431-
432-
g1 = idist.new_group(ranks)
433-
g2 = ProcessSet(ranks)
434-
435-
rank = idist.get_rank()
436420
if rank in ranks:
437-
assert g1.ranks == g2.ranks
421+
assert g.ranks == ranks
422+
423+
if rank in ranks:
424+
assert not _rank_not_in_group(g)
425+
else:
426+
assert _rank_not_in_group(g)
438427

439428
elif idist.backend() is None:
440-
ranks = [0, 1]
441429
assert idist.new_group(ranks) == ranks
442430

443431
with pytest.raises(ValueError, match="Argument ranks should be list of int"):

0 commit comments

Comments
 (0)