Finding the Best Training Configuration for Distributed Large Model Training
1. Problem Statement and Decision Framework
Given a model with Ptotal parameters, a target global batch sizeBglobal, a cluster of Nnodesnodes each containingGGPUs (total GPU countNGPU=Nnodes×G), and per-GPU memory capacity MGPU, the objective is to determine the optimal 5D parallelism configuration:
where Fmodelis the forward-pass FLOPs per token (approximately2Pfor dense transformers),Tstepis the wall-clock time per training step, andΦpeakis the peak FLOPs/s per GPU (e.g.,≈989TFLOP/s for H100 SXM in BF16).
2. Step 1 — Fitting a Training Step in Memory
2.1 Memory Accounting
The total per-GPU memory requirement decomposes as:
DTP⋅DPP2P bytes (BF16 master copy on each shard)
Gradients
Mgrads
DTP⋅DPP2P bytes
Optimizer states
Mopt
DTP⋅DPP12P bytes (FP32 copy + first & second moments)
Activations
Mact
∝bmbs⋅s⋅h⋅Llocal (depends on recomputation strategy)
Temporary buffers
Mtemp
Communication buffers, workspace allocations
Fragmentation overhead
Mfrag
CUDA memory allocator overhead
Here his the hidden dimension andLlocal=L/DPPis the number of transformer layers assigned to each pipeline stage, withL total layers.
2.2 ZeRO Optimization Stages and Their Memory Impact
ZeRO (Zero Redundancy Optimizer) partitions optimizer state, gradients, and optionally parameters across DDP data-parallel ranks:
ZeRO Stage
What is Sharded
Per-GPU Memory for Params+Grads+Optimizer
Stage 0 (baseline)
Nothing
DTP⋅DPP(2+2+12)P=DTP⋅DPP16P
Stage 1
Optimizer states
DTP⋅DPP(2+2)P+DDP⋅DTP⋅DPP12P
Stage 2
Optimizer states + Gradients
DTP⋅DPP2P+DDP⋅DTP⋅DPP(2+12)P
Stage 3
Optimizer states + Gradients + Parameters
DDP⋅DTP⋅DPP(2+2+12)P=DDP⋅DTP⋅DPP16P
The memory feasibility constraint is therefore:
Mtotal(P,DDP,DTP,DPP,Z,bmbs,s,recomp)≤MGPU
2.3 GPU-Rich Case — Decision Heuristics
The decision tree is determined by parameter count Pand available GPU countNGPU:
Case A: P<10B (Small-to-Medium Models)
A single parallelism dimension typically suffices:
Option 1: Tensor Parallelism with DTP≤8 (intra-node).
Per-GPU parameter memory:
MparamsTP=DTP2Pbytes
Option 2: ZeRO-3 with Data Parallelism across 8 GPUs plus full activation recomputation.
Per-GPU total model state memory:
MmodelZ3=DDP16Pbytes
For a 7B parameter model on 8 GPUs with ZeRO-3:
MmodelZ3=816×7×109=14GB per GPU
This easily fits within the 80 GB of an H100, leaving ample room for activations.
Case B: 10B≤P≤100B (Large Models, Multi-Node)
Multiple parallelism dimensions must be composed. The number of GPUs required exceeds one node (>8 GPUs), introducing inter-node communication as a critical consideration. Viable configurations include:
Communication volume analysis dictates the preferred choice. For ZeRO-3, the per-step communication volume per GPU scales as:
VZeRO-3=2×DDP⋅DTP⋅DPP2P⋅(DDP−1)≈DTP⋅DPP4Pfor large DDP
The factor of 2 accounts for one all-gather (forward) and one reduce-scatter (backward) per training step.
Case C: NGPU≥512 (Large-Scale Clusters)
At this scale, pure ZeRO-3 becomes communication-bound because the all-gather and reduce-scatter collectives span hundreds of GPUs across many nodes. The communication time for a ring all-gather over DDP ranks is:
with ZeRO Stage 2 (sharding optimizer states and gradients but not parameters, thus avoiding the additional all-gather for parameters during forward pass).
Special Considerations
Context Parallelism (CP) for very long sequences (s≫4096): Activation memory scales linearly with sequence length. When s is large, activations dominate:
Mact∝bmbs⋅s⋅h⋅Llocal
Context parallelism partitions the sequence dimension across DCP GPUs:
MactCP=DCPMact
CP is placed across nodes since its communication (ring attention all-to-all for KV exchange) is less bandwidth-intensive than TP.
Expert Parallelism (EP) for Mixture-of-Experts (MoE): If the model has Eexperts per MoE layer, each expert has parametersPexpert. Expert parallelism distributes experts across DEP GPUs:
Experts per GPU=DEPEMexpertper-GPU=DEPE⋅Pexpert×bytes per param
EP is placed across nodes because the all-to-all communication pattern (dispatching tokens to experts) can tolerate higher latency compared to the tight synchronization required by TP.
2.4 GPU-Poor Case — Memory Reduction Techniques
When Mtotal>MGPU and additional GPUs are unavailable, two primary strategies reduce memory:
Full Activation Recomputation: Instead of storing all intermediate activations during the forward pass, discard them and recompute during the backward pass. This eliminates Mact at the cost of one additional forward pass:
This is approximately a 33% increase in compute time for a near-complete elimination of activation memory.
Gradient Accumulation: Process the global batch as NGAS sequential micro-batches, accumulating gradients before the optimizer step:
Bglobal=DDP×NGAS×bmbs×s
Increasing NGASallows using a smallerbmbs, reducing peak activation memory:
Mact∝bmbs⋅s⋅h⋅Llocal⇒Reduce bmbs to reduce Mact
3. Step 2 — Achieving the Target Global Batch Size
After Step 1 establishes a memory-feasible configuration with some initial values of DDP, NGAS, and bmbs, the current effective global batch size is:
Bcurrent=DDP×DCP×NGAS×bmbs×s
The target is Bglobal=Btarget(e.g.,1M tokens).
3.1 Increasing BcurrenttoBtarget
If Bcurrent<Btarget, increase the batch size via:
Mechanism
Action
Trade-off
Scale DDP
Add more GPUs to data parallelism
More hardware, more communication
Scale NGAS
Increase gradient accumulation steps
Longer step time (sequential micro-batches), no memory increase
where ris the reallocation factor, preservingNGPU=DDP×DTP×DPP×DCP×DEP.
4. Step 3 — Optimizing Training Throughput
With memory feasibility and correct batch size established, the final objective is maximizing MFU. The per-step training time decomposes as:
Tstep=Tcompute+Tcomm−Toverlap+Tidle
where:
Tcompute: time for forward + backward + optimizer computation
Tcomm: total communication time across all parallelism dimensions
Toverlap: time where communication is hidden behind computation
Tidle: idle time due to pipeline bubbles, synchronization barriers, load imbalance
4.1 Communication Time Analysis per Parallelism Dimension
Tensor Parallelism communication per layer (2 all-reduce operations in forward, 2 in backward):
TcommTP=4Llocal×DTP2(DTP−1)×βintraMtensor
where Mtensor∝bmbs×s×his the activation tensor size, andβintrais the intra-node bandwidth (NVLink:∼900GB/s bidirectional on H100).
Data Parallelism (ZeRO-2) communication per step (one all-reduce of gradients, decomposed as reduce-scatter + all-gather of optimizer states):
TcommDP=2×DDP2(DDP−1)×βeffMgrad
where Mgrad=2P/(DTP⋅DPP)andβeff is the effective cross-node bandwidth.
Pipeline Parallelism introduces bubble overhead rather than bandwidth-limited communication. For a 1F1B schedule with DPPstages andm micro-batches:
TidlePP=m(DPP−1)×Tstepideal
The pipeline bubble fraction is:
fbubble=m+DPP−1DPP−1≈mDPP−1when m≫DPP
where m=DDP×NGAS is the total number of micro-batches in the pipeline.
4.2 Throughput Optimization Heuristics (Ordered by Priority)
Priority 1 — Maximize TP within a node:
DTP→G(e.g., DTP=8 on 8-GPU nodes)
Since TP communication occurs over NVLink (βintra≫βinter), this minimizes communication latency. However, TP introduces synchronization points per layer, so there exists a diminishing-returns threshold where computation per GPU becomes too small relative to communication.
The compute-to-communication ratio for TP must satisfy:
When ρTPdrops below∼1, TP becomes communication-bound.
Priority 2 — Scale DP with ZeRO-3 while maintaining Btarget:
If DDP×NGAS×bmbs×s=Btargetcan be maintained, increasingDDPdistributes computation while keeping the batch size constant (by reducingNGAS). This is beneficial because:
Priority 3 — Transition to PP when DP communication saturates:
When DP communication time TcommDPcan no longer be overlapped with backward computation (i.e.,TcommDP>Tcomputebwd), introduce pipeline parallelism to reduce DDP:
DDPnew=DPPnewDDPold
This trades DP communication overhead for pipeline bubble overhead. The transition is beneficial when:
mDPP−1⋅Tcompute<TcommDP-excess
Priority 4 — Tune micro-batch size bmbs:
The micro-batch size affects multiple performance dimensions simultaneously:
bmbs↑⇒⎩⎨⎧Mact↑GPU utilization↑m=DDP⋅bmbs⋅sBtarget↓TcommTP↑(higher memory)(larger matrix multiplications, better SM occupancy)(fewer micro-batches, larger pipeline bubble)(larger activation tensors to communicate)
The optimal bmbs∗ balances these competing effects and must be found empirically.
5. Benchmarking Thousands of Configurations
5.1 Search Space Enumeration
For a given model size Pand cluster sizeNGPU, the total number of valid configurations is:
where Bis the set of candidate micro-batch sizes,Gis the set of valid gradient accumulation steps, and the indicator function1[⋅] filters configurations that satisfy memory constraints and batch size targets.
Even after pruning infeasible configurations, ∣S∣ remains in the thousands across all model sizes and cluster sizes.
5.2 Benchmark Setup and Experimental Conditions
The benchmarks referenced in the content were conducted with:
Parameter
Value
Sequence length s
4096 tokens
Global batch size Bglobal
1M tokens(=1,048,576)
GPUs per node G
8×H100 SXM
Nodes Nnodes
1to64
Total GPUs NGPU
8to512
Interconnect (intra-node)
NVLink (900GB/s bidirectional)
Interconnect (inter-node)
InfiniBand (400Gb/s)
Precision
BF16 mixed precision
For each (P,NGPU)pair, every valid configurationC∈S was benchmarked, and the MFU was recorded.
5.3 Heatmap Analysis and Key Insights
The heatmap visualization plots the optimal configuration C∗and its corresponding MFU for each combination of model sizePand node countNnodes.
Insight 1: Efficiency Decreases with Increasing Node Count (Especially for Small Models)
For a fixed model size P, increasing Nnodes(and henceNGPU) reduces MFU. The root cause is the arithmetic intensity drop:
Arithmetic Intensity=Communication bytes per GPUFLOPs per GPU=VcommFmodel⋅Bglobal/NGPU
When Pis small,Fmodel=2Pis small, so the numerator shrinks asNGPUincreases whileVcommdoes not decrease proportionally. For small models, even increasingBglobalto compensate is impossible when constrained to1M tokens.
For small models, ηscale decays sharply because the compute-to-communication ratio falls below the threshold for efficient overlap.
Insight 2: Large Models Face Memory Walls on Small Clusters
For large P(e.g.,80B) on few nodes (e.g., 4 nodes = 32 GPUs):
Mmodelmin=NGPU16P=3216×80×109=40GB per GPU
This is the absolute minimum (ZeRO-3 with no activations). With activations, the memory exceeds 80 GB, forcing aggressive recomputation and small bmbs, which in turn leads to:
This SM contention is the fundamental reason why the assumption “computation and communication can be efficiently overlapped without throughput impact” is violated in practice. The actual throughput during overlap is:
The entire process is iterative and empirical: theoretical analysis narrows the search space S, but the final optimal C∗ must be discovered through systematic benchmarking on the target hardware, as implementation quality, network topology, and GPU-level resource contention introduce performance variations that analytical models cannot fully capture.