Learning to love mesh-oriented sharding
Famously, PyTorch and JAX don’t agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it’s easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: …
Learning to love mesh-oriented sharding
Famously, PyTorch and JAX don’t agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it’s easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: what is the better internal representation of sharding? I don’t claim to have all the answers, but my motivation for writing this post is to help explain where I currently stand and how I evaluate proposals for evolving DTensor and sharding in PyTorch.
Closed versus open. I am going to make a precise technical claim: JAX sharding is closed, where as PyTorch sharding is (in principle) open. Here, what I mean by closed/open refers to the capability for users to extend a system: traditional ADTs are closed (you can’t add another constructor to an ADT), whereas object-oriented classes are open (you can define a new subclass of a class). Now, technically JAX sharding is open: the jax.sharding.Sharding is a base class that is intended to be subclassed, but to do this you have to define things like _to_xla_hlo_sharding, which is as good as not being supported. The regular class everyone uses, NamedSharding, consists of a mesh and a tuple of mesh axes, with no obvious extension points. I also offer for the defense this unanswered forum post: https://github.com/jax-ml/jax/discussions/23703
In contrast, PyTorch sharding is in principle extensible: the sharding is expressed as a list of Placement, a class which is subclassed to define custom shardings. The extensibility of Placement isn’t really well supported (for example, there’s no way of conveniently adding extra rules for placements to sharding rules), but it works enough that both internally and externally there are implementations of weird placements (internally, StridedShard and NormPartial... and technically all of the non-sum reductions supported by Partial as well as uneven sharding; externally, see RaggedShard and InterleavedShard).
Why does mesh-dim oriented sharding support extensibility in this way? The key is that mesh-oriented sharding is very imperative in nature: you can think of the list of placements as a sequence of transformations you apply to the tensor from left-to-right. Concretely, given the current local tensor (as produced by all of the placements you handled for the mesh dims before the one you’re currently processing), run an invertible function to split this tensor along the current mesh dimension. This gives you a bunch of new local tensors which you recursively continue sharding with the rest of the mesh dims. The invertibility of the function is the only real constraint on what function you can provide (since you need to be able to reassemble the shards back into the original full tensor), but otherwise your choice of function is unconstrained. It is in this sense that Placement is morally extensible.
When designing systems, it is not an unambiguous good to make the system more flexible. Closed systems like JAX’s mean you don’t have to worry about hilariously complicated situations like what if you unevenly shard on the same dimension multiple times (do you have any guarantees on the local sizes of tensors being somewhat balanced?) But sometimes, the use case demands a greater degree of expressivity (in the same way that manual memory management allows you to do more than you can conveniently do in a GC’ed language.)
How expressive does Sharding have to be? One of the primary value propositions of DTensor is that it specifies a standardized representation for saying how a tensor is sharded across your cluster. It’s very good to have this information, because it prevents accidents, like forgetting that a tensor dimension is sharded so you do a reduction on that dimension without first doing a collective and you get subtly wrong results that take weeks to debug. It’s better to have a system that is correct but slow, than it is to have a system that is fast but incorrect.
Being able to express all distributed states is not a terminal goal. There are lots of situations in distributed optimizations where you temporarily need to put the system in a state where it is very difficult to describe exactly how to interpret data across nodes. For example, when you implement ring attention, to avoid communications when performing softmax, you instead perform online softmax. It’s quite difficult to say what the "placements" of the running quantities in online softmax are. In this case, we shouldn’t overly stress ourselves with defining a placement: we should just use local_map or shard_map and absolve ourselves of needing to actually say exactly how data is laid out at any given point in time. But the key is that we should only do this in local regions of code; if we give up and local_map our entire model, we might as well have just not written our code with DTensor at all. So we should seek additional expressivity when it is needed to express how data is being communicated across system boundaries.
Here are some classic examples from LLM training where you need a little bit of extra expressivity, starting with simple cases and becoming more complicated:
- Suppose you are applying FSDP to a model, where the parameter sizes haven’t been chosen with parallelism in mind; and in particular, the size of your cluster doesn’t evenly divide with the parameter count. It can be convenient to allow for an uneven sharding to happen in this case, so that the user doesn’t have to manually take care of padding out their tensor so that it can be allgathered.
- Say you do a matrix multiply between two tensors which are sharded on the contraction dimension. A reduction is required to communicate the local results into the final tensor. However, sometimes, it can be profitable to delay this reduction, since it can be coalesced into a later reduction. This requires the ability to express that a tensor has a pending reduction on some mesh axis.
- If you have both FSDP and row-wise TP, if your FSDP implementation naively shards on the first dim of your weight tensor, you need to ensure that the TP sharding occurs before the FSDP sharding (so that when you undo your FSDP sharding, you have the expected TP sharding ready to go for TP.) This requires the ability to express the order of sharding in a non-standard order (right-to-left, as is supported by list of mesh axes aka PartitionSpec), OR the ability to express that the FSDP is a weird "strided" shard where you don’t have contiguous data, instead you have stripes of data that will then be further sharded by the TP sharding.
- Suppose you have a tensor going into a matrix multiply which is not sharded on batch (because you’re naughty and you’re following the standard PyTorch convention of not actually expressing batch-sharding in DTensor) but is sharded on sequence. If you want to treat both batch and sequence as "batch" for the matmul, in PyTorch, this typically requires flattening these two dimensions into a single flat batch dimension. However, this cannot be done, as there is no native Placement that can represent a flattened (Replicate, Shard); however, this works with StridedShard (or InterleavedShard, which is the same thing.) More generally, it is extremely irritating that DTensors cannot reliably have view operations applied to them (that would be supported on plain tensors), and you need weird shard types to be able to handle many view operations.
- Traditional FSDP2 assumes that there’s not any requirement for how parameters are placed on nodes; but things like block-wise quantization and structure-aware optimizers need the ability to place a block/parameter on a single device, so that you have access to all the data you need. This won’t be a standard sharding; the shards will be ragged.
I think it’s a worthy use of complexity budget to search for a system design that can handle all of these things, especially since PyTorch’s existing mesh-oriented sharding is already tantalizingly close to supporting this.
Why is adding a new Placement to PyTorch hard? I tend to think, fundamentally, that a mesh-oriented sharding strategy can support arbitrary Placement subclasses. So why does this not work so well in PyTorch? I think there really only are two issues:
- There is no official way to retroactively add extra sharding propagation rules to existing operators. What I always tell people is that a sharding propagation rule is simply a mathematical equality, saying that map(f, in_placement(x)) == out_placement(f(x)). Mathematical equalities are infinitely compositional: you can always add more true equalities to your system without compromising correctness. But there isn’t actually a way to do this.
- Many sharding propagation rules are written assuming only shard exists. Placement provides an is_shard method to test if something is a shard (as opposed to replicate/partial), and sharding propagation rules often assume that if this is True, you specifically have a standard, even Shard, as if it was the only sharding in the universe. This means that rules are often secretly buggy when custom Placements are added. StridedShard, in particular, naughtily advertises that it is_shard(), which means that we will happily allow for it to contract with a plain Shard, leading to this bug: https://github.com/pytorch/pytorch/issues/166598 Now, to be clear; often rules WILL work for arbitrary sharding subclasses; for example, if an input dimension is a batch dimension, it doesn’t matter how the data is sliced up, your operation is functorial over that dimension. Will Constable has been working on refactoring our sharding rules to distinguish between the "it’s a batch dimension" situation versus the "I specifically need an even sharding" or "I need these shardings on two inputs to be exactly the same kind of sharding."
I think with these two issues fixed, and a bit of careful design on what the overrideable API on Placement is for subclasses, I think we can have a very good extensibility story for shardings.