3
3
import torch .distributed as dist
4
4
5
5
import ignite .distributed as idist
6
- from ignite .distributed .utils import all_gather_tensors_with_shapes , sync
6
+ from ignite .distributed .utils import _rank_not_in_group , all_gather_tensors_with_shapes , sync
7
7
from ignite .engine import Engine , Events
8
8
9
9
@@ -122,7 +122,7 @@ def _test_distrib_all_reduce_group(device):
122
122
assert idist .get_world_size () > 1 , idist .get_world_size ()
123
123
assert idist .backend () is not None , idist .backend ()
124
124
125
- ranks = [ 0 , 1 ]
125
+ ranks = sorted ( range ( idist . get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2 , 1]
126
126
rank = idist .get_rank ()
127
127
t = torch .tensor ([rank ], device = device )
128
128
bnd = idist .backend ()
@@ -225,32 +225,27 @@ def _test_distrib_all_gather(device):
225
225
def _test_distrib_all_gather_group (device ):
226
226
assert idist .get_world_size () > 1 , idist .get_world_size ()
227
227
228
- ranks = list (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2, 1]
228
+ ranks = sorted (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2, 1]
229
229
rank = idist .get_rank ()
230
230
bnd = idist .backend ()
231
231
232
232
t = torch .tensor ([rank ], device = device )
233
233
group = idist .new_group (ranks )
234
- if bnd in ( "horovod" ):
235
- with pytest . raises ( NotImplementedError , match = r"all_gather with group for horovod is not implemented" ) :
236
- res = idist . all_gather ( t , group = group )
234
+ res = idist . all_gather ( t , group = group )
235
+ if rank in ranks :
236
+ assert torch . equal ( res , torch . tensor ( ranks , device = device ) )
237
237
else :
238
- res = idist .all_gather (t , group = group )
239
- if rank in ranks :
240
- assert torch .equal (res , torch .tensor (sorted (ranks ), device = device )), res
241
- else :
242
- assert res == t
238
+ assert res == t
243
239
244
240
t = torch .tensor ([rank ], device = device )
245
- if bnd in ("horovod" ):
246
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
247
- res = idist .all_gather (t , group = ranks )
241
+ if bnd == "horovod" :
242
+ res = idist .all_gather (t , group = group )
248
243
else :
249
244
res = idist .all_gather (t , group = ranks )
250
- if rank in ranks :
251
- assert torch .equal (res , torch .tensor (sorted ( ranks ) , device = device ))
252
- else :
253
- assert res == t
245
+ if rank in ranks :
246
+ assert torch .equal (res , torch .tensor (ranks , device = device ))
247
+ else :
248
+ assert res == t
254
249
255
250
t = {
256
251
"a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
@@ -262,12 +257,12 @@ def _test_distrib_all_gather_group(device):
262
257
res = idist .all_gather (t , group = ranks )
263
258
elif bnd in ("horovod" ):
264
259
with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
265
- res = idist .all_gather (t , group = ranks )
260
+ res = idist .all_gather (t , group = group )
266
261
else :
267
262
res = idist .all_gather (t , group = ranks )
268
263
if rank in ranks :
269
264
assert isinstance (res , list ) and len (res ) == len (ranks )
270
- for i , obj in zip (sorted ( ranks ) , res ):
265
+ for i , obj in zip (ranks , res ):
271
266
assert isinstance (obj , dict )
272
267
assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
273
268
expected_device = (
@@ -284,22 +279,20 @@ def _test_distrib_all_gather_group(device):
284
279
else :
285
280
assert res == t
286
281
287
- if bnd in ("nccl" , "gloo" , "mpi" ):
288
- with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
282
+ t = torch .tensor ([rank ], device = device )
283
+ if bnd in ("nccl" , "gloo" , "mpi" , "horovod" ):
284
+ with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
289
285
res = idist .all_gather (t , group = "abc" )
290
286
elif bnd in ("xla-tpu" ):
291
287
with pytest .raises (ValueError , match = r"Argument group should be list of int" ):
292
288
res = idist .all_gather (t , group = "abc" )
293
- elif bnd in ("horovod" ):
294
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
295
- res = idist .all_gather (t , group = "abc" )
296
289
297
290
298
291
def _test_idist_all_gather_tensors_with_shapes (device ):
299
292
torch .manual_seed (41 )
300
293
rank = idist .get_rank ()
301
294
ws = idist .get_world_size ()
302
- reference = torch .randn (ws * ( ws + 1 ) // 2 , ws * ( ws + 3 ) // 2 , ws * ( ws + 5 ) // 2 , device = device )
295
+ reference = torch .randn (ws * 5 , ws * 5 , ws * 5 , device = device )
303
296
rank_tensor = reference [
304
297
rank * (rank + 1 ) // 2 : rank * (rank + 1 ) // 2 + rank + 1 ,
305
298
rank * (rank + 3 ) // 2 : rank * (rank + 3 ) // 2 + rank + 2 ,
@@ -312,41 +305,37 @@ def _test_idist_all_gather_tensors_with_shapes(device):
312
305
r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
313
306
r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
314
307
]
315
- assert ( r_tensor == tensors [r ]). all ( )
308
+ assert r_tensor . allclose ( tensors [r ])
316
309
317
310
318
311
def _test_idist_all_gather_tensors_with_shapes_group (device ):
319
312
assert idist .get_world_size (), idist .get_world_size ()
320
313
torch .manual_seed (41 )
321
314
322
315
rank = idist .get_rank ()
323
- ranks = list (range (1 , idist .get_world_size ()))
316
+ ranks = sorted (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [1, 2, 3]
324
317
ws = idist .get_world_size ()
325
- bnd = idist .backend ()
326
318
if rank in ranks :
327
- reference = torch .randn (ws * ( ws + 1 ) // 2 , ws * ( ws + 3 ) // 2 , ws * ( ws + 5 ) // 2 , device = device )
319
+ reference = torch .randn (ws * 5 , ws * 5 , ws * 5 , device = device )
328
320
rank_tensor = reference [
329
321
rank * (rank + 1 ) // 2 : rank * (rank + 1 ) // 2 + rank + 1 ,
330
322
rank * (rank + 3 ) // 2 : rank * (rank + 3 ) // 2 + rank + 2 ,
331
323
rank * (rank + 5 ) // 2 : rank * (rank + 5 ) // 2 + rank + 3 ,
332
324
]
333
325
else :
334
326
rank_tensor = torch .tensor ([rank ], device = device )
335
- if bnd in ("horovod" ):
336
- with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
337
- tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
327
+
328
+ tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
329
+ if rank in ranks :
330
+ for i , r in enumerate (ranks ):
331
+ r_tensor = reference [
332
+ r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
333
+ r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
334
+ r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
335
+ ]
336
+ assert r_tensor .allclose (tensors [i ])
338
337
else :
339
- tensors = all_gather_tensors_with_shapes (rank_tensor , [[r + 1 , r + 2 , r + 3 ] for r in ranks ], ranks )
340
- if rank in ranks :
341
- for r in ranks :
342
- r_tensor = reference [
343
- r * (r + 1 ) // 2 : r * (r + 1 ) // 2 + r + 1 ,
344
- r * (r + 3 ) // 2 : r * (r + 3 ) // 2 + r + 2 ,
345
- r * (r + 5 ) // 2 : r * (r + 5 ) // 2 + r + 3 ,
346
- ]
347
- assert (r_tensor == tensors [r - 1 ]).all ()
348
- else :
349
- assert [rank_tensor ] == tensors
338
+ assert [rank_tensor ] == tensors
350
339
351
340
352
341
def _test_distrib_broadcast (device ):
@@ -413,31 +402,30 @@ def _test_distrib_barrier(device):
413
402
assert tt .item () == true_res + 10.0
414
403
415
404
416
- def _test_distrib_new_group (device ):
405
+ def _test_distrib_group (device ):
406
+ ranks = sorted (range (idist .get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [1, 2, 3]
417
407
if idist .get_world_size () > 1 and idist .backend () is not None :
418
408
bnd = idist .backend ()
419
- ranks = [0 , 1 ]
409
+ rank = idist .get_rank ()
410
+ g = idist .new_group (ranks )
420
411
if idist .has_native_dist_support and bnd in ("nccl" , "gloo" , "mpi" ):
421
- g1 = idist .new_group (ranks )
422
- g2 = dist .new_group (ranks )
423
-
424
- rank = idist .get_rank ()
425
412
if rank in ranks :
426
- assert g1 .rank () == g2 .rank ()
413
+ # mapping between group ranks and global ranks
414
+ global_to_group = {r : i for i , r in enumerate (ranks )}
415
+ assert g .rank () == global_to_group [rank ], (g .rank (), global_to_group , rank )
416
+
427
417
elif idist .has_xla_support and bnd in ("xla-tpu" ):
428
- assert idist . new_group ( ranks ) == [ranks ]
418
+ assert g == [ranks ]
429
419
elif idist .has_hvd_support and bnd in ("horovod" ):
430
- from horovod .common .process_sets import ProcessSet
431
-
432
- g1 = idist .new_group (ranks )
433
- g2 = ProcessSet (ranks )
434
-
435
- rank = idist .get_rank ()
436
420
if rank in ranks :
437
- assert g1 .ranks == g2 .ranks
421
+ assert g .ranks == ranks
422
+
423
+ if rank in ranks :
424
+ assert not _rank_not_in_group (g )
425
+ else :
426
+ assert _rank_not_in_group (g )
438
427
439
428
elif idist .backend () is None :
440
- ranks = [0 , 1 ]
441
429
assert idist .new_group (ranks ) == ranks
442
430
443
431
with pytest .raises (ValueError , match = "Argument ranks should be list of int" ):
0 commit comments