Tensor Parallelism (TP) and Sequence Parallelism (SP)
1. Motivation and Problem Statement
ZeRO-style data parallelism shards parameters, gradients, and optimizer states across GPUs. However, as model scale increases, activation memory becomes the dominant bottleneck. ZeRO requires an all-gather to reconstruct the full parameter tensor before each computation step, which limits scalability.
Tensor Parallelism (TP) addresses this by sharding weights, gradients, optimizer states, and activations simultaneously, performing computation directly on the shards without gathering the full tensors beforehand.
2. Mathematical Foundations of Tensor Parallelism
Tensor parallelism exploits two fundamental properties of matrix multiplication. Given matrices A∈Rm×kandB∈Rk×n:
2.1 Column-wise Partitioning (Equation 1)
Partition BintoN column blocks:
B=[B1B2⋯BN],Bi∈Rk×Nn
Then:
A⋅B=A⋅[B1B2⋯BN]=[AB1AB2⋯ABN]
Each partial product ABi∈Rm×Nncan be computed independently on GPUi. The full result is recovered via concatenation (all-gather along the column dimension).
Each GPU stores only N1of the weight columns and producesN1 of the output columns. The input is replicated, but the output activations are sharded.
4. Row-Linear (Row-Parallel) Sharding
4.1 Mechanism
Partition the weight matrix W along its input (row) dimension:
W=W1W2⋮WN,Wi∈RNhin×hout
This requires a corresponding partition of the input along the hidden dimension:
X=[X1X2⋯XN],Xi∈Rb×s×Nhin
4.2 Forward Pass
Step
Operation
Description
1
Scatter
Split and distribute input Xso GPUireceivesXi
2
Local matmul
GPU icomputesYi=Xi⋅Wi∈Rb×s×hout
3
All-reduce
Sum {Y1,Y2,…,YN}element-wise:Y=∑i=1NYi
4.3 Communication Primitives
Scatter: O(b⋅s⋅hin) — distribute input chunks
All-reduce: O(b⋅s⋅hout) — sum partial results (equivalent to reduce-scatter + all-gather)
4.4 Key Property
Each GPU stores only N1 of the weight rows. The input is sharded, and the output is full-sized but requires summation for correctness.
5. Tensor Parallelism in the Transformer Block
A standard Transformer decoder layer consists of two primary sub-blocks:
Transformer Layer=LayerNorm→MHA→LayerNorm→MLP
5.1 MLP Block
The feedforward MLP in a Transformer consists of two linear projections with a nonlinearity (e.g., GeLU or SiLU):
MLP(X)=σ(XW1)W2
where:
W1∈Rh×4h (up-projection, or gate projection)
W2∈R4h×h (down-projection)
σ is an element-wise activation function (GeLU, SiLU, etc.)
TP Strategy for MLP:
Layer
Parallelism Type
Rationale
W1 (FC1 / up-projection)
Column-linear
Splits output hidden dim; input is broadcast (or already synced)
σ (activation)
Local
Applied element-wise on sharded activations
W2 (FC2 / down-projection)
Row-linear
Takes sharded input, produces full output; requires all-reduce
Why Column-Linear → Row-Linear (not vice versa)?
If we used Row-Linear → Column-Linear, we would need an intermediate all-reduce between the two layers (to get correct row-linear output) followed by another communication for the column-linear input. The Column-Linear → Row-Linear ordering requires:
Forward pass: One broadcast (often a no-op if inputs are already synced) + one all-reduce
Backward pass: One all-reduce + one no-op
This yields one all-reduce per MLP sub-block per direction — the minimal communication.
Rationale: In the forward pass, inputs are already replicated across TP ranks (making fa no-op), but outputs from row-linear layers must be summed (f∗= all-reduce). In the backward pass, gradient flow reverses: gradients arriving atf∗are already correct (no-op), while gradients atf must be synchronized (all-reduce).
6.4 Critical Path and Overlap Limitations
Unlike ZeRO, where communication (all-gather of parameters) can be overlapped with computation, TP places communication directly on the critical path:
Tlayer=Tcompute+Tall-reduceexposed
The all-reduce at the end of each MLP and MHA block cannot begin until the local matrix multiplication completes, and the subsequent LayerNorm cannot begin until the all-reduce finishes. This creates a synchronization barrier on every layer.
Partial mitigation (Megatron-LM, Nanotron): Overlap all-gather with FC1 computation using block/chunked matrix multiplication:
Divide the weight into chunks
As each chunk’s all-gather completes, begin matmul for that chunk while the next chunk’s all-gather proceeds asynchronously
Advanced mitigation (Domino): Novel scheduling techniques to maximize overlap of communication and computation within TP regions.
7. Scaling Behavior and Trade-offs
7.1 Intra-node vs. Inter-node Communication
Regime
Interconnect
Bandwidth (typical)
Latency
NTP≤8 (intra-node)
NVLink / NVSwitch
600–900 GB/s (bidirectional, per GPU)
~µs
NTP>8 (inter-node)
InfiniBand / EFA
50–400 GB/s (per node)
~10–100 µs
Empirical observation: Throughput drops significantly at NTP=16(crossing node boundary) and precipitously atNTP=32.
7.2 Memory Reduction
For a model with Ptotal parameters, with TP degreeN:
Parameters per GPU=NPGradients per GPU=NPOptimizer states per GPU=NS⋅P
where Sis the optimizer state multiplier (e.g.,S=12 bytes per parameter for Adam in mixed precision: 4 bytes fp32 master weights + 4 bytes first moment + 4 bytes second moment).
Activation memory in TP regions:
Activation per GPU (TP region)=Nb⋅s⋅h(intermediate activations)
However, operations like LayerNorm and Dropout still require the full activation tensor (b,s,h), partially negating activation memory savings.
7.3 Throughput–Memory Trade-off
Effective throughput per GPU∝Tcompute+TcommunicationTcompute
Net effect: diminishing returns beyond N=8 (single node)
The benefit is that reduced memory per GPU allows larger batch sizes, which can compensate for per-GPU throughput loss at the system level.
8. Sequence Parallelism (SP)
8.1 Motivation
Even with TP, operations outside the attention and MLP sub-blocks — specifically LayerNorm and Dropout — require the full hidden dimensionhand therefore cannot be sharded alongh. These operations still store activations of shape (b,s,h) on each GPU, creating memory bottlenecks.
Sequence Parallelism shards these operations along the sequence dimensions instead:
SP shard on GPU i:Xi∗∈Rb×Ns×h
8.2 Why LayerNorm Requires Full Hidden Dimension
LayerNorm computes:
LayerNorm(x)=γ⋅σ2+ϵx−μ+β
where the statistics are computed across the hidden dimensionh:
μ=h1j=1∑hxj,σ2=h1j=1∑h(xj−μ)2
Since μandσ2require access to allhelements, we cannot shard LayerNorm alongh. However, different sequence positions are independent for LayerNorm, so we can shard along s.
These replace the f/f∗ operators at the boundaries between SP regions and TP regions.
8.4 Forward Pass Through a TP+SP Transformer Layer (Step-by-Step)
Consider an MLP sub-block with TP+SP on N=2 GPUs:
Step 1: Initial LayerNorm (SP Region)
Each GPU iholdsXi∗∈Rb×2s×h (sharded along sequence dimension).
LayerNorm is computed independently per sequence position:
Yi∗=LayerNorm(Xi∗)∈Rb×2s×h
Step 2: Transition SP → TP (g = all-gather)
Gather sequence chunks to reconstruct the full sequence on each GPU:
Y=AllGather(Y1∗,Y2∗)∈Rb×s×h
This is necessary because the column-linear layer (FC1) needs the full input.
Step 3: First Linear Layer — Column-Linear (TP Region)
GPU iholds column shardW1(i)∈Rh×N4h:
Zi∗=σ(Y⋅W1(i))∈Rb×s×N4h
where σ is the activation function (GeLU/SiLU), applied element-wise.
Step 4: Second Linear Layer — Row-Linear (TP Region)
GPU iholds row shardW2(i)∈RN4h×h:
Y^i=Zi∗⋅W2(i)∈Rb×s×h
Step 5: Transition TP → SP (g∗ = reduce-scatter)
Y^j∗=ReduceScatter(Y^1,Y^2)∈Rb×Ns×h
This simultaneously:
Reduces (sums) the partial results from the row-linear operation (required for correctness: Y^=∑iY^i)
Scatters the result along the sequence dimension (returning to SP sharding)
8.5 Activation Shape Summary Table
Region
Vanilla TP
TP + SP
Enter TP (column-linear input)
h: full, s: full
h: full, s: all-gather to full
TP region (between FC1 and FC2)
h: sharded (Nh), s: full
h: sharded (Nh), s: full
Exit TP (row-linear output)
h: full + all-reduce, s: full
h: full + reduce-scatter, s: reduce-scatter to sharded
SP region (LayerNorm, Dropout)
h: full, s: full
h: full, s: sharded (Ns)
8.6 Maximum Activation Size
Vanilla TP:
Max activation per GPU=b⋅s⋅h
(Full activations required in LayerNorm/Dropout regions.)
TP + SP:
Max activation per GPU=Nb⋅s⋅h
At every point in the computation, activations are sharded along eitherh(in TP regions) ors(in SP regions), never requiring the full(b,s,h) tensor on any single GPU.
and TP+SP replaces each all-reduce with one all-gather and one reduce-scatter (at different points in the computation):
Total communication(TP+SP)=Total communication(Vanilla TP)
SP achieves strictly better activation memory with no additional communication overhead.
8.8 Gradient Synchronization Notes
LayerNorm weights in SP regions: Since each TP rank sees different sequence positions (but the same LayerNorm parameters γ,β), gradients for γandβ will differ across ranks. An all-reduce of LayerNorm gradients is required during the backward pass:
∇γ=AllReduce(∇γ(1),∇γ(2),…,∇γ(N))
This overhead is negligible since LayerNorm has only 2h parameters (compared to the full model’s billions).
Dropout in SP regions: Random masks must be synchronized across TP ranks to maintain deterministic behavior. In practice, this is achieved by synchronizing the random seed across ranks:
seeddropout(i)=seedglobal,∀i∈{1,…,N}
9. Embedding Layer Treatment
The vocabulary embedding E∈RV×h(whereV is vocabulary size) is typically sharded along the vocabulary dimension (row-linear):
Configuration
Sharding
Communication
Vanilla TP
h: full (all-reduce for correctness), s: full
All-reduce
TP + SP
h: full (reduce-scatter for correctness), s: reduce-scatter to sharded
Reduce-scatter
10. Limitations of TP + SP
Limitation
Description
Sequence length scaling
In the TP region, activations are (b,s,Nh); as sgrows, activation memory still scales linearly withs in these regions
Inter-node communication
For NTP>8 (exceeding a single node), bandwidth drops from NVLink (~900 GB/s) to network interconnect (~100–400 GB/s), causing severe throughput degradation
Critical path communication
All-gather and reduce-scatter remain on the critical path and cannot be fully overlapped with computation
Head count constraint
NTP≤nh; limits maximum parallelism degree
Solutions to these limitations:
Context Parallelism (CP): Addresses sequence-length-induced activation memory blowup by sharding the attention computation across the sequence dimension
Pipeline Parallelism (PP): Addresses model-too-large-for-one-node by partitioning layers across nodes, avoiding the need for NTP>8
This pattern repeats for both the MHA and MLP sub-blocks within each Transformer layer, yielding 4 communication operations per layer per pass (2 for MHA + 2 for MLP).