Ring Attention & Sequence Sharding with JAX
If you’ve spent any time pushing Large Language Models (LLMs) to their limits lately, you know the context window arms race is in full swing. A few years ago, 4k tokens was standard. Today, we are routinely asked to ingest 100k, 500k, or even 1-million token documents.
But if you actually try to run a 1-million token forward pass on a standard GPU or TPU, you’ll almost immediately hit an Out of Memory (OOM) error.
Why? Let's pull some formulas and do the math.
In a Transformer, the memory footprint during inference or long-context prefill is heavily dominated by the Key-Value (KV) cache. The formula for the KV cache size in bytes (assuming bfloat16 or float16, which is 2 bytes per parameter) is:
KV Cache Size = 2 × Sequence Length × Layers × KV Heads × Head Dimension × 2 (for K and V)
Let’s plug in the numbers for a model like LLaMA-3 70B. It has 80 layers, 8 KV heads, and a head dimension of 128. If we want a 1-million token sequence length:
2 × 1,000,000 × 80 × 8 × 128 × 2 = 163,840,000,000 bytes ≈ 163 GB
That is 163 GB just to hold the KV cache for a single sequence.
A single Google TPU v5e has 16 GB of HBM (High Bandwidth Memory). A flagship TPU v5p has 96 GB. Even if we completely ignore the 140 GB needed to store the model weights, the KV cache alone physically cannot fit on a single chip.
The Limits of Tensor Parallelism
Your first instinct might be to use Tensor Parallelism (often called Megatron sharding). As Chapter 5 of the Scaling book explains, Tensor Parallelism shards the heads across multiple chips.
But look at our model specs again: LLaMA-3 70B only has 8 KV heads. That means the absolute maximum we can shard the KV cache using standard Tensor Parallelism is 8 ways.
If we split that 163 GB across 8 TPUs, each chip still has to hold about 20 GB of KV cache. On a TPU v5e (16 GB), we still OOM. On a TPU v5p (96 GB), it fits, but once you add the weights and activations, you are cutting it dangerously close, and you've hit a hard ceiling. You literally cannot scale to 16 or 32 chips to get more memory because you ran out of KV heads to split.
We need a different axis to slice. We need to shard the sequence dimension.
Enter Sequence Sharding and Ring Attention
If we slice the sequence dimension, each TPU holds a chunk of the tokens. For example, with 8 TPUs and 1-million tokens, TPU 0 holds tokens 0 to 125,000. TPU 1 holds tokens 125,001 to 250,000, and so on.
The problem with sharding the sequence dimension is the attention mechanism itself. To compute attention, every Query () needs to calculate a dot product with every single Key () and Value () in the past. If TPU 0 only holds the first chunk of and , how does TPU 7 compute attention for the last chunk of tokens?
We solve this by passing the and blocks around in a circle.
TPUs are physically wired together using Inter-Chip Interconnects (ICI) in a 2D or 3D torus (a ring). We can compute attention for our local chunk of against our local chunk of and . Then, everyone simultaneously passes their and to their right neighbor in the ring. We compute attention again. We repeat this until the and blocks have made a full lap around the TPUs.
This algorithm is called Ring Attention. To write it efficiently, we drop down from standard high-level JAX and use jax.experimental.shard_map. Unlike jax.jit, which acts like an automatic compiler that guesses how to shard things, shard_map gives you a local, per-device view. You write the code exactly as it will execute on a single TPU, and you manually specify the communication.
Let's start by setting up our TPU mesh.
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
# Let's assume we are running on an 8-TPU slice (like a v5e 4x2)
num_devices = 8
devices = jax.devices()[:num_devices]
# We define a 1D mesh specifically for the sequence dimension
mesh = Mesh(devices, ('seq',))
# We'll use dummy dimensions for our example
# B=1, SEQ=8192 (total), HEADS=4, DIM=128
B, SEQ, HEADS, DIM = 1, 8192, 4, 128
# Create dummy Q, K, V arrays.
# We tell JAX to shard them across the 'seq' axis.
q = jax.device_put(jax.random.normal(jax.random.PRNGKey(0), (B, SEQ, HEADS, DIM)), jax.NamedSharding(mesh, P(None, 'seq', None, None)))
k = jax.device_put(jax.random.normal(jax.random.PRNGKey(1), (B, SEQ, HEADS, DIM)), jax.NamedSharding(mesh, P(None, 'seq', None, None)))
v = jax.device_put(jax.random.normal(jax.random.PRNGKey(2), (B, SEQ, HEADS, DIM)), jax.NamedSharding(mesh, P(None, 'seq', None, None)))
The Math of Flash Attention
Before we write the shard_map loop, we have to address a mathematical problem.
Standard attention computes the softmax over the entire sequence at once. But in our ring, TPU 0 only sees 1/8th of the keys at a time. You can't compute a true softmax if you only have a piece of the data, because the denominator of the softmax requires the sum of the exponentials for the entire row.
To fix this, we use the exact same math that makes Flash Attention work (detailed in Appendix A of Chapter 4 of the Scaling book). We maintain a running maximum () and a running sum of exponentials ().
When a new block of and arrives over the network, we compute the local attention scores. We find the new maximum. We use the difference between the old maximum and the new maximum to scale down our previous running sums and output vectors, ensuring the math remains identical to doing it all at once.
Let's write the skeleton of our shard_map function and initialize those running statistics.
@jax.jit
def ring_attention(q_sharded, k_sharded, v_sharded):
# We tell shard_map that the inputs are sharded along 'seq'
# and the output will also be sharded along 'seq'
@shard_map(mesh=mesh, in_specs=(P(None, 'seq', None, None),
P(None, 'seq', None, None),
P(None, 'seq', None, None)),
out_specs=P(None, 'seq', None, None))
def local_ring_forward(q_local, k_local, v_local):
# Inside this function, q_local shape is (B, SEQ/8, HEADS, DIM)
# For our dummy sizes: (1, 1024, 4, 128)
batch, local_seq, heads, dim = q_local.shape
# Initialize running statistics for Flash Attention
# m_i: running maximums, initialized to negative infinity
m_i = jnp.full((batch, local_seq, heads, 1), -jnp.inf)
# l_i: running denominator sums, initialized to 0
l_i = jnp.zeros((batch, local_seq, heads, 1))
# out_i: the actual attention output we are accumulating
out_i = jnp.zeros_like(q_local)
scale = 1.0 / jnp.sqrt(dim)
q_scaled = q_local * scale
Passing the Blocks with ppermute
Now comes the actual ring communication. We need to loop num_devices times. In each iteration, we compute the flash attention math for the and blocks currently sitting in our TPU's VMEM.
Once the compute is done, we use jax.lax.ppermute to shift the and arrays to our neighbor. ppermute is a collective operation that performs a permutation of data across the devices in a mesh axis. It maps perfectly to the physical ICI wires connecting the TPUs.
Here is the loop:
# We use a jax.lax.fori_loop to iterate over the ring
def loop_body(step, carry):
m_i, l_i, out_i, k_current, v_current = carry
# --- 1. Compute Local Attention (Flash Math) ---
# Dot product of local Q and the current K block
# q_scaled: (B, L_SEQ, HEADS, DIM) -> (B, HEADS, L_SEQ, DIM)
# k_current: (B, L_SEQ, HEADS, DIM) -> (B, HEADS, DIM, L_SEQ)
q_transposed = jnp.transpose(q_scaled, (0, 2, 1, 3))
k_transposed = jnp.transpose(k_current, (0, 2, 3, 1))
# scores shape: (B, HEADS, L_SEQ, L_SEQ)
scores = jnp.matmul(q_transposed, k_transposed)
scores = jnp.transpose(scores, (0, 2, 1, 3)) # back to (B, L_SEQ, HEADS, L_SEQ)
# Find the max in this block
m_block = jnp.max(scores, axis=-1, keepdims=True)
# Find the new global max between the running max and this block's max
m_new = jnp.maximum(m_i, m_block)
# Compute exponentials, adjusting by the new global max for numerical stability
exp_scores = jnp.exp(scores - m_new)
l_block = jnp.sum(exp_scores, axis=-1, keepdims=True)
# Scale down the old running sum and old output by the difference in maxes
scale_old = jnp.exp(m_i - m_new)
l_new = l_i * scale_old + l_block
# Multiply attention weights by V
# exp_scores: (B, L_SEQ, HEADS, L_SEQ)
# v_current: (B, L_SEQ, HEADS, DIM)
exp_scores_t = jnp.transpose(exp_scores, (0, 2, 1, 3))
v_transposed = jnp.transpose(v_current, (0, 2, 1, 3))
out_block = jnp.matmul(exp_scores_t, v_transposed)
out_block = jnp.transpose(out_block, (0, 2, 1, 3))
# Accumulate into the running output
out_new = out_i * scale_old + out_block
# --- 2. Ring Communication ---
# Shift K and V to the left neighbor in the ring
axis_size = jax.lax.axis_size('seq')
# Define the permutation: device j sends to device (j-1) % axis_size
perm = [(j, (j - 1) % axis_size) for j in range(axis_size)]
k_next = jax.lax.ppermute(k_current, axis_name='seq', perm=perm)
v_next = jax.lax.ppermute(v_current, axis_name='seq', perm=perm)
return m_new, l_new, out_new, k_next, v_next
# Run the loop
initial_carry = (m_i, l_i, out_i, k_local, v_local)
final_carry = jax.lax.fori_loop(0, num_devices, loop_body, initial_carry)
_, final_l, final_out, _, _ = final_carry
# Final normalization: divide the accumulated output by the accumulated denominator
return final_out / final_l
# Call the shard_map function
return local_ring_forward(q_sharded, k_sharded, v_sharded)
# Run it!
final_attention_output = ring_attention(q, k, v)
Overlapping Compute and Communication
If you run this code and open up the JAX/Tensorboard Profiler (as detailed in Chapter 9 of the book), you'll see something beautiful in the Trace Viewer.
The ppermute operation translates directly to an XLA collective operation over the ICI links. Because matrix multiplication in the MXU and data transfer over the ICI network use entirely different physical hardware pathways on the TPU, the XLA compiler can overlap them.
While the MXU is busy churning through jnp.matmul(q_transposed, k_transposed), the networking hardware is simultaneously executing jax.lax.ppermute to fetch the next block of and into VMEM. As long as your local chunk size is large enough that the compute time () is greater than the network transfer time (), the network communication is entirely hidden. The data arrives right before the MXU needs it for the next loop iteration.
Why this matters
By stepping outside of the compiler's automatic sharding and writing our own explicit shard_map function, we've broken the memory ceiling.
If we want to double our sequence length, we just double the number of TPUs in our ring. The memory footprint per chip remains exactly the same. This is exactly the kind of hardware-aware machine learning that modern Frontier models require to achieve massive context windows. You aren't just writing Python anymore; you are choreographing the physical flow of data across a supercomputer. Google Cloud credits are provided for this project. #TPUSprint