Data Parallelism: A Comprehensive Technical Treatment
1. Foundational Concept
Data Parallelism (DP) is the most fundamental distributed training strategy for deep learning. The core idea is straightforward: replicate the entire model on
Formal Setup
Let the model parameters be denoted
The synchronized (averaged) gradient used for the parameter update is:
This averaged gradient
The parameter update then proceeds identically on every GPU:
Because every GPU applies the same averaged gradient to the same parameters, all replicas remain synchronized after every step.
2. The All-Reduce Communication Primitive
The operation that computes
After the all-reduce completes, every GPU holds the identical averaged gradient
Communication Cost of All-Reduce (Ring All-Reduce)
For a tensor of size
This decomposes into a reduce-scatter phase (volume
3. Naive Data Parallelism and Its Inefficiency
A naive implementation proceeds sequentially:
- Forward pass on each GPU (computation)
- Backward pass on each GPU (computation)
- All-reduce over gradients (communication)
- Optimizer step (computation)
The critical inefficiency: during step 3, all GPUs are idle — they have finished computation and are waiting for communication to complete. This sequential dependency between computation and communication is fundamentally wasteful.
4. Three Key Optimizations
4.1 First Optimization: Overlapping Gradient Synchronization with the Backward Pass
Key Insight: In the backward pass, gradients are computed layer by layer, starting from the last layer and moving toward the first. The gradient
Therefore, we can trigger the all-reduce for layer
Implementation Mechanism: In PyTorch, this is achieved by registering a post-accumulate gradient hook on each parameter:
def register_backward_hook(self, hook):
"""
Registers a backward hook for all parameters of the model that
require gradients.
"""
for p in self.module.parameters():
if p.requires_grad is True:
p.register_post_accumulate_grad_hook(hook)
When the gradient for parameter
Result: The all-reduce communication is overlapped with backward computation. In the ideal case, by the time the backward pass finishes computing
4.2 Second Optimization: Bucketing Gradients
Key Insight: GPU kernels and network operations are significantly more efficient on large contiguous tensors than on many small tensors. Launching an independent all-reduce for each individual parameter tensor incurs excessive kernel launch overhead and underutilizes network bandwidth.
Solution: Group gradients into buckets of a fixed size (e.g., 25 MB in PyTorch DDP). A single all-reduce is launched for the entire bucket once all gradients within that bucket have been computed.
Procedure:
- Assign model parameters to buckets in reverse computation order (so that the last-computed gradients fill the first bucket).
- When all gradients in a bucket are ready, launch a single all-reduce for that bucket.
- This reduces the number of communication operations from
(where is the number of parameters) to .
Analogy: Instead of shipping many small packages individually, pack items into a few large boxes and ship those — reducing per-item shipping overhead.
4.3 Third Optimization: Interplay with Gradient Accumulation
Gradient accumulation performs
Problem with naive combination: If DP is active, an all-reduce is triggered after every backward pass. But during gradient accumulation, we only need the synchronized gradient after the final accumulation step — the intermediate all-reduces are wasteful.
Solution: Use a no-sync context manager to disable gradient synchronization during the first
for k in range(K):
if k < K - 1:
with model.no_sync(): # Disable all-reduce
loss = model(micro_batch[k])
loss.backward()
else:
loss = model(micro_batch[k]) # All-reduce triggered here
loss.backward()
optimizer.step()
This reduces communication overhead by a factor of
Memory Contiguity Note
Communication operations require tensors to be contiguous in memory. In practice, pre-allocated contiguous communication buffers are used to avoid redundant memory copies. While this accelerates communication, it contributes to peak memory usage during training.
5. Global Batch Size Equation
With data parallelism and gradient accumulation, the relationship between batch sizes is:
Where:
| Symbol | Definition |
|---|---|
| Global batch size — total number of samples processed per optimizer step | |
| Micro-batch size — number of samples per forward pass on a single GPU | |
| Gradient accumulation steps — number of sequential forward-backward passes before an optimizer step | |
| Data parallel degree — number of GPU replicas |
Key Practical Principle
Maximize
- Data parallelism is inherently parallel — all
GPUs compute simultaneously - Gradient accumulation is inherently sequential —
steps execute one after another
Therefore,
6. Practical Recipe for 1D Data-Parallel Training
Step-by-Step Procedure
- Determine the optimal global batch size
(in tokens) — from literature or convergence experiments. - Select the training sequence length
— typically 2,048 to 8,192 tokens works reliably for current evaluation benchmarks. (Longer documents are rare on the web; shorter sequences suffice for most pretraining.) - Convert to samples:
- Find the maximum micro-batch size
by increasing it on a single GPU until out-of-memory. - Determine available GPUs
. - Compute the required gradient accumulation steps:
Concrete Example
- Target:
, - Batch size in samples:
(nearest power of 2) - Observation: a single GPU fits
| GPUs ( |
Behavior | |
|---|---|---|
| 128 | 4 sequential accumulation steps | |
| 512 | No accumulation needed — faster training | |
| 1024+ | GPU-rich: reduce |
7. Scaling Limits of Data Parallelism
Communication Overhead at Scale
At large
- Ring latency: The minimum time for a signal to traverse all
nodes in a ring topology scales as . - Network bandwidth saturation: The aggregate gradient traffic approaches network capacity limits.
- Overlap breakdown: The backward pass computation can no longer fully mask the growing communication time.
As a result, throughput per GPU decreases with each additional DP rank beyond a critical point (empirically around
Memory Limitation
Data parallelism requires that at least one complete layer (and ideally one full forward pass with
Quick memory estimate for parameters alone:
For example, a 70B parameter model requires approximately
8. Zero Redundancy Optimizer (ZeRO)
Motivation
In vanilla DP, every GPU stores a complete copy of:
- Model parameters
- Gradients
- Optimizer states (e.g., Adam’s first and second moments)
This is massively redundant. ZeRO eliminates this redundancy by partitioning (sharding) these tensors across DP ranks, reconstructing them on demand when needed.
Memory Baseline (Mixed-Precision Training with Adam)
Let
| Component | Precision | Memory |
|---|---|---|
| Parameters | BF16/FP16 | |
| Gradients | BF16/FP16 | |
| FP32 master copy of parameters | FP32 | |
| Adam first moment ( |
FP32 | |
| Adam second moment ( |
FP32 |
The optimizer states memory multiplier is
Without FP32 gradient accumulation:
With FP32 gradient accumulation (optional additional
8.1 ZeRO Stage 1: Optimizer State Partitioning
What is sharded: Optimizer states only (FP32 master weights, Adam
How it works:
- Forward pass: Each GPU uses the full BF16 parameters
(identical across replicas) on its own micro-batch. - Backward pass: Each GPU computes full gradients
on its micro-batch. - Reduce-scatter on gradients: Instead of all-reduce, perform a reduce-scatter. After this operation, GPU
holds only the -th shard of the summed gradients — precisely the shard corresponding to its optimizer state partition. - Local optimizer step: Each GPU updates only its
shard of optimizer states and produces of the updated FP32 parameters, which are cast back to BF16. - All-gather on BF16 parameters: Reconstruct the full BF16 parameter set on every GPU for the next forward pass.
Memory per GPU:
As
Communication:
| Operation | Volume per GPU | When |
|---|---|---|
| Reduce-scatter (gradients) | After backward pass | |
| All-gather (BF16 parameters) | After optimizer step |
Note on reduce-scatter vs. all-reduce: A reduce-scatter has half the communication volume of an all-reduce (
Overlapping Strategies for the All-Gather
The all-gather of BF16 parameters (step 5) is a new communication cost not present in vanilla DP. Two strategies exist to overlap it:
- During the optimizer step: Initiate the all-gather as soon as the first shard is updated, overlapping with updates of remaining shards.
- During the forward pass: Prefetch parameters layer-by-layer — all-gather layer
’s parameters while computing the forward pass for layer .
8.2 ZeRO Stage 2: Optimizer State + Gradient Partitioning
Key Insight: Since each GPU only needs
What is sharded: Optimizer states and gradients.
How it works:
The procedure is identical to ZeRO-1, except:
- After the reduce-scatter in the backward pass, each GPU retains only its gradient shard and discards the rest.
- Gradients are released from memory on the fly as they are scattered.
Memory per GPU:
As
Compared to the baseline
Communication: Identical to ZeRO-1:
Practical Note: ZeRO-2 has no communication overhead relative to ZeRO-1 while providing strictly better memory savings. Therefore, ZeRO-2 is generally preferred over ZeRO-1.
8.3 ZeRO Stage 3: Full Partitioning (FSDP)
What is sharded: Optimizer states, gradients, and parameters.
PyTorch’s native implementation of ZeRO-3 is called FSDP (Fully Sharded Data Parallelism).
How it works:
Each GPU stores only
Forward Pass
For each layer
- All-gather the full parameters
for layer from all GPUs. - Compute the forward pass for layer
. - Discard the non-local parameter shards (free memory).
Backward Pass
For each layer
- All-gather the full parameters
again (they were discarded after the forward pass). - Compute the backward pass for layer
, producing gradients. - Reduce-scatter the gradients to retain only the local shard.
- Discard the non-local parameter shards.
Memory per GPU:
As
Communication Cost Analysis
| Operation | Count per Step | Volume per Operation | Total |
|---|---|---|---|
| All-gather (forward pass) | |||
| All-gather (backward pass) | |||
| Reduce-scatter (gradients) | 1 |
This is a 1.5× increase over ZeRO-2’s
Prefetching to Overlap Communication
The additional all-gathers can be overlapped with computation via prefetching:
- Forward pass: While computing layer
, initiate all-gather for layer ’s parameters. - Backward pass: While computing layer
, initiate all-gather for layer ’s parameters.
This overlap is effective as long as
Critical Limitation
ZeRO-3 partitions parameters, gradients, and optimizer states but cannot partition activations. Since each DP replica processes a different micro-batch, the activations are unique to each GPU (not duplicated) and therefore cannot be sharded across DP ranks.
Activation memory scales as:
where
9. Comparative Summary of ZeRO Stages
| Stage | What is Sharded | Memory per GPU | Communication Volume |
|---|---|---|---|
| Vanilla DP | Nothing | ||
| ZeRO-1 | Optimizer states | ||
| ZeRO-2 | Optimizer states + gradients | ||
| ZeRO-3 (FSDP) | Optimizer states + gradients + parameters |
With
| Stage | Memory per GPU | Memory as |
|---|---|---|
| Vanilla DP | ||
| ZeRO-1 | ||
| ZeRO-2 | ||
| ZeRO-3 |
10. Transition to Further Parallelism Dimensions
Data parallelism with ZeRO provides powerful memory savings for model states (parameters, gradients, optimizer states) but faces two fundamental limits:
- Communication overhead grows with
, eventually dominating computation time. - Activation memory is not shardable via ZeRO because activations differ across DP replicas.
- Single-layer memory must fit on one GPU — ZeRO-3 gathers full layer parameters, so each layer must fit in GPU memory.
These limitations motivate orthogonal parallelism dimensions:
- Tensor Parallelism (TP): Shards parameters, gradients, optimizer states, and activations across devices within a layer — without communicating full model parameters between GPUs.
- Pipeline Parallelism (PP): Distributes different layers across different GPUs.
- Context/Sequence Parallelism (CP/SP): Shards along the sequence length dimension.
Data parallelism constitutes the first dimension of parallelism (1D parallelism), upon which these additional dimensions are composed to enable training of models that exceed single-GPU or single-node capacity.