Docs AI Engineering Knowledge Hub

Context Parallelism and Ring Attention


1. Motivation: The Memory Wall of Long Sequences

1.1 Recap of Tensor Parallelism + Sequence Parallelism

With Tensor Parallelism (TP), model weight matrices are sharded across NTPN_{\text{TP}} GPUs. Sequence Parallelism (SP) complements TP by splitting activations along the sequence dimension in the non-TP regions of the model (e.g., LayerNorm, Dropout), thereby distributing activation memory for those modules.

Under TP+SP, the per-GPU activation memory for a single transformer layer scales approximately as:

MactTP+SPsbhNTPα M_{\text{act}}^{\text{TP+SP}} \approx \frac{s \cdot b \cdot h}{N_{\text{TP}}} \cdot \alpha

where:

  • ss = sequence length
  • bb = micro-batch size
  • hh = hidden dimension
  • NTPN_{\text{TP}} = tensor-parallel degree
  • α\alpha = a constant capturing the number of intermediate tensors retained per layer

1.2 The Residual Scaling Problem

Even with TP+SP and full activation recomputation (which incurs 30%\sim 30\% compute overhead by rerunning the forward pass during backpropagation), certain activations must be retained at layer boundaries (the inputs to each transformer block needed for the backward pass). These boundary activations scale as:

Mboundary=O ⁣(Lsbh) M_{\text{boundary}} = \mathcal{O}\!\left(L \cdot s \cdot b \cdot h\right)

where LLis the number of transformer layers. Critically, this expression is linear inss. When ssscales to128k128\text{k}, 256k256\text{k}, or 1M1\text{M}tokens,MboundaryM_{\text{boundary}} can exceed the memory of an entire node even after TP+SP sharding.

Core problem statement: TP+SP distributes weights and some activations, but every GPU still processes the full sequence inside the TP region (the attention and MLP blocks). For very long sequences, this remaining per-GPU memory footprint becomes the binding constraint.


2. Context Parallelism — Definition and Mechanism

2.1 Core Idea

Context Parallelism (CP) splits the input along the sequence dimension and applies this split across the entire model, including the TP regions (attention, MLP), not just the SP-only regions.

Given a context-parallel group of size NCPN_{\text{CP}}, an input tensor of shape (b,s,h)(b, s, h)is partitioned intoNCPN_{\text{CP}} chunks along the sequence axis:

XRb×s×h    X(i)Rb×(s/NCP)×h,i=0,1,,NCP1 X \in \mathbb{R}^{b \times s \times h} \;\longrightarrow\; X^{(i)} \in \mathbb{R}^{b \times (s / N_{\text{CP}}) \times h}, \quad i = 0, 1, \ldots, N_{\text{CP}} - 1

Each GPU iiin the CP group holds onlyX(i)X^{(i)} and processes it through every layer of the network.

2.2 Impact on Different Module Types

Module Token Interaction Pattern CP Communication Required?
LayerNorm Per-token (independent) No — each token normalizes independently
MLP / FFN Per-token (independent) No — pointwise or token-local computation
Dropout Per-token (independent) No
Self-Attention All-to-all across sequence Yes — each query must access all keys/values

For token-independent modules (LayerNorm, MLP), splitting along the sequence dimension requires zero inter-GPU communication — each GPU applies the module to its local token subset identically to how it would process the full sequence. The weight matrices are not split (unlike TP), so no expensive all-reduce on activations is needed for these modules.

2.3 Gradient Synchronization

After the backward pass, each CP-rank holds gradients computed from a different subset of the sequence. Since the weights are replicated across the CP group (analogous to Data Parallelism), an all-reduce over the CP group is required to aggregate gradients before the optimizer step:

θL=1NCPi=0NCP1θL(i) \nabla_{\theta} \mathcal{L} = \frac{1}{N_{\text{CP}}} \sum_{i=0}^{N_{\text{CP}}-1} \nabla_{\theta} \mathcal{L}^{(i)}

This all-reduce is identical in structure to the gradient synchronization in standard data parallelism.

2.4 Per-GPU Activation Memory Under CP

With CP, the per-GPU activation memory becomes:

MactCPsbhNCPNTPα M_{\text{act}}^{\text{CP}} \approx \frac{s \cdot b \cdot h}{N_{\text{CP}} \cdot N_{\text{TP}}} \cdot \alpha

The sequence-length dependence is now reduced by a factor of NCPN_{\text{CP}}, which is precisely the relief needed for very long sequences.

2.5 Relationship to Sequence Parallelism

Property Sequence Parallelism (SP) Context Parallelism (CP)
Where applied Non-TP regions only (LayerNorm, Dropout) Entire model including TP regions
Split dimension Sequence Sequence
Weight sharding Via TP (in TP regions) Weights replicated across CP group
Gradient sync Handled by TP all-reduce Separate CP all-reduce
Communication for attention Not applicable (attention is inside TP) Required — Ring Attention

3. The Attention Bottleneck Under Context Parallelism

3.1 Why Attention Is Special

The scaled dot-product attention for a single head is defined as:

Attention(Q,K,V)=softmax ⁣(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^T}{\sqrt{d_k}}\right) V

where:

  • QRs×dkQ \in \mathbb{R}^{s \times d_k} — queries
  • KRs×dkK \in \mathbb{R}^{s \times d_k} — keys
  • VRs×dvV \in \mathbb{R}^{s \times d_v} — values
  • dkd_k — head dimension

The attention score matrix A=QKTRs×sA = Q K^T \in \mathbb{R}^{s \times s}couples every query position with every key position. Under CP, each GPU holds onlys/NCPs / N_{\text{CP}}query positions ands/NCPs / N_{\text{CP}}key/value positions. To compute the full attention, each GPU’s queries must access allss key/value positions from all GPUs.

3.2 Naive Communication Cost

A naive implementation would perform an all-gather of all KKandVV tensors before computing attention. The total communication volume per GPU would be:

Commnaive=2(NCP1)sNCPdkbnh \text{Comm}_{\text{naive}} = 2 \cdot (N_{\text{CP}} - 1) \cdot \frac{s}{N_{\text{CP}}} \cdot d_k \cdot b \cdot n_h

where nhn_his the number of attention heads and the factor22accounts for bothKKandVV. This communication is blocking: the GPU cannot begin computing attention until all K/VK/V data has arrived, resulting in idle GPU cycles.


4. Ring Attention — Efficient Communication for CP

4.1 Core Principle

Ring Attention organizes the NCPN_{\text{CP}}GPUs into a logical ring topology. Instead of gathering allK/VK/V pairs at once, each GPU:

  1. Sends its current K/VK/V chunk to the next GPU in the ring (asynchronous, non-blocking).
  2. Computes partial attention using the K/VK/V chunk it currently holds in local memory.
  3. Receives a K/VK/V chunk from the previous GPU in the ring.
  4. Repeats for NCPN_{\text{CP}} steps total.

This overlaps communication with computation, hiding the latency of data transfer behind useful arithmetic.

4.2 Detailed Step-by-Step Execution

Consider NCP=4N_{\text{CP}} = 4GPUs and a sequence of lengthsssplit into 4 chunks. GPUii holds:

  • Q(i),K(i),V(i)Rb×(s/4)×dkQ^{(i)}, K^{(i)}, V^{(i)} \in \mathbb{R}^{b \times (s/4) \times d_k}

At time step t{0,1,2,3}t \in \{0, 1, 2, 3\}, GPU iihas theK/VK/Vchunk originally from GPU(it)mod4(i - t) \bmod 4.

Operations at each step tt:

Step A (Communication — non-blocking):

GPU isendK(curr),V(curr)GPU (i+1)modNCP \text{GPU } i \xrightarrow{\text{send}} K^{(\text{curr})}, V^{(\text{curr})} \xrightarrow{} \text{GPU } (i+1) \bmod N_{\text{CP}}

Step B (Compute — overlapped with Step A):

Compute the partial attention score block:

S(i,t)=Q(i)(K(curr))TdkRb×(s/NCP)×(s/NCP) S^{(i,t)} = \frac{Q^{(i)} \left(K^{(\text{curr})}\right)^T}{\sqrt{d_k}} \in \mathbb{R}^{b \times (s/N_{\text{CP}}) \times (s/N_{\text{CP}})}
P(i,t)=softmaxpartial ⁣(S(i,t)) P^{(i,t)} = \text{softmax}_{\text{partial}}\!\left(S^{(i,t)}\right)
O(i,t)=P(i,t)V(curr) O^{(i,t)} = P^{(i,t)} \cdot V^{(\text{curr})}

Step C (Synchronization):

Wait for receive of K(new),V(new)K^{(\text{new})}, V^{(\text{new})}from GPU(i1)modNCP(i-1) \bmod N_{\text{CP}}. Set K(curr)K(new)K^{(\text{curr})} \leftarrow K^{(\text{new})}, V(curr)V(new)V^{(\text{curr})} \leftarrow V^{(\text{new})}.

After all NCPN_{\text{CP}}steps, each GPU has computed partial attention outputs{O(i,0),O(i,1),,O(i,NCP1)}\{O^{(i,0)}, O^{(i,1)}, \ldots, O^{(i, N_{\text{CP}}-1)}\}.

4.3 Online Softmax Aggregation

A critical subtlety: softmax is computed row-wise over the full key sequence, but each GPU sees only one chunk of keys at each step. A naive approach would require storing all ss logits per query, defeating the purpose.

Online softmax (also used in FlashAttention) resolves this. At each step tt, GPU ii maintains running statistics:

  • m(i,t)Rb×(s/NCP)m^{(i,t)} \in \mathbb{R}^{b \times (s/N_{\text{CP}})} — running row-wise maximum of logits
  • (i,t)Rb×(s/NCP)\ell^{(i,t)} \in \mathbb{R}^{b \times (s/N_{\text{CP}})} — running row-wise sum of exponentiated logits (denominator of softmax)
  • O(i,t)Rb×(s/NCP)×dvO^{(i,t)} \in \mathbb{R}^{b \times (s/N_{\text{CP}}) \times d_v} — running weighted output accumulator

Update rules (per query row qq):

At step t=0t = 0, initialize:

m(i,0)=maxjSq,j(i,0) m^{(i,0)} = \max_j S^{(i,0)}_{q,j}
(i,0)=jexp ⁣(Sq,j(i,0)m(i,0)) \ell^{(i,0)} = \sum_j \exp\!\left(S^{(i,0)}_{q,j} - m^{(i,0)}\right)
O(i,0)=1(i,0)jexp ⁣(Sq,j(i,0)m(i,0))Vj(curr) O^{(i,0)} = \frac{1}{\ell^{(i,0)}} \sum_j \exp\!\left(S^{(i,0)}_{q,j} - m^{(i,0)}\right) V^{(\text{curr})}_j

At step t>0t > 0, compute the new block’s statistics:

m~=maxjSq,j(i,t) \tilde{m} = \max_j S^{(i,t)}_{q,j}
mnew(i,t)=max ⁣(m(i,t1),  m~) m^{(i,t)}_{\text{new}} = \max\!\left(m^{(i,t-1)},\; \tilde{m}\right)

Rescale the running accumulator:

(i,t)=(i,t1)exp ⁣(m(i,t1)mnew(i,t))+jexp ⁣(Sq,j(i,t)mnew(i,t)) \ell^{(i,t)} = \ell^{(i,t-1)} \cdot \exp\!\left(m^{(i,t-1)} - m^{(i,t)}_{\text{new}}\right) + \sum_j \exp\!\left(S^{(i,t)}_{q,j} - m^{(i,t)}_{\text{new}}\right)
O(i,t)=(i,t1)exp ⁣(m(i,t1)mnew(i,t))(i,t)O(i,t1)+1(i,t)jexp ⁣(Sq,j(i,t)mnew(i,t))Vj(curr) O^{(i,t)} = \frac{\ell^{(i,t-1)} \cdot \exp\!\left(m^{(i,t-1)} - m^{(i,t)}_{\text{new}}\right)}{\ell^{(i,t)}} \cdot O^{(i,t-1)} + \frac{1}{\ell^{(i,t)}} \sum_j \exp\!\left(S^{(i,t)}_{q,j} - m^{(i,t)}_{\text{new}}\right) V^{(\text{curr})}_j

After all NCPN_{\text{CP}}steps,O(i,NCP1)O^{(i, N_{\text{CP}}-1)}is the exact attention output for the query rows on GPUii— numerically equivalent to computing the full softmax over allss keys.

4.4 Communication Volume Analysis

At each of NCPN_{\text{CP}}steps, each GPU sends oneK/VK/V chunk:

Comm per step per GPU=2sNCPdkbnh \text{Comm per step per GPU} = 2 \cdot \frac{s}{N_{\text{CP}}} \cdot d_k \cdot b \cdot n_h

Total communication volume:

Commtotal=NCP2sNCPdkbnh=2sdkbnh \text{Comm}_{\text{total}} = N_{\text{CP}} \cdot 2 \cdot \frac{s}{N_{\text{CP}}} \cdot d_k \cdot b \cdot n_h = 2 \cdot s \cdot d_k \cdot b \cdot n_h

This is the same total volume as the naive all-gather, but the latency is hidden because computation overlaps with communication. The effective wall-clock overhead approaches zero when:

Tcompute(t)Tcomm(t) T_{\text{compute}}^{(t)} \geq T_{\text{comm}}^{(t)}

i.e., the time to compute one partial attention block exceeds the time to transfer one K/VK/V chunk.

4.5 Relationship to FlashAttention

Aspect FlashAttention Ring Attention (CP)
Scope Single GPU, single attention kernel Multi-GPU, distributed attention
Tiling axis Tiles QQandK/VK/Vblocks in SRAM TilesK/VK/V blocks across GPUs
Memory savings Avoids materializing s×ss \times s attention matrix in HBM Avoids storing full sequence activations per GPU
Online softmax Yes — essential for block-wise tiling Yes — essential for incremental aggregation
Complementary? Yes — can be used within each GPU’s local attention block Yes — each Ring Attention step can use FlashAttention for its local block computation

5. The Causal Mask Imbalance Problem

5.1 Causal Attention Mask

For autoregressive (decoder-only) models, the attention mask enforces causality:

Mq,k={0if pos(q)pos(k)otherwise M_{q,k} = \begin{cases} 0 & \text{if } \text{pos}(q) \geq \text{pos}(k) \\ -\infty & \text{otherwise} \end{cases}

The masked attention score matrix becomes:

A=softmax ⁣(QKTdk+M) A = \text{softmax}\!\left(\frac{Q K^T}{\sqrt{d_k}} + M\right)

This produces a lower-triangular structure, meaning earlier tokens attend to fewer keys than later tokens.

5.2 Load Imbalance Under Naive CP Partitioning

With naive sequential partitioning (GPU 0 gets tokens 1,,s/NCP1, \ldots, s/N_{\text{CP}}; GPU 1 gets tokens s/NCP+1,,2s/NCPs/N_{\text{CP}}+1, \ldots, 2s/N_{\text{CP}}; etc.):

  • GPU 0 holds the earliest tokens. Due to causal masking, these tokens attend only to themselves and need no K/VK/V from other GPUs. GPU 0 finishes almost immediately and idles.
  • GPU NCP1N_{\text{CP}}-1 holds the latest tokens. These tokens attend to the entire preceding sequence, requiring K/VK/V from all other GPUs and performing maximum computation.

The number of non-masked entries (FLOPs) for GPU iiholding tokens in the range[ai,bi][a_i, b_i] scales as:

FLOPs(i)q=aibiq=(biai+1)(ai+bi)2 \text{FLOPs}^{(i)} \propto \sum_{q=a_i}^{b_i} q = \frac{(b_i - a_i + 1)(a_i + b_i)}{2}

For the first GPU (a0=1,b0=s/NCPa_0 = 1, b_0 = s/N_{\text{CP}}):

FLOPs(0)(s/NCP)((s/NCP)+1)2s22NCP2 \text{FLOPs}^{(0)} \propto \frac{(s/N_{\text{CP}})((s/N_{\text{CP}}) + 1)}{2} \approx \frac{s^2}{2 N_{\text{CP}}^2}

For the last GPU (aN1=ss/NCP+1,bN1=sa_{N-1} = s - s/N_{\text{CP}} + 1, b_{N-1} = s):

FLOPs(N1)(s/NCP)(2ss/NCP+1)2s2NCP \text{FLOPs}^{(N-1)} \propto \frac{(s/N_{\text{CP}})(2s - s/N_{\text{CP}} + 1)}{2} \approx \frac{s^2}{N_{\text{CP}}}

The imbalance ratio is:

FLOPs(N1)FLOPs(0)s2/NCPs2/(2NCP2)=2NCP \frac{\text{FLOPs}^{(N-1)}}{\text{FLOPs}^{(0)}} \approx \frac{s^2 / N_{\text{CP}}}{s^2 / (2 N_{\text{CP}}^2)} = 2 N_{\text{CP}}

For NCP=8N_{\text{CP}} = 8, the last GPU performs 16×\sim 16\times more work than the first GPU — a catastrophic imbalance.


6. Zig-Zag Ring Attention — Balanced Computation

6.1 Token Assignment Strategy

Instead of contiguous blocks, Zig-Zag Attention assigns tokens to GPUs in an interleaved, folded pattern. For NCP=4N_{\text{CP}} = 4ands=16s = 16, the assignment is:

GPU 0:{1,8,9,16}GPU 1:{2,7,10,15}GPU 2:{3,6,11,14}GPU 3:{4,5,12,13} \begin{aligned} \text{GPU 0:} & \quad \{1, 8, 9, 16\} \\ \text{GPU 1:} & \quad \{2, 7, 10, 15\} \\ \text{GPU 2:} & \quad \{3, 6, 11, 14\} \\ \text{GPU 3:} & \quad \{4, 5, 12, 13\} \end{aligned}

The pattern follows a zig-zag (or folded) ordering: within each “fold” of NCPN_{\text{CP}}tokens, the first fold is assigned0,1,2,,NCP10, 1, 2, \ldots, N_{\text{CP}}-1and the second fold is assignedNCP1,NCP2,,0N_{\text{CP}}-1, N_{\text{CP}}-2, \ldots, 0, then repeating.

Formally, for the kk-th token (00-indexed), the GPU assignment is:

GPU(k)={kmodNCPif k/NCP is even(NCP1)(kmodNCP)if k/NCP is odd \text{GPU}(k) = \begin{cases} k \bmod N_{\text{CP}} & \text{if } \left\lfloor k / N_{\text{CP}} \right\rfloor \text{ is even} \\ (N_{\text{CP}} - 1) - (k \bmod N_{\text{CP}}) & \text{if } \left\lfloor k / N_{\text{CP}} \right\rfloor \text{ is odd} \end{cases}

6.2 Why This Balances Computation

Each GPU now holds a mixture of early and late tokens. Under the causal mask, early tokens contribute few non-masked entries while late tokens contribute many. By pairing them together, each GPU’s total non-masked computation sums to approximately:

FLOPs(i)1NCPs(s+1)2  i \text{FLOPs}^{(i)} \approx \frac{1}{N_{\text{CP}}} \cdot \frac{s(s+1)}{2} \quad \forall \; i

This achieves near-perfect load balancing across all GPUs.

6.3 Communication Implication

Under zig-zag assignment, every GPU needs K/VK/V from every other GPU to complete all its row computations (since each GPU holds tokens scattered across the full sequence range). This is compatible with both ring-based and all-gather-based communication patterns.

6.4 Distinction from Striped Attention

Striped Attention uses a simpler round-robin assignment:

GPU(k)=kmodNCP \text{GPU}(k) = k \bmod N_{\text{CP}}

This also achieves load balancing but with a different scattering pattern. Zig-Zag achieves slightly better contiguity within each GPU’s token set (tokens come in pairs of consecutive indices from alternating ends), which can improve memory access patterns and cache utilization. The practical difference is minor but relevant for highly optimized implementations.


7. Communication Strategies for Zig-Zag / Ring Attention

7.1 All-Gather Implementation

All GPUs simultaneously execute an all-gather collective on the KKandVV tensors:

AllGather ⁣({K(0),K(1),,K(NCP1)})KfullRb×s×dkon every GPU \text{AllGather}\!\left(\{K^{(0)}, K^{(1)}, \ldots, K^{(N_{\text{CP}}-1)}\}\right) \to K_{\text{full}} \in \mathbb{R}^{b \times s \times d_k} \quad \text{on every GPU}

Properties:

  • Communication pattern: Single collective operation; all GPUs receive all data.
  • Temporary memory per GPU: Must store full KfullK_{\text{full}}andVfullV_{\text{full}}:
Mtempall-gather=2sdkbnh M_{\text{temp}}^{\text{all-gather}} = 2 \cdot s \cdot d_k \cdot b \cdot n_h

This negates the memory savings of CP for the K/VK/V tensors during the attention computation window.

  • Latency: Dominated by a single all-gather; bandwidth-optimal but high peak memory.

7.2 All-to-All (Ring) Implementation

GPUs exchange K/VK/V chunks in a ring pattern, one chunk per step, as described in Section 4.

Properties:

  • Communication pattern: NCPN_{\text{CP}} point-to-point send/receive steps in a ring.
  • Temporary memory per GPU: Only one additional K/VK/V chunk at a time:
Mtempring=2sNCPdkbnh M_{\text{temp}}^{\text{ring}} = 2 \cdot \frac{s}{N_{\text{CP}}} \cdot d_k \cdot b \cdot n_h

A factor of NCPN_{\text{CP}} reduction compared to the all-gather approach.

  • Latency: NCPN_{\text{CP}} communication steps, each with startup latency, but each overlapped with computation.

7.3 Comparison Summary

Property All-Gather All-to-All (Ring)
Temporary K/VK/Vmemory O(s)\mathcal{O}(s)per GPU O(s/NCP)\mathcal{O}(s / N_{\text{CP}}) per GPU
Communication steps 11 NCPN_{\text{CP}}
Overlap potential Limited (compute starts after gather) High (compute overlapped with send/recv)
Implementation complexity Low Moderate (ring scheduling + online softmax)
Best suited for Moderate NCPN_{\text{CP}}, high bandwidth Large NCPN_{\text{CP}}, memory-constrained

8. Combined Parallelism Dimensions — Where CP Fits

At this point in the parallelism hierarchy:

Parallelism What It Shards Communication Cost Scaling Regime
Data Parallelism (DP) Data (batches) All-reduce on gradients Scales across nodes
Tensor Parallelism (TP) Weight matrices (columns/rows) All-reduce on activations per layer Intra-node (high bandwidth required)
Sequence Parallelism (SP) Activations in non-TP regions Scatter/gather at TP↔SP boundaries Tied to TP group
Context Parallelism (CP) Sequence across entire model Ring attention + gradient all-reduce Intra- or inter-node
Pipeline Parallelism (PP) Layers across stages Point-to-point activation transfers Scales across nodes

The total number of GPUs used is:

Ntotal=NDP×NTP×NCP×NPP N_{\text{total}} = N_{\text{DP}} \times N_{\text{TP}} \times N_{\text{CP}} \times N_{\text{PP}}

CP addresses a specific gap: TP+SP reduce per-GPU memory for both weights and activations, but the sequence length still appears unsharded inside the TP region’s attention computation. CP eliminates this last remaining sequence-length bottleneck, enabling training on sequences of 128k128\text{k}+ tokens without running out of activation memory.

Limitation of TP: TP requires high-bandwidth interconnects (e.g., NVLink within a node) because it communicates activations at every layer. It does not scale well across nodes with lower-bandwidth interconnects. When model weights exceed the memory of a single node even after TP sharding, Pipeline Parallelism (PP) becomes necessary — partitioning the model’s layers across nodes with only point-to-point activation transfers at stage boundaries.

Context Parallelism: SOTA Implementation with PyTorch + Triton


1. System Architecture Overview

The implementation consists of six composable modules, each addressing a distinct concern:

┌─────────────────────────────────────────────────────────────┐
│                ContextParallelAttention (nn.Module)          │
│  ┌───────────────────────────────────────────────────────┐  │
│  │          RingAttentionFunction (autograd.Function)     │  │
│  │  ┌──────────┐  ┌──────────────┐  ┌────────────────┐  │  │
│  │  │ ZigZag   │  │    Ring      │  │ Triton Kernel: │  │  │
│  │  │Partition │──│ Communicator │──│ Fused Attn +   │  │  │
│  │  │          │  │ (send/recv)  │  │ Online Softmax │  │  │
│  │  └──────────┘  └──────────────┘  └────────────────┘  │  │
│  └───────────────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────────────┘

Each GPU in the CP group of size NCPN_{\text{CP}} holds:

Q(i),K(i),V(i)RB×nh×(s/NCP)×dk Q^{(i)}, K^{(i)}, V^{(i)} \in \mathbb{R}^{B \times n_h \times (s / N_{\text{CP}}) \times d_k}

where iiis the CP rank,BBis batch size,nhn_his the number of heads,ssis full sequence length, anddkd_k is head dimension.


2. Zig-Zag Sequence Partitioning

2.1 Mathematical Formulation

For global token index k{0,1,,s1}k \in \{0, 1, \ldots, s-1\}, the zig-zag assignment to GPU rank is:

rank(k)={kmodNCPif k/NCP is even(NCP1)(kmodNCP)if k/NCP is odd \text{rank}(k) = \begin{cases} k \bmod N_{\text{CP}} & \text{if } \left\lfloor k / N_{\text{CP}} \right\rfloor \text{ is even} \\[6pt] (N_{\text{CP}} - 1) - (k \bmod N_{\text{CP}}) & \text{if } \left\lfloor k / N_{\text{CP}} \right\rfloor \text{ is odd} \end{cases}

This ensures each GPU holds a balanced mixture of early and late tokens, equalizing the causal-mask workload:

FLOPs(i)1NCPs(s+1)2  i \text{FLOPs}^{(i)} \approx \frac{1}{N_{\text{CP}}} \cdot \frac{s(s+1)}{2} \quad \forall \; i

2.2 Implementation

import torch
import torch.distributed as dist
from typing import List, Tuple


def zigzag_partition(
    seq_len: int,
    cp_size: int,
    cp_rank: int,
) -> torch.Tensor:
    """
    Compute global position indices assigned to `cp_rank` under zig-zag ordering.

    Args:
        seq_len:  Total sequence length (must be divisible by cp_size).
        cp_size:  Context-parallel world size  (N_CP).
        cp_rank:  This GPU's rank in the CP group.

    Returns:
        positions: LongTensor of shape [seq_len // cp_size] containing
                   the global token indices owned by this rank.

    Example (seq_len=16, cp_size=4):
        rank 0 → [ 0,  7,  8, 15]
        rank 1 → [ 1,  6,  9, 14]
        rank 2 → [ 2,  5, 10, 13]
        rank 3 → [ 3,  4, 11, 12]
    """
    assert seq_len % cp_size == 0, (
        f"seq_len ({seq_len}) must be divisible by cp_size ({cp_size})"
    )

    indices = torch.arange(seq_len, dtype=torch.long)
    # Reshape into folds of size cp_size: [num_folds, cp_size]
    folds = indices.reshape(-1, cp_size)
    # Reverse odd-numbered folds to create zig-zag pattern
    folds[1::2] = folds[1::2].flip(dims=[1])
    # Column `cp_rank` across all folds gives this rank's tokens
    positions = folds[:, cp_rank].contiguous()
    return positions


def zigzag_split(
    x: torch.Tensor,
    cp_size: int,
    cp_rank: int,
    seq_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Split a full-sequence tensor according to zig-zag assignment.

    Args:
        x:        Input tensor with a sequence dimension.
        cp_size:  Context-parallel world size.
        cp_rank:  This GPU's CP rank.
        seq_dim:  Which dimension is the sequence axis (default: 2).

    Returns:
        x_local:   The subsequence assigned to this rank.
        positions:  The global position indices (for causal masking).
    """
    seq_len = x.shape[seq_dim]
    positions = zigzag_partition(seq_len, cp_size, cp_rank).to(x.device)
    x_local = x.index_select(seq_dim, positions)
    return x_local, positions

2.3 Verification

# Quick sanity check: every position appears exactly once across all ranks
all_positions = torch.cat([
    zigzag_partition(16, 4, r) for r in range(4)
])
assert all_positions.sort()[0].equal(torch.arange(16))
# Verify balanced causal workload:
# FLOPs ~ sum of positions (each token attends to all preceding tokens)
for r in range(4):
    pos = zigzag_partition(16, 4, r)
    print(f"Rank {r}: positions={pos.tolist()}, "
          f"causal_flops ∝ {pos.sum().item()}")
# Output:
# Rank 0: positions=[0, 7, 8, 15], causal_flops ∝ 30
# Rank 1: positions=[1, 6, 9, 14], causal_flops ∝ 30
# Rank 2: positions=[2, 5, 10, 13], causal_flops ∝ 30
# Rank 3: positions=[3, 4, 11, 12], causal_flops ∝ 30

3. Ring Communication Layer

3.1 Ring Topology

The NCPN_{\text{CP}}GPUs form a logical ring where GPUiisends to GPU(i+1)modNCP(i+1) \bmod N_{\text{CP}}and receives from GPU(i1)modNCP(i-1) \bmod N_{\text{CP}}. Double buffering allows overlapping communication with computation: while the kernel processes kv_current, the next chunk is received into kv_next.

class RingCommunicator:
    """
    Manages asynchronous ring send/recv of K/V tensors
    with double buffering for compute-communication overlap.
    """

    def __init__(self, process_group: dist.ProcessGroup):
        self.group = process_group
        self.cp_size = dist.get_world_size(group=process_group)
        self.cp_rank = dist.get_rank(group=process_group)

        # Ring neighbors
        self.send_rank = (self.cp_rank + 1) % self.cp_size
        self.recv_rank = (self.cp_rank - 1) % self.cp_size

    def _get_global_rank(self, group_rank: int) -> int:
        """Convert group-local rank to global rank."""
        return dist.get_global_rank(self.group, group_rank)

    def send_recv_kv(
        self,
        k_send: torch.Tensor,
        v_send: torch.Tensor,
        k_recv: torch.Tensor,
        v_recv: torch.Tensor,
    ) -> List[dist.Work]:
        """
        Initiate non-blocking ring send/recv for K and V tensors.

        Args:
            k_send, v_send: Tensors to send to next rank.
            k_recv, v_recv: Pre-allocated buffers to receive from prev rank.

        Returns:
            List of async work handles (call .wait() to synchronize).
        """
        ops = []
        # Send to next neighbor
        ops.append(dist.isend(
            k_send, dst=self._get_global_rank(self.send_rank), group=self.group
        ))
        ops.append(dist.isend(
            v_send, dst=self._get_global_rank(self.send_rank), group=self.group
        ))
        # Receive from previous neighbor
        ops.append(dist.irecv(
            k_recv, src=self._get_global_rank(self.recv_rank), group=self.group
        ))
        ops.append(dist.irecv(
            v_recv, src=self._get_global_rank(self.recv_rank), group=self.group
        ))
        return ops

    @staticmethod
    def wait_all(handles: List[dist.Work]):
        """Block until all async operations complete."""
        for h in handles:
            h.wait()

4. Online Softmax Merge — Mathematical Foundation

4.1 Problem Statement

At each ring step tt, GPU iicomputes a partial attention output using only theK/VK/Vchunk currently in local memory. The challenge is combining these partial results into the exact full-sequence attention output without ever materializing the fulls×ss \times s attention matrix.

4.2 Derivation

Let the full attention for query row qq be:

Oq=j=1sexp(sq,j)vjj=1sexp(sq,j) O_q = \frac{\sum_{j=1}^{s} \exp(s_{q,j}) \cdot v_j}{\sum_{j=1}^{s} \exp(s_{q,j})}

where sq,j=qkj/dks_{q,j} = q \cdot k_j / \sqrt{d_k}.

Partition the key positions into two disjoint sets AAandBB (e.g., two ring steps). Define per-set statistics:

mA=maxjAsq,j,A=jAexp(sq,jmA),O~A=jAexp(sq,jmA)vj m_A = \max_{j \in A} s_{q,j}, \quad \ell_A = \sum_{j \in A} \exp(s_{q,j} - m_A), \quad \tilde{O}_A = \sum_{j \in A} \exp(s_{q,j} - m_A) \cdot v_j

The merge of sets AAandBB is:

mAB=max(mA,mB) m_{AB} = \max(m_A, m_B)
AB=Aexp(mAmAB)+Bexp(mBmAB) \ell_{AB} = \ell_A \cdot \exp(m_A - m_{AB}) + \ell_B \cdot \exp(m_B - m_{AB})
O~AB=O~Aexp(mAmAB)+O~Bexp(mBmAB) \tilde{O}_{AB} = \tilde{O}_A \cdot \exp(m_A - m_{AB}) + \tilde{O}_B \cdot \exp(m_B - m_{AB})

Final normalized output after all ring steps:

Oq=O~finalfinal O_q = \frac{\tilde{O}_{\text{final}}}{\ell_{\text{final}}}

This is numerically exact — not an approximation.

4.3 PyTorch Implementation

def online_softmax_merge(
    O_acc: torch.Tensor,    # [B, H, S_local, D]  — unnormalized accumulated output
    m_acc: torch.Tensor,    # [B, H, S_local]     — running max
    l_acc: torch.Tensor,    # [B, H, S_local]     — running sum of exp
    O_new: torch.Tensor,    # [B, H, S_local, D]  — new block's unnormalized output
    m_new: torch.Tensor,    # [B, H, S_local]     — new block's max
    l_new: torch.Tensor,    # [B, H, S_local]     — new block's sum of exp
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Merge two sets of online softmax statistics.
    All inputs are in float32 for numerical stability.
    """
    m_merged = torch.maximum(m_acc, m_new)                   # [B, H, S_local]

    alpha = torch.exp(m_acc - m_merged)                       # rescale factor for old
    beta  = torch.exp(m_new - m_merged)                       # rescale factor for new

    l_merged = alpha * l_acc + beta * l_new                   # [B, H, S_local]

    O_merged = (
        alpha.unsqueeze(-1) * O_acc
        + beta.unsqueeze(-1) * O_new
    )                                                         # [B, H, S_local, D]

    return O_merged, m_merged, l_merged

5. PyTorch Reference: Ring Attention Forward

This is a complete, correct, pure-PyTorch implementation suitable for validating the Triton version. Each step corresponds to one rotation of the K/VK/V ring.

def ring_attention_forward_reference(
    q: torch.Tensor,              # [B, H, S_local, D]  — local queries
    k: torch.Tensor,              # [B, H, S_local, D]  — local keys
    v: torch.Tensor,              # [B, H, S_local, D]  — local values
    q_positions: torch.Tensor,    # [S_local]            — global positions of local queries
    all_positions: List[torch.Tensor],  # list of [S_local] per CP rank
    comm: RingCommunicator,
    causal: bool = True,
) -> torch.Tensor:
    """
    Reference ring attention forward pass (no Triton, no overlap).

    Returns:
        O: [B, H, S_local, D]  — attention output for local queries.
    """
    B, H, S_local, D = q.shape
    device = q.device
    scale = D ** -0.5

    # ---- Initialize online softmax accumulators ----
    O_acc = torch.zeros(B, H, S_local, D, device=device, dtype=torch.float32)
    m_acc = torch.full((B, H, S_local,), float("-inf"), device=device, dtype=torch.float32)
    l_acc = torch.zeros(B, H, S_local, device=device, dtype=torch.float32)

    # ---- Double buffers for K/V ----
    k_cur, v_cur = k.clone(), v.clone()
    k_buf, v_buf = torch.empty_like(k), torch.empty_like(v)

    for step in range(comm.cp_size):
        # -- Step A: initiate async send/recv (except last step) --
        if step < comm.cp_size - 1:
            handles = comm.send_recv_kv(k_cur, v_cur, k_buf, v_buf)

        # -- Step B: compute partial attention for current K/V chunk --
        # Determine which rank's K/V we currently hold
        source_rank = (comm.cp_rank - step) % comm.cp_size
        k_positions = all_positions[source_rank].to(device)     # [S_local]

        # Attention scores: [B, H, S_local, S_local]
        S = torch.matmul(q.float(), k_cur.float().transpose(-2, -1)) * scale

        # Causal mask based on global positions
        if causal:
            # q_positions: [S_local], k_positions: [S_local]
            causal_mask = q_positions.unsqueeze(-1) >= k_positions.unsqueeze(-2)
            # Expand to [1, 1, S_local, S_local]
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
            S = S.masked_fill(~causal_mask, float("-inf"))

        # Block-local online softmax statistics
        m_block = S.max(dim=-1).values                       # [B, H, S_local]
        # Guard against all-masked rows (m_block = -inf)
        P = torch.exp(S - m_block.unsqueeze(-1))             # [B, H, S_local, S_local]
        l_block = P.sum(dim=-1)                              # [B, H, S_local]
        O_block = torch.matmul(P, v_cur.float())             # [B, H, S_local, D]
        # O_block is unnormalized: O_block = sum_j exp(s_j - m_block) * v_j

        # Merge with running accumulator
        O_acc, m_acc, l_acc = online_softmax_merge(
            O_acc, m_acc, l_acc,
            O_block, m_block, l_block,
        )

        # -- Step C: wait for communication, swap buffers --
        if step < comm.cp_size - 1:
            RingCommunicator.wait_all(handles)
            k_cur, k_buf = k_buf, k_cur
            v_cur, v_buf = v_buf, v_cur

    # ---- Final normalization ----
    O = O_acc / l_acc.unsqueeze(-1).clamp(min=1e-6)
    return O.to(q.dtype)

Complexity per GPU:

Metric Value
Compute (FLOPs) NCP×O ⁣(sNCPsNCPdk)=O ⁣(s2dkNCP)\displaystyle N_{\text{CP}} \times \mathcal{O}\!\left(\frac{s}{N_{\text{CP}}} \cdot \frac{s}{N_{\text{CP}}} \cdot d_k\right) = \mathcal{O}\!\left(\frac{s^2 d_k}{N_{\text{CP}}}\right)
Memory (activations) O ⁣(sNCPdk)\displaystyle \mathcal{O}\!\left(\frac{s}{N_{\text{CP}}} \cdot d_k\right)— only localQQ+ oneK/VK/V chunk + accumulators
Communication (volume) 2sdkBnh\displaystyle 2 \cdot s \cdot d_k \cdot B \cdot n_h (same total as all-gather, but pipelined)

6. Triton Kernel: Fused Block Attention with Online Softmax Merge

6.1 Kernel Design

This kernel processes one ring step: it computes attention of local queries QQagainst oneK/VK/Vchunk, tiled overK/VK/V in blocks of BLOCK_N, maintaining online softmax internally. At the end, it merges the result with the inter-step accumulators (O~acc,macc,acc)(\tilde{O}_{\text{acc}}, m_{\text{acc}}, \ell_{\text{acc}}) stored in global memory.

Grid: (batch×num_heads,  Slocal/BLOCK_M)(\texttt{batch} \times \texttt{num\_heads},\; \lceil S_{\text{local}} / \texttt{BLOCK\_M} \rceil)

Each program instance handles BLOCK_M query rows over all N_K key columns.

import triton
import triton.language as tl


@triton.jit
def _ring_attn_fwd_kernel(
    # ---- Tensor pointers ----
    Q_ptr, K_ptr, V_ptr,
    O_acc_ptr, M_acc_ptr, L_acc_ptr,
    Q_pos_ptr, K_pos_ptr,
    # ---- Strides (elements, not bytes) ----
    stride_q_bh, stride_q_s, stride_q_d,    # Q: [BH, S_Q, D]
    stride_k_bh, stride_k_s, stride_k_d,    # K: [BH, S_K, D]
    stride_v_bh, stride_v_s, stride_v_d,    # V: [BH, S_K, D]
    stride_o_bh, stride_o_s, stride_o_d,    # O_acc: [BH, S_Q, D]
    # ---- Dimensions ----
    S_Q,                                      # local query seq length
    S_K,                                      # current K/V chunk length
    # ---- Compile-time constants ----
    D: tl.constexpr,
    BLOCK_M: tl.constexpr,                   # tile size along queries
    BLOCK_N: tl.constexpr,                   # tile size along keys
    CAUSAL: tl.constexpr,
):
    """
    Fused ring-attention step kernel.

    Computes partial attention of Q against one K/V chunk,
    then merges with the inter-step accumulator in-place.

    Numerically equivalent to:
        S = Q @ K^T / sqrt(D)
        if CAUSAL: S = masked_fill(S, q_pos < k_pos, -inf)
        m_block, l_block = rowmax(S), rowsum(exp(S - m_block))
        O_block = exp(S - m_block) @ V    (unnormalized)
        (O_acc, m_acc, l_acc) = merge(O_acc, m_acc, l_acc,
                                       O_block, m_block, l_block)
    """
    # ================================================================
    # Program ID → (batch_head, query_block)
    # ================================================================
    pid_bh = tl.program_id(0)
    pid_m  = tl.program_id(1)

    # ================================================================
    # Offset calculations
    # ================================================================
    off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)     # [BLOCK_M]
    off_d = tl.arange(0, D)                              # [D]
    mask_m = off_m < S_Q

    # Base pointers for this batch-head slice
    Q_bh = Q_ptr + pid_bh * stride_q_bh
    K_bh = K_ptr + pid_bh * stride_k_bh
    V_bh = V_ptr + pid_bh * stride_v_bh
    O_bh = O_acc_ptr + pid_bh * stride_o_bh

    # ================================================================
    # Load Q block: [BLOCK_M, D]
    # ================================================================
    q = tl.load(
        Q_bh + off_m[:, None] * stride_q_s + off_d[None, :] * stride_q_d,
        mask=mask_m[:, None],
        other=0.0,
    ).to(tl.float32)

    # Load query global positions for causal masking
    q_pos = tl.load(Q_pos_ptr + off_m, mask=mask_m, other=0).to(tl.int64)

    # ================================================================
    # Initialize block-local accumulators (online softmax within chunk)
    # ================================================================
    m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    o_i = tl.zeros([BLOCK_M, D], dtype=tl.float32)

    scale: tl.constexpr = 1.0 / tl.sqrt(D.to(tl.float32))

    # ================================================================
    # Inner loop: tile over K/V in blocks of BLOCK_N
    # ================================================================
    for start_n in range(0, S_K, BLOCK_N):
        off_n = start_n + tl.arange(0, BLOCK_N)          # [BLOCK_N]
        mask_n = off_n < S_K

        # Load K tile: [BLOCK_N, D]
        k = tl.load(
            K_bh + off_n[:, None] * stride_k_s + off_d[None, :] * stride_k_d,
            mask=mask_n[:, None],
            other=0.0,
        ).to(tl.float32)

        # S = Q @ K^T * scale : [BLOCK_M, BLOCK_N]
        s = tl.dot(q, tl.trans(k)) * scale

        # ---- Apply causal mask ----
        if CAUSAL:
            k_pos = tl.load(
                K_pos_ptr + off_n, mask=mask_n, other=2147483647
            ).to(tl.int64)
            causal_mask = q_pos[:, None] >= k_pos[None, :]
            s = tl.where(causal_mask, s, float("-inf"))

        # Mask out-of-bounds keys
        s = tl.where(mask_n[None, :], s, float("-inf"))

        # ---- Online softmax update (within-chunk) ----
        m_ij = tl.max(s, axis=1)                          # [BLOCK_M]
        m_new = tl.maximum(m_i, m_ij)

        # Rescale previous accumulator
        alpha = tl.exp(m_i - m_new)
        # Compute exp(s - m_new) for current tile
        p = tl.exp(s - m_new[:, None])                    # [BLOCK_M, BLOCK_N]

        l_i = alpha * l_i + tl.sum(p, axis=1)
        o_i = o_i * alpha[:, None]

        # Load V tile: [BLOCK_N, D]
        v = tl.load(
            V_bh + off_n[:, None] * stride_v_s + off_d[None, :] * stride_v_d,
            mask=mask_n[:, None],
            other=0.0,
        ).to(tl.float32)

        # Accumulate: O += P @ V
        o_i += tl.dot(p.to(tl.float32), v)

        m_i = m_new

    # ================================================================
    # Merge with inter-step accumulator (across ring steps)
    # ================================================================
    # Load previous accumulator state
    o_acc = tl.load(
        O_bh + off_m[:, None] * stride_o_s + off_d[None, :] * stride_o_d,
        mask=mask_m[:, None],
        other=0.0,
    ).to(tl.float32)
    m_acc = tl.load(
        M_acc_ptr + pid_bh * S_Q + off_m,
        mask=mask_m,
        other=float("-inf"),
    ).to(tl.float32)
    l_acc = tl.load(
        L_acc_ptr + pid_bh * S_Q + off_m,
        mask=mask_m,
        other=0.0,
    ).to(tl.float32)

    # Merge formula:
    #   m_merged = max(m_acc, m_i)
    #   l_merged = l_acc * exp(m_acc - m_merged) + l_i * exp(m_i - m_merged)
    #   O_merged = O_acc * exp(m_acc - m_merged) + O_i * exp(m_i - m_merged)
    m_merged = tl.maximum(m_acc, m_i)
    alpha_acc = tl.exp(m_acc - m_merged)
    alpha_new = tl.exp(m_i - m_merged)

    l_merged = l_acc * alpha_acc + l_i * alpha_new
    o_merged = o_acc * alpha_acc[:, None] + o_i * alpha_new[:, None]

    # ================================================================
    # Store merged state back to global memory
    # ================================================================
    tl.store(
        O_bh + off_m[:, None] * stride_o_s + off_d[None, :] * stride_o_d,
        o_merged,
        mask=mask_m[:, None],
    )
    tl.store(
        M_acc_ptr + pid_bh * S_Q + off_m,
        m_merged,
        mask=mask_m,
    )
    tl.store(
        L_acc_ptr + pid_bh * S_Q + off_m,
        l_merged,
        mask=mask_m,
    )

6.2 Kernel Launch Wrapper

def triton_ring_attn_step(
    q: torch.Tensor,              # [B*H, S_Q, D]
    k: torch.Tensor,              # [B*H, S_K, D]
    v: torch.Tensor,              # [B*H, S_K, D]
    o_acc: torch.Tensor,          # [B*H, S_Q, D]  (float32, in-place)
    m_acc: torch.Tensor,          # [B*H, S_Q]      (float32, in-place)
    l_acc: torch.Tensor,          # [B*H, S_Q]      (float32, in-place)
    q_positions: torch.Tensor,    # [S_Q]           (int64)
    k_positions: torch.Tensor,    # [S_K]           (int64)
    causal: bool = True,
    BLOCK_M: int = 128,
    BLOCK_N: int = 64,
):
    """
    Launch the Triton kernel for one ring attention step.
    Computes partial attention of `q` against `k, v` and merges
    into accumulators `o_acc, m_acc, l_acc` in-place.
    """
    BH, S_Q, D = q.shape
    _, S_K, _ = k.shape

    # Ensure D is power of 2 (required by tl.arange)
    assert D & (D - 1) == 0, f"Head dim D={D} must be a power of 2"

    grid = (BH, triton.cdiv(S_Q, BLOCK_M))

    _ring_attn_fwd_kernel[grid](
        # Tensor pointers
        q, k, v,
        o_acc, m_acc, l_acc,
        q_positions, k_positions,
        # Strides
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        o_acc.stride(0), o_acc.stride(1), o_acc.stride(2),
        # Dimensions
        S_Q, S_K,
        # Compile-time constants
        D=D,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        CAUSAL=causal,
    )

7. Ring Attention Forward — Triton Orchestration

This function combines the ring communication (Section 3) with the Triton kernel (Section 6), implementing the full overlapped pipeline.

def ring_attention_forward_triton(
    q: torch.Tensor,              # [B, H, S_local, D]
    k: torch.Tensor,              # [B, H, S_local, D]
    v: torch.Tensor,              # [B, H, S_local, D]
    q_positions: torch.Tensor,    # [S_local]
    all_positions: List[torch.Tensor],
    comm: RingCommunicator,
    causal: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Triton-accelerated ring attention forward with compute-communication overlap.

    Returns:
        O:     [B, H, S_local, D]   — normalized attention output
        m_acc: [B*H, S_local]       — final row-wise max  (saved for backward)
        l_acc: [B*H, S_local]       — final row-wise lse  (saved for backward)
    """
    B, H, S_local, D = q.shape
    BH = B * H
    device = q.device

    # Flatten batch and heads for the kernel: [BH, S_local, D]
    q_3d = q.reshape(BH, S_local, D).contiguous()

    # ---- Initialize accumulators in float32 ----
    o_acc = torch.zeros(BH, S_local, D, device=device, dtype=torch.float32)
    m_acc = torch.full((BH, S_local), float("-inf"), device=device, dtype=torch.float32)
    l_acc = torch.zeros(BH, S_local, device=device, dtype=torch.float32)

    # ---- Double-buffered K/V ----
    k_cur = k.reshape(BH, S_local, D).contiguous()
    v_cur = v.reshape(BH, S_local, D).contiguous()
    k_buf = torch.empty_like(k_cur)
    v_buf = torch.empty_like(v_cur)

    q_pos = q_positions.to(device)

    for step in range(comm.cp_size):
        # ---- (A) Start async communication (non-blocking) ----
        if step < comm.cp_size - 1:
            handles = comm.send_recv_kv(k_cur, v_cur, k_buf, v_buf)

        # ---- (B) Compute: Triton kernel for this K/V chunk ----
        source_rank = (comm.cp_rank - step) % comm.cp_size
        k_pos = all_positions[source_rank].to(device)

        triton_ring_attn_step(
            q_3d, k_cur, v_cur,
            o_acc, m_acc, l_acc,
            q_pos, k_pos,
            causal=causal,
        )

        # ---- (C) Wait for communication, swap buffers ----
        if step < comm.cp_size - 1:
            RingCommunicator.wait_all(handles)
            k_cur, k_buf = k_buf, k_cur
            v_cur, v_buf = v_buf, v_cur

    # ---- Final normalization: O = O_unnorm / l ----
    O = o_acc / l_acc.unsqueeze(-1).clamp(min=1e-6)
    O = O.reshape(B, H, S_local, D).to(q.dtype)

    return O, m_acc, l_acc

7.1 Timing Overlap Condition

Effective overlap requires the compute time of one ring step to exceed the communication time:

Tcompute(step)=2(s/NCP)2dkBnhGPU FLOPS    Tcomm(step)=2(s/NCP)dkBnhsizeof(dtype)Interconnect BW T_{\text{compute}}^{(\text{step})} = \frac{2 \cdot (s/N_{\text{CP}})^2 \cdot d_k \cdot B \cdot n_h}{\text{GPU FLOPS}} \;\geq\; T_{\text{comm}}^{(\text{step})} = \frac{2 \cdot (s/N_{\text{CP}}) \cdot d_k \cdot B \cdot n_h \cdot \texttt{sizeof(dtype)}}{\text{Interconnect BW}}

Simplifying:

sNCP    GPU FLOPSsizeof(dtype)Interconnect BW \frac{s}{N_{\text{CP}}} \;\geq\; \frac{\text{GPU FLOPS} \cdot \texttt{sizeof(dtype)}}{\text{Interconnect BW}}

For an A100 (312 TFLOPS bf16) with NVLink (600 GB/s), this gives s/NCP312×1012×2/(600×109)1024s / N_{\text{CP}} \geq 312 \times 10^{12} \times 2 / (600 \times 10^9) \approx 1024 tokens — easily satisfied in long-context training scenarios.


8. Backward Pass

8.1 Mathematical Derivation

Given the forward computation at ring step tton GPUii:

S(t)=Qi(K(t))Tdk,P(t)=exp ⁣(S(t)mfinal)final,Oi=t=0NCP1P(t)V(t) S^{(t)} = \frac{Q_i \left(K^{(t)}\right)^T}{\sqrt{d_k}}, \quad P^{(t)} = \frac{\exp\!\left(S^{(t)} - m_{\text{final}}\right)}{\ell_{\text{final}}}, \quad O_i = \sum_{t=0}^{N_{\text{CP}}-1} P^{(t)} V^{(t)}

The gradients are:

D(t)=dOi(P(t)V(t)),δ(t)=rowsum ⁣(dOiOi) D^{(t)} = dO_i \odot \left(P^{(t)} V^{(t)}\right), \quad \delta^{(t)} = \text{rowsum}\!\left(dO_i \odot O_i\right)
dS(t)=P(t)(dOi(V(t))Tδ(t)) dS^{(t)} = P^{(t)} \odot \left(dO_i \cdot \left(V^{(t)}\right)^T - \delta^{(t)}\right)
dQi+=dS(t)K(t)dk,dK(t)+=(dS(t))TQidk,dV(t)+=(P(t))TdOi dQ_i \mathrel{+}= \frac{dS^{(t)} \cdot K^{(t)}}{\sqrt{d_k}}, \quad dK^{(t)} \mathrel{+}= \frac{\left(dS^{(t)}\right)^T \cdot Q_i}{\sqrt{d_k}}, \quad dV^{(t)} \mathrel{+}= \left(P^{(t)}\right)^T \cdot dO_i

8.2 Ring Backward Algorithm

The backward pass mirrors the forward ring structure:

  1. K/VK/V chunks rotate in the same direction as the forward pass
  2. At each step, recompute S(t)S^{(t)}andP(t)P^{(t)}using savedmfinal,finalm_{\text{final}}, \ell_{\text{final}} (avoids storing the full attention matrix)
  3. dQdQ accumulates locally (each GPU’s queries are fixed)
  4. dK,dVdK, dVare accumulated into a rotating buffer that travels in the reverse direction, so each gradient reaches its owning GPU afterNCPN_{\text{CP}} steps

8.3 PyTorch Reference Implementation

def ring_attention_backward_reference(
    dO: torch.Tensor,             # [B, H, S_local, D] — upstream gradient
    q: torch.Tensor,              # [B, H, S_local, D] — saved from forward
    k: torch.Tensor,              # [B, H, S_local, D] — saved from forward
    v: torch.Tensor,              # [B, H, S_local, D] — saved from forward
    O: torch.Tensor,              # [B, H, S_local, D] — saved output
    m_final: torch.Tensor,        # [B*H, S_local]     — saved from forward
    l_final: torch.Tensor,        # [B*H, S_local]     — saved from forward
    q_positions: torch.Tensor,
    all_positions: List[torch.Tensor],
    comm: RingCommunicator,
    causal: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Backward pass for ring attention.
    Returns dQ, dK, dV each of shape [B, H, S_local, D].
    """
    B, H, S_local, D = q.shape
    device = q.device
    scale = D ** -0.5

    dQ = torch.zeros_like(q, dtype=torch.float32)

    # Pre-compute delta = rowsum(dO * O): [B, H, S_local]
    delta = (dO.float() * O.float()).sum(dim=-1)

    # Reshape m_final, l_final to [B, H, S_local]
    m_f = m_final.reshape(B, H, S_local)
    l_f = l_final.reshape(B, H, S_local)

    # ---- K/V double buffer (same rotation as forward) ----
    k_cur, v_cur = k.clone(), v.clone()
    k_buf, v_buf = torch.empty_like(k), torch.empty_like(v)

    # ---- dK/dV accumulator (reverse rotation) ----
    # At step t, we compute dK/dV for the K/V chunk currently held.
    # We accumulate into dk_acc/dv_acc, which we send BACKWARD
    # through the ring so gradients reach the owning GPU.
    dk_acc = torch.zeros_like(k, dtype=torch.float32)
    dv_acc = torch.zeros_like(v, dtype=torch.float32)
    dk_recv = torch.empty_like(dk_acc)
    dv_recv = torch.empty_like(dv_acc)

    # Reverse ring: send to (rank-1), recv from (rank+1)
    rev_send = (comm.cp_rank - 1) % comm.cp_size
    rev_recv = (comm.cp_rank + 1) % comm.cp_size

    for step in range(comm.cp_size):
        # ---- Async K/V forward rotation (except last step) ----
        if step < comm.cp_size - 1:
            kv_handles = comm.send_recv_kv(k_cur, v_cur, k_buf, v_buf)

        # ---- Determine source and positions ----
        source_rank = (comm.cp_rank - step) % comm.cp_size
        k_positions = all_positions[source_rank].to(device)

        # ---- Recompute S and P for this chunk ----
        S = torch.matmul(q.float(), k_cur.float().transpose(-2, -1)) * scale

        if causal:
            causal_mask = q_positions.unsqueeze(-1) >= k_positions.unsqueeze(-2)
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
            S = S.masked_fill(~causal_mask, float("-inf"))

        # Reconstruct P using saved global statistics
        P = torch.exp(S - m_f.unsqueeze(-1)) / l_f.unsqueeze(-1).clamp(min=1e-6)

        # ---- Compute gradients ----
        # dS = P * (dO @ V^T - delta)
        dP = torch.matmul(dO.float(), v_cur.float().transpose(-2, -1))
        dS = P * (dP - delta.unsqueeze(-1))

        # dQ += dS @ K / sqrt(d_k)
        dQ += torch.matmul(dS, k_cur.float()) * scale

        # dK, dV for this chunk
        dk_step = torch.matmul(dS.transpose(-2, -1), q.float()) * scale
        dv_step = torch.matmul(P.transpose(-2, -1), dO.float())
        dk_acc += dk_step
        dv_acc += dv_step

        # ---- Wait for K/V rotation ----
        if step < comm.cp_size - 1:
            RingCommunicator.wait_all(kv_handles)
            k_cur, k_buf = k_buf, k_cur
            v_cur, v_buf = v_buf, v_cur

        # ---- Reverse-rotate dk_acc, dv_acc ----
        if step < comm.cp_size - 1:
            rev_handles = []
            rev_handles.append(dist.isend(
                dk_acc, dst=dist.get_global_rank(comm.group, rev_send),
                group=comm.group))
            rev_handles.append(dist.isend(
                dv_acc, dst=dist.get_global_rank(comm.group, rev_send),
                group=comm.group))
            rev_handles.append(dist.irecv(
                dk_recv, src=dist.get_global_rank(comm.group, rev_recv),
                group=comm.group))
            rev_handles.append(dist.irecv(
                dv_recv, src=dist.get_global_rank(comm.group, rev_recv),
                group=comm.group))
            for h in rev_handles:
                h.wait()
            dk_acc, dk_recv = dk_recv, dk_acc
            dv_acc, dv_recv = dv_recv, dv_acc

    return dQ.to(q.dtype), dk_acc.to(k.dtype), dv_acc.to(v.dtype)

9. Autograd Function — Complete Integration

class RingAttentionFunction(torch.autograd.Function):
    """
    Autograd-compatible ring attention with zig-zag partitioning.
    Supports both Triton-accelerated and pure-PyTorch execution.
    """

    @staticmethod
    def forward(
        ctx,
        q: torch.Tensor,              # [B, H, S_local, D]
        k: torch.Tensor,
        v: torch.Tensor,
        q_positions: torch.Tensor,
        all_positions: List[torch.Tensor],
        comm: RingCommunicator,
        causal: bool,
        use_triton: bool,
    ) -> torch.Tensor:

        if use_triton:
            O, m_final, l_final = ring_attention_forward_triton(
                q, k, v, q_positions, all_positions, comm, causal
            )
        else:
            O = ring_attention_forward_reference(
                q, k, v, q_positions, all_positions, comm, causal
            )
            # For backward, we'd need m_final, l_final from reference too
            # (omitted here for brevity — extend reference to return them)
            m_final = l_final = None

        # Save tensors for backward
        ctx.save_for_backward(q, k, v, O)
        ctx.q_positions = q_positions
        ctx.all_positions = all_positions
        ctx.comm = comm
        ctx.causal = causal
        ctx.m_final = m_final
        ctx.l_final = l_final

        return O

    @staticmethod
    def backward(ctx, dO: torch.Tensor):
        q, k, v, O = ctx.saved_tensors

        dQ, dK, dV = ring_attention_backward_reference(
            dO, q, k, v, O,
            ctx.m_final, ctx.l_final,
            ctx.q_positions, ctx.all_positions,
            ctx.comm, ctx.causal,
        )
        # Return gradients for each forward input
        # (None for non-Tensor / non-differentiable args)
        return dQ, dK, dV, None, None, None, None, None

10. ContextParallelAttention Module

import torch.nn as nn


class ContextParallelAttention(nn.Module):
    """
    Drop-in multi-head attention module with context parallelism.

    Handles zig-zag partitioning, ring attention, and gradient sync.
    Input shape: [B, S_full, hidden_dim] (full sequence, pre-partitioned)
        — OR —
    Pre-partitioned inputs when integrated into a CP-aware pipeline.
    """

    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        cp_group: dist.ProcessGroup,
        causal: bool = True,
        use_triton: bool = True,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.causal = causal
        self.use_triton = use_triton

        # Projections (weights replicated across CP group, sharded by TP if combined)
        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Communication
        self.comm = RingCommunicator(cp_group)

        # Pre-compute zig-zag positions for all ranks
        self._positions_cache = {}

    def _get_positions(
        self, seq_len: int, device: torch.device
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Cached zig-zag position computation."""
        key = (seq_len, self.comm.cp_size)
        if key not in self._positions_cache:
            all_pos = [
                zigzag_partition(seq_len, self.comm.cp_size, r).to(device)
                for r in range(self.comm.cp_size)
            ]
            self._positions_cache[key] = all_pos

        all_pos = self._positions_cache[key]
        return all_pos[self.comm.cp_rank], all_pos

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, S_full, hidden_dim] — full sequence input.

        Returns:
            out: [B, S_local, hidden_dim] — output for this rank's token subset.
        """
        B, S_full, _ = x.shape

        # ---- Zig-zag partition input ----
        q_pos, all_pos = self._get_positions(S_full, x.device)
        x_local = x.index_select(1, q_pos)                  # [B, S_local, hidden_dim]

        # ---- Project to Q, K, V ----
        S_local = x_local.shape[1]
        q = self.q_proj(x_local).reshape(B, S_local, self.num_heads, self.head_dim)
        k = self.k_proj(x_local).reshape(B, S_local, self.num_heads, self.head_dim)
        v = self.v_proj(x_local).reshape(B, S_local, self.num_heads, self.head_dim)

        # Transpose to [B, H, S_local, D]
        q = q.transpose(1, 2).contiguous()
        k = k.transpose(1, 2).contiguous()
        v = v.transpose(1, 2).contiguous()

        # ---- Ring Attention ----
        O = RingAttentionFunction.apply(
            q, k, v,
            q_pos, all_pos,
            self.comm,
            self.causal,
            self.use_triton,
        )                                                     # [B, H, S_local, D]

        # ---- Output projection ----
        O = O.transpose(1, 2).reshape(B, S_local, self.hidden_dim)
        out = self.o_proj(O)

        return out

11. End-to-End Usage Example

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


def setup_context_parallel(cp_size: int) -> dist.ProcessGroup:
    """
    Create context-parallel process groups.
    Assumes torch.distributed is already initialized.
    """
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    assert world_size % cp_size == 0

    # Each CP group contains `cp_size` consecutive ranks
    num_cp_groups = world_size // cp_size
    cp_group = None
    for i in range(num_cp_groups):
        ranks = list(range(i * cp_size, (i + 1) * cp_size))
        group = dist.new_group(ranks)
        if rank in ranks:
            cp_group = group

    return cp_group


def main():
    # ---- Initialize distributed ----
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(device)

    # ---- Hyperparameters ----
    CP_SIZE     = 4
    BATCH_SIZE  = 2
    SEQ_LEN     = 8192       # full sequence length
    HIDDEN_DIM  = 4096
    NUM_HEADS   = 32
    HEAD_DIM    = HIDDEN_DIM // NUM_HEADS  # 128

    # ---- Create CP group ----
    cp_group = setup_context_parallel(CP_SIZE)

    # ---- Instantiate model ----
    model = ContextParallelAttention(
        hidden_dim=HIDDEN_DIM,
        num_heads=NUM_HEADS,
        cp_group=cp_group,
        causal=True,
        use_triton=True,
    ).to(device).to(torch.bfloat16)

    # ---- Synthetic input (each rank gets full sequence, module partitions it) ----
    x = torch.randn(
        BATCH_SIZE, SEQ_LEN, HIDDEN_DIM,
        device=device, dtype=torch.bfloat16,
    )

    # ---- Forward ----
    out = model(x)   # [B, S_local, HIDDEN_DIM]  where S_local = SEQ_LEN / CP_SIZE

    # ---- Backward ----
    loss = out.sum()
    loss.backward()

    # ---- Gradient sync across CP group (weights are replicated) ----
    for param in model.parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
            param.grad /= CP_SIZE

    if rank == 0:
        print(f"Output shape: {out.shape}")          # [2, 2048, 4096]
        print(f"Tokens per GPU: {SEQ_LEN // CP_SIZE}")  # 2048
        print(f"Memory per GPU (activations): "
              f"~{out.element_size() * out.nelement() / 1e6:.1f} MB")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Launch command:

torchrun --nproc_per_node=4 context_parallel.py

12. Performance Analysis and Optimization Notes

12.1 Memory Footprint Comparison

Configuration Per-GPU Activation Memory (Attention)
No parallelism O(Bnhs2)\mathcal{O}(B \cdot n_h \cdot s^2) (full attention matrix)
FlashAttention only O(Bnhs)\mathcal{O}(B \cdot n_h \cdot s) (no materialized attn matrix)
CP + Ring Attention O ⁣(BnhsNCPdk)\displaystyle \mathcal{O}\!\left(B \cdot n_h \cdot \frac{s}{N_{\text{CP}}} \cdot d_k\right) (accumulators only)
CP + Ring + FlashAttn O ⁣(BnhsNCP)\displaystyle \mathcal{O}\!\left(B \cdot n_h \cdot \frac{s}{N_{\text{CP}}}\right) (minimal — accumulators + LSE)

12.2 Kernel Performance Tuning

# Auto-tune BLOCK_M and BLOCK_N for different hardware
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64},  num_warps=4, num_stages=3),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 64},  num_warps=4, num_stages=4),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128}, num_warps=8, num_stages=3),
    ],
    key=["S_Q", "S_K", "D"],
)
@triton.jit
def _ring_attn_fwd_kernel_autotuned(
    # ... same signature as _ring_attn_fwd_kernel ...
):
    pass  # identical body

12.3 Critical Implementation Considerations

Consideration Detail
Numerical precision All accumulations (O~\tilde{O}, \ell, mm) must be float32 regardless of input dtype to prevent overflow in exp()\exp(\cdot)
All-masked rows When =0\ell = 0(all keys masked for a query), clamp toϵ\ell \geq \epsilon before division to avoid NaN
Position index transport Positions are computed analytically from source_rank, never sent over the network
Backward recomputation Forward saves only (Q,K,V,O,mfinal,final)(Q, K, V, O, m_{\text{final}}, \ell_{\text{final}}); PPis recomputed during backward usingmfinalm_{\text{final}}andfinal\ell_{\text{final}}, trading compute for memory
Combining with TP When TP shards attention heads, CP operates on the local heads; the CP group is orthogonal to the TP group
GQA / MQA compatibility For grouped-query attention, K/VK/Vhave fewer heads thanQQ; rotate only the nkvn_{kv} heads and broadcast to query head groups

12.4 Communication-Computation Overlap Profiling

def profile_overlap_efficiency(
    seq_len: int,
    cp_size: int,
    head_dim: int,
    num_heads: int,
    batch_size: int,
    gpu_tflops: float = 312.0,         # A100 bf16 peak
    bw_gbps: float = 600.0,            # NVLink 4.0 bidirectional
    dtype_bytes: int = 2,              # bfloat16
) -> dict:
    """Estimate whether communication is fully hidden behind compute."""
    S_local = seq_len // cp_size

    # FLOPs per ring step: 2 * S_local * S_local * head_dim * batch * heads
    flops_per_step = 2 * S_local * S_local * head_dim * batch_size * num_heads
    t_compute = flops_per_step / (gpu_tflops * 1e12)

    # Bytes per ring step: 2 * S_local * head_dim * batch * heads * dtype_bytes
    # (factor 2 for K and V)
    bytes_per_step = 2 * S_local * head_dim * batch_size * num_heads * dtype_bytes
    t_comm = bytes_per_step / (bw_gbps * 1e9)

    overlap_ratio = t_compute / t_comm

    return {
        "t_compute_ms": t_compute * 1e3,
        "t_comm_ms": t_comm * 1e3,
        "overlap_ratio": overlap_ratio,
        "fully_hidden": overlap_ratio >= 1.0,
        "arithmetic_intensity": flops_per_step / bytes_per_step,
    }

# Example:
stats = profile_overlap_efficiency(
    seq_len=131072, cp_size=8, head_dim=128,
    num_heads=32, batch_size=1
)
# t_compute_ms ≈ 8.59,  t_comm_ms ≈ 0.22,  overlap_ratio ≈ 39.3x  → fully hidden

13. Integration with Multi-Dimensional Parallelism

The total GPU count factorizes as:

Ntotal=NDP×NTP×NCP×NPP N_{\text{total}} = N_{\text{DP}} \times N_{\text{TP}} \times N_{\text{CP}} \times N_{\text{PP}}
def create_parallel_groups(
    world_size: int,
    dp_size: int,
    tp_size: int,
    cp_size: int,
    pp_size: int,
) -> dict:
    """
    Create non-overlapping process groups for 4D parallelism.

    Ordering (innermost to outermost): TP → CP → PP → DP
    This places TP within a node (NVLink), CP across nearby nodes,
    PP across distant nodes, and DP across replicas.
    """
    assert world_size == dp_size * tp_size * cp_size * pp_size

    rank = dist.get_rank()
    groups = {}

    # TP groups: tp_size consecutive ranks
    for dp in range(dp_size):
        for pp in range(pp_size):
            for cp in range(cp_size):
                ranks = [
                    dp * (tp_size * cp_size * pp_size)
                    + pp * (tp_size * cp_size)
                    + cp * tp_size
                    + tp
                    for tp in range(tp_size)
                ]
                g = dist.new_group(ranks)
                if rank in ranks:
                    groups["tp"] = g

    # CP groups: cp_size ranks, strided by tp_size
    for dp in range(dp_size):
        for pp in range(pp_size):
            for tp in range(tp_size):
                ranks = [
                    dp * (tp_size * cp_size * pp_size)
                    + pp * (tp_size * cp_size)
                    + cp * tp_size
                    + tp
                    for cp in range(cp_size)
                ]
                g = dist.new_group(ranks)
                if rank in ranks:
                    groups["cp"] = g

    # PP and DP groups follow analogous patterns (omitted for brevity)

    return groups

This places TP on the fastest interconnect (intra-node NVLink), CP on the next tier (inter-node NVLink or InfiniBand), and PP/DP on the outermost communication rings — matching communication intensity to available bandwidth at each level.

Previous NextData Parallelism (DP): A Comprehensive Technical Treatment

Generated from llm_training_at_scale at .