Skip to content

Commit 0bb3b06

Browse files
wayi1pytorchmergebot
authored andcommitted
[Model Averaging] Support hierarchical model averaging (pytorch#73285)
Summary: Implement hierarchical model averaging proposed in pytorch#71325. Unit tests are added. Since I don't have access to 4-GPU machines in open-source environment, expect that the branch with the prefix of `ci-all` can run the test that requires 4 GPUs. In the future, the internals of `PeriodicModelAveraging` can be simplified as an implementation of a specialized hierarchical model averaging, where `period_group_size_dict` only has a pair of period and world size. Pull Request resolved: pytorch#73285 Reviewed By: mrshenli Differential Revision: D34457792 Pulled By: rohan-varma fbshipit-source-id: 39a6c5bf8a2852b6394a56abbad17b8a909b9fba (cherry picked from commit 5f543d4)
1 parent bcd0843 commit 0bb3b06

File tree

4 files changed

+262
-2
lines changed

4 files changed

+262
-2
lines changed

LICENSE

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ All rights reserved.
2828
All contributions by Kakao Brain:
2929
Copyright 2019-2020 Kakao Brain
3030

31+
All contributions by Cruise LLC:
32+
Copyright (c) 2022 Cruise LLC.
33+
All rights reserved.
34+
3135
All contributions from Caffe:
3236
Copyright(c) 2013, 2014, 2015, the respective contributors
3337
All rights reserved.

torch/distributed/algorithms/model_averaging/averagers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
warnings.warn(
9595
"When period is 1, no need to use model averaging because the communication cost "
9696
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
97-
"by DistributedDataParall in the backward pass. Therefore, only "
97+
"by DistributedDataParallel in the backward pass. Therefore, only "
9898
"DistributedDataParallel should be used for this case."
9999
)
100100
self.period = period
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2022 Cruise LLC
2+
import warnings
3+
from collections import OrderedDict
4+
import logging
5+
6+
import torch.distributed as dist
7+
import torch.distributed.algorithms.model_averaging.utils as utils
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class HierarchicalModelAverager:
13+
r"""
14+
A group of model averagers used for hierarchical model averaging (hierarchical SGD).
15+
Process groups of different sizes are organized in a hierarhicy, and they average parameters
16+
by using different periods concurrently after the warm-up stage.
17+
This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
18+
that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
19+
a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
20+
level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
21+
Similarly, the process groups within this class do not have such an intra-machine process
22+
subgroup, which should be embedded by the post-local SGD communication hook instead.
23+
24+
Args:
25+
period_group_size_dict: An ordered dict mapping keys of model averaging period to
26+
process group size, used for initializing process groups of
27+
different sizes in a hierarchy to average parameters concurrently.
28+
Particularly, at each iteration, there will be at most a single
29+
process group that runs averaging -- the period of such group should
30+
have the largest period which the current step can be divided by.
31+
For example, if the dict has three keys: 2, 4, and 8,
32+
then this means totally three process groups will be created to
33+
average parameters every 2, 4, and 8 iterations, respectively.
34+
At the 4th iteration, only the second process group will run
35+
averaging, because the first process group should be a
36+
subset of the second process group, and no need to execute the first
37+
process group redundantly.
38+
On the other hand, the third process group can only be triggered
39+
every 8 iterations, so it will not be triggered at the 4th iteration.
40+
warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
41+
process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
42+
If ``None``, the default process group, which is created
43+
by :func:`torch.distributed.init_process_group`, will be used.
44+
(default: ``None``)
45+
46+
Example::
47+
>>> from collections import OrderedDict
48+
>>> import torch
49+
>>> import torch.distributed as dist
50+
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
51+
>>> PostLocalSGDState,
52+
>>> post_localSGD_hook,
53+
>>> )
54+
>>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
55+
>>> import torch.nn as nn
56+
>>>
57+
>>> dist.init_process_group("nccl", rank=rank, world_size=16)
58+
>>> torch.cuda.set_device(rank)
59+
>>> module = nn.Linear(1, 1, bias=False).to(rank)
60+
>>> model = nn.parallel.DistributedDataParallel(
61+
>>> module, device_ids=[rank], output_device=rank
62+
>>> )
63+
>>> # Register a post-localSGD communication hook.
64+
>>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
65+
>>> subgroup, _ = dist.new_subgroups()
66+
>>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
67+
>>> model.register_comm_hook(state, post_localSGD_hook)
68+
>>>
69+
>>> # Average parameters among each group of 8 processes every 4 iterations, and among all
70+
>>> # the 16 processes every 16 iterations.
71+
>>> averager = hierarchicalSGD.HierarchicalModelAverager(
72+
>>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
73+
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
74+
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
75+
>>> # After 100 steps, run model averaging at two levels.
76+
>>> for step in range(0, 200):
77+
>>> optimizer.zero_grad()
78+
>>> loss = loss_fn(output, labels)
79+
>>> loss.backward()
80+
>>> optimizer.step()
81+
>>> # Average parameters after ``optimizer.step()``.
82+
>>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
83+
>>> averager.average_parameters(model.parameters())
84+
85+
.. warning ::
86+
The last group size in the dict must be the size of the provided ``process_group``,
87+
which indicates model averaging at the highest level of the hierarchy.
88+
If ``process_group`` is not provided, then the last group size should be equal to the world size.
89+
90+
.. warning ::
91+
`HierarchicalModelAverager` is experimental and subject to change.
92+
"""
93+
94+
def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
95+
if not period_group_size_dict:
96+
raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
97+
self._periods = list(period_group_size_dict.keys())
98+
if self._periods[0] <= 0:
99+
raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.")
100+
elif self._periods[-1] == 1:
101+
warnings.warn(
102+
"When the maximum period in arg ``period_group_size_dict`` is 1, "
103+
"no need to use model averaging because the communication cost "
104+
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
105+
"by DistributedDataParallel in the backward pass. Therefore, only "
106+
"DistributedDataParallel should be used for this case."
107+
)
108+
ovall_group : dist.ProcessGroup = (
109+
process_group if process_group is not None else dist.group.WORLD
110+
)
111+
overall_group_size = dist.get_world_size(group=ovall_group)
112+
if list(period_group_size_dict.values())[-1] != overall_group_size:
113+
raise ValueError(
114+
"The last value in arg ``period_process_group_dict`` "
115+
"must be equal to the size of arg ``process_group``.")
116+
117+
self.period_process_group_dict = OrderedDict()
118+
logger.info("Model averaging hierarchy:")
119+
for period, group_size in period_group_size_dict.items():
120+
logger.info(
121+
f"\tEach group that has {group_size} processes average parameters every {period} iterations, "
122+
"if no higher-level averaging.")
123+
if group_size != overall_group_size:
124+
self.period_process_group_dict[period], _ = dist.new_subgroups(
125+
group_size=group_size, group=ovall_group)
126+
else:
127+
self.period_process_group_dict[period] = ovall_group
128+
129+
if warmup_steps < 0:
130+
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
131+
self.warmup_steps = warmup_steps
132+
self.step = 0
133+
134+
def _find_process_group(self):
135+
"""
136+
Returns a tuple consisting of whether ``step`` can be divided by
137+
a period in the keys of ``period_process_group_dict`` and the associated process group if any.
138+
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
139+
then the returned process group is the one corresponding to the largest period,
140+
since this process group will be used for averaging parameters at this ``step``.
141+
"""
142+
for period in reversed(self._periods):
143+
if self.step % period == 0:
144+
return (True, self.period_process_group_dict[period])
145+
return (False, None)
146+
147+
def average_parameters(self, params):
148+
r"""
149+
Averages parameters if ``step`` is no less than ``warmup_steps``
150+
and it can be divided by a period in the keys of ``period_process_group_dict``,
151+
where ``step`` is increased by 1 at each iteration in the training loop.
152+
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
153+
only the largest period is used, and the corresponding process group is used for averaging parameters.
154+
"""
155+
if self.step >= self.warmup_steps:
156+
found, group = self._find_process_group()
157+
if found:
158+
utils.average_parameters(iter(params), group)
159+
self.step += 1

torch/testing/_internal/distributed/distributed_test.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
import tempfile
88
import time
9-
from collections import namedtuple
9+
from collections import namedtuple, OrderedDict
1010
from contextlib import contextmanager, suppress
1111
from datetime import timedelta
1212
from functools import reduce
@@ -16,6 +16,7 @@
1616
import torch.cuda
1717
import torch.distributed as dist
1818
import torch.distributed.algorithms.model_averaging.averagers as averagers
19+
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
1920
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
2021
import torch.nn as nn
2122
import torch.nn.functional as F
@@ -1033,6 +1034,102 @@ def test_periodic_model_averager(self):
10331034
# No model averaging, so the parameters are not updated.
10341035
self.assertEqual(param.data, tensor)
10351036

1037+
@sandcastle_skip_if(
1038+
BACKEND not in DistTestCases.backend_feature["subgroup"],
1039+
f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
1040+
)
1041+
@skip_if_lt_x_gpu(2)
1042+
def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(self):
1043+
rank = dist.get_rank()
1044+
world_size = dist.get_world_size()
1045+
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1046+
device_id = rank_to_GPU[rank][0]
1047+
1048+
model = nn.Linear(1, 5, bias=False).cuda(device_id)
1049+
param = next(model.parameters())
1050+
tensor = torch.ones_like(param.data) * rank
1051+
expected_avg_tensor = (
1052+
torch.ones_like(param.data) * sum(range(world_size)) / world_size
1053+
)
1054+
period = 4
1055+
for warmup_steps in [12, 13, 14, 15]:
1056+
averager = hierarchicalSGD.HierarchicalModelAverager(
1057+
# Run the global averaging at a period of 4,
1058+
# which is equivalent to the above periodic model averaging test case.
1059+
period_group_size_dict=OrderedDict([(period, world_size)]), warmup_steps=warmup_steps
1060+
)
1061+
1062+
averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
1063+
for step in range(0, 20):
1064+
# Reset the parameters at every step.
1065+
param.data = copy.deepcopy(tensor)
1066+
averager.average_parameters(model.parameters())
1067+
if step >= warmup_steps and (step - warmup_steps) % period == 0:
1068+
self.assertEqual(param.data, expected_avg_tensor)
1069+
else:
1070+
# No model averaging, so the parameters are not updated.
1071+
self.assertEqual(param.data, tensor)
1072+
1073+
@sandcastle_skip_if(
1074+
BACKEND not in DistTestCases.backend_feature["subgroup"],
1075+
f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
1076+
)
1077+
@require_world_size(4)
1078+
@skip_if_lt_x_gpu(4)
1079+
def test_3_level_hierarchical_model_averager(self):
1080+
rank = dist.get_rank()
1081+
world_size = dist.get_world_size()
1082+
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
1083+
device_id = rank_to_GPU[rank][0]
1084+
1085+
model = nn.Linear(1, 5, bias=False).cuda(device_id)
1086+
param = next(model.parameters())
1087+
tensor = torch.ones_like(param.data) * rank
1088+
# Set up such a hierarchical model averaging as follows:
1089+
# after the first 10 warmup steps,
1090+
# run model averaging every 2 steps within each subgroup of size 2,
1091+
# run model averaging every 4 steps within each subgroup of size 3,
1092+
# and run the global model averaging every 8 steps.
1093+
# If there is a conflict in model averaging at a step, only run the highest-level model averaging.
1094+
warmup_steps = 10
1095+
subgroup_size1 = 2
1096+
subgroup_avg_period1 = 2
1097+
subgroup_size2 = 4
1098+
subgroup_avg_period2 = 4
1099+
global_avg_period = 8
1100+
period_group_size_dict = OrderedDict(
1101+
[(subgroup_avg_period1, subgroup_size1),
1102+
(subgroup_avg_period2, subgroup_size2),
1103+
(global_avg_period, world_size)])
1104+
averager = hierarchicalSGD.HierarchicalModelAverager(
1105+
period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
1106+
)
1107+
expected_avg_tensor_within_subgroup1 = (
1108+
torch.ones_like(param.data) * sum(range(subgroup_size1)) / subgroup_size1
1109+
)
1110+
expected_avg_tensor_within_subgroup2 = (
1111+
torch.ones_like(param.data) * sum(range(subgroup_size2)) / subgroup_size2
1112+
)
1113+
expected_global_avg_tensor = (
1114+
torch.ones_like(param.data) * sum(range(world_size)) / world_size
1115+
)
1116+
for step in range(0, 25):
1117+
# Reset the parameters at every step.
1118+
param.data = copy.deepcopy(tensor)
1119+
averager.average_parameters(model.parameters())
1120+
if step == 16 or step == 24:
1121+
# Run global model averaging when `step` can be divided by 8.
1122+
self.assertEqual(param.data, expected_global_avg_tensor)
1123+
elif step == 12 or step == 20:
1124+
# Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
1125+
self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
1126+
elif step == 10 or step == 14 or step == 18 or step == 22:
1127+
# Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
1128+
self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
1129+
else:
1130+
# No model averaging, so the parameters are not updated.
1131+
self.assertEqual(param.data, tensor)
1132+
10361133
# NCCL Batch SEND RECV
10371134
@skip_if_no_gpu
10381135
@sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")

0 commit comments

Comments
 (0)