Context Parallelism and Ring Attention
1. Motivation: The Memory Wall of Long Sequences
1.1 Recap of Tensor Parallelism + Sequence Parallelism
With Tensor Parallelism (TP), model weight matrices are sharded across
Under TP+SP, the per-GPU activation memory for a single transformer layer scales approximately as:
where:
= sequence length = micro-batch size = hidden dimension = tensor-parallel degree = a constant capturing the number of intermediate tensors retained per layer
1.2 The Residual Scaling Problem
Even with TP+SP and full activation recomputation (which incurs
where
Core problem statement: TP+SP distributes weights and some activations, but every GPU still processes the full sequence inside the TP region (the attention and MLP blocks). For very long sequences, this remaining per-GPU memory footprint becomes the binding constraint.
2. Context Parallelism — Definition and Mechanism
2.1 Core Idea
Context Parallelism (CP) splits the input along the sequence dimension and applies this split across the entire model, including the TP regions (attention, MLP), not just the SP-only regions.
Given a context-parallel group of size
Each GPU
2.2 Impact on Different Module Types
| Module | Token Interaction Pattern | CP Communication Required? |
|---|---|---|
| LayerNorm | Per-token (independent) | No — each token normalizes independently |
| MLP / FFN | Per-token (independent) | No — pointwise or token-local computation |
| Dropout | Per-token (independent) | No |
| Self-Attention | All-to-all across sequence | Yes — each query must access all keys/values |
For token-independent modules (LayerNorm, MLP), splitting along the sequence dimension requires zero inter-GPU communication — each GPU applies the module to its local token subset identically to how it would process the full sequence. The weight matrices are not split (unlike TP), so no expensive all-reduce on activations is needed for these modules.
2.3 Gradient Synchronization
After the backward pass, each CP-rank holds gradients computed from a different subset of the sequence. Since the weights are replicated across the CP group (analogous to Data Parallelism), an all-reduce over the CP group is required to aggregate gradients before the optimizer step:
This all-reduce is identical in structure to the gradient synchronization in standard data parallelism.
2.4 Per-GPU Activation Memory Under CP
With CP, the per-GPU activation memory becomes:
The sequence-length dependence is now reduced by a factor of
2.5 Relationship to Sequence Parallelism
| Property | Sequence Parallelism (SP) | Context Parallelism (CP) |
|---|---|---|
| Where applied | Non-TP regions only (LayerNorm, Dropout) | Entire model including TP regions |
| Split dimension | Sequence | Sequence |
| Weight sharding | Via TP (in TP regions) | Weights replicated across CP group |
| Gradient sync | Handled by TP all-reduce | Separate CP all-reduce |
| Communication for attention | Not applicable (attention is inside TP) | Required — Ring Attention |
3. The Attention Bottleneck Under Context Parallelism
3.1 Why Attention Is Special
The scaled dot-product attention for a single head is defined as:
where:
— queries — keys — values — head dimension
The attention score matrix
3.2 Naive Communication Cost
A naive implementation would perform an all-gather of all
where
4. Ring Attention — Efficient Communication for CP
4.1 Core Principle
Ring Attention organizes the
- Sends its current
chunk to the next GPU in the ring (asynchronous, non-blocking). - Computes partial attention using the
chunk it currently holds in local memory. - Receives a
chunk from the previous GPU in the ring. - Repeats for
steps total.
This overlaps communication with computation, hiding the latency of data transfer behind useful arithmetic.
4.2 Detailed Step-by-Step Execution
Consider
At time step
Operations at each step
Step A (Communication — non-blocking):
Step B (Compute — overlapped with Step A):
Compute the partial attention score block:
Step C (Synchronization):
Wait for receive of
After all
4.3 Online Softmax Aggregation
A critical subtlety: softmax is computed row-wise over the full key sequence, but each GPU sees only one chunk of keys at each step. A naive approach would require storing all
Online softmax (also used in FlashAttention) resolves this. At each step
— running row-wise maximum of logits — running row-wise sum of exponentiated logits (denominator of softmax) — running weighted output accumulator
Update rules (per query row
At step
At step
Rescale the running accumulator:
After all
4.4 Communication Volume Analysis
At each of
Total communication volume:
This is the same total volume as the naive all-gather, but the latency is hidden because computation overlaps with communication. The effective wall-clock overhead approaches zero when:
i.e., the time to compute one partial attention block exceeds the time to transfer one
4.5 Relationship to FlashAttention
| Aspect | FlashAttention | Ring Attention (CP) |
|---|---|---|
| Scope | Single GPU, single attention kernel | Multi-GPU, distributed attention |
| Tiling axis | Tiles |
Tiles |
| Memory savings | Avoids materializing |
Avoids storing full sequence activations per GPU |
| Online softmax | Yes — essential for block-wise tiling | Yes — essential for incremental aggregation |
| Complementary? | Yes — can be used within each GPU’s local attention block | Yes — each Ring Attention step can use FlashAttention for its local block computation |
5. The Causal Mask Imbalance Problem
5.1 Causal Attention Mask
For autoregressive (decoder-only) models, the attention mask enforces causality:
The masked attention score matrix becomes:
This produces a lower-triangular structure, meaning earlier tokens attend to fewer keys than later tokens.
5.2 Load Imbalance Under Naive CP Partitioning
With naive sequential partitioning (GPU 0 gets tokens
- GPU 0 holds the earliest tokens. Due to causal masking, these tokens attend only to themselves and need no
from other GPUs. GPU 0 finishes almost immediately and idles. - GPU
holds the latest tokens. These tokens attend to the entire preceding sequence, requiring from all other GPUs and performing maximum computation.
The number of non-masked entries (FLOPs) for GPU
For the first GPU (
For the last GPU (
The imbalance ratio is:
For
6. Zig-Zag Ring Attention — Balanced Computation
6.1 Token Assignment Strategy
Instead of contiguous blocks, Zig-Zag Attention assigns tokens to GPUs in an interleaved, folded pattern. For
The pattern follows a zig-zag (or folded) ordering: within each “fold” of
Formally, for the
6.2 Why This Balances Computation
Each GPU now holds a mixture of early and late tokens. Under the causal mask, early tokens contribute few non-masked entries while late tokens contribute many. By pairing them together, each GPU’s total non-masked computation sums to approximately:
This achieves near-perfect load balancing across all GPUs.
6.3 Communication Implication
Under zig-zag assignment, every GPU needs
6.4 Distinction from Striped Attention
Striped Attention uses a simpler round-robin assignment:
This also achieves load balancing but with a different scattering pattern. Zig-Zag achieves slightly better contiguity within each GPU’s token set (tokens come in pairs of consecutive indices from alternating ends), which can improve memory access patterns and cache utilization. The practical difference is minor but relevant for highly optimized implementations.
7. Communication Strategies for Zig-Zag / Ring Attention
7.1 All-Gather Implementation
All GPUs simultaneously execute an all-gather collective on the
Properties:
- Communication pattern: Single collective operation; all GPUs receive all data.
- Temporary memory per GPU: Must store full
and :
This negates the memory savings of CP for the
- Latency: Dominated by a single all-gather; bandwidth-optimal but high peak memory.
7.2 All-to-All (Ring) Implementation
GPUs exchange
Properties:
- Communication pattern:
point-to-point send/receive steps in a ring. - Temporary memory per GPU: Only one additional
chunk at a time:
A factor of
- Latency:
communication steps, each with startup latency, but each overlapped with computation.
7.3 Comparison Summary
| Property | All-Gather | All-to-All (Ring) |
|---|---|---|
| Temporary |
||
| Communication steps | ||
| Overlap potential | Limited (compute starts after gather) | High (compute overlapped with send/recv) |
| Implementation complexity | Low | Moderate (ring scheduling + online softmax) |
| Best suited for | Moderate |
Large |
8. Combined Parallelism Dimensions — Where CP Fits
At this point in the parallelism hierarchy:
| Parallelism | What It Shards | Communication Cost | Scaling Regime |
|---|---|---|---|
| Data Parallelism (DP) | Data (batches) | All-reduce on gradients | Scales across nodes |
| Tensor Parallelism (TP) | Weight matrices (columns/rows) | All-reduce on activations per layer | Intra-node (high bandwidth required) |
| Sequence Parallelism (SP) | Activations in non-TP regions | Scatter/gather at TP↔SP boundaries | Tied to TP group |
| Context Parallelism (CP) | Sequence across entire model | Ring attention + gradient all-reduce | Intra- or inter-node |
| Pipeline Parallelism (PP) | Layers across stages | Point-to-point activation transfers | Scales across nodes |
The total number of GPUs used is:
CP addresses a specific gap: TP+SP reduce per-GPU memory for both weights and activations, but the sequence length still appears unsharded inside the TP region’s attention computation. CP eliminates this last remaining sequence-length bottleneck, enabling training on sequences of
Limitation of TP: TP requires high-bandwidth interconnects (e.g., NVLink within a node) because it communicates activations at every layer. It does not scale well across nodes with lower-bandwidth interconnects. When model weights exceed the memory of a single node even after TP sharding, Pipeline Parallelism (PP) becomes necessary — partitioning the model’s layers across nodes with only point-to-point activation transfers at stage boundaries.
Context Parallelism: SOTA Implementation with PyTorch + Triton
1. System Architecture Overview
The implementation consists of six composable modules, each addressing a distinct concern:
┌─────────────────────────────────────────────────────────────┐
│ ContextParallelAttention (nn.Module) │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ RingAttentionFunction (autograd.Function) │ │
│ │ ┌──────────┐ ┌──────────────┐ ┌────────────────┐ │ │
│ │ │ ZigZag │ │ Ring │ │ Triton Kernel: │ │ │
│ │ │Partition │──│ Communicator │──│ Fused Attn + │ │ │
│ │ │ │ │ (send/recv) │ │ Online Softmax │ │ │
│ │ └──────────┘ └──────────────┘ └────────────────┘ │ │
│ └───────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
Each GPU in the CP group of size
where
2. Zig-Zag Sequence Partitioning
2.1 Mathematical Formulation
For global token index
This ensures each GPU holds a balanced mixture of early and late tokens, equalizing the causal-mask workload:
2.2 Implementation
import torch
import torch.distributed as dist
from typing import List, Tuple
def zigzag_partition(
seq_len: int,
cp_size: int,
cp_rank: int,
) -> torch.Tensor:
"""
Compute global position indices assigned to `cp_rank` under zig-zag ordering.
Args:
seq_len: Total sequence length (must be divisible by cp_size).
cp_size: Context-parallel world size (N_CP).
cp_rank: This GPU's rank in the CP group.
Returns:
positions: LongTensor of shape [seq_len // cp_size] containing
the global token indices owned by this rank.
Example (seq_len=16, cp_size=4):
rank 0 → [ 0, 7, 8, 15]
rank 1 → [ 1, 6, 9, 14]
rank 2 → [ 2, 5, 10, 13]
rank 3 → [ 3, 4, 11, 12]
"""
assert seq_len % cp_size == 0, (
f"seq_len ({seq_len}) must be divisible by cp_size ({cp_size})"
)
indices = torch.arange(seq_len, dtype=torch.long)
# Reshape into folds of size cp_size: [num_folds, cp_size]
folds = indices.reshape(-1, cp_size)
# Reverse odd-numbered folds to create zig-zag pattern
folds[1::2] = folds[1::2].flip(dims=[1])
# Column `cp_rank` across all folds gives this rank's tokens
positions = folds[:, cp_rank].contiguous()
return positions
def zigzag_split(
x: torch.Tensor,
cp_size: int,
cp_rank: int,
seq_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Split a full-sequence tensor according to zig-zag assignment.
Args:
x: Input tensor with a sequence dimension.
cp_size: Context-parallel world size.
cp_rank: This GPU's CP rank.
seq_dim: Which dimension is the sequence axis (default: 2).
Returns:
x_local: The subsequence assigned to this rank.
positions: The global position indices (for causal masking).
"""
seq_len = x.shape[seq_dim]
positions = zigzag_partition(seq_len, cp_size, cp_rank).to(x.device)
x_local = x.index_select(seq_dim, positions)
return x_local, positions
2.3 Verification
# Quick sanity check: every position appears exactly once across all ranks
all_positions = torch.cat([
zigzag_partition(16, 4, r) for r in range(4)
])
assert all_positions.sort()[0].equal(torch.arange(16))
# Verify balanced causal workload:
# FLOPs ~ sum of positions (each token attends to all preceding tokens)
for r in range(4):
pos = zigzag_partition(16, 4, r)
print(f"Rank {r}: positions={pos.tolist()}, "
f"causal_flops ∝ {pos.sum().item()}")
# Output:
# Rank 0: positions=[0, 7, 8, 15], causal_flops ∝ 30
# Rank 1: positions=[1, 6, 9, 14], causal_flops ∝ 30
# Rank 2: positions=[2, 5, 10, 13], causal_flops ∝ 30
# Rank 3: positions=[3, 4, 11, 12], causal_flops ∝ 30
3. Ring Communication Layer
3.1 Ring Topology
The kv_current, the next chunk is received into kv_next.
class RingCommunicator:
"""
Manages asynchronous ring send/recv of K/V tensors
with double buffering for compute-communication overlap.
"""
def __init__(self, process_group: dist.ProcessGroup):
self.group = process_group
self.cp_size = dist.get_world_size(group=process_group)
self.cp_rank = dist.get_rank(group=process_group)
# Ring neighbors
self.send_rank = (self.cp_rank + 1) % self.cp_size
self.recv_rank = (self.cp_rank - 1) % self.cp_size
def _get_global_rank(self, group_rank: int) -> int:
"""Convert group-local rank to global rank."""
return dist.get_global_rank(self.group, group_rank)
def send_recv_kv(
self,
k_send: torch.Tensor,
v_send: torch.Tensor,
k_recv: torch.Tensor,
v_recv: torch.Tensor,
) -> List[dist.Work]:
"""
Initiate non-blocking ring send/recv for K and V tensors.
Args:
k_send, v_send: Tensors to send to next rank.
k_recv, v_recv: Pre-allocated buffers to receive from prev rank.
Returns:
List of async work handles (call .wait() to synchronize).
"""
ops = []
# Send to next neighbor
ops.append(dist.isend(
k_send, dst=self._get_global_rank(self.send_rank), group=self.group
))
ops.append(dist.isend(
v_send, dst=self._get_global_rank(self.send_rank), group=self.group
))
# Receive from previous neighbor
ops.append(dist.irecv(
k_recv, src=self._get_global_rank(self.recv_rank), group=self.group
))
ops.append(dist.irecv(
v_recv, src=self._get_global_rank(self.recv_rank), group=self.group
))
return ops
@staticmethod
def wait_all(handles: List[dist.Work]):
"""Block until all async operations complete."""
for h in handles:
h.wait()
4. Online Softmax Merge — Mathematical Foundation
4.1 Problem Statement
At each ring step
4.2 Derivation
Let the full attention for query row
where
Partition the key positions into two disjoint sets
The merge of sets
Final normalized output after all ring steps:
This is numerically exact — not an approximation.
4.3 PyTorch Implementation
def online_softmax_merge(
O_acc: torch.Tensor, # [B, H, S_local, D] — unnormalized accumulated output
m_acc: torch.Tensor, # [B, H, S_local] — running max
l_acc: torch.Tensor, # [B, H, S_local] — running sum of exp
O_new: torch.Tensor, # [B, H, S_local, D] — new block's unnormalized output
m_new: torch.Tensor, # [B, H, S_local] — new block's max
l_new: torch.Tensor, # [B, H, S_local] — new block's sum of exp
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Merge two sets of online softmax statistics.
All inputs are in float32 for numerical stability.
"""
m_merged = torch.maximum(m_acc, m_new) # [B, H, S_local]
alpha = torch.exp(m_acc - m_merged) # rescale factor for old
beta = torch.exp(m_new - m_merged) # rescale factor for new
l_merged = alpha * l_acc + beta * l_new # [B, H, S_local]
O_merged = (
alpha.unsqueeze(-1) * O_acc
+ beta.unsqueeze(-1) * O_new
) # [B, H, S_local, D]
return O_merged, m_merged, l_merged
5. PyTorch Reference: Ring Attention Forward
This is a complete, correct, pure-PyTorch implementation suitable for validating the Triton version. Each step corresponds to one rotation of the
def ring_attention_forward_reference(
q: torch.Tensor, # [B, H, S_local, D] — local queries
k: torch.Tensor, # [B, H, S_local, D] — local keys
v: torch.Tensor, # [B, H, S_local, D] — local values
q_positions: torch.Tensor, # [S_local] — global positions of local queries
all_positions: List[torch.Tensor], # list of [S_local] per CP rank
comm: RingCommunicator,
causal: bool = True,
) -> torch.Tensor:
"""
Reference ring attention forward pass (no Triton, no overlap).
Returns:
O: [B, H, S_local, D] — attention output for local queries.
"""
B, H, S_local, D = q.shape
device = q.device
scale = D ** -0.5
# ---- Initialize online softmax accumulators ----
O_acc = torch.zeros(B, H, S_local, D, device=device, dtype=torch.float32)
m_acc = torch.full((B, H, S_local,), float("-inf"), device=device, dtype=torch.float32)
l_acc = torch.zeros(B, H, S_local, device=device, dtype=torch.float32)
# ---- Double buffers for K/V ----
k_cur, v_cur = k.clone(), v.clone()
k_buf, v_buf = torch.empty_like(k), torch.empty_like(v)
for step in range(comm.cp_size):
# -- Step A: initiate async send/recv (except last step) --
if step < comm.cp_size - 1:
handles = comm.send_recv_kv(k_cur, v_cur, k_buf, v_buf)
# -- Step B: compute partial attention for current K/V chunk --
# Determine which rank's K/V we currently hold
source_rank = (comm.cp_rank - step) % comm.cp_size
k_positions = all_positions[source_rank].to(device) # [S_local]
# Attention scores: [B, H, S_local, S_local]
S = torch.matmul(q.float(), k_cur.float().transpose(-2, -1)) * scale
# Causal mask based on global positions
if causal:
# q_positions: [S_local], k_positions: [S_local]
causal_mask = q_positions.unsqueeze(-1) >= k_positions.unsqueeze(-2)
# Expand to [1, 1, S_local, S_local]
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
S = S.masked_fill(~causal_mask, float("-inf"))
# Block-local online softmax statistics
m_block = S.max(dim=-1).values # [B, H, S_local]
# Guard against all-masked rows (m_block = -inf)
P = torch.exp(S - m_block.unsqueeze(-1)) # [B, H, S_local, S_local]
l_block = P.sum(dim=-1) # [B, H, S_local]
O_block = torch.matmul(P, v_cur.float()) # [B, H, S_local, D]
# O_block is unnormalized: O_block = sum_j exp(s_j - m_block) * v_j
# Merge with running accumulator
O_acc, m_acc, l_acc = online_softmax_merge(
O_acc, m_acc, l_acc,
O_block, m_block, l_block,
)
# -- Step C: wait for communication, swap buffers --
if step < comm.cp_size - 1:
RingCommunicator.wait_all(handles)
k_cur, k_buf = k_buf, k_cur
v_cur, v_buf = v_buf, v_cur
# ---- Final normalization ----
O = O_acc / l_acc.unsqueeze(-1).clamp(min=1e-6)
return O.to(q.dtype)
Complexity per GPU:
| Metric | Value |
|---|---|
| Compute (FLOPs) | |
| Memory (activations) | |
| Communication (volume) |
6. Triton Kernel: Fused Block Attention with Online Softmax Merge
6.1 Kernel Design
This kernel processes one ring step: it computes attention of local queries BLOCK_N, maintaining online softmax internally. At the end, it merges the result with the inter-step accumulators
Grid:
Each program instance handles BLOCK_M query rows over all N_K key columns.
import triton
import triton.language as tl
@triton.jit
def _ring_attn_fwd_kernel(
# ---- Tensor pointers ----
Q_ptr, K_ptr, V_ptr,
O_acc_ptr, M_acc_ptr, L_acc_ptr,
Q_pos_ptr, K_pos_ptr,
# ---- Strides (elements, not bytes) ----
stride_q_bh, stride_q_s, stride_q_d, # Q: [BH, S_Q, D]
stride_k_bh, stride_k_s, stride_k_d, # K: [BH, S_K, D]
stride_v_bh, stride_v_s, stride_v_d, # V: [BH, S_K, D]
stride_o_bh, stride_o_s, stride_o_d, # O_acc: [BH, S_Q, D]
# ---- Dimensions ----
S_Q, # local query seq length
S_K, # current K/V chunk length
# ---- Compile-time constants ----
D: tl.constexpr,
BLOCK_M: tl.constexpr, # tile size along queries
BLOCK_N: tl.constexpr, # tile size along keys
CAUSAL: tl.constexpr,
):
"""
Fused ring-attention step kernel.
Computes partial attention of Q against one K/V chunk,
then merges with the inter-step accumulator in-place.
Numerically equivalent to:
S = Q @ K^T / sqrt(D)
if CAUSAL: S = masked_fill(S, q_pos < k_pos, -inf)
m_block, l_block = rowmax(S), rowsum(exp(S - m_block))
O_block = exp(S - m_block) @ V (unnormalized)
(O_acc, m_acc, l_acc) = merge(O_acc, m_acc, l_acc,
O_block, m_block, l_block)
"""
# ================================================================
# Program ID → (batch_head, query_block)
# ================================================================
pid_bh = tl.program_id(0)
pid_m = tl.program_id(1)
# ================================================================
# Offset calculations
# ================================================================
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BLOCK_M]
off_d = tl.arange(0, D) # [D]
mask_m = off_m < S_Q
# Base pointers for this batch-head slice
Q_bh = Q_ptr + pid_bh * stride_q_bh
K_bh = K_ptr + pid_bh * stride_k_bh
V_bh = V_ptr + pid_bh * stride_v_bh
O_bh = O_acc_ptr + pid_bh * stride_o_bh
# ================================================================
# Load Q block: [BLOCK_M, D]
# ================================================================
q = tl.load(
Q_bh + off_m[:, None] * stride_q_s + off_d[None, :] * stride_q_d,
mask=mask_m[:, None],
other=0.0,
).to(tl.float32)
# Load query global positions for causal masking
q_pos = tl.load(Q_pos_ptr + off_m, mask=mask_m, other=0).to(tl.int64)
# ================================================================
# Initialize block-local accumulators (online softmax within chunk)
# ================================================================
m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
o_i = tl.zeros([BLOCK_M, D], dtype=tl.float32)
scale: tl.constexpr = 1.0 / tl.sqrt(D.to(tl.float32))
# ================================================================
# Inner loop: tile over K/V in blocks of BLOCK_N
# ================================================================
for start_n in range(0, S_K, BLOCK_N):
off_n = start_n + tl.arange(0, BLOCK_N) # [BLOCK_N]
mask_n = off_n < S_K
# Load K tile: [BLOCK_N, D]
k = tl.load(
K_bh + off_n[:, None] * stride_k_s + off_d[None, :] * stride_k_d,
mask=mask_n[:, None],
other=0.0,
).to(tl.float32)
# S = Q @ K^T * scale : [BLOCK_M, BLOCK_N]
s = tl.dot(q, tl.trans(k)) * scale
# ---- Apply causal mask ----
if CAUSAL:
k_pos = tl.load(
K_pos_ptr + off_n, mask=mask_n, other=2147483647
).to(tl.int64)
causal_mask = q_pos[:, None] >= k_pos[None, :]
s = tl.where(causal_mask, s, float("-inf"))
# Mask out-of-bounds keys
s = tl.where(mask_n[None, :], s, float("-inf"))
# ---- Online softmax update (within-chunk) ----
m_ij = tl.max(s, axis=1) # [BLOCK_M]
m_new = tl.maximum(m_i, m_ij)
# Rescale previous accumulator
alpha = tl.exp(m_i - m_new)
# Compute exp(s - m_new) for current tile
p = tl.exp(s - m_new[:, None]) # [BLOCK_M, BLOCK_N]
l_i = alpha * l_i + tl.sum(p, axis=1)
o_i = o_i * alpha[:, None]
# Load V tile: [BLOCK_N, D]
v = tl.load(
V_bh + off_n[:, None] * stride_v_s + off_d[None, :] * stride_v_d,
mask=mask_n[:, None],
other=0.0,
).to(tl.float32)
# Accumulate: O += P @ V
o_i += tl.dot(p.to(tl.float32), v)
m_i = m_new
# ================================================================
# Merge with inter-step accumulator (across ring steps)
# ================================================================
# Load previous accumulator state
o_acc = tl.load(
O_bh + off_m[:, None] * stride_o_s + off_d[None, :] * stride_o_d,
mask=mask_m[:, None],
other=0.0,
).to(tl.float32)
m_acc = tl.load(
M_acc_ptr + pid_bh * S_Q + off_m,
mask=mask_m,
other=float("-inf"),
).to(tl.float32)
l_acc = tl.load(
L_acc_ptr + pid_bh * S_Q + off_m,
mask=mask_m,
other=0.0,
).to(tl.float32)
# Merge formula:
# m_merged = max(m_acc, m_i)
# l_merged = l_acc * exp(m_acc - m_merged) + l_i * exp(m_i - m_merged)
# O_merged = O_acc * exp(m_acc - m_merged) + O_i * exp(m_i - m_merged)
m_merged = tl.maximum(m_acc, m_i)
alpha_acc = tl.exp(m_acc - m_merged)
alpha_new = tl.exp(m_i - m_merged)
l_merged = l_acc * alpha_acc + l_i * alpha_new
o_merged = o_acc * alpha_acc[:, None] + o_i * alpha_new[:, None]
# ================================================================
# Store merged state back to global memory
# ================================================================
tl.store(
O_bh + off_m[:, None] * stride_o_s + off_d[None, :] * stride_o_d,
o_merged,
mask=mask_m[:, None],
)
tl.store(
M_acc_ptr + pid_bh * S_Q + off_m,
m_merged,
mask=mask_m,
)
tl.store(
L_acc_ptr + pid_bh * S_Q + off_m,
l_merged,
mask=mask_m,
)
6.2 Kernel Launch Wrapper
def triton_ring_attn_step(
q: torch.Tensor, # [B*H, S_Q, D]
k: torch.Tensor, # [B*H, S_K, D]
v: torch.Tensor, # [B*H, S_K, D]
o_acc: torch.Tensor, # [B*H, S_Q, D] (float32, in-place)
m_acc: torch.Tensor, # [B*H, S_Q] (float32, in-place)
l_acc: torch.Tensor, # [B*H, S_Q] (float32, in-place)
q_positions: torch.Tensor, # [S_Q] (int64)
k_positions: torch.Tensor, # [S_K] (int64)
causal: bool = True,
BLOCK_M: int = 128,
BLOCK_N: int = 64,
):
"""
Launch the Triton kernel for one ring attention step.
Computes partial attention of `q` against `k, v` and merges
into accumulators `o_acc, m_acc, l_acc` in-place.
"""
BH, S_Q, D = q.shape
_, S_K, _ = k.shape
# Ensure D is power of 2 (required by tl.arange)
assert D & (D - 1) == 0, f"Head dim D={D} must be a power of 2"
grid = (BH, triton.cdiv(S_Q, BLOCK_M))
_ring_attn_fwd_kernel[grid](
# Tensor pointers
q, k, v,
o_acc, m_acc, l_acc,
q_positions, k_positions,
# Strides
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o_acc.stride(0), o_acc.stride(1), o_acc.stride(2),
# Dimensions
S_Q, S_K,
# Compile-time constants
D=D,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
CAUSAL=causal,
)
7. Ring Attention Forward — Triton Orchestration
This function combines the ring communication (Section 3) with the Triton kernel (Section 6), implementing the full overlapped pipeline.
def ring_attention_forward_triton(
q: torch.Tensor, # [B, H, S_local, D]
k: torch.Tensor, # [B, H, S_local, D]
v: torch.Tensor, # [B, H, S_local, D]
q_positions: torch.Tensor, # [S_local]
all_positions: List[torch.Tensor],
comm: RingCommunicator,
causal: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Triton-accelerated ring attention forward with compute-communication overlap.
Returns:
O: [B, H, S_local, D] — normalized attention output
m_acc: [B*H, S_local] — final row-wise max (saved for backward)
l_acc: [B*H, S_local] — final row-wise lse (saved for backward)
"""
B, H, S_local, D = q.shape
BH = B * H
device = q.device
# Flatten batch and heads for the kernel: [BH, S_local, D]
q_3d = q.reshape(BH, S_local, D).contiguous()
# ---- Initialize accumulators in float32 ----
o_acc = torch.zeros(BH, S_local, D, device=device, dtype=torch.float32)
m_acc = torch.full((BH, S_local), float("-inf"), device=device, dtype=torch.float32)
l_acc = torch.zeros(BH, S_local, device=device, dtype=torch.float32)
# ---- Double-buffered K/V ----
k_cur = k.reshape(BH, S_local, D).contiguous()
v_cur = v.reshape(BH, S_local, D).contiguous()
k_buf = torch.empty_like(k_cur)
v_buf = torch.empty_like(v_cur)
q_pos = q_positions.to(device)
for step in range(comm.cp_size):
# ---- (A) Start async communication (non-blocking) ----
if step < comm.cp_size - 1:
handles = comm.send_recv_kv(k_cur, v_cur, k_buf, v_buf)
# ---- (B) Compute: Triton kernel for this K/V chunk ----
source_rank = (comm.cp_rank - step) % comm.cp_size
k_pos = all_positions[source_rank].to(device)
triton_ring_attn_step(
q_3d, k_cur, v_cur,
o_acc, m_acc, l_acc,
q_pos, k_pos,
causal=causal,
)
# ---- (C) Wait for communication, swap buffers ----
if step < comm.cp_size - 1:
RingCommunicator.wait_all(handles)
k_cur, k_buf = k_buf, k_cur
v_cur, v_buf = v_buf, v_cur
# ---- Final normalization: O = O_unnorm / l ----
O = o_acc / l_acc.unsqueeze(-1).clamp(min=1e-6)
O = O.reshape(B, H, S_local, D).to(q.dtype)
return O, m_acc, l_acc
7.1 Timing Overlap Condition
Effective overlap requires the compute time of one ring step to exceed the communication time:
Simplifying:
For an A100 (312 TFLOPS bf16) with NVLink (600 GB/s), this gives
8. Backward Pass
8.1 Mathematical Derivation
Given the forward computation at ring step
The gradients are:
8.2 Ring Backward Algorithm
The backward pass mirrors the forward ring structure:
chunks rotate in the same direction as the forward pass - At each step, recompute
and using saved (avoids storing the full attention matrix) accumulates locally (each GPU’s queries are fixed) are accumulated into a rotating buffer that travels in the reverse direction, so each gradient reaches its owning GPU after steps
8.3 PyTorch Reference Implementation
def ring_attention_backward_reference(
dO: torch.Tensor, # [B, H, S_local, D] — upstream gradient
q: torch.Tensor, # [B, H, S_local, D] — saved from forward
k: torch.Tensor, # [B, H, S_local, D] — saved from forward
v: torch.Tensor, # [B, H, S_local, D] — saved from forward
O: torch.Tensor, # [B, H, S_local, D] — saved output
m_final: torch.Tensor, # [B*H, S_local] — saved from forward
l_final: torch.Tensor, # [B*H, S_local] — saved from forward
q_positions: torch.Tensor,
all_positions: List[torch.Tensor],
comm: RingCommunicator,
causal: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Backward pass for ring attention.
Returns dQ, dK, dV each of shape [B, H, S_local, D].
"""
B, H, S_local, D = q.shape
device = q.device
scale = D ** -0.5
dQ = torch.zeros_like(q, dtype=torch.float32)
# Pre-compute delta = rowsum(dO * O): [B, H, S_local]
delta = (dO.float() * O.float()).sum(dim=-1)
# Reshape m_final, l_final to [B, H, S_local]
m_f = m_final.reshape(B, H, S_local)
l_f = l_final.reshape(B, H, S_local)
# ---- K/V double buffer (same rotation as forward) ----
k_cur, v_cur = k.clone(), v.clone()
k_buf, v_buf = torch.empty_like(k), torch.empty_like(v)
# ---- dK/dV accumulator (reverse rotation) ----
# At step t, we compute dK/dV for the K/V chunk currently held.
# We accumulate into dk_acc/dv_acc, which we send BACKWARD
# through the ring so gradients reach the owning GPU.
dk_acc = torch.zeros_like(k, dtype=torch.float32)
dv_acc = torch.zeros_like(v, dtype=torch.float32)
dk_recv = torch.empty_like(dk_acc)
dv_recv = torch.empty_like(dv_acc)
# Reverse ring: send to (rank-1), recv from (rank+1)
rev_send = (comm.cp_rank - 1) % comm.cp_size
rev_recv = (comm.cp_rank + 1) % comm.cp_size
for step in range(comm.cp_size):
# ---- Async K/V forward rotation (except last step) ----
if step < comm.cp_size - 1:
kv_handles = comm.send_recv_kv(k_cur, v_cur, k_buf, v_buf)
# ---- Determine source and positions ----
source_rank = (comm.cp_rank - step) % comm.cp_size
k_positions = all_positions[source_rank].to(device)
# ---- Recompute S and P for this chunk ----
S = torch.matmul(q.float(), k_cur.float().transpose(-2, -1)) * scale
if causal:
causal_mask = q_positions.unsqueeze(-1) >= k_positions.unsqueeze(-2)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
S = S.masked_fill(~causal_mask, float("-inf"))
# Reconstruct P using saved global statistics
P = torch.exp(S - m_f.unsqueeze(-1)) / l_f.unsqueeze(-1).clamp(min=1e-6)
# ---- Compute gradients ----
# dS = P * (dO @ V^T - delta)
dP = torch.matmul(dO.float(), v_cur.float().transpose(-2, -1))
dS = P * (dP - delta.unsqueeze(-1))
# dQ += dS @ K / sqrt(d_k)
dQ += torch.matmul(dS, k_cur.float()) * scale
# dK, dV for this chunk
dk_step = torch.matmul(dS.transpose(-2, -1), q.float()) * scale
dv_step = torch.matmul(P.transpose(-2, -1), dO.float())
dk_acc += dk_step
dv_acc += dv_step
# ---- Wait for K/V rotation ----
if step < comm.cp_size - 1:
RingCommunicator.wait_all(kv_handles)
k_cur, k_buf = k_buf, k_cur
v_cur, v_buf = v_buf, v_cur
# ---- Reverse-rotate dk_acc, dv_acc ----
if step < comm.cp_size - 1:
rev_handles = []
rev_handles.append(dist.isend(
dk_acc, dst=dist.get_global_rank(comm.group, rev_send),
group=comm.group))
rev_handles.append(dist.isend(
dv_acc, dst=dist.get_global_rank(comm.group, rev_send),
group=comm.group))
rev_handles.append(dist.irecv(
dk_recv, src=dist.get_global_rank(comm.group, rev_recv),
group=comm.group))
rev_handles.append(dist.irecv(
dv_recv, src=dist.get_global_rank(comm.group, rev_recv),
group=comm.group))
for h in rev_handles:
h.wait()
dk_acc, dk_recv = dk_recv, dk_acc
dv_acc, dv_recv = dv_recv, dv_acc
return dQ.to(q.dtype), dk_acc.to(k.dtype), dv_acc.to(v.dtype)
9. Autograd Function — Complete Integration
class RingAttentionFunction(torch.autograd.Function):
"""
Autograd-compatible ring attention with zig-zag partitioning.
Supports both Triton-accelerated and pure-PyTorch execution.
"""
@staticmethod
def forward(
ctx,
q: torch.Tensor, # [B, H, S_local, D]
k: torch.Tensor,
v: torch.Tensor,
q_positions: torch.Tensor,
all_positions: List[torch.Tensor],
comm: RingCommunicator,
causal: bool,
use_triton: bool,
) -> torch.Tensor:
if use_triton:
O, m_final, l_final = ring_attention_forward_triton(
q, k, v, q_positions, all_positions, comm, causal
)
else:
O = ring_attention_forward_reference(
q, k, v, q_positions, all_positions, comm, causal
)
# For backward, we'd need m_final, l_final from reference too
# (omitted here for brevity — extend reference to return them)
m_final = l_final = None
# Save tensors for backward
ctx.save_for_backward(q, k, v, O)
ctx.q_positions = q_positions
ctx.all_positions = all_positions
ctx.comm = comm
ctx.causal = causal
ctx.m_final = m_final
ctx.l_final = l_final
return O
@staticmethod
def backward(ctx, dO: torch.Tensor):
q, k, v, O = ctx.saved_tensors
dQ, dK, dV = ring_attention_backward_reference(
dO, q, k, v, O,
ctx.m_final, ctx.l_final,
ctx.q_positions, ctx.all_positions,
ctx.comm, ctx.causal,
)
# Return gradients for each forward input
# (None for non-Tensor / non-differentiable args)
return dQ, dK, dV, None, None, None, None, None
10. ContextParallelAttention Module
import torch.nn as nn
class ContextParallelAttention(nn.Module):
"""
Drop-in multi-head attention module with context parallelism.
Handles zig-zag partitioning, ring attention, and gradient sync.
Input shape: [B, S_full, hidden_dim] (full sequence, pre-partitioned)
— OR —
Pre-partitioned inputs when integrated into a CP-aware pipeline.
"""
def __init__(
self,
hidden_dim: int,
num_heads: int,
cp_group: dist.ProcessGroup,
causal: bool = True,
use_triton: bool = True,
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.causal = causal
self.use_triton = use_triton
# Projections (weights replicated across CP group, sharded by TP if combined)
self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Communication
self.comm = RingCommunicator(cp_group)
# Pre-compute zig-zag positions for all ranks
self._positions_cache = {}
def _get_positions(
self, seq_len: int, device: torch.device
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Cached zig-zag position computation."""
key = (seq_len, self.comm.cp_size)
if key not in self._positions_cache:
all_pos = [
zigzag_partition(seq_len, self.comm.cp_size, r).to(device)
for r in range(self.comm.cp_size)
]
self._positions_cache[key] = all_pos
all_pos = self._positions_cache[key]
return all_pos[self.comm.cp_rank], all_pos
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, S_full, hidden_dim] — full sequence input.
Returns:
out: [B, S_local, hidden_dim] — output for this rank's token subset.
"""
B, S_full, _ = x.shape
# ---- Zig-zag partition input ----
q_pos, all_pos = self._get_positions(S_full, x.device)
x_local = x.index_select(1, q_pos) # [B, S_local, hidden_dim]
# ---- Project to Q, K, V ----
S_local = x_local.shape[1]
q = self.q_proj(x_local).reshape(B, S_local, self.num_heads, self.head_dim)
k = self.k_proj(x_local).reshape(B, S_local, self.num_heads, self.head_dim)
v = self.v_proj(x_local).reshape(B, S_local, self.num_heads, self.head_dim)
# Transpose to [B, H, S_local, D]
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
# ---- Ring Attention ----
O = RingAttentionFunction.apply(
q, k, v,
q_pos, all_pos,
self.comm,
self.causal,
self.use_triton,
) # [B, H, S_local, D]
# ---- Output projection ----
O = O.transpose(1, 2).reshape(B, S_local, self.hidden_dim)
out = self.o_proj(O)
return out
11. End-to-End Usage Example
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_context_parallel(cp_size: int) -> dist.ProcessGroup:
"""
Create context-parallel process groups.
Assumes torch.distributed is already initialized.
"""
world_size = dist.get_world_size()
rank = dist.get_rank()
assert world_size % cp_size == 0
# Each CP group contains `cp_size` consecutive ranks
num_cp_groups = world_size // cp_size
cp_group = None
for i in range(num_cp_groups):
ranks = list(range(i * cp_size, (i + 1) * cp_size))
group = dist.new_group(ranks)
if rank in ranks:
cp_group = group
return cp_group
def main():
# ---- Initialize distributed ----
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
# ---- Hyperparameters ----
CP_SIZE = 4
BATCH_SIZE = 2
SEQ_LEN = 8192 # full sequence length
HIDDEN_DIM = 4096
NUM_HEADS = 32
HEAD_DIM = HIDDEN_DIM // NUM_HEADS # 128
# ---- Create CP group ----
cp_group = setup_context_parallel(CP_SIZE)
# ---- Instantiate model ----
model = ContextParallelAttention(
hidden_dim=HIDDEN_DIM,
num_heads=NUM_HEADS,
cp_group=cp_group,
causal=True,
use_triton=True,
).to(device).to(torch.bfloat16)
# ---- Synthetic input (each rank gets full sequence, module partitions it) ----
x = torch.randn(
BATCH_SIZE, SEQ_LEN, HIDDEN_DIM,
device=device, dtype=torch.bfloat16,
)
# ---- Forward ----
out = model(x) # [B, S_local, HIDDEN_DIM] where S_local = SEQ_LEN / CP_SIZE
# ---- Backward ----
loss = out.sum()
loss.backward()
# ---- Gradient sync across CP group (weights are replicated) ----
for param in model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
param.grad /= CP_SIZE
if rank == 0:
print(f"Output shape: {out.shape}") # [2, 2048, 4096]
print(f"Tokens per GPU: {SEQ_LEN // CP_SIZE}") # 2048
print(f"Memory per GPU (activations): "
f"~{out.element_size() * out.nelement() / 1e6:.1f} MB")
dist.destroy_process_group()
if __name__ == "__main__":
main()
Launch command:
torchrun --nproc_per_node=4 context_parallel.py
12. Performance Analysis and Optimization Notes
12.1 Memory Footprint Comparison
| Configuration | Per-GPU Activation Memory (Attention) |
|---|---|
| No parallelism | |
| FlashAttention only | |
| CP + Ring Attention | |
| CP + Ring + FlashAttn |
12.2 Kernel Performance Tuning
# Auto-tune BLOCK_M and BLOCK_N for different hardware
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=3),
],
key=["S_Q", "S_K", "D"],
)
@triton.jit
def _ring_attn_fwd_kernel_autotuned(
# ... same signature as _ring_attn_fwd_kernel ...
):
pass # identical body
12.3 Critical Implementation Considerations
| Consideration | Detail |
|---|---|
| Numerical precision | All accumulations (float32 regardless of input dtype to prevent overflow in |
| All-masked rows | When |
| Position index transport | Positions are computed analytically from source_rank, never sent over the network |
| Backward recomputation | Forward saves only |
| Combining with TP | When TP shards attention heads, CP operates on the local heads; the CP group is orthogonal to the TP group |
| GQA / MQA compatibility | For grouped-query attention, |
12.4 Communication-Computation Overlap Profiling
def profile_overlap_efficiency(
seq_len: int,
cp_size: int,
head_dim: int,
num_heads: int,
batch_size: int,
gpu_tflops: float = 312.0, # A100 bf16 peak
bw_gbps: float = 600.0, # NVLink 4.0 bidirectional
dtype_bytes: int = 2, # bfloat16
) -> dict:
"""Estimate whether communication is fully hidden behind compute."""
S_local = seq_len // cp_size
# FLOPs per ring step: 2 * S_local * S_local * head_dim * batch * heads
flops_per_step = 2 * S_local * S_local * head_dim * batch_size * num_heads
t_compute = flops_per_step / (gpu_tflops * 1e12)
# Bytes per ring step: 2 * S_local * head_dim * batch * heads * dtype_bytes
# (factor 2 for K and V)
bytes_per_step = 2 * S_local * head_dim * batch_size * num_heads * dtype_bytes
t_comm = bytes_per_step / (bw_gbps * 1e9)
overlap_ratio = t_compute / t_comm
return {
"t_compute_ms": t_compute * 1e3,
"t_comm_ms": t_comm * 1e3,
"overlap_ratio": overlap_ratio,
"fully_hidden": overlap_ratio >= 1.0,
"arithmetic_intensity": flops_per_step / bytes_per_step,
}
# Example:
stats = profile_overlap_efficiency(
seq_len=131072, cp_size=8, head_dim=128,
num_heads=32, batch_size=1
)
# t_compute_ms ≈ 8.59, t_comm_ms ≈ 0.22, overlap_ratio ≈ 39.3x → fully hidden
13. Integration with Multi-Dimensional Parallelism
The total GPU count factorizes as:
def create_parallel_groups(
world_size: int,
dp_size: int,
tp_size: int,
cp_size: int,
pp_size: int,
) -> dict:
"""
Create non-overlapping process groups for 4D parallelism.
Ordering (innermost to outermost): TP → CP → PP → DP
This places TP within a node (NVLink), CP across nearby nodes,
PP across distant nodes, and DP across replicas.
"""
assert world_size == dp_size * tp_size * cp_size * pp_size
rank = dist.get_rank()
groups = {}
# TP groups: tp_size consecutive ranks
for dp in range(dp_size):
for pp in range(pp_size):
for cp in range(cp_size):
ranks = [
dp * (tp_size * cp_size * pp_size)
+ pp * (tp_size * cp_size)
+ cp * tp_size
+ tp
for tp in range(tp_size)
]
g = dist.new_group(ranks)
if rank in ranks:
groups["tp"] = g
# CP groups: cp_size ranks, strided by tp_size
for dp in range(dp_size):
for pp in range(pp_size):
for tp in range(tp_size):
ranks = [
dp * (tp_size * cp_size * pp_size)
+ pp * (tp_size * cp_size)
+ cp * tp_size
+ tp
for cp in range(cp_size)
]
g = dist.new_group(ranks)
if rank in ranks:
groups["cp"] = g
# PP and DP groups follow analogous patterns (omitted for brevity)
return groups
This places TP on the fastest interconnect (intra-node NVLink), CP on the next tier (inter-node NVLink or InfiniBand), and PP/DP on the outermost communication rings — matching communication intensity to available bandwidth at each level.