Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail-safe and partial redundancy for HSDP on unreliable compute #561

Open
evkogs opened this issue Aug 27, 2024 · 5 comments
Open

Fail-safe and partial redundancy for HSDP on unreliable compute #561

evkogs opened this issue Aug 27, 2024 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@evkogs
Copy link

evkogs commented Aug 27, 2024

I'd like to propose a feature for implementing fail-safe mechanisms and partial redundancy in FSDP2 (possibly not FSDP already, more like HSDP) to allow for more robust training on unreliable compute resources, such as cloud spot instances. The main goal is to make training more resilient to node failures, GPU issues, and other potential interruptions.

Key points:

  1. Implement an abstraction over DDP and FSDP with configurable parameters for redundancy.
  2. Allow for partial redundancy, similar to RAID5 or RAID6 concepts, where full redundancy would be equivalent to DDP and zero redundancy would be equivalent to FSDP full-shard or Zero-3.
  3. Mitigate node failures and individual GPU failures by storing additional fractions (e.g., 1/8 or 1/4) of other nodes' optimizer states on each node.
  4. Trade-off between memory usage and all-reduce overhead (estimated 10-20%) for increased training resilience.
  5. Implement automatic downscaling with resharding and upscaling with automatic sharding, with a configurable overlapping sharding parameter (0.0 to 1.0).

Use case examples:

  1. Training on cloud spot instances that may be terminated mid-training.
  2. Giant model training on 99.9% reliable hardware, protecting against network adapter failures, power outages, etc.
  3. Enabling cross-regional model training on spot instances or multi-region clusters for colossal models.
  4. Supporting distributed training methods like DisTrO (https://github.com/NousResearch/DisTrO) that allow training over the internet with much lower throughput requirements than traditional all-reduce approach.

This feature would greatly enhance the flexibility and reliability of large-scale distributed training, especially in scenarios where compute resources are not guaranteed to be stable throughout the entire training process.

A key aspect of this implementation would be an overlapping factor, ranging from 0.0 to 1.0, which determines the degree of redundancy. For example, with 64 GPUs across 8 nodes:

  • An overlapping factor of 0.0 would be equivalent to standard FSDP (no redundancy).
  • An overlapping factor of 0.125 (1/8) would allow for one node failure without interrupting training.
  • An overlapping factor of 0.25 (1/4) would provide resilience against two simultaneous node failures.
  • An overlapping factor of 1.0 would be equivalent to full DDP (complete redundancy).

The system would need to integrate downscaling with resharding and automatic restoring, as well as upscaling with automatic sharding, all governed by this specified overlapping factor (probably using Kubernetes with torchx, for example).

I'd be happy to discuss this further and provide more details if needed! Looking forward to your thoughts on this proposal!

@evkogs evkogs changed the title Fail-safe and partial redundancy for FSDP2 on unreliable compute Fail-safe and partial redundancy for HSDP on unreliable compute Aug 27, 2024
@tianyu-l
Copy link
Contributor

@awgu @wconstab @fegin

@tianyu-l tianyu-l added the enhancement New feature or request label Aug 27, 2024
@evkogs
Copy link
Author

evkogs commented Aug 27, 2024

I see it mainly as a complementary addition to the existing torch.distributed.elastic functionality.

Also, considering numerous ways to launch a training job, the main functionality would be restoring all model weights, activations, and optimizer states to a smaller number of workers (scale down).

In the case of a specified launcher e.g. torchrun or torchx with Kubernetes scheduler, there's also an option to fully manage the cluster and replace workers (both scale up and down).

Also, for clusters of thousands of GPUs, overhead won't be significant: for 64-128 or more nodes, the desired overlapping factor might be 2.5% - 5% to guarantee resilience to outages, which is a small cost.

@jiamings
Copy link

This is actually a great idea -- as ECC error is quite common in HBMs this can help us to not have to restart the entire job when we encounter a single ECC error. But not sure how well this works with distributed checkpointing.

@wconstab
Copy link
Contributor

wconstab commented Sep 3, 2024

Thanks for this proposal @evkogs! We would need to get more specific about a design to say for sure, but I think there are largely 2 issues that need to be addressed before this could be feasible.

  1. How can we drop some members out of a communicator and add new ones when the scheduler replaces them (e.g. PyTorch ProcessGroupNCCL + nccl communicator)? Today, the only way to do this is to tear down the 'world' and create a new 'world'. This can be expensive, and requires coordination.
  2. What is the right abstraction boundary between pytorch and the scheduler? We probably do not want to build all of this logic into pytorch as some of it ties into the job scheduling layer. Can we come up with a clear abstraction and propose which behaviors pytorch should implement and which ones should be provided by the scheduler itself?

@evkogs
Copy link
Author

evkogs commented Sep 3, 2024

Thanks @wconstab !

How can we drop some members out of a communicator and add new ones when the scheduler replaces them (e.g. PyTorch ProcessGroupNCCL + nccl communicator)? Today, the only way to do this is to tear down the 'world' and create a new 'world'. This can be expensive and requires coordination.

Well, I don't think that's an issue as it would be an infrequent event, at most 2-3 times for many nodes in an unreliable setup. So I think the current way would be absolutely fine for real-world cases. From torch.distributed.elastic docs:

Membership Changes
Node departure (scale-down): The agent is notified of the departure, all existing workers are stopped, a new WorkerGroup is formed, and all workers are started with a new RANK and WORLD_SIZE.
Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, a new WorkerGroup is formed, and all workers are started with a new RANK and WORLD_SIZE.

What is the right abstraction boundary between pytorch and the scheduler? We probably do not want to build all of this logic into pytorch as some of it ties into the job scheduling layer. Can we come up with a clear abstraction and propose which behaviors pytorch should implement and which ones should be provided by the scheduler itself?

That's a very good question! I think there's a place for a unified approach, combining all existing ones. Also, I was curious to look into pytorch h2 2024, and saw there are plans to integrate up to 5D model parallelism (whatever this means), so it might get even trickier soon. I feel if we continue to grow number of abstractions, it won't end well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants