High-Level Overview of Distributed Training: Foundations, Memory Analysis, and First-Step Techniques
1. The Three Fundamental Challenges of Scalable Training
Every technique in large-scale model training addresses one or more of three orthogonal resource constraints:
1.1 Memory Usage (Hard Constraint)
Memory is a binary gate: if the aggregate of model parameters, gradients, optimizer states, and activations exceeds available GPU VRAM, the training step cannot execute at all. There is no graceful degradation — the process terminates with an Out-Of-Memory (OOM) error.
1.2 Compute Efficiency (Utilization Constraint)
Hardware accelerators achieve peak throughput only when arithmetic logic units (ALUs) are saturated with floating-point operations. Any time spent on:
- Memory reads/writes (bandwidth-bound operations)
- Kernel launch overhead
- Waiting for synchronization barriers
represents wasted compute capacity. The goal is to maximize the ratio:
1.3 Communication Overhead (Coordination Constraint)
In multi-GPU settings, GPUs must exchange data (gradients, activations, parameters). Communication has two distinct regimes:
| Link Type | Typical Bandwidth | Latency |
|---|---|---|
| Intra-node (NVLink, NVSwitch) | 450–900 GB/s per GPU (H100) | ~1–5 |
| Inter-node (InfiniBand, RoCE) | 50–400 Gb/s per NIC | ~1–10 |
Any time a GPU is idle waiting for a remote tensor constitutes communication overhead. The primary mitigation strategies are:
- Overlap communication with computation (hide latency behind useful work)
- Minimize total bytes transferred
- Prefer intra-node links over inter-node links when possible
1.4 The Fundamental Trade-off Triangle
These three resources are fungible — one can be traded for another:
| Technique | Saves | Costs |
|---|---|---|
| Activation Recomputation | Memory | Compute |
| Tensor Parallelism | Memory | Communication |
| Gradient Accumulation | Memory (activations) | Compute (sequential passes) |
| Gradient Compression | Communication | Compute + slight accuracy |
Finding the optimal operating point within this triangle is the central systems-level challenge of distributed training.
2. Training on One GPU: The Canonical Three-Step Loop
2.1 The Three Phases
All neural network training, regardless of scale, consists of three atomic operations per optimization step:
Phase 1 — Forward Pass:
Given a mini-batch of inputs
The final output
where
Phase 2 — Backward Pass:
Compute gradients via the chain rule (backpropagation). For each layer
This requires the stored activations
Phase 3 — Optimizer Step: Update parameters using the computed gradients. For the Adam optimizer:
where
3. Batch Size: Convergence, Throughput, and Token-Based Reporting
3.1 Definition and Impact on Convergence
The batch size (
- Small
: Gradient estimates have high variance , causing noisy updates. Useful early in training for rapid exploration of the loss landscape, but impedes precise convergence later. - Large
: Gradient estimates approach the true gradient , but each token contributes less unique information per optimizer step — convergence in terms of tokens consumed becomes slower, potentially wasting compute.
The critical batch size
Here,
3.2 Practical Batch Size Schedules
Modern LLM training frequently uses batch size warm-up:
| Training | Initial |
Final |
Transition Point |
|---|---|---|---|
| DeepSeek-V3/R1 | 3,072 | 15,360 | After 469B tokens |
3.3 Token-Based Batch Size
Because input sequences may vary in length across training configurations, the community reports batch size in tokens:
where
Typical ranges for modern LLM pretraining:
| Model | Total Training Tokens | |
|---|---|---|
| Llama 1 | ~4M | 1.4T |
| DeepSeek-V3 | ~60M | 14T |
3.4 Sensitivity
An important empirical observation: final model performance exhibits low sensitivity to the exact batch size value in a neighborhood around the optimum. This provides practical flexibility when tuning batch size for hardware constraints.
4. Memory Usage in Transformer Training
4.1 The Four Memory Occupants
During training, GPU VRAM must simultaneously hold:
- Model weights
- Gradients
- Optimizer states (e.g., Adam’s
and ) - Activations
and all intermediate tensors needed for gradient computation
Additional (constant) overhead:
- CUDA kernel context: ~1–2 GB (verified via
import torch; torch.ones((1,1)).to("cuda")thennvidia-smi) - Internal buffers and memory fragmentation (typically small)
4.2 Numerical Precision and Bytes Per Element
| Format | Bits | Bytes per element | Exponent bits | Mantissa bits |
|---|---|---|---|---|
| FP32 | 32 | 4 | 8 | 23 |
| BF16 | 16 | 2 | 8 | 7 |
| FP16 | 16 | 2 | 5 | 10 |
| FP8 (E4M3) | 8 | 1 | 4 | 3 |
5. Parameter Count for Transformer LLMs
For a standard decoder-only transformer without fixed positional embeddings, the total parameter count is:
where:
| Symbol | Meaning |
|---|---|
| Hidden dimension | |
| Vocabulary size | |
| Number of transformer layers |
Derivation of the per-layer term
Each transformer layer contains:
| Sub-module | Parameters |
|---|---|
| Self-attention QKV projection | |
| Self-attention output projection | |
| MLP up-projection (typically |
|
| MLP down-projection | |
| Total weight matrices | |
| LayerNorm (×2, each with scale + bias) | |
| Bias terms in QKV, output, MLP layers | Various, summing to |
| Total biases + norms |
The embedding layer contributes
Scaling insight: As
6. Memory for Weights, Gradients, and Optimizer States
6.1 Full Precision (FP32) Training
All tensors stored in FP32 (4 bytes each):
For the Adam optimizer, which maintains first moment
Total (FP32):
6.2 Mixed Precision (BF16 + FP32 Master Weights) Training
The standard mixed precision scheme:
| Component | Precision | Bytes per parameter |
|---|---|---|
| Working parameters (forward/backward) | BF16 | 2 |
| Gradients (forward/backward) | BF16 | 2 |
| Master weights (optimizer copy) | FP32 | 4 |
| Adam first moment |
FP32 | 4 |
| Adam second moment |
FP32 | 4 |
Total (mixed precision without FP32 grad accumulation):
Total (mixed precision with FP32 gradient accumulation):
Some libraries (e.g., Nanotron) store an additional FP32 copy of gradients for numerical stability:
6.3 Key Insight
Mixed precision does not reduce total weight/gradient/optimizer memory; the total bytes per parameter remain
- Faster arithmetic: BF16 matrix multiplications achieve 2× or greater throughput on tensor cores compared to FP32.
- Reduced activation memory: Activations during forward/backward are stored in BF16 (2 bytes instead of 4), which is the dominant memory consumer at scale.
6.4 Practical Memory Table
| Model Size ( |
FP32 or BF16 (without FP32 grad acc) | BF16 (with FP32 grad acc) |
|---|---|---|
| 1B | $16N = $16 GB | $20N = $ 20 GB |
| 7B | 112 GB | 140 GB |
| 70B | 1,120 GB | 1,400 GB |
| 405B | 6,480 GB | 8,100 GB |
For reference, a single NVIDIA H100 SXM has 80 GB of HBM3 VRAM. At 7B parameters, the weight/gradient/optimizer memory alone (
7. Memory for Activations
7.1 Why Activations Must Be Stored
During the backward pass, computing
7.2 Activation Memory Formula
For a transformer model in mixed precision (BF16 activations), the total activation memory is:
where:
| Symbol | Meaning |
|---|---|
| Number of transformer layers | |
| Sequence length | |
| Batch size (in samples) | |
| Hidden dimension | |
| Number of attention heads |
Derivation sketch (following Korthikanti et al., 2022):
Within each transformer layer, the intermediate activations that must be stored include:
| Operation | Stored Activation | Size (elements) | Bytes (BF16) |
|---|---|---|---|
| Input to self-attention LayerNorm | |||
| Q, K, V projections (3 matrices) | |||
| Attention scores (pre-softmax) | |||
| Attention weights (post-softmax) | |||
| Dropout mask (attention) | binary | ||
| Attention output projection input | |||
| Residual + LayerNorm input to MLP | |||
| MLP intermediate (up-projected) | |||
| GeLU/activation function input | |||
| MLP down-projection input | |||
| Dropout masks (×2) | binary |
Summing the terms proportional to
7.3 Scaling Behavior
Two critical observations:
- Linear scaling with
: Doubling the batch size doubles activation memory. - Quadratic scaling with
: The attention score matrices scale as per layer.
For short sequences, the
By contrast,
8. Activation Recomputation (Gradient Checkpointing / Rematerialization)
8.1 Core Idea
Trade compute for memory. Instead of storing all intermediate activations during the forward pass, discard most of them and recompute them on-the-fly during the backward pass from a small set of saved “checkpoint” activations.
Formally, without recomputation, we store:
With recomputation, we store only:
When the backward pass requires
8.2 Strategies
8.2.1 Full Recomputation
Checkpoint only at layer boundaries: store
- Memory saved: All intra-layer activation intermediates (the dominant component).
- Compute cost: Essentially one additional full forward pass during the backward pass.
- Overhead: Typically 30–40% increase in wall-clock time.
Activation memory with full recomputation reduces to approximately:
8.2.2 Selective Recomputation
Observation from Korthikanti et al. (2022): the attention score matrices (
Strategy: Discard attention scores and softmax outputs; keep the expensive feedforward (MLP) activations.
- GPT-3 (175B) empirical result: ~70% activation memory reduction at only 2.7% compute cost.
- FlashAttention natively implements this strategy: it recomputes attention scores in the backward pass from
, , blocks, never materializing the full attention matrix.
8.2.3 DeepSeek-V3 / Multi-Head Latent Attention (MLA)
MLA compresses the key-value cache into a low-rank latent space, reducing the activation footprint of attention even further beyond standard selective recomputation.
8.3 FLOPS Utilization Metrics
Recomputation changes the total number of floating-point operations performed, which affects how we measure hardware efficiency:
Hardware FLOPS Utilization (HFU):
Model FLOPS Utilization (MFU):
MFU is the preferred metric for comparing accelerators and training configurations, because it measures useful work per unit time, independent of implementation-level choices like recomputation.
8.4 A Counter-Intuitive Performance Effect
Although recomputation increases total FLOPs, it reduces memory traffic. On bandwidth-limited hardware (which GPUs often are), fewer memory accesses can make overall execution faster despite more arithmetic — a net win in both memory and speed.
9. Gradient Accumulation
9.1 Problem Statement
Even with activation recomputation, activation memory scales linearly with
9.2 Mechanism
Split the global batch into
Execute the training loop as:
gradient_buffer = 0
for i in range(grad_acc):
micro_batch = get_micro_batch(i)
loss = forward(micro_batch) / grad_acc # normalize
loss.backward() # accumulates into gradient_buffer
optimizer.step()
optimizer.zero_grad()
The division by
This makes the result mathematically identical to processing the full global batch at once (assuming no batch normalization or similar batch-dependent operations).
9.3 Memory Analysis
| Component | Without Gradient Accumulation | With Gradient Accumulation |
|---|---|---|
| Parameters | ||
| Gradients | ||
| Optimizer | ||
| Activations |
The activation memory is reduced by a factor of
Trade-off: The gradient buffer must persist across all micro-batch iterations (it is not freed until after optimizer.step()), creating a small additional memory overhead compared to the non-accumulation case where gradients are computed and freed layer-by-layer during the backward pass. However, this is vastly outweighed by the activation memory savings.
9.4 Compute Cost
Gradient accumulation is sequential:
9.5 The Path to Data Parallelism
A critical observation: the
10. Profiling GPU Compute and Communication
10.1 PyTorch Profiler
The PyTorch profiler instruments both CPU and CUDA activity, generating traces viewable in TensorBoard or Chrome’s chrome://tracing:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profile'),
with_stack=True
) as prof:
for step in range(steps):
train_step()
prof.step()
| Parameter | Purpose |
|---|---|
wait=1 |
Skip 1 step (cold start) |
warmup=1 |
Profile 1 step but discard (cache warming) |
active=3 |
Record 3 steps for analysis |
with_stack=True |
Include Python call stacks in trace |
10.2 What the Trace Reveals
The profiler trace shows multiple concurrent tracks:
- CPU thread(s): Launching CUDA kernels asynchronously, managing data loading, executing Python logic.
- CUDA compute stream(s): Executing matrix multiplications, activations, normalization kernels.
- CUDA communication stream(s): Executing NCCL collectives (AllReduce, AllGather, ReduceScatter) for gradient synchronization.
10.3 Key Bottleneck Patterns to Identify
| Pattern | Symptom in Trace | Root Cause |
|---|---|---|
| Sequential compute + communication | Communication kernel starts only after all backward kernels finish | Missing overlap of gradient sync with backward pass |
| GPU idle gaps | Empty regions in CUDA compute stream | CPU-side bottleneck (data loading, Python overhead) or CUDA synchronization barriers |
Excessive cudaMemcpy |
Large H2D/D2H blocks | Data not pre-pinned or pre-staged on GPU |
| Kernel launch overhead | Many tiny CUDA kernels with gaps between them | Operator fusion needed (e.g., via torch.compile) |
| First step anomaly | Longer first iteration with memory allocation plateaus | PyTorch caching allocator warming up memory pools |
10.4 First-Step vs. Steady-State Behavior
The profiler reveals a characteristic difference:
- Step 1: Activations ramp up, then plateau as the PyTorch CUDA caching allocator pre-allocates memory blocks. Optimizer states do not yet exist.
- Step 2+: Optimizer states (
and for Adam) are allocated after the first optimizer.step(), permanently increasing the memory baseline bybytes. This explains why training can succeed on step 1 but OOM on step 2.
11. Memory Budget Summary
For a transformer with
where
and
With full activation recomputation,
With selective recomputation,
12. Conceptual Map: From Single GPU to Distributed Training
Single-GPU Training
├── Forward → Backward → Optimize
├── Memory constraints
│ ├── Weights + Grads + Optimizer: 16N–20N bytes (fixed per model)
│ └── Activations: O(L · bs · seq · h + L · bs · n_heads · seq²) (variable)
├── Memory mitigation
│ ├── Activation Recomputation (trade compute ↔ memory)
│ └── Gradient Accumulation (trade time ↔ memory)
└── Next step: Data Parallelism
└── Parallelize independent micro-batch computations across GPUs
The independent micro-batch computations identified through gradient accumulation form the natural entry point to data parallelism, where multiple GPUs execute forward-backward passes on different micro-batches simultaneously, synchronize gradients via collective communication (AllReduce), and perform a unified optimizer step — the subject of the next stage of scaling.