Team: William Hu, Drew Wadsworth, Sean Siddens, Stanley Winata, Daniel Fu, Ryan Swann, Muhammad Osama, Christopher Ré, Simran Arora Links: Arxiv | Code
AI is compute hungry. So we’ve been asking: How do we build AI from the hardware up? How do we lead AI developers to do what the hardware prefers?
AMD GPUs are now offering …
Team: William Hu, Drew Wadsworth, Sean Siddens, Stanley Winata, Daniel Fu, Ryan Swann, Muhammad Osama, Christopher Ré, Simran Arora Links: Arxiv | Code
AI is compute hungry. So we’ve been asking: How do we build AI from the hardware up? How do we lead AI developers to do what the hardware prefers?
AMD GPUs are now offering state-of-the-art speeds and feeds. However, this performance is locked away from AI workflows due to the lack of mature AMD software. We share HipKittens, an opinionated collection of programming primitives to help developers realize the hardware’s capabilities: optimized register tiles, 8-wave and 4-wave kernel patterns instead of wave-specialization to schedule work within processors, and chiplet-optimized cache reuse patterns to schedule work across processors.
Checkout part one of this series for an intro to HipKittens and checkout this post for a technical deep dive.
What do AMD CDNA GPUs look like? A lay of the land.
An AMD MI355X GPU has 256 processors called “compute units” (CUs) and a CU contains four SIMDs. A SIMD has different execution units. A 64-thread “wave” (contrasting a 32-thread warp on NVIDIA) occupies a single SIMD. We show the MI355X memory hierarchy below.
Unsurprisingly, making AMD GPUs go brr boils down to keeping the “matrix cores” (tensor cores on NVIDIA) fed. There are a few differences in how we think about this hardware:
- What it’s not. An MI355X has 70% the SRAM of a B200 (165KB instead of 228KB), lacks asynchronous matrix multiplication instructions that operate on inputs in shared or tensor memory (wgmma, tcgen05), lacks register reallocation (the ability for some waves to give their registers to others), lacks tensor memory acceleration (dedicated hardware for global memory access), and lacks first class mbarrier primitives (for fine-grained synchronization).
- What it is. On the other hand, AMD GPUs have a 2x larger register file per processor than the B200 and offers 60% more processors per GPU (256 compute units versus 160 streaming multiprocessors). AMD offers tiny and fine-grained matrix core instructions, while NVIDIA tensor cores instructions are generally called with large input operands. AMD has a TMA-like direct global to shared memory loads via buffer_load_dword\verb|buffer_load_dword| instructions, which bypass the register file.
- Towards chiplet architectures. AMD is also leading the charge in the shift from monolithic grids to chiplets. AMD splits the 256 processors into 8 chiplets called “XCDs” of 32 CUs. NVIDIA B200s include 2 chips. The AMD cache is disaggregated: an AMD XCD has a private L2 cache and there is an extra last level cache (LLC) that sits between the L2 and HBM memory.


| Spec | NVIDIA B200 SXM5 | AMD MI355X OAM |
|---|---|---|
| BF16 matrix / tensor | 2.2 PFLOPs | 2.5 PFLOPs |
| MXFP8 matrix / tensor | 4.5 PFLOPs | 5.0 PFLOPs |
| MXFP6 matrix / tensor | 4.5 PFLOPs | 10.1 PFLOPs |
| MXFP4 matrix / tensor | 9.0 PFLOPs | 10.1 PFLOPs |
| Memory capacity | 180 GB | 288 GB |
| Memory bandwidth | 8.0 TB/s | 8.0 TB/s |
Table 1: Hardware overview. Peak memory and compute speeds for the latest generation GPU platforms.
These differences impact the ways in which we design kernels on AMD.
- Optimized memory access: Of course this matters on NVIDIA, but AMD’s layouts, HIPCC compiler limitations, and (undocumented) quirky behaviors of different I/O instructions yields new challenges.
- Scheduling within processors: We need to rely on our register file and small matrix core instructions instead of shared memory and bulky wgmma/tcgen05 instructions to establish deep pipelines and hide memory costs. Wave specialization / producer consumer, which reigns supreme in NVIDIA kernels, is not the right answer on AMD.
- Scheduling across processors: We need to start thinking about NUMA effects at the cache level as we schedule work across thread blocks.
We walk through these three topics next.
HipKittens memory access patterns
As in ThunderKittens, in HK, developers program using tiles as the basic data structure. Tiles exist in shared or register memory, and are parametrized by a data type, size dimensions (a multiple of the matrix core instruction shape), and layout. HK provides a library of PyTorch-like functions that operate on tiles, for instance exp\verb|exp|, mma\verb|mma|, sub\verb|sub|, add\verb|add|, row_max\verb|row_max| compute ops and load, store memory ops. A tile is collectively owned by threads in a wave (warp). The functions use template metaprogramming to generalize to different input tiles and are lightweight, directly wrapping assembly (PTX, CDNA ISA) and C++.
A memory layout determines how data elements map to thread ownership. A matrix core instruction expects a particular register layout depending on the data type and instruction shape. We also want to maximize the granularity and coalescing of global memory loads. Between registers and HBM, shared memory is split into banks (4-byte regions) that can serve data simultaneously. If threads from a wave request data from the same bank at the same time, their accesses are serialized; efficient kernels use “swizzle patterns” to organize data in a way that avoids these bank conflicts.
A few challenges for memory access in HK include:
- Register scheduling: A core tenant of HK and TK is to give developers full control over register allocation by remaining C++ embedded. Compilers like Triton prevent register management altogether, but surprisingly we find that even the HIPCC compiler imposes severe limitations (no wonder AMD uses raw assembly!).1 For instance, 4-wave (1-wave per SIMD) kernels compiled via HIPCC cannot use data held in certain types of registers as inputs to matrix instructions. This motivated us to add explicit register scheduling to HK, where developers pin specific registers when creating registers tiles, effectively replacing the compiler’s register management capabilities. Developers thus have the control necessary to write peak performance kernels! Learn more about explicit register scheduling
When a single wave is mapped per SIMD, the 512 registers are actually divided into 256 accumulator general purpose registers (AGPRs) and 256 vector general purpose registers (VGPRs). AGPRs have fundamental hardware limitations (e.g., vector arithmetic instructions cannot operate on them), but they can still crucially serve as the input or outputs for MFMA instructions. HIPCC, however, cannot generate code that uses AGPRs as input to MFMA instructions, leading to inefficient register management for register heavy workloads and redundant accvgpr_read/write instructions that move data between VGPRs and AGPRs.
- Register layouts: NVIDIA tensor core layouts are regular – as we vary the data type or matrix shape, the layout is composed of an underlying “core matrix” structure. Thus, frameworks like TK and Gluon can apply a unified swizzling rule to avoid bank conflicts. However, AMD layouts differ significantly based on the data type and matrix shape. In fact, we show that it’s not possible to use a single swizzle pattern for all layouts. Further, sometimes we want to use multiple matrix shapes within the same kernel meaning that our swizzle pattern needs to be compatible with multiple types of layouts concurrently.

Figure: AMD register layouts for matrix instructions are less structured. NVIDIA layouts are all composed from an underlying core matrix structure.
- Instruction phases: Waves (and NVIDIA warps) execute shared memory read/write instructions in phases, where a subset of threads in the wave access shared memory concurrently. NVIDIA instructions sequentially assign threads to phases (e.g., threads 0-7 in phase one, 8-15 in phase two). However, the phase groups are both non-sequential and differ entirely based on AMD CDNA memory instruction. For example, we found that a ds_read_b128\verb|ds_read_b128| instruction reading 128 bits for each thread executes across 4 phases and has access to 64 banks. On the other hand, a ds_write_b64\verb|ds_write_b64| instruction writing 64 bits for each thread executes across 4 phases and has access to only 32 banks. This behavior is not well-documented even within AMD(! 😔), so we created and release solvers that reverse-engineer this behavior. Learn why we can’t use a single swizzle pattern for all AMD layouts.
Proof by contradiction:
To show why a single swizzling pattern is insufficient across different register tile shapes and layouts on AMD GPUs, consider the following two access patterns that surface in attention backwards:
- A row-layout 16x16 bf16 tile is written to shared memory. For this tile configuration, each thread holds 4 contiguous bf16 values - 64 bits in memory - and the most optimal instruction to issue this write is ds_write_b64\verb|ds_write_b64|. Avoiding bank conflicts for this access requires a swizzle pattern that respects the phase ordering and bank behavior previously mentioned. In this case, a swizzle that abides by these constraints is computed as offset\mathrm{offset} ^ =((offset%512)>>7)<<3= ((\mathrm{offset} % 512) >> 7) << 3, where 64-bit chunks of memory is shifted around memory using an XOR swizzle.
- A row-layout 16x32 bf16 tile is read from shared memory. For this tile, each thread holds 8 contiguous bf16 values - 128 bits in memory - and the most optimal instruction to issue this read is ds_read_b128\verb|ds_read_b128|.
Regardless of the swizzling pattern required for ds_read_b128\verb|ds_read_b128|, the granularities of these two instructions are in conflict with each other. ds_read_b128\verb|ds_read_b128| requires at least 128 bits of memory to be contiguous in shared memory, and the swizzle pattern for ds_write_b64\verb|ds_write_b64| breaks apart memory into 64-bit chunks. As a result, different swizzling patterns need to be used for each.
- Address generation: AMD GPUs support direct asynchronous HBM to shared memory loads. Like TMA, these loads bypass the register file. The instruction takes as input per-thread addresses in HBM from which each thread will read data. While DSLs like TK directly swizzle the shared memory addresses, swizzling shared memory on AMD is instead accomplished by swizzling on the HBM addresses.
We provide developers with optimized tile layouts and memory access patterns by default within HK. Checkout our paper to learn more about how we implement solutions to the above challenges.
HipKittens schedules within a processor
Ideally we would have simple, reusable patterns for scheduling the compute and memory within kernels that generalize across AI workloads. Wave specialization / producer consumer serves this purpose on NVIDIA, but what about on AMD?
Wave specialization struggles on AMD. Wave specialization is the dominant paradigm for achieving high occupancy on modern NVIDIA GPUs. Producer waves focus on memory movement while consumer waves focus on computation. This strategy underpins today’s state-of-the-art AI kernels—including FlashAttention-3, COMET for MOE models, and high-performance GEMMs —as well as kernel DSLs such as ThunderKittens LSCF and TileLang.
But, we show that wave specialization underperforms on AMD due to the lack of register reallocation. On the MI355X, registers are statically divided across all waves. Producer waves that only need a few registers for address calculation are allocated more registers than they need; consumer waves cannot recoup those registers and must either spill registers to scratch memory or run at a lower arithmetic intensity. Both are disastrous for performance. Wave specialization limits the output tile size and makes our kernels more memory bound. For GEMMs, data loaded from memory is O(MK + NK) while compute is O(MNK). Decreasing the M or N in our per thread block output tile size lowers arithmetic intensity. 2
| # P / # C | MFMA Shape | Output | TFLOPS |
|---|---|---|---|
| HK 4 / 8 | 16×16×32 | 128×256 | 893 |
| HK 4 / 12 | 16×16×32 | 192×256 | 1278 |
| HK 0 / 8 | 16×16×32 | 192×256 | 1281 |
| HK 0 / 8 | 16×16×32 | 256×256 | 1605 |
| TK | 256×256×16 | 256×256 | 1538 |
| CUTLASS | 256×256×16 | 256×256 | 1570 |
Figure: Wave specialization underperforms on AMD GPUs. We benchmark AMD GEMMs on the MI355X using different numbers of producers (P) and consumer (C) waves. We report the matrix core intrinsic shape, output tile size computed per thread block, and TFLOPs (500 iterations warmup / 100 iterations measured). The CUTLASS GEMM is selected and tuned using the CUTLASS profiler tool on a B200 GPU.
As an aside, it might be surprising that AMD matches NVIDIA GEMM performance without all the bells and whistles of wgmma/tma, producer consumer, TMA, mbarriers, large shared memory for deep multi-stage pipelining etc. But… AMD has a 2×2\times larger register file and AMD’s smaller tensor core shapes (e.g., 16×16×3216\times16\times32) provide an alternative path to establish deep pipelines by using finer-granularity load and compute stages.
Scheduling patterns for AMD. Our attempt to use wave specialization - a strategy that works well on NVIDIA GPUs - did not yield the expected speedups on AMD hardware. All is not lost! We found two scheduling patterns that consistently yield high occupancy AMD GPUs, while using tile programming primitives (no raw assembly)!
- 8-wave ping-pong: We assign two waves per SIMD and at any given time, one is executing a cluster of memory instructions while the other wave executes a cluster of compute instructions. The waves swap at the end of cluster execution. With this approach, the developer can use large HK tiles since a thread issues many of the same instructions at once!
- 4-wave interleave: We assign one wave per SIMD and threads in this wave finely switch between issuing memory and compute operations. Here, the developer uses small HK tiles (essentially matching the size of the matrix core instruction shape) to achieve the fine-grained schedule.
These two patterns tradeoff programmability and performance, where 8-wave and its large tile primitives lead to compact code and 4-wave fine-grained interleaving expands code size. Surprisingly, the 8-wave schedule is sufficient to achieve SoTA-level performance on GEMMs and attention forwards. For GQA non-causal attention backwards, 8-wave also outperforms all AMD baselines by 1.8×1.8\times, and our HK 4-wave further outperforms by 2.3×2.3\times.

Figure: HK 8-wave ping pong pattern. We include a profiler snippet of the HK BF16 GEMM.
HipKittens schedules across processors
Modern GPUs are moving toward chiplet-based architectures, shifting away from traditional monolithic dies. AMD’s MI355X, for instance, integrates eight chiplets (XCDs), each with its own L2 cache, while NVIDIA’s B200 pairs two dies together. This shift enables higher scalability and yield but introduces a new performance challenge: disaggregated memory hierarchies. Each cluster of compute units now has local caches, and memory locality is no longer uniform across the chip.
On AMD GPUs, thread blocks are scheduled to chiplets in a round-robin fashion, meaning that the order in which blocks are launched—the grid schedule—directly affects how effectively data is reused in cache. Even perfectly tuned kernels can lose bandwidth if their grid layout is cache-unfriendly.






Figure: Visualization of three different grid schedules for the output matrix of a BF16 GEMM. The color represents the XCD assignment for the first set of thread blocks scheduled across the GPU’s 256 processors. Top row is for a 9216×9216×92169216\times9216\times9216 shaped GEMM and the bottom row is for a 14592×14592×1459214592\times14592\times14592 shaped GEMM. The left most column shows the assignments under a naive row-major layout, the middle column shows an approach that optimizes L2 cache reuse, and the right column shows the output from our algorithm, balancing L2 and LLC cache reuse.
| Block Order | L2 % | LLC % | Mem. BW | TFLOPS |
|---|---|---|---|---|
| Matrix Multiply (M=N=K=9216) | ||||
| Row-major | 55% | 95% | 15.1 TB/s | 1113 |
| XCD (W 7/C 216) | 79% | 24% | 14.9 TB/s | 991 |
| XCD (W 5/C 25) | 75% | 93% | 18.3 TB/s | 1145 |
| Matrix Multiply (M=N=K=14592) | ||||
| Row-major | 36% | 76% | 10.7 TB/s | 900 |
| XCD (W 8/C 542) | 79% | 7% | 13.9 TB/s | 980 |
| XCD W 8/C 64 | 78% | 55% | 16.6 TB/s | 1068 |
Table: Performance results corresponding to the above chiplet swizzling figures.
Above for a GEMM D=AB+CD=AB + C, we show different patterns for assigning thread blocks the responsibility of computing different tiles of the output matrix DD. When thread blocks are scheduled in naíve row-major order, cache reuse is suboptimal (≈55%\approx55%) because blocks that share the same L2 cache often load different, non-overlapping tiles of A and B. Further, optimizing purely for L2 locality can cause each XCD to fetch disjoint portions of A and B, leading to redundant loads at the next cache level.
To address this, HipKittens introduces a chiplet-aware scheduling strategy that reorganizes the grid launch order to better exploit locality at both the L2 (per-chiplet) and LLC (shared) cache levels for GEMM workloads. The key idea is to group thread blocks that operate on nearby regions of the output matrix so that they naturally reuse overlapping tiles of input data across cache hierarchies.
Putting it all together
Let’s take a look at a few kernels written in HK.
-
First, here’s the hot loop of our attention forwards pass kernel (the entire kernel is ≈500\approx 500 lines of code). We can see that the kernel uses HK’s 8-wave ping pong schedule where waves alternate between compute instruction clusters and memory clusters.
-
Here’s the core hot loop structure for our BF16 GEMM kernel. Again, we can see that waves alternate between compute clusters and memory clusters using HK’s 8-wave ping pong schedule.
Multi-silicon AI is coming!
HipKittens delivers competitive performance on AMD CDNA3 and CDNA4 through three key insights: optimized memory access, AMD-centric wave scheduling patterns within a processor, and chiplet-aware grid scheduling across processors to exploits AMD’s disaggregated cache hierarchy. Our kernels consistently achieve peak performance amongst AMD baselines across workloads (and compete with peak Blackwell kernels as well).
Realizing AI’s full potential requires diverse, open hardware.1 Today, that means making AMD GPUs truly accessible.
We want more AI in the world. AI has relied on and innovated on a single hardware provider, but we need to be able to use and experiment with all the compute we can. We need to be able to use the fastest hardware out there. We’re happy to help address these problems with HipKittens!

Figure: Surfing towards multi-silicon AI!
- We believe that the HIPCC register scheduling is one of the most important areas for improvement in AMD’s kernel software stack.↩
- We hope these findings lead to hardware changes that support wave specialization or guide AMD kernel development; for instance, Mojo currently provides a warp-specialized matmul kernel as of 11/06/2025 even though AMD CDNA doesn’t have register reallocation.↩