Distributed training with mlx: tensor parallelism
stefpi.net·1d·
Discuss: Hacker News
SIMD Optimization
Preview
Report Post

The content of this document was orginally intended to be merged into the MLX official docs as an example for tensor parallelism. It is currently in PR review here. Since I was proud of the effort and explanation I gave, I decided to cross post it here as well.

Tensor Parallelism in MLX

MLX enables efficient implementation of tensor parallelism (TP) through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.

Sharded Layers

AllToShardedLinear

Column-wise tensor parallelism. This layer replicates a common input and shards the weight matrix along the output dimension (column…

Similar Posts

Loading similar posts...

Keyboard Shortcuts

Navigation
Next / previous item
j/k
Open post
oorEnter
Preview post
v
Post Actions
Love post
a
Like post
l
Dislike post
d
Undo reaction
u
Recommendations
Add interest / feed
Enter
Not interested
x
Go to
Home
gh
Interests
gi
Feeds
gf
Likes
gl
History
gy
Changelog
gc
Settings
gs
Browse
gb
Search
/
General
Show this help
?
Submit feedback
!
Close modal / unfocus
Esc

Press ? anytime to show this help