Muhammad Maaz
Home / Writings October 5, 2025 The code accompanying this post is available at pbt-batch-invariance.
Recently, a Thinking Machines blog post discussed why nondeterminism in large language models is a problem. The blog argues that batch-invariance in matrix multiplication, RMSNorm, and attention is crucial for deterministic inference. In their repo, the test_batch_invariance.py
file shows a simple test for batch-invariance of matrix multiplication, with a random draw of PyTorch tensors (basicall…
Muhammad Maaz
Home / Writings October 5, 2025 The code accompanying this post is available at pbt-batch-invariance.
Recently, a Thinking Machines blog post discussed why nondeterminism in large language models is a problem. The blog argues that batch-invariance in matrix multiplication, RMSNorm, and attention is crucial for deterministic inference. In their repo, the test_batch_invariance.py
file shows a simple test for batch-invariance of matrix multiplication, with a random draw of PyTorch tensors (basically, using torch.randn
).
This testing seemed interesting enough, but I wanted to do something more rigorous. I wanted to use property-based testing to test for batch-invariance. The Hypothesis library allows for more sophisticated testing: you define an input domain (which can be quite complex), and property or properties that should hold. Hypothesis then generates random inputs from the domain, and checks that the properties hold. Another good thing about Hypothesis is that, if it finds a counterexample, it will attempt to shrink it to a minimal example.
An example of a property is, for example, that a sorted list, sorted via a function my_sort()
, should be non-decreasing. In Hypothesis, we would write this as:
from hypothesis import given, strategies as st
@given(st.lists(st.integers()))
def test_sorted_list(lst):
sorted_lst = my_sort(lst)
for i in range(len(sorted_lst) - 1):
assert sorted_lst[i] <= sorted_lst[i + 1]
This test generates random lists of integers, sorts them via my_sort()
, and checks that the sorted list is non-decreasing. We are going to try to do something similar for batch-invariance.
The batch-invariance property
Matrix multiplication
The test that the Thinking Machine repo tests is essentially the following: given two tensors ( a ) and ( b ), test that ( a[:1] @ b = (a @ b)[:1] ). The left-hand side is ( a[:1] ) matrix-multiplied by ( b ), while the right-hand side is the first row of ( a @ b ). The dimensions are the same, and because of how matrix multiplication works, the two are equivalent. However, as the Thinking Machines blog post argues, in practice, they are not equivalent, because of how kernel operations are implemented. The way the test is currently written draws a random ( a ) and ( b ), of fixed size, with a fixed sequence of numbers defined by a linear space, and then tests this property by computing the difference of the two sides of the equality.
It would be better to define a general property and let Hypothesis generate random tensors and test this property. First of all, a more general input generation strategy is that instead of taking the slice of the first row, we can take any slice, namely, rows ( m ) to ( n ), exclusive of the last row. Second of all, the sizes of the tensors can be random, and the elements of the tensor are random floats within a given range. So, we write the following input strategy:
- Generate random dimensions, ( B,D,N ), in a specified range
- Generate random tensors ( a ) (size ( B \times D )), and ( b ) (size ( D \times N )) with elements in a specified range, and (\text{inf}) and (\text{nan}) are disallowed
- Generate random integers ( m ) and ( n ) such that ( 0 \leq m < n \leq B ) are the slice indices Then, define two ways of computing the matrix multiplication as (\text{out1} = a[m:n] @ b) and (\text{out2} = (a @ b)[m:n]). Finally, the property we assert that (\text{diff} = \max \text{abs} | \text{out1} - \text{out2} |) is less than a specified tolerance.
RMSNorm
You can do something very similar for RMSNorm (the root-mean-square normalization). Now the input strategy is:
- Generate random dimensions, ( B,D ), in a specified range
- Generate random tensors ( x ) (size ( B \times D )) and ( \gamma ) (size ( B )) with elements in a specified range, and (\text{inf}) and (\text{nan}) are disallowed
- Generate random integers ( m ) and ( n ) such that ( 0 \leq m < n \leq B ) are the slice indices Then, the property is essentially the same: assert a tolerance of the difference between the RMSNorm of ( x[m:n] ) and (\gamma) versus the [( m:n )] slice of the RMSNorm of ( x ) and (\gamma).
Attention
Lastly, you can do the same thing with attention (the scaled dot-product attention). Now the input strategy is:
- Generate random dimensions, ( B, \text{seq_len}, \text{num_heads}, \text{head_dim} ), in a specified range
- Generate random tensors ( Q ) (size ( B \times \text{seq_len} \times \text{num_heads} \times \text{head_dim} )), ( K ) (size ( B \times \text{seq_len} \times \text{num_heads} \times \text{head_dim} )), and ( V ) (size ( B \times \text{seq_len} \times \text{num_heads} \times \text{head_dim} )) with elements in a specified range, and (\text{inf}) and (\text{nan}) are disallowed
- Generate random integers ( m ) and ( n ) such that ( 0 \leq m < n \leq B ) are the slice indices Then, we assert a tolerance of the difference between the attention of ( Q[m:n] ), ( K[m:n] ), and ( V[m:n] ) versus the [( m:n )] slice of the attention of ( Q ), ( K ), and ( V ).
Implementations
Batched
The built-in implementations use reduction across the batch dimension.For matrix multiplication, we use the @
operator. For RMSNorm, we use the definition: x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * gamma
. For attention, we use the torch.nn.functional.scaled_dot_product_attention
function.
Rowwise
The rowwise implementations break up the computation across rows (batches) in order to enforce batch-invariance. They are essentially the same as the batched implementation, but instead compute each row separately, then stack them. This is obviously much slower, and is not the same as the batch-invariant kernels implemented in the Thinking Machines post.
How tests are structured
For each operation, we write a parameterized test that takes the operation implementation as an argument. Each operation has a batched version and a rowwise version. Each test comes with a strategy for generating the random inputs, as described above. Then, when we run each test, it tests the batched version and the rowwise version. Theoretically, the batched version should fail, as it is not batch-invariant, and the rowwise version should pass.
Results
I ran the tests on my MacBook Air (CPU) as well as a Google Colab notebook with T4 GPU. The results of these tests can be seen in the GitHub repo, under the test_outputs
folder, including the counterexamples that Hypothesis found. On a CPU, the results are:
test_batch_invariance.py::test_matmul[matmul_batched] FAILED
test_batch_invariance.py::test_matmul[matmul_rowwise] PASSED
test_batch_invariance.py::test_rmsnorm[rmsnorm_batched] PASSED
test_batch_invariance.py::test_rmsnorm[rmsnorm_rowwise] PASSED
test_batch_invariance.py::test_attn[attn_batched] PASSED
test_batch_invariance.py::test_attn[attn_rowwise] PASSED
Here, on the CPU version, only matmul_batched
failed, not the other *_batched
versions. This is likely due to how CPU implementations work. The output of the tests on a GPU is:
test_batch_invariance.py::test_matmul[matmul_batched] FAILED
test_batch_invariance.py::test_matmul[matmul_rowwise] PASSED
test_batch_invariance.py::test_rmsnorm[rmsnorm_batched] FAILED
test_batch_invariance.py::test_rmsnorm[rmsnorm_rowwise] PASSED
test_batch_invariance.py::test_attn[attn_batched] FAILED
test_batch_invariance.py::test_attn[attn_rowwise] PASSED
Here, on the GPU version, all *_batched
versions failed, and all the *_rowwise
versions passed. As the original blog post argues, the way that GPU kernels handle batching causes nondeterminism.
Conclusion
These tests have definitively shown that the built-in implementations for matrix multiplication, RMSNorm, and attention are not batch-invariant. Using Hypothesis, we have tested this property across a much wider range of inputs than the original tests from the Thinking Machines blog post, giving us explicit, minimal counterexamples. Notably, on CPU environments, the batched versions also pass, while on the GPU environments, they fail, reflecting differences in how kernels are implemented.