Scaling Distributed Training: Foundations and First Principles
1. High-Level Overview: The Three Fundamental Challenges
Every technique in large-scale model training addresses one or more of three orthogonal resource constraints. These constraints form the scaling trilemma of distributed deep learning.
1.1 Memory Usage — The Hard Constraint
Memory is a binary gate: if the aggregate memory required by a single training step exceeds the available GPU high-bandwidth memory (HBM), training cannot proceed at all — there is no graceful degradation, only an Out-Of-Memory (OOM) crash.
During a single training step, four principal memory occupants coexist on the GPU:
Memory Occupant
Description
Model Parameters
The learnable weight tensors W of the network
Gradients
∇WL, partial derivatives of the loss w.r.t. each parameter
Optimizer States
Auxiliary running statistics (e.g., first and second moment estimates in Adam)
Activations
Intermediate tensors stored during the forward pass, required for gradient computation in the backward pass
Additionally, there are minor but non-negligible constant overheads:
CUDA context/kernels: typically 1–2 GB upon first CUDA tensor allocation.
Fragmentation & temporary buffers: memory that exists but cannot be utilized due to allocator fragmentation or short-lived intermediate results.
Modern accelerators (e.g., NVIDIA H100 SXM5) deliver peak throughput of approximately 989 TFLOPS in BF16 Tensor Core operations. However, achieved throughput is invariably lower due to:
Memory-bound operations: kernels whose execution time is dominated by data movement (loads/stores) rather than arithmetic.
Kernel launch overhead: CPU-side latency to dispatch GPU kernels.
Pipeline bubbles: idle time when certain stages of computation must wait for others.
The goal is to maximize the compute utilization ratio:
Every second the GPU spends waiting — for data transfers from CPU to GPU, for memory allocation, or for synchronization barriers — directly reduces this ratio.
1.3 Communication Overhead — Minimizing GPU Idle Time
In multi-GPU and multi-node settings, GPUs must exchange data (gradients, activations, parameters). Communication occurs over interconnects with heterogeneous bandwidths:
Interconnect Type
Example Technology
Typical Bandwidth
Intra-node (fast)
NVLink 4.0 (H100)
900 GB/s bidirectional
Inter-node (slower)
InfiniBand NDR400
400Gb/s (≈50 GB/s) per port
Communication overhead keeps GPUs idle (not performing useful arithmetic). Two primary strategies mitigate this:
Overlap communication with computation: launch asynchronous collective operations (e.g., AllReduce) concurrently with backward-pass gradient computation on independent layers.
Topology-aware placement: assign communication-heavy operations to the fast intra-node links and minimize the volume of data traversing slow inter-node links.
1.4 The Scaling Trilemma: Trading Off Resources
These three resources — memory, compute, and communication — are not independent. Optimizing one often comes at the cost of another. Two canonical examples:
Technique
Saves
Costs
Activation Recomputation
Memory (activations discarded)
Compute (activations recomputed during backward)
Tensor Parallelism
Memory (model sharded across GPUs)
Communication (intermediate activations exchanged)
Formally, if we denote total training step wall-clock time as T, we can decompose it as:
where Mpeakis the peak memory during a training step andMGPU is the physical HBM capacity.
2. First Steps: Training on One GPU
2.1 The Three Phases of a Training Step
A single training step on one GPU consists of three sequential phases:
Phase 1: Forward Pass
The input batch X∈Rbs×seq×dinputis propagated throughLsuccessive layers of the model. For a generic layerℓwith parametersWℓ, the forward computation is:
a(ℓ)=fℓ(a(ℓ−1);Wℓ),ℓ=1,2,…,L
where a(0)=Xis the input embedding anda(ℓ)is the activation (hidden state) output of layerℓ. The final output y^=a(L) is used to compute the scalar loss:
L=L(y^,ytarget)
Critical point: Every intermediate activation a(ℓ) must be stored in GPU memory because it is needed for the backward pass (see Phase 2).
Phase 2: Backward Pass (Backpropagation)
Using the chain rule, gradients of the loss with respect to each layer’s parameters are computed in reverse order (ℓ=L,L−1,…,1):
∂Wℓ∂L=∂a(ℓ)∂L⋅∂Wℓ∂a(ℓ)
The upstream gradient ∂a(ℓ)∂L is propagated backward through:
∂a(ℓ−1)∂L=∂a(ℓ)∂L⋅∂a(ℓ−1)∂a(ℓ)
As each layer ℓcompletes its gradient computation, the stored activationa(ℓ)is freed from memory, while the gradient tensor∇WℓL is allocated.
Phase 3: Optimizer Step
The optimizer uses the accumulated gradients to update all parameters. For the Adam optimizer, the update rule for each parameter tensor θat stept is:
where gt=∇θL, mtis the first moment (momentum),vtis the second moment (variance),ηis the learning rate, andϵis a numerical stability constant (typically10−8).
After the optimizer step, the gradient buffers are zeroed and the cycle repeats.
2.2 Batch Size: Definitions, Impact, and Practical Ranges
Definition
The batch size (bs) is the number of independent input samples processed in a single forward–backward pass before the optimizer updates the parameters.
In the LLM pretraining community, batch sizes are reported in tokens (bst) to decouple from the choice of sequence length (seq):
bst=bs×seq
Impact on Convergence
Regime
Gradient Estimate Quality
Convergence Behavior
Smallbs
High variance (noisy)
Fast early exploration; difficulty converging to sharp optima
Largebs
Low variance (accurate)
Diminishing returns per token; slower convergence per token seen
OpenAI’s seminal study on large batch training [1] demonstrated the existence of a critical batch sizeBcrit for each model and dataset, below which doubling the batch size nearly halves the number of required optimization steps, and above which diminishing returns set in rapidly.
Key practical insight: The sensitivity of final model performance to the exact batch size is low around the optimum — i.e., batch size can be varied within a broad range near Bcrit without significant degradation.
Real-World Examples
Model
Batch Size (tokens)
Total Training Tokens
Notes
Llama 1
∼4M
1.4T
Fixed batch size throughout
DeepSeek-V3/R1
3,072→15,360sequences ($\sim$60M tokens)
14.8T
Batch size ramped during first 469B tokens
The sweet spot for contemporary LLM pretraining is typically:
bst∈[4×106,60×106] tokens per global batch
3. Memory Usage in Transformers: Detailed Breakdown
3.1 The Four Principal Memory Occupants
#
Occupant
Depends on Batch Size?
Depends on Sequence Length?
1
Model weights W
No
No
2
Gradients ∇WL
No
No
3
Optimizer states (m, v for Adam)
No
No
4
Activations a(ℓ)
Yes (linear)
Yes (quadratic)
Items 1–3 are static with respect to batch size and sequence length — they depend only on the model architecture (parameter count N). Item 4 is dynamic and is the dominant memory consumer for large batch sizes and long sequences.
3.2 Numerical Precision Formats
Format
Bytes per Value
Exponent Bits
Mantissa Bits
Dynamic Range
Precision
FP32
4
8
23
∼10±38
High
BF16
2
8
7
∼10±38
Reduced mantissa
FP16
2
5
10
∼6.5×104
Narrower range
FP8 (E4M3)
1
4
3
∼448
Very low
The memory footprint of any tensor is:
Memory (bytes)=number of elements×bytes per element
4. Memory for Weights, Gradients, and Optimizer States
4.1 Parameter Count of a Transformer LLM
For a decoder-only transformer with:
h: hidden dimension
v: vocabulary size
L: number of layers
No fixed positional embeddings (e.g., using RoPE)
The total parameter count is:
N=h⋅v+L⋅(12h2+13h)+2h
Breakdown of the per-layer term12h2+13h:
Component
Parameters
Count
Self-attention: WQ,WK,WV
3×h×h
3h2
Self-attention: WO(output projection)
h×h
h2
Feed-forward: W1(up-projection)
h×4h
4h2
Feed-forward: W2(down-projection)
4h×h
4h2
LayerNorm (×2 per layer, each with γ,β)
2×2h
4h
Biases in attention and FFN (varies)
—
∼9h
Total per layer
—
12h2+13h
The term h⋅vaccounts for the token embedding matrix, and the final2h accounts for the final LayerNorm (gain and bias).
Scaling observation: As hgrows, the dominant term is12Lh2, which grows quadratically in the hidden dimension. The linear terms (13Lh, hv, 2h) become negligible for very large models.
4.2 Full Precision (FP32) Training Memory
In pure FP32 training, every value — parameters, gradients, optimizer states — is stored in 32-bit floating point (4 bytes):
Total (mixed precision, FP32 gradient accumulation):
If gradients are accumulated in FP32 (for stability with small gradient values in BF16), an additional 4N bytes is required:
mtotalmixed, FP32 grad=2N+2N+4N+4N+8N=20N bytes
Key insight: Mixed precision training does not reduce the total memory for weights + gradients + optimizer states compared to full FP32 training (both are 16N bytes without FP32 gradient accumulation). The advantages of mixed precision are:
Faster computation: BF16 Tensor Core operations are 2× faster than FP32.
Reduced activation memory: activations stored in BF16 use half the memory of FP32.
4.4 Practical Memory Requirements Table
Model Size (N)
FP32 or BF16 (no FP32 grad acc) — 16N
BF16 with FP32 grad acc —20N
1B
16GB
20 GB
7B
112GB
140 GB
70B
1,120GB
1,400 GB
405B
6,480GB
8,100 GB
For reference, an NVIDIA H100 SXM5 has 80 GB HBM3. At 7B parameters, the weights + gradients + optimizer states alone (112–140 GB) already exceed a single GPU’s capacity — before accounting for activations.
5. Memory for Activations
5.1 Why Activations Must Be Stored
During the backward pass, computing ∂Wℓ∂Lrequires the input activationa(ℓ−1)to layerℓ:
∂Wℓ∂L=(∂a(ℓ)∂L)⊤⋅∂Wℓ∂a(ℓ)
For a linear layer a(ℓ)=Wℓa(ℓ−1), this becomes:
∂Wℓ∂L=(∂a(ℓ)∂L)⊤⋅a(ℓ−1)
Without the stored activation a(ℓ−1), this gradient cannot be computed. Hence, all intermediate activations must be retained in memory from the forward pass until they are consumed in the backward pass.
5.2 Activation Memory Formula
For a transformer model in mixed precision (BF16 activations), the total activation memory is:
mact=L⋅seq⋅bs⋅h⋅(34+h5⋅nheads⋅seq)
where:
Symbol
Meaning
L
Number of transformer layers
seq
Sequence length (number of tokens per sample)
bs
Batch size (number of samples)
h
Hidden dimension
nheads
Number of attention heads
Derivation Sketch (per layer, per sample)
Within each transformer layer, the following intermediate tensors must be stored for backpropagation:
Activation Tensor
Shape
Bytes (BF16, 2 bytes)
Input to attention LayerNorm
seq×h
2⋅seq⋅h
Q,K,Vprojections
3×seq×h
6⋅seq⋅h
Attention scores softmax(dkQK⊤)
nheads×seq×seq
2⋅nheads⋅seq2
Attention output before WO
seq×h
2⋅seq⋅h
Dropout masks (attention + FFN)
—
∼2⋅seq⋅h (1 byte each × 2)
Input to FFN LayerNorm
seq×h
2⋅seq⋅h
FFN intermediate (4hhidden)
seq×4h
8⋅seq⋅h
FFN output
seq×h
2⋅seq⋅h
Residual connections, GeLU input, etc.
—
remaining terms
Summing all per-layer contributions for a single sample and then scaling by Llayers andbs samples yields the formula above. The full accounting is detailed in the NVIDIA recomputation paper [4].
Critical scaling behavior:
Linear in bs (batch size)
Linear in L (depth)
Quadratic in seq(due to thenheads⋅seq2/h attention score term)
The attention score memory ∝seq2 dominates for long sequences
This means:
mact=O(L⋅bs⋅seq2⋅nheads)for large seq
For short sequences, the 34⋅seq⋅hterm dominates and memory is approximately linear inseq. For long sequences, the quadratic term takes over and activation memory explodes.
Activation recomputation trades compute for memory: instead of storing all intermediate activations during the forward pass, we discard most of them and recompute them on-the-fly during the backward pass from a small set of saved checkpoints.
Without recomputation, memory for activations is O(L)(allLlayers’ activations stored). With full recomputation, if we checkpoint only everyk-th layer, activation memory becomes O(L/k), but we pay an additional forward-pass compute cost to recompute the discarded activations.
6.2 Strategies
Strategy 1: Full Recomputation
What is stored: only the activation at the boundary of each transformer layer (i.e., a(ℓ)forℓ=0,1,…,L, but none of the intermediate tensors within each layer).
Recomputation cost: during the backward pass, for each layer ℓ, the entire forward pass through layer ℓ must be re-executed to reconstruct internal activations (attention scores, FFN intermediates, etc.).
Net effect: approximately one additional full forward pass is performed during each backward pass.
Compute overhead: typically +30%to+40% wall-clock time increase per training step.
Memory savings: maximal — all intra-layer activations are freed.
Strategy 2: Selective Recomputation
The key observation from the NVIDIA paper [4] is that not all activations are equally costly to store or cheap to recompute:
Activation Type
Memory Footprint
Recompute FLOPS Cost
Attention scores (QK⊤/dk)
Large (∝nheads⋅seq2)
Low (matrix multiply, softmax)
FFN intermediates
Moderate (∝seq⋅4h)
High (large matrix multiplications)
Selective strategy: discard and recompute only the attention-related activations (which are large but cheap to recompute), while retaining the FFN activations (which are expensive to recompute).
For GPT-3 (175B):
Activation memory reduction: ∼70%
Compute overhead: only ∼2.7%
This is a dramatically better trade-off than full recomputation.
Example — DeepSeek-V3: uses Multi-Head Latent Attention (MLA), which compresses the Q, K, V representations into a low-rank latent space. This reduces the attention activation memory even further, making selective checkpointing even more effective:
K~,V~∈Rseq×dc,dc≪h
where dc is the compressed latent dimension, drastically reducing the stored activation size for key-value pairs.
6.3 FlashAttention and Native Recomputation
FlashAttention (Dao et al.) is a hardware-aware exact attention algorithm that:
Computes attention in tiled blocks that fit in GPU SRAM (on-chip memory), avoiding materialization of the full seq×seq attention matrix in HBM.
Natively integrates selective recomputation: during the backward pass, attention scores are recomputed from Q, K, V rather than loaded from HBM.
Since FlashAttention is now the default in most training frameworks, practitioners using it are already benefiting from selective recomputation without explicit gradient checkpointing configuration for the attention layers.
6.4 Hardware FLOPS Utilization (HFU) vs. Model FLOPS Utilization (MFU)
Recomputation adds “extra” floating-point operations that are not part of the theoretical minimum computation for a forward + backward pass. This creates an important distinction in efficiency metrics:
This measures how effectively the hardware arithmetic units are utilized, including recomputation. A high HFU means the GPU is kept busy, but some of that work is “redundant.”
Model FLOPS Utilization (MFU):
MFU=Δt⋅Peak FLOPSFforward+Fbackward
This measures how much useful (non-redundant) computation the hardware performs per unit time. MFU is the better metric for comparing different hardware or different training configurations, because it rewards setups that can avoid recomputation (e.g., by having more memory available).
For an ideal training setup with no recomputation: HFU=MFU.
For a setup with full recomputation: HFU>MFU (because the denominator is the same but HFU’s numerator includes the extra forward pass).
7. Gradient Accumulation
7.1 Core Mechanism
Gradient accumulation decouples the global batch size (which determines the gradient quality and convergence behavior) from the micro-batch size (which determines the activation memory per forward–backward pass).
The procedure for a single optimizer step with G gradient accumulation steps:
For i=1,2,…,G:1. Forward pass on micro-batch Bi of size mbs2. Backward pass: compute ∇WL(Bi)3. Accumulate: gˉ←gˉ+∇WL(Bi)4. Average: gˉ←Ggˉ5. Optimizer step: W←Adam(W,gˉ)6. Zero: gˉ←0
7.2 Batch Size Relationship
gbs=mbs×grad_acc
where:
Symbol
Meaning
gbs
Global batch size — total number of samples per optimizer step
mbs
Micro-batch size — number of samples per single forward–backward pass
grad_acc(G)
Gradient accumulation steps — number of sequential forward–backward passes per optimizer step
In tokens:
gbst=gbs×seq=mbs×grad_acc×seq
7.3 Memory Trade-Off
Without Gradient Accumulation
With Gradient Accumulation
mact∝gbs⋅seq(full batch in memory)
mact∝mbs⋅seq (only one micro-batch in memory)
Single forward–backward pass
G sequential forward–backward passes
Activation memory reduction factor: mbsgbs=G
Drawback: Gradient accumulation requires Gsequential forward–backward passes per optimizer step. The wall-clock time per optimizer step increases approximately linearly withG:
Tstep≈G⋅(Tfwd+Tbwd)+Topt
Subtle memory note: Gradient accumulation requires a persistent gradient buffer of size mgradthat persists across allG micro-batch passes. Without gradient accumulation, gradients can be computed and immediately consumed during the backward pass, enabling slightly lower peak memory through operator fusion. With accumulation, gradients from previous micro-batches must persist, creating a small memory overhead.
7.4 The Parallelism Opportunity
The Gmicro-batch forward–backward passes are independent computations (different input samples, no inter-dependencies). This independence is precisely what enables Data Parallelism: distribute theGmicro-batches acrossG GPUs, perform forward–backward passes simultaneously, and synchronize gradients via an AllReduce collective before the optimizer step. This is covered in the next section of the book.
8. Profiling GPU Compute and Communication
8.1 The PyTorch Profiler
PyTorch provides a built-in profiler that traces CPU and GPU activity at the kernel level. The API:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profile'),
with_stack=True
) as prof:
for step inrange(steps):
train_step()
prof.step()
Parameters explained:
Parameter
Meaning
activities
Which devices to trace (CPU thread activity, CUDA kernel launches and execution)
schedule(wait=1, warmup=1, active=3)
Skip 1 step (cold start), warm up for 1 step (JIT compilation, allocator), then actively trace 3 steps
on_trace_ready
Callback to export the trace (here, to TensorBoard format)
with_stack=True
Capture Python call stacks for each operation (enables source-code-level attribution)
8.2 Anatomy of a Profiler Trace
A profiler trace visualized in TensorBoard or Chrome’s chrome://tracing reveals multiple concurrent timelines:
Timeline
What It Shows
CPU threads
Python execution, kernel launch commands, data loading, synchronization calls
Allocation and deallocation events, peak memory watermark
Key patterns to identify:
Sequential compute and communication: If a gradient AllReduce operation appears after the backward pass completes (rather than overlapping with it), there is unnecessary GPU idle time. The fix is to launch AllReduce asynchronously as soon as each layer’s gradients are ready.
Idle GPU time (gaps between kernels): May indicate:
CPU bottleneck (kernel launch overhead, data preprocessing)
Memory allocation stalls (fragmentation forcing the CUDA caching allocator to search or defragment)
CUDA Syncs and CPU↔GPU data transfers: cudaMemcpy (host-to-device or device-to-host) operations appear as blocking periods. These should be minimized or overlapped with computation using pinned memory and non-blocking transfers.
First step anomaly: The first training step exhibits a markedly different profile due to:
CUDA context initialization: loading CUDA kernels, JIT-compiling fused operations.
PyTorch caching allocator warm-up: the allocator performs trial allocations to build a free-block cache (see Zach DeVito’s blog on the PyTorch CUDA caching allocator). Subsequent steps reuse these cached blocks, making allocation nearly free.
Optimizer state initialization: Adam’s mtandvt tensors are allocated after the first backward pass and persist for all subsequent steps.
Practical consequence: A model that fits in memory during step 1 may OOM at step 2, because the optimizer states (8N bytes for Adam in FP32) are allocated only after the first backward pass, permanently increasing the memory baseline.
9. Summary: The Landscape Before Multi-GPU Scaling
Concept
Key Formula / Insight
Training step
Forward → Backward → Optimizer update
Batch size (tokens)
bst=bs×seq
Parameter count (Transformer)
N=hv+L(12h2+13h)+2h
Memory: params + grads + opt (FP32)
16N bytes
Memory: params + grads + opt (mixed, no FP32 grad)
16N bytes
Memory: params + grads + opt (mixed, FP32 grad)
20N bytes
Memory: activations (mixed precision)
mact=L⋅seq⋅bs⋅h⋅(34+h5⋅nheads⋅seq)
Activation recomputation
Trade compute (+30–40% full, +2.7% selective) for memory (up to 70% activation reduction)
Gradient accumulation
gbs=mbs×grad_acc; constant activation memory regardless of gbs
Scaling trilemma
Memory ↔ Compute ↔ Communication — optimizing one often costs another
These single-GPU foundations — memory anatomy, activation recomputation, and gradient accumulation — form the building blocks upon which all multi-GPU parallelism strategies (Data Parallelism, Tensor Parallelism, Pipeline Parallelism, ZeRO, etc.) are constructed.
Corrected Production-Accurate Memory Analysis for Transformer Training
0. The Fundamental Error in the Naïve “16N Bytes” Estimate
The statement “16N bytes for FP32 training” commits a critical analytical error: it treats all four memory occupants — parameters, gradients, optimizer states, and activations — as if they coexist simultaneously at their full sizes throughout the entire training step. In reality:
Gradients do not exist before the backward pass begins. They are allocated layer-by-layer during backpropagation.
Activations do not vanish the instant the backward pass starts. They are freed progressively as each layer’s backward computation consumes them.
The true peak memory occurs at a specific transient moment — the start of the backward pass — when gradients begin accumulating while nearly all activations are still resident.
The corrected analysis requires tracing memory occupancy as a function of the training phase, not collapsing it to a single static number.
1. Notation and Assumptions
1.1 Symbols
Symbol
Definition
N
Total number of learnable parameters
L
Number of transformer layers
h
Hidden dimension
v
Vocabulary size
nheads
Number of attention heads
seq
Sequence length (tokens per sample)
bs
Batch size (samples)
Aℓ
Activation memory (bytes) stored for layerℓ during forward pass
Atot
Total activation memory across all layers:Atot=∑ℓ=1LAℓ
Arem(k)
Remaining (unreleased) activation memory when backward pass has reached layerk
1.2 Precision Assumption (FP32 Baseline)
All tensors — parameters, gradients, optimizer states — stored in FP32 (4bytes per scalar). Adam optimizer stores two auxiliary states per parameter: first momentmtand second momentvt.
1.3 Generalized Parameter Count
The formula N=hv+L(12h2+13h)+2his architecture-specific (standard GPT-style decoder with4h FFN intermediate dimension, bias terms in all projections, two LayerNorms per layer, and no tied embeddings). A generalized parameter count for an arbitrary transformer should account for:
N=embeddingNembed+per-layerℓ=1∑LNℓ+output head + final normNhead
where each layer ℓ contributes:
Nℓ=QKV projectionsnheads⋅dhead⋅h⋅3+WOnheads⋅dhead⋅h+FFN up + downh⋅dff+dff⋅h+norms, biasesNnormℓ+Nbiasℓ
with dhead=h/nheadsanddffbeing the FFN intermediate dimension (often4h, but 38hin SwiGLU-based models like Llama). Whether biases exist, whether embeddings are tied to the output head, whether GQA (grouped-query attention) is used — all changeN. The specific formula must be derived per architecture; no single formula is universal.
For the remainder of this analysis, we treat N as a known constant for a given model and focus on the memory dynamics during training, which are architecture-independent in structure.
2. Fixed Memory: Parameters and Optimizer States
These components depend only on N and are independent of batch size, sequence length, and training phase (once initialized):
Component
Memory (bytes)
When Allocated
When Freed
Parameters W
4N
Model initialization
Never (persistent)
Adam first moment mt
4N
After first backward pass
Never (persistent)
Adam second moment vt
4N
After first backward pass
Never (persistent)
Gradients ∇WL
4N
During backward pass
Zeroed after optimizer step
mparam+opt=4N+4N+4N=12N bytes (persistent baseline after step 1)mgrad=4N bytes (transient, exists only during backward + optimizer)
Critical observation: The gradient tensor’s 4Nbytes is not always resident. Before the backward pass, no gradient memory is allocated (or it is zero-filled from the previous step’s clearing). It builds up during the backward pass. This temporal behavior is what makes the simple “16N” summary misleading.
3. Phase-by-Phase Memory Analysis
3.1 Phase 0: Before Forward Pass (Steady-State Baseline)
At this point in the training loop, the previous step’s optimizer update has completed, gradients have been zeroed, and activations from the previous step have been fully freed.
Component
Memory
Parameters W
4N
Adam mt
4N
Adam vt
4N
Activations
0
Gradients
0
Mphase 0=12N
Note on Step 0 vs. Step 1+: At the very first step, the optimizer states m0andv0do not yet exist (they are initialized to zero tensors only after the first backward pass). Therefore, memory before the first forward pass is only4N(parameters alone). After the first optimizer step, the persistent baseline jumps to12N. This explains why a model that fits in GPU memory at step 0 can OOM at step 1.
3.2 Phase 1: During the Forward Pass
As the forward pass proceeds through layers ℓ=1,2,…,L, each layer produces intermediate activations a(ℓ) that must be retained for the backward pass. The activation memory monotonically increases as more layers are processed.
After completing the forward pass through layer k(i.e., layers1throughk are done):
3.3 Phase 2: End of Forward Pass (All Activations Stored)
When the forward pass is complete (k=L), all layer activations are simultaneously resident:
Mend-fwd=12N+Atot,where Atot=ℓ=1∑LAℓ
At this point, the loss L is computed, and the backward pass is about to begin. No gradients have been allocated yet.
For the activation memory formula in mixed precision (BF16):
Atot=L⋅seq⋅bs⋅h⋅(34+h5⋅nheads⋅seq)
For FP32, all activation bytes double (replace the constant 34and coefficient5 with their FP32 equivalents derived from the same accounting with 4-byte storage per element instead of 2-byte).
3.4 Phase 3: During the Backward Pass — The True Peak
The backward pass proceeds in reverse layer order: ℓ=L,L−1,…,1. At each backward step k(processing layerk in reverse):
The activation a(k−1)(input to layerk) is read to compute ∂Wk∂L.
The gradient ∇WkLis written to the gradient buffer (4Nkbytes, whereNkis the parameter count of layerk).
After layer k’s backward computation completes, activation a(k) is freed.
Define the remaining activation memory when the backward pass has just reached layer k(i.e., layersL,L−1,…,k+1have completed their backward, and layerk is about to begin):
Arem(k)=ℓ=k∑LAℓ
And the accumulated gradient memory at this point (gradients for layers L,L−1,…,k+1have been computed, and layerk’s gradient is about to be computed):
Gacc(k)=ℓ=k+1∑L4Nℓ≈4N⋅LL−k
(Approximate, assuming equal parameter distribution across layers.)
3.5 Identifying the True Peak: Start of Backward (k=L)
The peak memory occurs at the very beginning of the backward pass, when:
All activations are still resident: Arem(1)≈Atot
Gradients are beginning to be allocated: Gacc is initially small but the gradient buffer for the full model is typically pre-allocated
In most frameworks (PyTorch, JAX), the full gradient buffer of size 4N is allocated at once when .backward() is called, not incrementally. Therefore, at the first instant of the backward pass:
Mpeak=4N+8N+4N+Atot=16N+Atot
This is the worst-case peak memory — the absolute maximum GPU memory required during training.
This is strictly greater than 16N:
Mpeak=16N+Atot>16Nsince Atot>0 for any non-trivial model/batch
3.6 Memory Decrease During Backward Pass
As the backward pass progresses from layer Ldown to layer1:
Arem(k)decreases (activations freed after each layer’s backward)
Gacc(k)increases (gradients accumulate) — but since the gradient buffer is pre-allocated, this doesn’t actually change total memory
Net effect: memory decreases monotonically during the backward pass, from the peak of 16N+Atottoward16N (all activations freed, all gradients computed).
3.7 Phase 4: Optimizer Step
After the backward pass completes:
Component
Memory
Parameters W
4N
Gradients ∇WL
4N
Adam mt
4N
Adam vt
4N
Activations
0
Mopt=16N
The optimizer reads ∇WL, updates mt, vt, and W in-place. After the update, gradients are zeroed (or freed and re-allocated next step).
3.8 Phase 5: After Optimizer Step (Return to Baseline)
Mpost-opt=12N
Parameters (4N) + Adam states (8N) persist. Gradient buffers are zeroed/freed. The cycle repeats.
In words: The peak training memory occurs at the start of the backward pass and equals the full parameter + gradient + optimizer state memory (16N) plus all unreleased activations (Atot).
This is an exact invariant (modulo constant CUDA context overhead and allocator fragmentation) that holds for:
Any transformer architecture (encoder, decoder, encoder-decoder)
The naïve “16N” table dramatically underestimates the true memory requirement by ignoring Atot, which can be comparable to or larger than the parameter-related memory, especially for long sequences and large batch sizes.
7. Why This Correction Matters in Practice
7.1 OOM Debugging
When a training run crashes with OOM, practitioners often check only whether “16N fits in GPU memory.” This is insufficient. The true check must be:
16N+Atot(bs,seq)≤MGPU−MCUDA context
where MCUDA context≈1–2 GB.
7.2 Batch Size / Sequence Length Selection
Since Atot∝bs⋅seq(linearly) andAtot∝seq2(quadratically, through the attention term), the maximum feasiblebsandseqare not determined by16Nalone but by the residual memory after subtracting the persistent12N baseline:
For large batch sizes or long sequences where Atot≫N, the peak can be many times the baseline, making the naïve 16N estimate not just slightly wrong but qualitatively misleading about whether a training configuration will fit in memory.