Show HN: Optimizing DeepSeek's NSA for TPUs – A Kernel Worklog
henryhmko.github.io·12h·
Discuss: Hacker News

I enjoyed reading DeepSeek’s NSA and I thought it would be an interesting challenge to implement and optimize it for TPUs.

I was especially curious about how NSA, which is heavily optimized for GPUs, could be optimized for TPUs which have fundamentally different design philosophies.

Before we dive in, here’s the link to the colab notebook where all my code is. This includes the vectorized JAX baseline of NSA, Pallas kernels, and profiling code. I hope you get to tinker with my code to understand NSA and Pallas better.

Note: All code and experiments ran with TPU v5e. We’ll be looking at Selection Branch only as that is the most quirky branch in NSA.

Let me …

Similar Posts

Loading similar posts...