Docs AI Engineering Knowledge Hub

Tensor Parallelism (TP) and Sequence Parallelism (SP)


1. Motivation and Problem Statement

ZeRO-style data parallelism shards parameters, gradients, and optimizer states across GPUs. However, as model scale increases, activation memory becomes the dominant bottleneck. ZeRO requires an all-gather to reconstruct the full parameter tensor before each computation step, which limits scalability.

Tensor Parallelism (TP) addresses this by sharding weights, gradients, optimizer states, and activations simultaneously, performing computation directly on the shards without gathering the full tensors beforehand.


2. Mathematical Foundations of Tensor Parallelism

Tensor parallelism exploits two fundamental properties of matrix multiplication. Given matrices ARm×kA \in \mathbb{R}^{m \times k}andBRk×nB \in \mathbb{R}^{k \times n}:

2.1 Column-wise Partitioning (Equation 1)

Partition BBintoNN column blocks:

B=[B1B2BN],BiRk×nN B = \begin{bmatrix} B_1 & B_2 & \cdots & B_N \end{bmatrix}, \quad B_i \in \mathbb{R}^{k \times \frac{n}{N}}

Then:

AB=A[B1B2BN]=[AB1AB2ABN] A \cdot B = A \cdot \begin{bmatrix} B_1 & B_2 & \cdots & B_N \end{bmatrix} = \begin{bmatrix} AB_1 & AB_2 & \cdots & AB_N \end{bmatrix}

Each partial product ABiRm×nNAB_i \in \mathbb{R}^{m \times \frac{n}{N}}can be computed independently on GPUii. The full result is recovered via concatenation (all-gather along the column dimension).

2.2 Row-wise Partitioning (Equation 2)

Partition AAintoNNcolumn blocks andBBintoNN corresponding row blocks:

A=[A1A2AN],AiRm×kN A = \begin{bmatrix} A_1 & A_2 & \cdots & A_N \end{bmatrix}, \quad A_i \in \mathbb{R}^{m \times \frac{k}{N}}
B=[B1B2BN],BiRkN×n B = \begin{bmatrix} B_1 \\ B_2 \\ \vdots \\ B_N \end{bmatrix}, \quad B_i \in \mathbb{R}^{\frac{k}{N} \times n}

Then:

AB=i=1NAiBi A \cdot B = \sum_{i=1}^{N} A_i B_i

Each partial product AiBiRm×nA_i B_i \in \mathbb{R}^{m \times n}is computed independently on GPUii. The full result requires summation across GPUs (all-reduce).

2.3 Neural Network Convention

In neural networks, the standard linear layer computes:

Y=XW+b Y = X \cdot W + b

where:

  • XRb×s×hinX \in \mathbb{R}^{b \times s \times h_{\text{in}}}— input activations (batchbb, sequence length ss, hidden dimension hinh_{\text{in}})
  • WRhin×houtW \in \mathbb{R}^{h_{\text{in}} \times h_{\text{out}}} — weight matrix
  • bRhoutb \in \mathbb{R}^{h_{\text{out}}} — bias vector
  • YRb×s×houtY \in \mathbb{R}^{b \times s \times h_{\text{out}}} — output activations

3. Column-Linear (Column-Parallel) Sharding

3.1 Mechanism

Given NNGPUs (TP degree=N= N), partition the weight matrix WW along its output (column) dimension:

W=[W1W2WN],WiRhin×houtN W = \begin{bmatrix} W_1 & W_2 & \cdots & W_N \end{bmatrix}, \quad W_i \in \mathbb{R}^{h_{\text{in}} \times \frac{h_{\text{out}}}{N}}

3.2 Forward Pass

Step Operation Description
1 Broadcast Copy the full input XRb×s×hinX \in \mathbb{R}^{b \times s \times h_{\text{in}}}to allNN GPUs
2 Local matmul GPU iicomputesYi=XWiRb×s×houtNY_i = X \cdot W_i \in \mathbb{R}^{b \times s \times \frac{h_{\text{out}}}{N}}
3 All-gather Concatenate {Y1,Y2,,YN}\{Y_1, Y_2, \ldots, Y_N\}to reconstructYRb×s×houtY \in \mathbb{R}^{b \times s \times h_{\text{out}}}

3.3 Communication Primitives

  • Broadcast: O(hinbs)O(h_{\text{in}} \cdot b \cdot s) — replicate full input
  • All-gather: O ⁣(houtNbs(N1))O\!\left(\frac{h_{\text{out}}}{N} \cdot b \cdot s \cdot (N-1)\right) — reconstruct output

3.4 Key Property

Each GPU stores only 1N\frac{1}{N}of the weight columns and produces1N\frac{1}{N} of the output columns. The input is replicated, but the output activations are sharded.


4. Row-Linear (Row-Parallel) Sharding

4.1 Mechanism

Partition the weight matrix WW along its input (row) dimension:

W=[W1W2WN],WiRhinN×hout W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_N \end{bmatrix}, \quad W_i \in \mathbb{R}^{\frac{h_{\text{in}}}{N} \times h_{\text{out}}}

This requires a corresponding partition of the input along the hidden dimension:

X=[X1X2XN],XiRb×s×hinN X = \begin{bmatrix} X_1 & X_2 & \cdots & X_N \end{bmatrix}, \quad X_i \in \mathbb{R}^{b \times s \times \frac{h_{\text{in}}}{N}}

4.2 Forward Pass

Step Operation Description
1 Scatter Split and distribute input XXso GPUiireceivesXiX_i
2 Local matmul GPU iicomputesYi=XiWiRb×s×houtY_i = X_i \cdot W_i \in \mathbb{R}^{b \times s \times h_{\text{out}}}
3 All-reduce Sum {Y1,Y2,,YN}\{Y_1, Y_2, \ldots, Y_N\}element-wise:Y=i=1NYiY = \sum_{i=1}^{N} Y_i

4.3 Communication Primitives

  • Scatter: O(bshin)O(b \cdot s \cdot h_{\text{in}}) — distribute input chunks
  • All-reduce: O(bshout)O(b \cdot s \cdot h_{\text{out}}) — sum partial results (equivalent to reduce-scatter + all-gather)

4.4 Key Property

Each GPU stores only 1N\frac{1}{N} of the weight rows. The input is sharded, and the output is full-sized but requires summation for correctness.


5. Tensor Parallelism in the Transformer Block

A standard Transformer decoder layer consists of two primary sub-blocks:

Transformer Layer=LayerNormMHALayerNormMLP \text{Transformer Layer} = \text{LayerNorm} \rightarrow \text{MHA} \rightarrow \text{LayerNorm} \rightarrow \text{MLP}

5.1 MLP Block

The feedforward MLP in a Transformer consists of two linear projections with a nonlinearity (e.g., GeLU or SiLU):

MLP(X)=σ(XW1)W2 \text{MLP}(X) = \sigma(X W_1) W_2

where:

  • W1Rh×4hW_1 \in \mathbb{R}^{h \times 4h} (up-projection, or gate projection)
  • W2R4h×hW_2 \in \mathbb{R}^{4h \times h} (down-projection)
  • σ\sigma is an element-wise activation function (GeLU, SiLU, etc.)

TP Strategy for MLP:

Layer Parallelism Type Rationale
W1W_1 (FC1 / up-projection) Column-linear Splits output hidden dim; input is broadcast (or already synced)
σ\sigma (activation) Local Applied element-wise on sharded activations
W2W_2 (FC2 / down-projection) Row-linear Takes sharded input, produces full output; requires all-reduce

Why Column-Linear → Row-Linear (not vice versa)?

If we used Row-Linear → Column-Linear, we would need an intermediate all-reduce between the two layers (to get correct row-linear output) followed by another communication for the column-linear input. The Column-Linear → Row-Linear ordering requires:

  • Forward pass: One broadcast (often a no-op if inputs are already synced) + one all-reduce
  • Backward pass: One all-reduce + one no-op

This yields one all-reduce per MLP sub-block per direction — the minimal communication.

5.2 Multi-Head Attention (MHA) Block

Multi-head attention computes:

MHA(X)=Concat(head1,head2,,headnh)WO \text{MHA}(X) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_{n_h}) \cdot W^O

where each attention head ii computes:

headi=Attention(XWiQ,  XWiK,  XWiV) \text{head}_i = \text{Attention}(X W^Q_i, \; X W^K_i, \; X W^V_i)
Attention(Q,K,V)=softmax ⁣(QKdk)V \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V

with dk=hnhd_k = \frac{h}{n_h} being the per-head dimension.

TP Strategy for MHA:

Component Parallelism Type Rationale
WQ,WK,WVW^Q, W^K, W^Vprojections Column-linear Naturally partition along thenhn_h(num_heads) dimension; each GPU handlesnhN\frac{n_h}{N} heads
Attention computation Local Each GPU computes attention for its assigned heads independently
WOW^O (output projection) Row-linear Recombines head outputs; requires all-reduce

Natural interpretation: With TP degree NN, GPU iicomputes attention for heads{(i1)nhN+1,,inhN}\left\{\frac{(i-1) \cdot n_h}{N} + 1, \ldots, \frac{i \cdot n_h}{N}\right\}.

5.3 Constraints on TP Degree

Hard constraint:

NTPnh(number of query attention heads) N_{\text{TP}} \leq n_h \quad \text{(number of query attention heads)}

since we partition along the head dimension, and each GPU must receive at least one head.

GQA/MQA constraint: In Grouped Query Attention:

  • nqn_qquery heads,nkvn_{kv}key/value heads, withnqnkvn_q \geq n_{kv}
  • TP degree can go up to nqn_q, but when NTP>nkvN_{\text{TP}} > n_{kv}, multiple GPUs share the same K/V heads
  • Requires careful K/V head replication and synchronization

Example — Llama-3 8B:

  • nq=32n_q = 32query heads,nkv=8n_{kv} = 8 key/value heads
  • Maximum NTP=32N_{\text{TP}} = 32theoretically, but practical implementations typically useNTP8N_{\text{TP}} \leq 8 to avoid K/V synchronization overhead and inter-node communication

6. Communication Analysis in Tensor Parallelism

6.1 Communication Primitives Summary

Primitive Description Volume
Broadcast Replicate data from one GPU to all O(D)O(D)
Scatter Split data and distribute chunks O(D)O(D)
All-gather Each GPU contributes a chunk; all receive the full tensor O ⁣(N1ND)O\!\left(\frac{N-1}{N} \cdot D\right)
Reduce-scatter Reduce (sum) and scatter result chunks O ⁣(N1ND)O\!\left(\frac{N-1}{N} \cdot D\right)
All-reduce Sum across all GPUs; all receive the full result O ⁣(2N1ND)O\!\left(2 \cdot \frac{N-1}{N} \cdot D\right)

Critical identity:

All-Reduce=Reduce-Scatter+All-Gather \text{All-Reduce} = \text{Reduce-Scatter} + \text{All-Gather}

6.2 Per-Transformer-Layer Communication (Vanilla TP)

Each Transformer layer requires:

Sub-block Forward Backward
MHA 1 × all-reduce 1 × all-reduce
MLP 1 × all-reduce 1 × all-reduce
Total 2 × all-reduce 2 × all-reduce

6.3 Conjugate Operator Pairs (ff, ff^*)

Tensor parallelism uses conjugate operator pairs that swap roles between forward and backward passes:

Forward:f=no-op,f=all-reduceBackward:f=all-reduce,f=no-op \begin{aligned} &\textbf{Forward:} \quad f = \text{no-op}, \quad f^* = \text{all-reduce} \\ &\textbf{Backward:} \quad f = \text{all-reduce}, \quad f^* = \text{no-op} \end{aligned}

Rationale: In the forward pass, inputs are already replicated across TP ranks (making ffa no-op), but outputs from row-linear layers must be summed (ff^*= all-reduce). In the backward pass, gradient flow reverses: gradients arriving atff^*are already correct (no-op), while gradients atff must be synchronized (all-reduce).

6.4 Critical Path and Overlap Limitations

Unlike ZeRO, where communication (all-gather of parameters) can be overlapped with computation, TP places communication directly on the critical path:

Tlayer=Tcompute+Tall-reduceexposed T_{\text{layer}} = T_{\text{compute}} + T_{\text{all-reduce}}^{\text{exposed}}

The all-reduce at the end of each MLP and MHA block cannot begin until the local matrix multiplication completes, and the subsequent LayerNorm cannot begin until the all-reduce finishes. This creates a synchronization barrier on every layer.

Partial mitigation (Megatron-LM, Nanotron): Overlap all-gather with FC1 computation using block/chunked matrix multiplication:

  • Divide the weight into chunks
  • As each chunk’s all-gather completes, begin matmul for that chunk while the next chunk’s all-gather proceeds asynchronously

Advanced mitigation (Domino): Novel scheduling techniques to maximize overlap of communication and computation within TP regions.


7. Scaling Behavior and Trade-offs

7.1 Intra-node vs. Inter-node Communication

Regime Interconnect Bandwidth (typical) Latency
NTP8N_{\text{TP}} \leq 8 (intra-node) NVLink / NVSwitch 600–900 GB/s (bidirectional, per GPU) ~µs
NTP>8N_{\text{TP}} > 8 (inter-node) InfiniBand / EFA 50–400 GB/s (per node) ~10–100 µs

Empirical observation: Throughput drops significantly at NTP=16N_{\text{TP}} = 16(crossing node boundary) and precipitously atNTP=32N_{\text{TP}} = 32.

7.2 Memory Reduction

For a model with PPtotal parameters, with TP degreeNN:

Parameters per GPU=PN \text{Parameters per GPU} = \frac{P}{N}
Gradients per GPU=PN \text{Gradients per GPU} = \frac{P}{N}
Optimizer states per GPU=SPN \text{Optimizer states per GPU} = \frac{S \cdot P}{N}

where SSis the optimizer state multiplier (e.g.,S=12S = 12 bytes per parameter for Adam in mixed precision: 4 bytes fp32 master weights + 4 bytes first moment + 4 bytes second moment).

Activation memory in TP regions:

Activation per GPU (TP region)=bshN(intermediate activations) \text{Activation per GPU (TP region)} = \frac{b \cdot s \cdot h}{N} \quad \text{(intermediate activations)}

However, operations like LayerNorm and Dropout still require the full activation tensor (b,s,h)(b, s, h), partially negating activation memory savings.

7.3 Throughput–Memory Trade-off

Effective throughput per GPUTcomputeTcompute+Tcommunication \text{Effective throughput per GPU} \propto \frac{T_{\text{compute}}}{T_{\text{compute}} + T_{\text{communication}}}

As NN increases:

  • TcomputeT_{\text{compute}} decreases (less work per GPU)
  • TcommunicationT_{\text{communication}} increases (more participants, potentially crossing node boundaries)
  • Net effect: diminishing returns beyond N=8N = 8 (single node)

The benefit is that reduced memory per GPU allows larger batch sizes, which can compensate for per-GPU throughput loss at the system level.


8. Sequence Parallelism (SP)

8.1 Motivation

Even with TP, operations outside the attention and MLP sub-blocks — specifically LayerNorm and Dropout — require the full hidden dimension hhand therefore cannot be sharded alonghh. These operations still store activations of shape (b,s,h)(b, s, h) on each GPU, creating memory bottlenecks.

Sequence Parallelism shards these operations along the sequence dimension ss instead:

SP shard on GPU i:XiRb×sN×h \text{SP shard on GPU } i: \quad X_i^{*} \in \mathbb{R}^{b \times \frac{s}{N} \times h}

8.2 Why LayerNorm Requires Full Hidden Dimension

LayerNorm computes:

LayerNorm(x)=γxμσ2+ϵ+β \text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

where the statistics are computed across the hidden dimension hh:

μ=1hj=1hxj,σ2=1hj=1h(xjμ)2 \mu = \frac{1}{h} \sum_{j=1}^{h} x_j, \qquad \sigma^2 = \frac{1}{h} \sum_{j=1}^{h} (x_j - \mu)^2

Since μ\muandσ2\sigma^2require access to allhhelements, we cannot shard LayerNorm alonghh. However, different sequence positions are independent for LayerNorm, so we can shard along ss.

8.3 Conjugate Operator Pairs for SP (gg, gg^*)

Forward:g=all-gather,g=reduce-scatterBackward:g=reduce-scatter,g=all-gather \begin{aligned} &\textbf{Forward:} \quad g = \text{all-gather}, \quad g^* = \text{reduce-scatter} \\ &\textbf{Backward:} \quad g = \text{reduce-scatter}, \quad g^* = \text{all-gather} \end{aligned}

These replace the ff/ff^* operators at the boundaries between SP regions and TP regions.

8.4 Forward Pass Through a TP+SP Transformer Layer (Step-by-Step)

Consider an MLP sub-block with TP+SP on N=2N=2 GPUs:


Step 1: Initial LayerNorm (SP Region)

Each GPU iiholdsXiRb×s2×hX_i^* \in \mathbb{R}^{b \times \frac{s}{2} \times h} (sharded along sequence dimension).

LayerNorm is computed independently per sequence position:

Yi=LayerNorm(Xi)Rb×s2×h Y_i^* = \text{LayerNorm}(X_i^*) \in \mathbb{R}^{b \times \frac{s}{2} \times h}

Step 2: Transition SP → TP (gg = all-gather)

Gather sequence chunks to reconstruct the full sequence on each GPU:

Y=AllGather(Y1,Y2)Rb×s×h Y = \text{AllGather}(Y_1^*, Y_2^*) \in \mathbb{R}^{b \times s \times h}

This is necessary because the column-linear layer (FC1) needs the full input.


Step 3: First Linear Layer — Column-Linear (TP Region)

GPU iiholds column shardW1(i)Rh×4hNW_1^{(i)} \in \mathbb{R}^{h \times \frac{4h}{N}}:

Zi=σ(YW1(i))Rb×s×4hN Z_i^* = \sigma(Y \cdot W_1^{(i)}) \in \mathbb{R}^{b \times s \times \frac{4h}{N}}

where σ\sigma is the activation function (GeLU/SiLU), applied element-wise.


Step 4: Second Linear Layer — Row-Linear (TP Region)

GPU iiholds row shardW2(i)R4hN×hW_2^{(i)} \in \mathbb{R}^{\frac{4h}{N} \times h}:

Y^i=ZiW2(i)Rb×s×h \hat{Y}_i = Z_i^* \cdot W_2^{(i)} \in \mathbb{R}^{b \times s \times h}

Step 5: Transition TP → SP (gg^* = reduce-scatter)

Y^j=ReduceScatter(Y^1,Y^2)Rb×sN×h \hat{Y}_j^* = \text{ReduceScatter}(\hat{Y}_1, \hat{Y}_2) \in \mathbb{R}^{b \times \frac{s}{N} \times h}

This simultaneously:

  1. Reduces (sums) the partial results from the row-linear operation (required for correctness: Y^=iY^i\hat{Y} = \sum_i \hat{Y}_i)
  2. Scatters the result along the sequence dimension (returning to SP sharding)

8.5 Activation Shape Summary Table

Region Vanilla TP TP + SP
Enter TP (column-linear input) hh: full, ss: full hh: full, ss: all-gather to full
TP region (between FC1 and FC2) hh: sharded (hN\frac{h}{N}), ss: full hh: sharded (hN\frac{h}{N}), ss: full
Exit TP (row-linear output) hh: full + all-reduce, ss: full hh: full + reduce-scatter, ss: reduce-scatter to sharded
SP region (LayerNorm, Dropout) hh: full, ss: full hh: full, ss: sharded (sN\frac{s}{N})

8.6 Maximum Activation Size

Vanilla TP:

Max activation per GPU=bsh \text{Max activation per GPU} = b \cdot s \cdot h

(Full activations required in LayerNorm/Dropout regions.)

TP + SP:

Max activation per GPU=bshN \text{Max activation per GPU} = \frac{b \cdot s \cdot h}{N}

At every point in the computation, activations are sharded along either hh(in TP regions) orss(in SP regions), never requiring the full(b,s,h)(b, s, h) tensor on any single GPU.

8.7 Communication Cost Equivalence

Vanilla TP per transformer layer: 2 × all-reduce (forward), 2 × all-reduce (backward)

TP+SP per transformer layer: 2 × all-gather + 2 × reduce-scatter (forward), 2 × reduce-scatter + 2 × all-gather (backward)

Since:

cost(all-reduce)=cost(all-gather)+cost(reduce-scatter) \text{cost}(\text{all-reduce}) = \text{cost}(\text{all-gather}) + \text{cost}(\text{reduce-scatter})

and TP+SP replaces each all-reduce with one all-gather and one reduce-scatter (at different points in the computation):

Total communication(TP+SP)=Total communication(Vanilla TP) \text{Total communication}(\text{TP+SP}) = \text{Total communication}(\text{Vanilla TP})

SP achieves strictly better activation memory with no additional communication overhead.

8.8 Gradient Synchronization Notes

LayerNorm weights in SP regions: Since each TP rank sees different sequence positions (but the same LayerNorm parameters γ,β\gamma, \beta), gradients for γ\gammaandβ\beta will differ across ranks. An all-reduce of LayerNorm gradients is required during the backward pass:

γ=AllReduce ⁣(γ(1),γ(2),,γ(N)) \nabla_\gamma = \text{AllReduce}\!\left(\nabla_\gamma^{(1)}, \nabla_\gamma^{(2)}, \ldots, \nabla_\gamma^{(N)}\right)

This overhead is negligible since LayerNorm has only 2h2h parameters (compared to the full model’s billions).

Dropout in SP regions: Random masks must be synchronized across TP ranks to maintain deterministic behavior. In practice, this is achieved by synchronizing the random seed across ranks:

seeddropout(i)=seedglobal,  i{1,,N} \text{seed}_{\text{dropout}}^{(i)} = \text{seed}_{\text{global}}, \quad \forall \; i \in \{1, \ldots, N\}

9. Embedding Layer Treatment

The vocabulary embedding ERV×hE \in \mathbb{R}^{V \times h}(whereVV is vocabulary size) is typically sharded along the vocabulary dimension (row-linear):

Configuration Sharding Communication
Vanilla TP hh: full (all-reduce for correctness), ss: full All-reduce
TP + SP hh: full (reduce-scatter for correctness), ss: reduce-scatter to sharded Reduce-scatter

10. Limitations of TP + SP

Limitation Description
Sequence length scaling In the TP region, activations are (b,s,hN)(b, s, \frac{h}{N}); as ssgrows, activation memory still scales linearly withss in these regions
Inter-node communication For NTP>8N_{\text{TP}} > 8 (exceeding a single node), bandwidth drops from NVLink (~900 GB/s) to network interconnect (~100–400 GB/s), causing severe throughput degradation
Critical path communication All-gather and reduce-scatter remain on the critical path and cannot be fully overlapped with computation
Head count constraint NTPnhN_{\text{TP}} \leq n_h; limits maximum parallelism degree

Solutions to these limitations:

  • Context Parallelism (CP): Addresses sequence-length-induced activation memory blowup by sharding the attention computation across the sequence dimension
  • Pipeline Parallelism (PP): Addresses model-too-large-for-one-node by partitioning layers across nodes, avoiding the need for NTP>8N_{\text{TP}} > 8

11. Complete Communication Pattern Summary

For a single Transformer layer with TP+SP:

Forward:g  (all-gather)SPTP    Column-Linear    σ    Row-Linear    g  (reduce-scatter)TPSPBackward:g  (all-gather)SPTP    Row-Linear    σ    Column-Linear    g  (reduce-scatter)TPSP \boxed{ \begin{aligned} &\textbf{Forward:} \quad \underbrace{g \;(\text{all-gather})}_{\text{SP} \to \text{TP}} \;\to\; \text{Column-Linear} \;\to\; \sigma \;\to\; \text{Row-Linear} \;\to\; \underbrace{g^* \;(\text{reduce-scatter})}_{\text{TP} \to \text{SP}} \\[6pt] &\textbf{Backward:} \quad \underbrace{g^* \;(\text{all-gather})}_{\text{SP} \to \text{TP}} \;\to\; \nabla\text{Row-Linear} \;\to\; \nabla\sigma \;\to\; \nabla\text{Column-Linear} \;\to\; \underbrace{g \;(\text{reduce-scatter})}_{\text{TP} \to \text{SP}} \end{aligned} }

This pattern repeats for both the MHA and MLP sub-blocks within each Transformer layer, yielding 4 communication operations per layer per pass (2 for MHA + 2 for MLP).

PreviousTensor Parallelism (TP) and Sequence Parallelism (SP) NextTensor Parallelism (TP) and Sequence Parallelism (SP)

Generated from llm_training_at_scale at .