Diving into the GPUs — Fusing, Threading, and Mixing
1. GPU Architecture Primer
1.1 Compute Hierarchy
An NVIDIA GPU is organized as a hierarchical array of compute units. The fundamental building block is the Streaming Multiprocessor (SM), each of which contains multiple streaming processors (cores).
Smallest unit of execution; executes one instance of the kernel
Runs on a single core
Warp
Group of exactly 32 threads executing in lockstep (SIMD)
Scheduled as an atomic unit on an SM
Block
Programmer-defined grouping of threads (e.g., 256 or 1024 threads)
Assigned to exactly one SM; an SM may host multiple blocks
Grid
Collection of all blocks for a kernel launch
Distributed across all available SMs
1.2 Memory Hierarchy
The memory system is equally hierarchical, with a fundamental trade-off: smaller memories are faster but private; larger memories are slower but shared.
Memory Level
Scope
Capacity (H100)
Latency
Bandwidth
Registers
Private to each thread
256 KB per SM
~1 cycle
Highest
Shared Memory / L1 Cache
Shared across threads in a block (one SM)
228 KB per SM (configurable)
~20–30 cycles
~19 TB/s aggregate
L2 Cache
Shared across all SMs
50 MB
~200 cycles
~12 TB/s
Global Memory (HBM3)
Shared across entire GPU
80 GB
~400–600 cycles
3.35 TB/s
The performance optimization objective is:
Maximize data reuse in fast memories (registers, shared memory) to minimize accesses to slow global memory (HBM)
1.3 Kernel Execution Model
A kernel is a function that runs on the GPU. It is written in CUDA (C/C++ extension) or Triton (Python-based) and compiled to PTX (Parallel Thread Execution) — NVIDIA’s low-level virtual ISA.
tl.program_id(0): Returns the unique block ID in dimension 0 (analogous to blockIdx.x in CUDA)
BLOCK_SIZE: Compile-time constant defining how many elements each block processes
tl.arange(0, BLOCK_SIZE): Creates a vector of consecutive indices within a block
valid_mask: Bounds-checks to prevent out-of-range memory access
tl.load / tl.store: Masked memory operations
Triton vs. CUDA control granularity:
Aspect
Triton
CUDA
Programming unit
Block (program)
Thread
Shared memory management
Automatic
Manual (__shared__)
Warp-level primitives
Not directly exposed
Full access (__shfl_sync, etc.)
Scheduling within SM
Automatic
Manual (warp/thread indexing)
Memory coalescing
Automatic (vectorized loads)
Manual (access pattern design)
3. CUDA Optimization Techniques
3.1 Memory Coalescing
Definition: Memory coalescing is the hardware mechanism by which a GPU combines multiple memory access requests from threads within a single warp into a minimal number of memory transactions, exploiting the burst transfer behavior of DRAM.
DRAM burst mechanism: When any single address Min global memory (HBM/DRAM) is accessed, the DRAM chip reads a contiguous segment ofBburstbytes (typically 32 or 128 bytes) aroundM in a single operation. Coalescing ensures threads access addresses within the same burst segment.
Coalescing condition: For a warp of 32 threads accessing addresses a0,a1,…,a31, memory is perfectly coalesced when:
ai=a0+i⋅sizeof(element)∀i∈{0,1,…,31}
That is, consecutive threads access consecutive memory locations.
Row-major storage convention: An element at row r, column cof a matrix withNcols columns is stored at linear address:
addr(r,c)=base+(r×Ncols+c)×sizeof(element)
Naive 2D kernel:
__global__ voidmatmul_naive(int M, int N, int K,
constfloat *A, constfloat *B, float *C) {
const uint x = blockIdx.x * blockDim.x + threadIdx.x; // rowconst uint y = blockIdx.y * blockDim.y + threadIdx.y; // columnif (x < M && y < N) {
float tmp = 0.0;
for (int i = 0; i < K; ++i) {
tmp += A[x * K + i] * B[i * N + y];
}
C[x * N + y] = tmp;
}
}
Problem analysis: Two threads in the same warp with consecutive threadIdx.x values (e.g., thread (0,0)and thread(1,0)) have the same ybut differentx. At iteration i=0:
Thread (0,0)readsA[0⋅K+0]=A[0]
Thread (1,0)readsA[1⋅K+0]=A[K]
These addresses are Kelements apart in memory — not consecutive. The accesses toAare uncoalesced, resulting in∼32 separate memory transactions instead of 1.
Now consecutive threadIdx.x values produce the same x(row) but differenty(column). At iterationi=0:
Thread 0 reads A[x⋅K+0]andB[0⋅N+y0]
Thread 1 reads A[x⋅K+0](same!) andB[0⋅N+y0+1]
The Baccesses are now consecutive in memory (coalesced), and theAaccesses are identical (broadcast). The result is a ∼10× improvement in both memory throughput and execution time.
3.2 Tiling (Shared Memory Optimization)
Motivation: Even with coalesced access, global memory bandwidth is limited. For matrix multiplication, each element of AandB is loaded multiple times by different threads. Without optimization, the total number of global memory loads is:
Global loads (naive)=M×N×K×2
(Each of the M×Noutput elements requiresKloads fromAandKloads fromB.)
Tiling principle: Partition the computation into tiles of size TM×TK(fromA) and TK×TN(fromB). A block of threads cooperatively loads one tile of each matrix into shared memory (SRAM), then all threads in the block compute using the fast shared memory.
The output tile Ctile∈RTM×TNis accumulated over⌈K/TK⌉ iterations:
for (int tileIdx = 0; tileIdx < K; tileIdx += TILE_SIZE) {
// Cooperative load: each thread loads one element of A and B
sharedA[localRow * TILE_SIZE + localCol] = A[localRow * K + localCol];
sharedB[localRow * TILE_SIZE + localCol] = B[localRow * N + localCol];
__syncthreads(); // Barrier: ensure all loads complete// Compute partial dot product from shared memoryfor (int i = 0; i < TILE_SIZE; ++i) {
sum += sharedA[localRow * TILE_SIZE + i]
* sharedB[i * TILE_SIZE + localCol];
}
__syncthreads(); // Barrier before next tile load
A += TILE_SIZE; // Advance tile in A (across columns)
B += TILE_SIZE * N; // Advance tile in B (down rows)
}
Memory access reduction: With tiling, each element of AandBis loaded from global memory only once per tile, then reusedT times from shared memory:
Global loads (tiled)=TM×N×K×2
The reuse factor is T, giving a proportional reduction in global memory traffic.
__syncthreads() is a block-level barrier: all threads in the block must reach this point before any can proceed. Two barriers are required per tile iteration:
After loading data into shared memory (ensure all data is available before computation)
After computation (ensure all threads finish before overwriting shared memory with the next tile)
Shared memory requirement:
Mshared=(TM×TK+TK×TN)×sizeof(float)
For TM=TK=TN=T=32 with FP32:
Mshared=(32×32+32×32)×4=8,192bytes=8KB
This easily fits within the 228 KB per SM on H100.
3.3 Thread Coarsening
Problem identified via profiling: After tiling, warp stall analysis reveals that warps spend significant cycles in the stalled_mio_throttle state — stalled waiting for the Memory Input/Output (MIO) pipeline to service shared memory requests.
The root cause is that each thread computes a single output element, requiring many shared memory loads per thread. The shared memory access pipeline becomes a bottleneck.
Thread coarsening merges Cfthreads into a single coarsened thread, where each coarsened thread computesCf output elements instead of 1.
Before coarsening: Each thread computes 1 element, requiring 2TK shared memory loads per tile iteration:
Shared mem loads per thread per tile=2×TK
(One row of Ashared, one column of Bshared.)
After coarsening by factor Cf: Each thread computes Cfelements in the same row, sharing the same row ofAshared:
Shared mem loads per thread per tile=TK+Cf×TK=(1+Cf)×TK
Without coarsening, Cf separate threads would have loaded:
Total without coarsening=Cf×2×TK=2Cf×TK
The savings ratio is:
Reduction factor=(1+Cf)×TK2Cf×TK=1+Cf2Cf
For Cf=8: reduction factor =16/9≈1.78× fewer shared memory accesses.
3.4 Minimizing Control Divergence
GPUs execute in the SIMD (Single Instruction, Multiple Data) model at the warp level: all 32 threads in a warp execute the same instruction simultaneously on different data.
Control divergence occurs when threads within a warp encounter a conditional branch and take different execution paths:
if (condition) {
// Path A — executed by some threads
} else {
// Path B — executed by remaining threads
}
When divergence occurs, the hardware serializes execution:
Threads taking Path A execute while Path B threads are masked (idle)
Then Path B threads execute while Path A threads are masked
The effective warp throughput drops proportionally:
Throughputdivergent=NpathsThroughputpeak
For a simple if-else, Npaths=2, yielding 50% throughput.
Mitigation strategies:
Strategy
Description
Predication
Compiler replaces branches with conditional assignments; both paths execute but results are selectively written
Data reorganization
Arrange data so threads in the same warp follow the same path (e.g., sort by condition)
Warp-aligned branching
Ensure the branch condition is uniform across all 32 threads in a warp
4. Fused Kernels
4.1 The Kernel Launch Overhead Problem
In standard (unfused) execution, each operation is a separate kernel launch. For a sequence of npoint-wise operationsf1,f2,…,fnon a tensorx:
x1=f1(x),x2=f2(x1),…,xn=fn(xn−1)
Each kernel launch requires:
Writexi from SM registers/shared memory → global memory (HBM)
Readxi from global memory → SM for the next kernel
The total HBM traffic for n unfused kernels:
VHBMunfused=2(n−1)×∣x∣×bytes_per_element
(Each intermediate result is written once and read once from HBM.)
4.2 Kernel Fusion
Definition: Kernel fusion combines multiple operations into a single kernel that executes all computations without materializing intermediate results in global memory.
xn=fn∘fn−1∘⋯∘f1(x)(computed entirely in registers/shared memory)
For LayerNorm, which involves ~5 point-wise operations (subtract mean, square, average, reciprocal square root, scale+shift), the savings factor is ∼4×.
4.3 Applicability in Transformers
Fusion is most beneficial for memory-bound operations — those where HBM bandwidth, not compute throughput, is the bottleneck. The arithmetic intensity (AI) determines this:
AI=Bytes accessed from memoryFLOPs
Operation
FLOPs
Bytes
AI
Bound
MatMul (M×K×N)
2MKN
∼2(MK+KN+MN)× bytes
High
Compute-bound
LayerNorm
∼5×dper token
∼2d×bytes per token
Low (∼2.5)
Memory-bound
Activation (GELU, ELU)
∼1per element
2×bytes per element
∼0.5
Memory-bound
Softmax
∼5Nper row
∼2N× bytes per row
Low
Memory-bound
Fusion provides the greatest speedup for memory-bound operations since reducing HBM traffic directly reduces the bottleneck.
5. FlashAttention
5.1 Standard Attention and Its Memory Bottleneck
The standard scaled dot-product attention for a single head is:
Attention(Q,K,V)=softmax(dkQK⊤)V
where Q,K,V∈RN×dk, Nis the sequence length, anddk is the head dimension.
The naive computation proceeds in three steps:
Compute score matrix:S=dkQK⊤∈RN×N
Compute attention weights:P=softmax(S)∈RN×N
Compute output:O=PV∈RN×dk
Memory problem: Both SandP must be materialized in HBM. Their size is:
MS=MP=N2×bytes_per_element
For N=4096, dk=128, in BF16:
MS=40962×2=33.6MB per head
With nh=96 heads (as in a large model):
Mattntotal=96×33.6MB=3.2GB
This is a significant fraction of the 80 GB HBM on an H100, and the HBM read/write traffic for these matrices becomes the dominant bottleneck.
HBM traffic for naive attention: Each of the three steps requires reading inputs from HBM and writing outputs back:
The dominant terms are all O(N2), making naive attention memory-bandwidth-bound for typical N≫dk.
5.2 FlashAttention Algorithm
Core idea: Compute the output Oby processingQ, K, Vin tiles that fit in SRAM, never materializing the fullN×NmatricesSorP in HBM.
The key mathematical challenge is that softmax requires global normalization across the entire row:
softmax(si)j=∑k=1Nesikesij
FlashAttention uses the online softmax algorithm (Milakov & Gimelshein, 2018) to compute softmax incrementally over tiles. The algorithm maintains running statistics mi(row-wise maximum) andℓi (row-wise sum of exponentials) that are updated as each new tile is processed.
Tiled computation: Partition KandVinto blocks ofBcrows (columns of the attention matrix), andQinto blocks ofBr rows:
K=[K1;K2;…;KTc],V=[V1;V2;…;VTc]
where Tc=⌈N/Bc⌉.
For each query block Qi(Brrows), iterate over KV blocksj=1,…,Tc:
Step 1: Compute the local score tile in SRAM:
Sij=dkQiKj⊤∈RBr×Bc
Step 2: Compute local row-wise maximum and update running maximum:
m~ij=rowmax(Sij)∈RBrminew=max(miold,m~ij)
Step 3: Compute local exponentials and update running sum:
Since MSRAM≫dktypically, the HBM access is reduced by a factor of approximatelyMSRAM/dk. For H100 with MSRAM≈228KB anddk=128in BF16 (=256 bytes per row):
Memory reduction: The O(N2)memory for the attention matrix is eliminated, reducing toO(N) auxiliary storage.
Wall-clock speedup: By reducing HBM traffic — the true bottleneck for attention — FlashAttention achieves 2–4× speedup over standard attention implementations despite performing the same number of FLOPs.
Enabling longer sequences: Without the N2 memory overhead, much longer sequences become feasible. This is why earlier linear/subquadratic attention approximations have been largely abandoned in favor of FlashAttention — exact attention is now fast enough.
5.5 FlashAttention-2 and FlashAttention-3
Version
Key Improvements
FlashAttention-2
(1) Reduced non-matmul FLOPs (rewrote online softmax to minimize non-GEMM operations); (2) Better work partitioning among warps within a thread block (parallelism along sequence length instead of head dimension); (3) Added parallelism across the sequence length dimension
FlashAttention-3
(1) Optimized for Hopper (H100) architecture — exploits asynchronous WGMMA (Warpgroup Matrix Multiply-Accumulate) instructions; (2) FP8 attention support with per-tile quantization; (3) Exploits TMA (Tensor Memory Accelerator) for efficient data movement between HBM and shared memory
6. Mixed Precision Training
6.1 Floating-Point Number Representation
A floating-point number x in IEEE 754 format is represented as:
x=(−1)s×2E−bias×(1+i=1∑pbi⋅2−i)
where:
s∈{0,1}: sign bit
E: stored exponent value (unsigned integer)
bias=2e−1−1: exponent bias (e = number of exponent bits)
p: number of mantissa bits
bi: individual mantissa bits
The three components control distinct properties:
Component
Controls
More bits →
Sign (s)
Positive/negative
Always 1 bit
Exponent (E, e bits)
Dynamic range (magnitude span)
Wider range of representable magnitudes
Mantissa (bi, p bits)
Precision (significant figures)
Finer resolution between consecutive numbers
6.2 Format Comparison
Format
Total Bits
Sign
Exponent (e)
Mantissa (p)
Bias
ϵ (machine epsilon)
Dynamic Range
FP32
32
1
8
23
127
2−23≈1.19×10−7
∼10±38
FP16
16
1
5
10
15
2−10≈9.77×10−4
∼10±4.8
BF16
16
1
8
7
127
2−7≈7.81×10−3
∼10±38
FP8 (E4M3)
8
1
4
3
7
2−3=0.125
∼10±2.4
FP8 (E5M2)
8
1
5
2
15
2−2=0.25
∼10±4.8
Machine epsilonϵis defined as the smallestϵ>0such thatfl(1+ϵ)=fl(1):
ϵ=2−p
where p is the number of mantissa bits.
Dynamic range is determined by the exponent bits:
xmax=(2−2−p)×22e−1−1,xminnormal=2−(2e−1−2)
Key trade-off:
BF16 sacrifices precision (only 7 mantissa bits vs. FP16’s 10) but preserves the full FP32 dynamic range (8 exponent bits). This is critical for training stability because gradient magnitudes can span many orders of magnitude.
FP16 has better precision but a much narrower dynamic range (∼10±4.8), making overflow/underflow more likely.
The number of representable values between consecutive powers of 2 (e.g., in [1,2]) is 2p:
Format
Values in [1,2]
FP32
223=8,388,608
FP16
210=1,024
BF16
27=128
FP8 (E4M3)
23=8
FP8 (E5M2)
22=4
6.3 FP16/BF16 Mixed Precision Training
Naively replacing all FP32 tensors with FP16/BF16 causes training divergence due to three failure modes, each addressed by a specific technique:
Trick 1: FP32 Master Copy of Weights
Problem — Weight update underflow: If a weight whas magnitude∼1and the gradient-based updateΔwhas magnitude∼10−5, then in FP16:
fl16(w+Δw)=fl16(1.0+0.00001)=fl16(1.0)=1.0
because Δw<ϵFP16×∣w∣≈10−3. The update is lost entirely. Once weights reach zero through underflow, they remain zero permanently (no gradient signal).
Solution: Maintain a FP32 master copy of weights w(32). The training loop becomes:
w(16)=cast16(w(32))(for forward/backward)g(16)=∇w(16)L(computed in 16-bit)w(32)←w(32)−η⋅cast32(g(16))(update in FP32)
Trick 2: Loss Scaling
Problem — Gradient underflow: Gradients are often much smaller than 1 (e.g., ∣g∣∼10−6), falling below the minimum representable value in FP16 (∼5.96×10−8) or being too imprecise in BF16.
Solution: Scale the loss before the backward pass and unscale gradients afterward:
L^=α⋅L(scaled loss, α≫1)g^(16)=∇wL^=α⋅∇wL=α⋅ggtrue(16)=αg^(16)(unscale before optimizer step)
By linearity of differentiation, scaling the loss by αscales all gradients byα, shifting them into a representable range. The unscaling restores the correct values before any further processing (clipping, optimizer step).
Dynamic loss scaling starts with a large α and halves it whenever overflow (Inf/NaN) is detected in gradients, doubling it after a fixed number of successful steps.
Trick 3: FP32 Accumulation
Problem — Accumulation error: When summing many small values (e.g., computing means, batch norms, reductions), the running sum grows large while individual addends remain small, causing catastrophic cancellation:
fl16(i=1∑Nxi)=i=1∑Nxiwhen N is large and xi are small
The relative error of naive summation in precision ϵ is bounded by:
∑xifl(∑xi)−∑xi≤(N−1)⋅ϵ+O(ϵ2)
For N=4096andϵBF16≈7.8×10−3: relative error ≤32, which means the result can be completely wrong.
Solution: Accumulate intermediate sums in FP32 even when inputs/outputs are in 16-bit:
The full mixed precision training memory per parameter is:
Component
FP32 Baseline
BF16 Mixed
Description
Master weights
4 B
4 B (FP32)
Always FP32 for accurate updates
Working weights
—
2 B (BF16)
Used in forward/backward
Gradients
4 B
2 B (BF16)
Computed in BF16
Optimizer state 1 (momentum)
4 B
4 B (FP32)
Adam first moment
Optimizer state 2 (variance)
4 B
4 B (FP32)
Adam second moment
Grad accumulation buffer
—
4 B (FP32)
Optional FP32 accumulation
Total
16 B
20 B (with accum) or 16 B (without)
6.5 FP8 Pretraining
FP8 matrix multiplications on H100 achieve twice the peak throughput of BF16:
ΦFP8H100=1,979TFLOP/s≈2×ΦBF16H100=989TFLOP/s
However, FP8 introduces severe stability challenges due to the extremely limited dynamic range and precision.
Quantization for FP8
Converting a tensor x from high precision to FP8 requires scaling to fit within FP8’s representable range:
xFP8=castFP8(sxx),sx=xmaxFP8max(∣x∣)
where xmaxFP8=448for E4M3 or57344 for E5M2.
DeepSeek-V3’s tile-wise quantization computes separate scaling factors per tile to reduce the impact of outlier values:
Activations/inputs: tiles of size 1×128 (per-token granularity)
Weights: tiles of size 128×128
For a tile T:
sT=xmaxFP8maxx∈T(∣x∣)
This is much more robust than per-tensor scaling because a single outlier value in one tile doesn’t compress the dynamic range of all other tiles.
FP8 Mixed Precision Configurations
Configuration
GEMM Precision
Master Weights
Gradients
Optimizer States
Total Memory per Param
BF16 baseline (with FP32 accum)
BF16
FP32 (4B)
BF16 (2B)
FP32+FP32 (8B)
$\sim$20 B
Transformer Engine
FP8
FP32 (4B)
FP32 (4B)
FP32+FP32 (8B)
16 B (20% reduction)
FP8-LM O3
FP8
FP16 (2B)
FP8 (1B)
FP8+FP16 (3B)
9 B (55% reduction)
DeepSeek-V3
FP8
FP32 (4B)
BF16 (2B)
BF16+BF16 (4B)
15 B (25% reduction)
Nanotron FP8
FP8
FP32 (4B)
FP8 (1B)
FP8+FP8 (2B)
10 B (50% reduction)
The primary stability risk of FP8 pretraining is that the E4M3 format has only 23=8representable values per binade and a dynamic range of only∼10±2.4. Loss divergence typically manifests as:
Gradient underflow → zero gradients → stalled learning
Activation overflow → NaN propagation → loss explosion
For memory-bound operations, this directly translates to a 2× speedup. For compute-bound operations (GEMMs), the peak FLOPS doubling provides the speedup.
7. Connecting All Concepts: The GPU Performance Model
The overall throughput of a training step on a single GPU is determined by the roofline model:
Attainable FLOP/s=min(Φpeak,AI×βHBM)
where:
Φpeak: peak compute throughput (FLOP/s) at the chosen precision
AI=Bytes transferredFLOPs: arithmetic intensity of the operation
βHBM: HBM bandwidth (bytes/s)
The ridge point — the AI where the transition from memory-bound to compute-bound occurs — is:
AIridge=βHBMΦpeak
For H100 in BF16:
AIridgeBF16=3.35×1012989×1012≈295FLOPs/byte
For H100 in FP8:
AIridgeFP8=3.35×10121979×1012≈591FLOPs/byte
Operations with AI<AIridge(attention, LayerNorm, activations) benefit most from kernel fusion and FlashAttention (which reduce bytes transferred). Operations withAI>AIridge(large GEMMs) benefit most from lower precision (which increasesΦpeak).
This unified view explains why all three techniques — kernel fusion, FlashAttention, and mixed precision — are complementary and simultaneously necessary for maximizing GPU utilization in modern large-scale training.