Skip to content

Towards Native Fault Tolerance for Semi-Synchronous Training #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
WarrenZhu050413 opened this issue Apr 23, 2025 · 1 comment
Open
Labels
enhancement New feature or request

Comments

@WarrenZhu050413
Copy link

Problem Description

DiLoCo and LocalSGD, as implemented in torchFT, are all examples of semi-synchronous training algorithm. Semi-synchronous training algorithms have natural hierarchical structures of synchronization that one could take advantage of for more fine-grained fault tolerance. As implemented in torchFT, semi-synchronous training algorithms is treated similarly to HSDP, leading to interesting challenges and opportunities for torchFT to more matively support Fine-grained Fault Tolerance.

Current TorchFT implementation

The current algorithm for the outer optimization step in torchFT for DiLoCo (LocalSGD follows a similar algorithm, but I focus on DiLoCo to simplify the exposition) is the following:

        if self._local_step >= self._sync_every:
            self.sync()

Where in self.sync(), participants of the same rank in different Replica Groups all reduce their pseudogradients.

An issue with the above implementation is that it does not natively support multiple DP groups within each DiLoCo worker (the set of nodes that sync their gradients at every iteration), since torchFT is designed so that each Replica Group has DP_dim=1.

This support, however, may be important if we imagine DiLoCo being used for training across datacenters, with each datacenter corresponding to a DiLoCo worker: Many of the compute units would not be utilized without DP support.

Solution Using Current TorchFT LocalSGD.py

One way to implement multiple DP groups is to have multiple DP groups inside a single Replica Group launced by torchrun, as illustrated by the image below:

torchFT-DiLoCo-Implementation

However, this contains redundant inter-worker communication. Here, only one set of pseudogradient needs to be transmitted across the workers, since the pseudogradients are replicated within the DP groups of each worker.

An alternative solution

Instead, given the parallelism, and the semi-synchronous training structure, one could only send one set of pseudogradients across a "Leader" DP group, then let the Leader broadcast it to all the follower DP groups, as shown below:

torchFT-DiLoCo-alternative-design

Challenges and opportunities for Fault Tolerance Native to Semi-Synchronous Training

I have implemented a simple version of the above broadcast based solution. It does so by having a inter-worker and an intra-worker process group. By building it on top of the ManagerProcessGroup abstraction in torchft, I was able to have fault tolerance across the DP dimension of an individual DiLoCo worker. I have attached the code snippets at the end of the issue.

One limitation of my implementation is that I am unable to test it rigorously as I have only 2 available GPUs.

Even with my current implementation, I have encountered a few interesting challenges (and opportunities) with implementing such a multi-level coordination scheme:

Opportunity 1: Changing local DiLoCo step for synchronicity.

Related to #130.
If we enable different DP dimensions in different DiLoCo workers, we are in a setting where the worker capabilities of heterogeneous. Here, one may wish to do Dynamic Local Update, as proposed in the following Asynchronous Local-SGD Training for Language Modeling. Here, the local step size (sync_every) of each DiLoCo worker may be dynamically adjusted based on its compute capability.

I am attaching the relevant passages to the paper:

DyLU-1 DyLU-2 DyLu-footnotes

This could be implemented as an additional return parameter for quorum. Alternatively, semantics could be enforced. But this would lead to compute loss, unless excess DP groups are freed from DiLoCo training to do other pending jobs in the cluster.

Opportunity 2: Alternative algorithms beyond the Leader, Follower broadcast

The current proposal with a Leader broadcasting their all-reduced pseudo-gradients to Followers is simple to implement. However, it may not be optimal.

For example, one may want to distribute the pseudo-gradient synchronization and broadcasting across the different DP ranks in a single DiLoCo worker. If there are three DP groups inside a DiLoCo worker, we may want each to synchronize 1/3 of the pseudo-gradients, and broadcast the rest to their intra-DiLoCo worker peers.

This could similarly apply to Live Checkpoint Recovery, where we have multiple up to date DP replicas recovering a single unhealthy node.

Whether this would lead to substantial performance gain relies quite heavily on the network topology. Thus, this may potentially be combined with perftest to measure bandwidth (as implemented by prime) along with traceroute to figure out the network topology.

Opportunity 3: Different timeout parameters for inter-DiLoCo worker sync and intra-DiLoCo worker DP sync.

Related to #130.

The timeout parameters should be different given the difference in bandwidth within a DiLoCo worker and between DiLoCO workers.

This may be auto-configured by the lighthouse by using network bandwidth information or information about how long it takes for a NCCL operation to complete.

Opportunity 4: Using (or not using) multiple Lighthouses and managers

A naive extension of torchFT to this setting is to have a Local-Lighthouse to manage the DP within each DiLoCo worker, and a Global-Lighthouse between DiLoCo workers.

This introduces a lot more controller processes into the training and makes configuration more complicated.

Are there alternative ways of implementing multi-level fault tolerance?

Current Implementation

I do not believe that the current implementation is the fully correct design. However, I think it may be helpful to attach my code for reference. I have skipped many of the configuration details in the code.

Currently, we create two managers. The outer_manager is only created by the leader and is used for pseudo-gradient allreduce between DiLoCo workers. The inner_manager is both responsible for broadcasting and gradient allreduce within the DiLoCo workers.

We further implemented an augmented version of the DiLoCo context manager called "DiLoCo_ICT", where ICT stands for "Inter-Cluster Transport".

    # External configs
    CLUSTER_GROUP_ID = int(os.environ.get("CLUSTER_GROUP_ID", 0)) # ID of the inner cluster
    NUM_CLUSTERS = int(os.environ.get("NUM_CLUSTERS", 2)) # number of physical clusters
    # Local replica group ID and number within the inner cluster
    REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) # ID of the inner replica group
    leader = (REPLICA_GROUP_ID == 0)
    NUM_REPLICA_GROUPS_LOCAL_CLUSTER = int(os.environ.get("NUM_REPLICA_GROUPS_LOCAL_CLUSTER", 2)) # DP groups *inside* each cluster
    leader = (REPLICA_GROUP_ID == 0)

leader = (REPLICA_GROUP_ID == 0)
    inner_manager = Manager(
        pg=pg_inner,
        min_replica_size=MIN_SIZE_LOCAL,
        load_state_dict=inner_load_state_dict,
        state_dict=inner_state_dict,
        replica_id=f"train_outermanager_{CLUSTER_GROUP_ID}_{REPLICA_GROUP_ID}", #TODO: Do we need to have the cluster group id here?
        timeout=timedelta(seconds=10),
        # Different varaibles for outer and inner managers.
        lighthouse_addr=lighthouse_addr_inner,
        store_addr=store_addr_inner,
        store_port=store_port_inner,
        rank=rank_inner,
        world_size=world_size_replica_group,
    )

    outer_manager: Optional[Manager] = None

    if leader:
        # Define the outer manager, only needed by the leader
        def outer_load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
            m.load_state_dict(state_dict["model"])
            m.to(device)
            diloco.original_parameters = state_dict["original_params"]
            for name in diloco.original_parameters.keys():
                diloco.original_parameters[name] = diloco.original_parameters[name].to(
                    device
                )
            inner_optimizer.load_state_dict(state_dict["inner_optim"])
            outer_optimizer.load_state_dict(state_dict["outer_optim"])
            
        def outer_state_dict() -> Dict[str, Dict[str, object]]:  # pyre-ignore[53]
            return {
                "model": m.state_dict(),
                "original_params": diloco.original_parameters,
                "inner_optim": inner_optimizer.state_dict(),
                "outer_optim": outer_optimizer.state_dict(),
            }
        outer_manager = Manager(
            pg=pg_outer,
            min_replica_size=MIN_SIZE_GLOBAL,
            load_state_dict=outer_load_state_dict,
            state_dict=outer_state_dict,
            replica_id=f"train_outermanager_{CLUSTER_GROUP_ID}",
            timeout=timedelta(seconds=10),

            # Different variables for outer and inner managers.
            lighthouse_addr=lighthouse_addr_outer,
            store_addr=store_addr_outer,
            store_port=store_port_outer,
            rank=outer_manager_rank,
            world_size=world_size_cluster,
        )

        outer_manager._use_async_quorum = False

The train loop:

    managed_inner_optimizer = Optimizer(manager=inner_manager, optim=inner_optimizer)
    criterion = nn.CrossEntropyLoss()

    current_step = 0
    with DiLoCo_ICT(
        outer_manager=outer_manager,
        inner_manager=inner_manager,
        model=m,
        inner_optimizer=inner_optimizer,
        outer_optimizer=outer_optimizer,
        sync_every=sync_every,
        backup_device=device, #TODO: Make this CPU for CPU offloading. Currently using GPU for backup device.
        device=device,
        debug=True,
    ) as diloco:
        # If REPLICA_GROUP_ID == 0, then sync with the outer manager.
        # After syncing with the outer manager, needs to broadcast the model parameters to the other inner replica groups.
        for x, y in trainloader:
            current_step += 1
            if outer_manager is not None:
                debug_print(f"outer_manager.current_step(): {outer_manager.current_step()}")
            # debug_print(f"inner_manager.current_step(): {inner_manager.current_step()}")
            x = x.to(device)
            y = y.to(device)
            output = m(x)
            loss = criterion(output, y)
            managed_inner_optimizer.zero_grad()
            loss.backward()
            managed_inner_optimizer.step()

            if current_step >= steps_to_run:
                # complete training
                exit()

            sleep(sleep_time)

Here is the implementation of the DiLoCo_ICT class. The main change is in the _perform_outer_sync function, where we an additional broadcast to synchroniz the pseudogradients. One for broadcasting the pseudogradients, another to broadcast should_commit flag to tell the other DP groups the should_commit result of the outer manager.

A clear limitation to the currently implementation is that I have not implemented a mechanism for the DP groups within each DiLoCo worker to check that the broadcasting for the other shards in the DP group has returned successfully. I have not encountered this problem as I did not integrate this yet with FSDP.

    def _perform_outer_sync(self) -> None:
        """
        Overrides the sync method to calculate the pseugradient, average them across the outer_manager group, and
        step using the outer optimizer.
        """
        self.debug_print(f"Setting pseudogradients")
        if self._leader:
            # Set the .grad field of each parameter to its pseudogradient
            for name, p in self._model.named_parameters():
                pseudogradient = p.data - self.original_parameters[name]
                p.grad = pseudogradient

            self._average_grads() # synchronous allreduce

        # Restore the parameters back to the previous state
        self._restore_parameters() # Potentially asynchronous
        
        self.broadcast_pseudograds() # synchronous broadcast (at least currently so) 

        if self._leader:
            # Leader decides whether to commit and broadcasts the decision (1 = commit, 0 = skip)
            should_commit_flag = self._outer_manager.should_commit() # Implicitly waits for the live checkpoint recovery to be finished
            commit_tensor = torch.tensor([1 if should_commit_flag else 0], dtype=torch.uint8, device=self._device)
            # Broadcast to everyone in the *inner* replica‑group. Leader is root_rank 0.
            fut = self._inner_manager.broadcast_one(commit_tensor, root_rank=0)
            fut.wait()
            self.debug_print(f"Broadcasted should_commit={should_commit_flag}")
            
            # Use the outer optimizer to update the model parameters
            # Need to check whether should_commit()
            # Currently this will always return true because we have not tested FSDP
            if should_commit_flag:
                # Use the outer optimizer to update the model parameters
                self._outer_optimizer.step()
                self.debug_print(f"Commited, saving parameters")
                self._save_parameters() # Currently synchronous
            else:
                self.debug_print(f"Not committing")
            self._outer_optimizer.zero_grad()
        else:
            # Is follower
            # 1. Wait for the checkpoint assist to be finished
            # 2. Step the outer optimiaer
            # 3. Save the parameters
            # Followers receive leader’s decision
            # Wait until the checkpoint‑assist thread (if any) finishes before potentially stepping.
            self.debug_print(f"Waiting for checkpoint assist")
            self._finish_checkpoint_assist()

            # Receive leader’s decision
            commit_tensor = torch.zeros(1, dtype=torch.uint8, device=self._device)
            self._inner_manager.broadcast_one(commit_tensor, root_rank=0).wait()
            should_commit_flag = bool(commit_tensor.item())
            self.debug_print(f"Received should_commit={should_commit_flag} from leader")

            self.debug_print(f"Stepping outer optimizer")
            if should_commit_flag:
                self._outer_optimizer.step()
            self.debug_print(f"Save parameters")
            self._save_parameters() # Currently synchronous

        self.debug_print(f"Zeroing gradients")
        self._outer_optimizer.zero_grad()
@d4l3k d4l3k added the enhancement New feature or request label Apr 25, 2025
@d4l3k
Copy link
Member

d4l3k commented Apr 25, 2025

The main concern I have with multiple layers of DP is really around checkpoint management.

If you're doing local DP + broadcasts you need some signalling mechanism within the group to know when the step count has changed so you can also broadcast the weights. Though I suppose if you only run the manager/outer step on rank0 within each group and always broadcast after every step that may actually work just fine.

The HSDP structure is fairly nice for this since FSDP/PP/TP are sharded so you don't waste any resources by having all ranks participate in DDP/checkpoint restore and may actually be more efficient since you could have more bandwidth by having less data to transfer per worker.

Your implementation of the two layer Manager does seem reasonable elegant but I am nervous about what happens when the outer Manager restores a checkpoint. You need some mechanism to force the inner Manager to also heal, though maybe your sync code actually handles this fine?

May be useful to have call to chat more about this

Another interesting scenario to consider is for a LLM reinforcement learning scenario where you may want to have one training group and then multiple inference groups of a different size. Ideally you could coordinate between these and do scale up / scale down operations as necessary as it's fine to restart the inference workers but you don't want to block the trainer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants