By Yujun (Lucas) Qian In collaboration with William Baisi, KangHyuk Lee, and Anshul Sadh-Gauri (Columbia University). Mentored by Dr. Garrett Goon.
4 min read12 hours ago
–
The “Straggler” Problem in Large Scale Inference
As Large Language Models (LLMs) continue to grow in context length (now routinely handling 128k or even 1M tokens), serving them efficiently has become a major systems challenge. While techniques like FlashAttention and Ring Attention have revolutionized memory management by splitting the Key-Value (KV) cache across multiple GPUs, they have a hidden limitation: they implicitly assume hardware homogeneity.
In a perfect world, every GPU in a cluster is identical. In the real world, data centers evolve. Partial upgrades, cost co…
By Yujun (Lucas) Qian In collaboration with William Baisi, KangHyuk Lee, and Anshul Sadh-Gauri (Columbia University). Mentored by Dr. Garrett Goon.
4 min read12 hours ago
–
The “Straggler” Problem in Large Scale Inference
As Large Language Models (LLMs) continue to grow in context length (now routinely handling 128k or even 1M tokens), serving them efficiently has become a major systems challenge. While techniques like FlashAttention and Ring Attention have revolutionized memory management by splitting the Key-Value (KV) cache across multiple GPUs, they have a hidden limitation: they implicitly assume hardware homogeneity.
In a perfect world, every GPU in a cluster is identical. In the real world, data centers evolve. Partial upgrades, cost constraints, or simply replacing a failed H100 with an available A100 can lead to heterogeneous clusters.
In my recent project at Columbia University, collaborating with mentors from IBM, we asked a simple question: What happens to distributed inference when one GPU is slower than the rest?
The answer was impactful: The entire system collapses to the speed of the slowest device, the “straggler effect.”
The Challenge: Ring Attention Meets Reality
Ring Attention is a powerful algorithm that allows us to process sequences longer than a single GPU’s memory can hold. It works by circulating blocks of Key-Value data between GPUs in a ring topology, overlapping computation (attention calculation) with communication (P2P data transfers).
However, current implementations partition the sequence evenly. If you have a fast GPU and a slow GPU, the fast GPU finishes its chunk quickly and sits idle, wasting expensive compute cycles while waiting for the slower GPU to catch up.
Our Solution: Heterogeneity-Aware Context Parallelism
To solve this, we moved away from the rigid “even split” dogma. We proposed and implemented Heterogeneity-Aware Partitioning, which dynamically allocates the context workload based on the compute capability of each rank.
Get Yq’s stories in your inbox
Join Medium for free to get updates from this writer.
We explored three strategies to rebalance the load:
- Uneven Split (Proportional): Distributing tokens based on a simple ratio of compute throughput (e.g., if GPU A is 2x faster than GPU B, it gets 2x the tokens).
- Lookup Table (LUT): Using offline profiling data to map exact sequence lengths and hardware speeds to an optimal split.
- Polynomial Regression: Modeling the performance curve to predict the optimal split on the fly.
Press enter or click to view image in full size
Figure 1: Token distribution across GPUs under different partitioning strategies. As heterogeneity increases (lower MPS), adaptive strategies assign more tokens to the faster GPU to balance execution time.
Technical Implementation
We implemented these strategies within IBM’s Foundation Model Stack (FMS). This wasn’t just a Python logic change; it required deep systems work:
- **Modified **
**RingAttentionStrategy**: We rewrote the strategy to support uneven block lengths across the process group. - Custom Triton Kernels: Standard FlashAttention kernels assume uniform block sizes. We wrote custom Triton kernels to handle the complex online softmax calculations required when aggregating attention scores across uneven KV shards.
Methodology & Experiment Design
Since access to physical clusters with specific mixed cards (like H100 + A100 pairs) is difficult to control precisely, we simulated heterogeneity using NVIDIA Multi-Process Service (MPS).
- Setup: 2x NVIDIA L40 GPUs.
- Simulation: We kept Rank 0 at 100% capacity and throttled Rank 1 using MPS (varying from 10% to 90% capacity) to create a “synthetic straggler.”
- Workload: We focused on the prefill phase of inference, sweeping sequence lengths from 4k to 65k tokens.
Press enter or click to view image in full size
Figure 2: Simulating heterogeneity using NVIDIA MPS to throttle specific ranks, creating a controlled testbed for “straggler” nodes.
Results: Recovering Lost Performance
The results were stark. Under extreme heterogeneity (where one GPU is effectively 10x slower than the other), the standard approach failed significantly.
1. The Baseline Failure The “Even Split” strategy suffered a 5–8x slowdown. The system effectively ran as if both GPUs were the slow one, completely negating the presence of the faster card.
2. The Adaptive Success Our “Uneven Split” strategy achieved up to a 4.4x speedup over the baseline. By giving the faster GPU more work, we ensured both GPUs finished at roughly the same time, maximizing aggregate throughput.
Press enter or click to view image in full size
Figure 3: Slowdown factor relative to a homogeneous baseline. Note how the blue line (Even Split) spikes dramatically as heterogeneity increases, while our Orange line (Uneven Split) remains efficient.
3. Simplicity Wins Most importantly, we found that the simple Proportional Split strategy was incredibly effective, often matching or beating the complex LUT/Regression approaches. This suggests that for most production deployments, a simple heuristic based on relative FLOPs is enough to reclaim wasted performance.
Press enter or click to view image in full size
Figure 4: Speedup heatmap. As sequence lengths grow (y-axis) and heterogeneity increases (x-axis), the value of our approach scales exponentially.
Future Steps
This project, developed for the High Performance Machine Learning course (COMS E6998), highlights the importance of hardware-aware algorithms in modern ML systems. The next frontier involves:
- GQA Kernel Optimization: Extending our Triton kernels to support Grouped Query Attention (GQA).
- Disaggregated Inference: Decoupling the compute-bound prefill phase from the memory-bound decode phase, assigning them to different classes of hardware entirely.
- Advanced Q/KV Partitioning Algorithms: To design more sophisticated partitioning policies that jointly optimize Q and KV sharding, take into account network topology and contention, and potentially re-balance shards at runtime as loads change.
Special thanks to my teammates William Baisi, Anshul Sadh-Gauri, Kanghyuk Lee, and our mentor Dr. Garrett Goon for their contributions to this research.