I enjoyed reading DeepSeek’s NSA and I thought it would be an interesting challenge to implement and optimize it for TPUs.

I was especially curious about how NSA, which is heavily optimized for GPUs, could be optimized for TPUs which have fundamentally different design philosophies.

Before we dive in, here’s the link to the colab notebook where all my code is. This includes the vectorized JAX baseline of NSA, Pallas kernels, and profiling code. I hope you get to tinker with my code to understand NSA and Pallas better.

Note: All code and experiments ran with TPU v5e. We’ll be looking at Selection Branch only as that is the most quirky branch in NSA.

Let me …

Similar Posts

Loading similar posts...

Keyboard Shortcuts

Navigation
Next / previous item
j/k
Open post
oorEnter
Preview post
v
Post Actions
Love post
a
Like post
l
Dislike post
d
Undo reaction
u
Recommendations
Add interest / feed
Enter
Not interested
x
Go to
Home
gh
Interests
gi
Feeds
gf
Likes
gl
History
gy
Changelog
gc
Settings
gs
Browse
gb
Search
/
General
Show this help
?
Submit feedback
!
Close modal / unfocus
Esc

Press ? anytime to show this help