Tensor Parallelism (TP) and Sequence Parallelism (SP)
1. Motivation: Why Tensor Parallelism?
ZeRO (Zero Redundancy Optimizer) successfully shards parameters, gradients, and optimizer states across GPUs. However, activation memory — the intermediate tensors produced during the forward pass — remains replicated on every device. As model size and sequence length grow, activation memory dominates the per-GPU memory budget, creating an insurmountable bottleneck.
Tensor Parallelism (TP) resolves this by sharding not only parameters, gradients, and optimizer states but also activations — and critically, it does so without requiring a full gather of all shards before computation. Instead, TP exploits the algebraic structure of matrix multiplication to distribute computation natively across devices.
2. Mathematical Foundations of Tensor Parallelism
Tensor parallelism rests on two fundamental decomposition properties of matrix multiplication. Given matrices A∈Rm×kandB∈Rk×n:
2.1 Column-wise Decomposition (Splitting B by Columns)
Partition BintoN column blocks:
B=[B1B2⋯BN],Bi∈Rk×(n/N)
Then:
A⋅B=A⋅[B1B2⋯BN]=[AB1AB2⋯ABN]
Each partial product ABi∈Rm×(n/N)can be computed independently on GPUi, and the final result is obtained by concatenating (all-gather) the partial outputs along the column dimension.
2.2 Row-wise Decomposition (Splitting Bby Rows andA by Columns)
Broadcast (or replicate) the full input Xto allN GPUs.
ShardW along its column dimension:
W=[W1W2⋯WN],Wi∈Rh×(h′/N)
Each GPU i computes:
Yi=X⋅Wi∈Rb×s×(h′/N)
All-gather across GPUs to reconstruct:
Y=[Y1Y2⋯YN]∈Rb×s×h′
3.2 Communication Primitives
Step
Operation
Communication
Input distribution
Broadcast
X replicated to all ranks
Computation
Local matmul
No communication
Output combination
All-Gather
Concatenate partial outputs
4. Row-Linear Parallelism
4.1 Procedure
Given NGPUs and weight matrixW∈Rh×h′:
Scatter the input X along the hidden (or appropriate) dimension:
X=[X1X2⋯XN],Xi∈Rb×s×(h/N)
ShardW along its row dimension:
W=W1W2⋮WN,Wi∈R(h/N)×h′
Each GPU i computes:
Yi=Xi⋅Wi∈Rb×s×h′
All-reduce (sum) across GPUs:
Y=i=1∑NYi∈Rb×s×h′
4.2 Communication Primitives
Step
Operation
Communication
Input distribution
Scatter
X split across ranks
Computation
Local matmul
No communication
Output combination
All-Reduce
Sum partial outputs
5. Tensor Parallelism in a Transformer Block
A standard Transformer decoder layer comprises two sub-blocks:
Multi-Head Attention (MHA) block
Feedforward MLP block
Each is amenable to tensor parallelism due to the existence of naturally independent dimensions.
5.1 Feedforward MLP Block
A typical MLP in a Transformer consists of two linear projections with a nonlinearity:
MLP(X)=GELU(XW1)⋅W2
where W1∈Rh×4h(up-projection) andW2∈R4h×h (down-projection).
Optimal TP strategy for MLP:
Layer
Parallelism Type
Rationale
W1 (FC1)
Column-linear
Split output dimension; GELU applied independently per shard
W2 (FC2)
Row-linear
Accepts sharded input from FC1; produces full output via all-reduce
Why this ordering (column → row) is superior to (row → column):
Column-linear first requires only a broadcast (or no-op if inputs are already synchronized) to distribute X.
Row-linear second requires an all-reduce to combine results.
Total: 1 broadcast + 1 all-reduce per MLP block in forward pass.
The reverse ordering (row → column) would require an intermediate all-reduce between the two linear layers plus additional communication, making it strictly less efficient.
The GELU nonlinearity is applied element-wise within each shard, requiring no cross-GPU communication — this is precisely why column-linear sharding of W1 is essential (sharding along the output dimension preserves the independence of the nonlinearity).
5.2 Multi-Head Attention (MHA) Block
The attention mechanism computes:
Attention(Q,K,V)=softmax(dkQK⊤)V
where Q=XWQ, K=XWK, V=XWV, and dk=h/nheads is the per-head dimension.
Natural parallelism: Each attention head operates independently. With nheadsattention heads distributed acrossNGPUs, each GPU handlesnheads/N heads.
Projection
Parallelism Type
Rationale
WQ,WK,WV
Column-linear
Each column shard corresponds to a subset of attention heads
WO (output projection)
Row-linear
Accepts concatenated head outputs (already sharded); produces full hidden via all-reduce
The communication pattern is identical to the MLP block:
In GQA, the number of key/value heads nkv_headsis smaller than the number of query headsnattention_heads:
nattention_heads≥nkv_heads
Constraint on TP degree:
TP≤nattention_heads
When TP>nkv_heads, K/V heads must be replicated or carefully synchronized across TP ranks. For example, Llama-3 8B has:
nattention_heads=32
nkv_heads=8
Maximum TP = 32, but for TP>8, K/V heads require cross-rank synchronization
6. Communication Analysis of Tensor Parallelism
6.1 Communication Primitives in the Critical Path
For each Transformer decoder layer (MHA + MLP), the forward pass requires:
Forward: 2×all-reduce(one for MHA, one for MLP)Backward: 2×all-reduce(conjugate operations)
These all-reduce operations sit directly on the critical path of computation — they cannot be trivially overlapped with compute because subsequent operations (e.g., LayerNorm, residual addition) depend on the synchronized result.
Critical path: The longest chain of sequentially dependent operations determining the minimum wall-clock time for a forward or backward pass.
6.2 Communication Volume
For an all-reduce of a tensor of size MacrossN GPUs, the total communication volume (using ring all-reduce) is:
Vall-reduce=2⋅NN−1⋅M
For the MLP block, M=b⋅s⋅h, giving:
VMLP=2⋅NN−1⋅b⋅s⋅hper layer (forward)
6.3 Scaling Behavior and Interconnect Dependence
TP Degree
Interconnect
Observed Behavior
TP≤8
NVLink (intra-node, ~900 GB/s bidirectional on A100/H100)
High throughput; communication overhead manageable
TP=16
Inter-node (InfiniBand/EFA, ~100–400 GB/s)
Significant throughput degradation
TP=32
Inter-node
Steep decline; communication dominates compute
Practical guideline:
TP degree≤GPUs per node(typically 8)
7. Memory Benefits of Tensor Parallelism
With TP degree N, the per-GPU memory for a linear layer with weight W∈Rh×h′ is:
7.1 Parameters
Params per GPU=Nh×h′
7.2 Gradients
Gradients per GPU=Nh×h′
7.3 Optimizer States (Adam)
Adam maintains first moment mand second momentv, each the same size as the parameters:
Optimizer states per GPU=N2×h×h′(in fp32, so N2×4×h×h′ bytes)
7.4 Activations (Partial Benefit)
Intermediate activations within TP regions are sharded:
Activation per GPU (TP region)=Nb⋅s⋅h′(for column-linear output)
However, operations like LayerNorm and dropout still require the full activation tensor b×s×h, limiting the activation memory savings. This is the precise limitation that Sequence Parallelism addresses.
8. Sequence Parallelism (SP)
8.1 Core Idea
Sequence parallelism shards the activations along the sequence dimensions for operations that are outside the tensor-parallel regions — specifically LayerNorm and dropout.
These operations require the full hidden dimensionh(e.g., LayerNorm computes statistics acrossh), so they cannot be sharded along h. However, they operate independently across sequence positions, making sharding along s natural.
8.2 LayerNorm Definition
LayerNorm(x)=γ⋅σ2+ϵx−μ+β
where:
μ=h1j=1∑hxj,σ2=h1j=1∑h(xj−μ)2
Both μandσ2are computed across the hidden dimensionh for each sequence position independently. Therefore:
Cannot shard along h (would produce incorrect statistics)
The interplay between TP regions and SP regions requires carefully designed communication operators:
TP Region Operators (fandf∗)
Pass
f
f∗
Forward
No-op (activations already replicated)
All-reduce (synchronize partial results)
Backward
All-reduce (synchronize gradients)
No-op (gradients already replicated)
fandf∗ are conjugate pairs: when one is a no-op, the other is an all-reduce, and vice versa across forward and backward passes.
SP ↔ TP Transition Operators (gandg∗)
Pass
g
g∗
Forward
All-gather (reconstruct full sequence for TP)
Reduce-scatter (shard sequence for SP)
Backward
Reduce-scatter (distribute gradients)
All-gather (reconstruct gradients)
gandg∗ are also conjugate pairs.
8.4 Data Flow Through a Transformer Layer with TP+SP
Consider a two-GPU setup (N=2) with input X∈Rb×s×h:
Step 1: LayerNorm (SP region)
Each GPU holds Xi∗∈Rb×(s/N)×h (sharded along sequence).
Each GPU computes LayerNorm independently on its chunk:
Yi∗=LayerNorm(Xi∗)∈Rb×(s/N)×h
Step 2: SP → TP transition (g: all-gather)
Reconstruct the full sequence on each GPU:
Y=AllGather(Y1∗,Y2∗,…,YN∗)∈Rb×s×h
Step 3: Column-linear / FC1 (TP region)
Each GPU computes with its column shard W1(i):
Zi∗=GELU(Y⋅W1(i))∈Rb×s×(4h/N)
Step 4: Row-linear / FC2 (TP region)
Each GPU computes with its row shard W2(i):
Oi=Zi∗⋅W2(i)∈Rb×s×h
Step 5: TP → SP transition (g∗: reduce-scatter)
Oj∗=ReduceScatter(O1,O2,…,ON)∈Rb×(s/N)×h
This simultaneously:
Reduces (sums) the partial row-linear outputs for correctness
Scatters the result along the sequence dimension for the subsequent SP region
8.5 Activation Memory Comparison
Configuration
Maximum Activation Size per GPU
No parallelism
b⋅s⋅h
TP only
b⋅s⋅h (LayerNorm/dropout still need full tensor)
TP + SP
TPb⋅s⋅h
With TP+SP, at every point in the computation, activations are sharded along either the hidden dimension (in TP regions) or the sequence dimension (in SP regions), ensuring:
Max activation per GPU=TPb⋅s⋅h
8.6 Summary Table: Activation Shape Throughout Forward Pass
Region
TP Only
TP + SP
Enter TP (column-linear)
h: sharded, s: full
h: sharded, s: all-gather → full
TP region (between linears)
h: sharded, s: full
h: sharded, s: full
Exit TP (row-linear)
h: full (all-reduce), s: full
h: full (reduce-scatter), s: reduce-scatter → sharded
SP region (LayerNorm, dropout)
h: full, s: full
h: full, s: sharded
Embedding layer (row-linear)
h: full (all-reduce), s: full
h: full (reduce-scatter), s: reduce-scatter → sharded
9. Communication Equivalence: TP vs. TP+SP
9.1 Per-Layer Communication Count
Method
Forward Operations per Layer
TP only
2 × all-reduce
TP + SP
2 × all-gather + 2 × reduce-scatter
9.2 Why They Are Equivalent
A single all-reduce can be decomposed into:
all-reduce=reduce-scatter+all-gather
Therefore, 2 all-reduce operations have the same communication volume as 2 all-gather + 2 reduce-scatter operations: