|
| 1 | +# Copyright 2022 Cruise LLC |
| 2 | +import warnings |
| 3 | +from collections import OrderedDict |
| 4 | +import logging |
| 5 | + |
| 6 | +import torch.distributed as dist |
| 7 | +import torch.distributed.algorithms.model_averaging.utils as utils |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class HierarchicalModelAverager: |
| 13 | + r""" |
| 14 | + A group of model averagers used for hierarchical model averaging (hierarchical SGD). |
| 15 | + Process groups of different sizes are organized in a hierarhicy, and they average parameters |
| 16 | + by using different periods concurrently after the warm-up stage. |
| 17 | + This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` |
| 18 | + that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports |
| 19 | + a two-level hierarchy: the intra-machine level and the global level, where the intra-machine |
| 20 | + level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. |
| 21 | + Similarly, the process groups within this class do not have such an intra-machine process |
| 22 | + subgroup, which should be embedded by the post-local SGD communication hook instead. |
| 23 | +
|
| 24 | + Args: |
| 25 | + period_group_size_dict: An ordered dict mapping keys of model averaging period to |
| 26 | + process group size, used for initializing process groups of |
| 27 | + different sizes in a hierarchy to average parameters concurrently. |
| 28 | + Particularly, at each iteration, there will be at most a single |
| 29 | + process group that runs averaging -- the period of such group should |
| 30 | + have the largest period which the current step can be divided by. |
| 31 | + For example, if the dict has three keys: 2, 4, and 8, |
| 32 | + then this means totally three process groups will be created to |
| 33 | + average parameters every 2, 4, and 8 iterations, respectively. |
| 34 | + At the 4th iteration, only the second process group will run |
| 35 | + averaging, because the first process group should be a |
| 36 | + subset of the second process group, and no need to execute the first |
| 37 | + process group redundantly. |
| 38 | + On the other hand, the third process group can only be triggered |
| 39 | + every 8 iterations, so it will not be triggered at the 4th iteration. |
| 40 | + warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. |
| 41 | + process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. |
| 42 | + If ``None``, the default process group, which is created |
| 43 | + by :func:`torch.distributed.init_process_group`, will be used. |
| 44 | + (default: ``None``) |
| 45 | +
|
| 46 | + Example:: |
| 47 | + >>> from collections import OrderedDict |
| 48 | + >>> import torch |
| 49 | + >>> import torch.distributed as dist |
| 50 | + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( |
| 51 | + >>> PostLocalSGDState, |
| 52 | + >>> post_localSGD_hook, |
| 53 | + >>> ) |
| 54 | + >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD |
| 55 | + >>> import torch.nn as nn |
| 56 | + >>> |
| 57 | + >>> dist.init_process_group("nccl", rank=rank, world_size=16) |
| 58 | + >>> torch.cuda.set_device(rank) |
| 59 | + >>> module = nn.Linear(1, 1, bias=False).to(rank) |
| 60 | + >>> model = nn.parallel.DistributedDataParallel( |
| 61 | + >>> module, device_ids=[rank], output_device=rank |
| 62 | + >>> ) |
| 63 | + >>> # Register a post-localSGD communication hook. |
| 64 | + >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. |
| 65 | + >>> subgroup, _ = dist.new_subgroups() |
| 66 | + >>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100) |
| 67 | + >>> model.register_comm_hook(state, post_localSGD_hook) |
| 68 | + >>> |
| 69 | + >>> # Average parameters among each group of 8 processes every 4 iterations, and among all |
| 70 | + >>> # the 16 processes every 16 iterations. |
| 71 | + >>> averager = hierarchicalSGD.HierarchicalModelAverager( |
| 72 | + >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) |
| 73 | + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. |
| 74 | + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. |
| 75 | + >>> # After 100 steps, run model averaging at two levels. |
| 76 | + >>> for step in range(0, 200): |
| 77 | + >>> optimizer.zero_grad() |
| 78 | + >>> loss = loss_fn(output, labels) |
| 79 | + >>> loss.backward() |
| 80 | + >>> optimizer.step() |
| 81 | + >>> # Average parameters after ``optimizer.step()``. |
| 82 | + >>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. |
| 83 | + >>> averager.average_parameters(model.parameters()) |
| 84 | +
|
| 85 | + .. warning :: |
| 86 | + The last group size in the dict must be the size of the provided ``process_group``, |
| 87 | + which indicates model averaging at the highest level of the hierarchy. |
| 88 | + If ``process_group`` is not provided, then the last group size should be equal to the world size. |
| 89 | +
|
| 90 | + .. warning :: |
| 91 | + `HierarchicalModelAverager` is experimental and subject to change. |
| 92 | + """ |
| 93 | + |
| 94 | + def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): |
| 95 | + if not period_group_size_dict: |
| 96 | + raise ValueError("Arg ``period_group_size_dict`` must not be empty.") |
| 97 | + self._periods = list(period_group_size_dict.keys()) |
| 98 | + if self._periods[0] <= 0: |
| 99 | + raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.") |
| 100 | + elif self._periods[-1] == 1: |
| 101 | + warnings.warn( |
| 102 | + "When the maximum period in arg ``period_group_size_dict`` is 1, " |
| 103 | + "no need to use model averaging because the communication cost " |
| 104 | + "of all-reducing parameters will be no less than the cost of all-reducing gradients " |
| 105 | + "by DistributedDataParallel in the backward pass. Therefore, only " |
| 106 | + "DistributedDataParallel should be used for this case." |
| 107 | + ) |
| 108 | + ovall_group : dist.ProcessGroup = ( |
| 109 | + process_group if process_group is not None else dist.group.WORLD |
| 110 | + ) |
| 111 | + overall_group_size = dist.get_world_size(group=ovall_group) |
| 112 | + if list(period_group_size_dict.values())[-1] != overall_group_size: |
| 113 | + raise ValueError( |
| 114 | + "The last value in arg ``period_process_group_dict`` " |
| 115 | + "must be equal to the size of arg ``process_group``.") |
| 116 | + |
| 117 | + self.period_process_group_dict = OrderedDict() |
| 118 | + logger.info("Model averaging hierarchy:") |
| 119 | + for period, group_size in period_group_size_dict.items(): |
| 120 | + logger.info( |
| 121 | + f"\tEach group that has {group_size} processes average parameters every {period} iterations, " |
| 122 | + "if no higher-level averaging.") |
| 123 | + if group_size != overall_group_size: |
| 124 | + self.period_process_group_dict[period], _ = dist.new_subgroups( |
| 125 | + group_size=group_size, group=ovall_group) |
| 126 | + else: |
| 127 | + self.period_process_group_dict[period] = ovall_group |
| 128 | + |
| 129 | + if warmup_steps < 0: |
| 130 | + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") |
| 131 | + self.warmup_steps = warmup_steps |
| 132 | + self.step = 0 |
| 133 | + |
| 134 | + def _find_process_group(self): |
| 135 | + """ |
| 136 | + Returns a tuple consisting of whether ``step`` can be divided by |
| 137 | + a period in the keys of ``period_process_group_dict`` and the associated process group if any. |
| 138 | + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, |
| 139 | + then the returned process group is the one corresponding to the largest period, |
| 140 | + since this process group will be used for averaging parameters at this ``step``. |
| 141 | + """ |
| 142 | + for period in reversed(self._periods): |
| 143 | + if self.step % period == 0: |
| 144 | + return (True, self.period_process_group_dict[period]) |
| 145 | + return (False, None) |
| 146 | + |
| 147 | + def average_parameters(self, params): |
| 148 | + r""" |
| 149 | + Averages parameters if ``step`` is no less than ``warmup_steps`` |
| 150 | + and it can be divided by a period in the keys of ``period_process_group_dict``, |
| 151 | + where ``step`` is increased by 1 at each iteration in the training loop. |
| 152 | + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, |
| 153 | + only the largest period is used, and the corresponding process group is used for averaging parameters. |
| 154 | + """ |
| 155 | + if self.step >= self.warmup_steps: |
| 156 | + found, group = self._find_process_group() |
| 157 | + if found: |
| 158 | + utils.average_parameters(iter(params), group) |
| 159 | + self.step += 1 |
0 commit comments