Noam Shazeer, Youlong Cheng, Niki Parmar
Dustin Tran, Ashish Vaswani, Penporn Koanantakool, Peter Hawkins, HyoukJoong Lee
Mingsheng Hong, Cliff Young, Ryan Sepassi, Blake Hechtman
Google Brain
{noam, ylc, nikip, trandustin, avaswani, penporn, phawkins,
hyouklee, hongm, cliffy, rsepassi, blakehechtman}@google.com
32nd Conference on Neural Information Processing Systems (NIPS 2018), Montréal, Canada.
Batch-splitting (data-parallelism) is the dominant distributed Deep Neural Network (DNN) training strategy, due to its universal applicability and its amenability to Single-Program-Multiple-Data (SPMD) programming. However, batch-splitting suffers from problems including the inability to train very large models (due to memory constraints), high latency, and inefficiency at small batch sizes. All of these can be solved by more general distribution strategies (model-parallelism). Unfortunately, efficient model-parallel algorithms tend to be complicated to discover, describe, and to implement, particularly on large clusters. We introduce Mesh-TensorFlow, a language for specifying a general class of distributed tensor computations. Where data-parallelism can be viewed as splitting tensors and operations along the "batch" dimension, in Mesh-TensorFlow, the user can specify any tensor-dimensions to be split across any dimensions of a multi-dimensional mesh of processors. A Mesh-TensorFlow graph compiles into a SPMD program consisting of parallel operations coupled with collective communication primitives such as Allreduce. We use Mesh-TensorFlow to implement an efficient data-parallel, model-parallel version of the Transformer [21] sequence-to-sequence model. Using TPU meshes of up to 512 cores, we train Transformer models with up to 5 billion parameters, surpassing state of the art results on WMT'14 English-to-French translation task and the one-billion-word language modeling benchmark. Mesh-Tensorflow is available at https://github.com/tensorflow/mesh .
Executive Summary: Training large deep neural networks on supercomputers faces significant hurdles today, as models grow to billions of parameters to tackle complex tasks like machine translation and language modeling. Traditional data parallelism, which splits training data across processors while replicating model weights, hits limits: it demands excessive memory for big models, slows down with small batches, and causes high synchronization delays. This matters now because advancing AI requires scaling compute efficiently on hardware like Google's Tensor Processing Units (TPUs), yet current methods struggle to use massive clusters without waste or failure, delaying breakthroughs in natural language processing and beyond.
This paper introduces Mesh-TensorFlow, a simple extension to the TensorFlow framework, to enable flexible distribution of computations across processor arrays, or "meshes." It aims to overcome data parallelism's flaws by allowing users to split not just data batches but also model dimensions—like hidden layers or attention heads—across processors, compiling into efficient programs that run the same code on all devices with minimal communication.
The authors developed Mesh-TensorFlow as a high-level language where tensor dimensions get names, and users map them to mesh dimensions for splitting. This generates single-program-multiple-data code using local operations and collective communications, such as summing gradients across groups (Allreduce). They focused on reliable, identical processors in multi-dimensional meshes, assuming dimension sizes divide evenly, and tested over recent years on TPU clusters up to 512 cores. No fieldwork or human trials were involved; instead, they implemented and benchmarked a distributed version of the Transformer model—a key architecture for sequence tasks—on standard datasets like WMT’14 English-to-French translation (millions of sentence pairs) and the one-billion-word language benchmark (vast text corpus).
The core findings highlight Mesh-TensorFlow's power. First, it scales Transformer training to 5 billion parameters on 512 TPUs with over 50% efficiency—achieving 6 peak floating-point operations per second out of 11.5 possible—while keeping per-processor memory and communication steady as hardware grows. Second, larger models outperform priors: a 2.9-billion-parameter version hit 43.9 BLEU score on WMT’14 translation, topping the previous best of 41.8 from smaller ensembles; similarly, a 4.9-billion-parameter model reached 24.0 perplexity on language modeling, beating the prior 26.1. Third, quality rises predictably with size—perplexity dropped from 35.0 at 0.14 billion parameters to 24.0 at 4.9 billion—yielding more coherent text, like realistic sentences on tech topics versus fragmented ones from tiny models. Fourth, hybrid layouts (mixing data and model splitting) allow cubic processor scaling with only linear increases in batch or layer sizes, cutting communication bottlenecks by 40-50% compared to pure data parallelism in examples. Fifth, training times stayed practical: 13 hours for the largest language model, 22 hours for translation.
These results mean organizations can train superior AI models faster and cheaper on existing supercomputers, bypassing memory walls that once forced smaller models or costly ensembles. For instance, it slashes risks in AI deployment by enabling single, massive models that match or exceed multi-model setups, improving translation accuracy by about 5% and language prediction by 8-10% over baselines—vital for applications in global communication, search engines, and chatbots. This advances beyond earlier work, like basic model parallelism, by automating complex strategies without hand-coding, though it assumes ideal hardware links; on mismatched networks, efficiency might dip 10-20%. Overall, it accelerates AI progress, potentially cutting development timelines from months to days and supporting compliance with compute-intensive research policies.
Leaders should integrate Mesh-TensorFlow into TensorFlow workflows for large-scale language AI projects, starting with Transformer-based systems on TPU or similar accelerators—prioritizing the hybrid layout for clusters over 64 cores to balance speed and scale. For options, pure model parallelism suits memory-bound tasks but needs bigger layers; add data splitting for high-throughput needs, trading minor communication for broader batches. Before full rollout, run pilots on internal datasets to tune layouts. Further steps include automating layout optimization via search algorithms and extending to convolutions or non-TPU hardware; gather more data on varied networks to refine assumptions.
Confidence in these findings is high for TPU environments, backed by reproducible benchmarks and open-source code, establishing new records. However, limitations include reliance on even dimension splits (risking redesigns for odd sizes), focus on reliable meshes (gaps for flaky hardware), and untested scalability beyond 512 cores. Proceed cautiously on diverse setups without validation, but the framework's generality promises broad impact.
Section Summary: Batch-splitting, or data-parallelism, is the most common way to train large deep neural networks across multiple computers because it's straightforward and fits simple programming models. However, it runs into big problems with extremely large models, like needing too much memory for data storage and taking too long to sync updates, while alternative approaches like model-parallelism are hard to set up and create inefficient code. To fix this, the authors created Mesh-TensorFlow, a tool that lets users easily divide computations across a grid of processors in flexible ways, which they used to train massive Transformer models with billions of parameters on hundreds of processors, achieving top results in language translation and modeling tasks.
Batch-splitting (data-parallelism) is the dominant distributed Deep Neural Network (DNN) training strategy, due to its universal applicability and its amenability to Single-Program-Multiple-Data (SPMD) programming. However, batch-splitting suffers from several major problems when training very large models. The memory required to store parameters and/or activations and the time necessary to synchronize parameters can make purely-data-parallel algorithms impossible or inefficient. Different distribution strategies (model-parallelism [2]) can solve these issues, but specifying these strategies can be complicated, and the current MIMD implementations generate very large programs which can be difficult to compile and to optimize.
We solve this problem by introducing Mesh-TensorFlow, a language for specifying a general class of distributed tensor computations. Where data-parallelism can be viewed as splitting tensors and operations along the "batch" dimension, in Mesh-TensorFlow, the user can specify any tensor-dimensions to be split across any dimensions of a multi-dimensional mesh of processors. A Mesh-TensorFlow graph compiles into a SPMD program consisting of parallel operations coupled with collective communication primitives such as Allreduce. We use Mesh-TensorFlow to implement an efficient data-parallel, model-parallel version of the Transformer [1] sequence-to-sequence model. Using TPU meshes of up to 512 cores, we train Transformer models with up to 5 billion parameters, surpassing state-of-the-art results on WMT'14 English-to-French translation task and the one-billion-word Language modeling benchmark.
Section Summary: This section assumes computing clusters made up of identical, dependable processors, each with its own local memory, and it defines a "mesh" simply as a way to organize these into an n-dimensional grid. The mesh is just an abstract label and doesn't dictate the actual physical connections between processors, so the same hardware—like a 512-core cluster—can be viewed as a 3D grid of 16x16x2, a 2D grid of 32x16, or even a 1D line of 512 processors. However, the real network layout still influences performance, particularly for operations like MPI Allreduce, which work best when groups of processors are directly linked.
While much work deals with heterogeneous and/or unreliable hardware, we focus on clusters of identical, reliable processors, each with a local memory. We define a mesh as an n-dimensional array of such processors. The mesh is only a naming abstraction and does not imply a physical network topology. As such, different meshes can be defined over the same set of physical processors. For example, a 512-core TPU cluster with a 16x16x2 toroidal network interconnect could be represented by a 3-dimensional mesh with shape [16, 16, 2], a two-dimensional mesh with shape [32, 16], a one-dimensional mesh with shape [512], etc. The physical network topology does affect performance; particularly important is the performance of MPI Allreduce, grouped by splitting the mesh by a subset of the dimensions, which can be very efficient [3] [4] if each such group is physically connected.
Section Summary: In synchronous data-parallel training, each processor holds a full copy of the model's parameters and works on a portion of the training data batch, computing local gradients from forward and backward passes on its share. These gradients are then combined across all processors through a summing operation called allreduce, after which every processor updates its parameters with the averaged result. This approach, known as Single-Program-Multiple-Data (SPMD) batch-splitting, runs the same program on all processors while dividing computations along the batch dimension, and it inspires broader methods to split across other data dimensions.
We first review a commonly-used variant of synchronous data-parallelism where each processor keeps an identical copy of all parameters (Algorithm 1). For each step, the batch of training examples is split into sub-batches, one for each processor. Each processor computes the forward and backward passes on its sub-batch, resulting in gradients on the model parameters. These gradients are then summed across all processors and the results broadcast to all processors (MPI-allreduce). Finally, each processor updates its own copy of the parameters.
Compute partial parameter gradients $\nabla Q(W^{(t)}, b^{(t)}_p)$ *Local computation*
$\nabla Q(W^{(t)}, b^{(t)}) = \sum_{p' \in P} \nabla Q(W^{(t)}, b^{(t)}_p')$ *$Allreduce$*
$W^{(t+1)} = Update(W^{(t)}, \nabla Q(W^{(t)}, b^{(t)})$ *Local computation*
This algorithm is typically implemented using Single-Program-Multiple-Data (SPMD) programming, with every processor running the same program of local operations and MPI-allreduce primitives.
One way to see this algorithm is that every tensor and every operation in the computation is either split across all processors (if it has a "batch" dimension), or fully replicated across all processors (if it does not have a "batch" dimension). Operations which reduce out the "batch" dimension require an additional MPI-allreduce to produce the correct result. We can describe this as splitting the computation across the "batch" dimension. Mesh-TensorFlow generalizes this idea to splitting computations across arbitrary dimensions.
Section Summary: Mesh-TensorFlow builds on a basic method for splitting batches of data across multiple processors by allowing splits along various dimensions of the data tensors, ensuring each processor handles a portion of the work with minimal communication between them. It introduces named dimensions for tensors, such as "batch," to keep splitting consistent across different parts of the computation, and organizes processors into an n-dimensional grid with its own named dimensions. A key feature is the computation layout, which specifies how tensor dimensions map to processor grid dimensions—for instance, mapping the "batch" dimension to all processors replicates everything else while distributing the batch evenly.
Mesh-Tensorflow generalizes from the batch-splitting algorithm described in section Algorithm 1 to allow for splitting across different Tensor dimensions. The similarities are as follows:
The new elements in Mesh-TensorFlow are as follows:
"all_processors" and using the computation layout [("batch", "all_processors")]. This means that all tensors with a "batch" dimension are split along that dimension across all processors, while all other tensors are fully replicated.Section Summary: In Mesh-TensorFlow, each tensor is divided into a single slice distributed across processors in a computing mesh, guided by a layout that assigns the tensor's dimensions to unique mesh dimensions without overlap. If no dimensions are assigned, the entire tensor is replicated on every processor for full availability. For assigned dimensions, each processor's slice is limited to a specific portion, or stripe, of that dimension matching its position in the mesh, and the tensor's size must divide evenly into the mesh's size for proper distribution.
A tensor is represented as one slice of the tensor per processor. The layout of a tensor is an injective partial map from the tensor's dimensions to dimensions of the mesh, and is computed as the restriction of the global computation layout to that tensor's dimensions. It is illegal for two dimensions of the same tensor to map to the same mesh dimension. If a tensor's layout is empty, it is fully replicated on each processor. For every (tensor-dimension, mesh-dimension) pair in the tensor's layout, the slice on a processor is restricted along that tensor-dimension to a stripe corresponding to that processor's coordinate along that mesh-dimension. The current implementation of Mesh-TensorFlow requires the size of the tensor-dimension to be evenly divisible by the size of the mesh-dimension.
Section Summary: Mesh-TensorFlow implements operations through parallel computing on multiple processors, often involving communication for distributed tensors. Simple component-wise operations and reductions like summing or finding maximums work by computing locally on each processor's data slice and then sharing results across the network if needed, while more complex Einstein summations—such as matrix multiplications—follow a similar pattern of local calculations followed by collective reductions. Reshaping tensors can also require network exchanges, like gathering data from all processors or redistributing it, especially when changing how dimensions are split across the system to switch between data and model parallelism.
Each operation is implemented by parallel computation on every processor, and sometimes collective communication. We describe the implementations of some important operations here:
Component-wise Operations
Mesh-TensorFlow supports component-wise operations where the shapes (and hence the layouts) of the input and output tensors are identical. These are trivially implemented by parallel operations on each processor to compute that processor's slice of the output from that processor's slice(s) of the input(s).
Reduction (reduce_sum(), reduce_max(), etc.)
Mesh-TensorFlow supports reductions where the output dimensions are a subset of the input dimensions. These can be implemented by local reductions of each slice, followed by MPI-allreduce across any mesh dimensions corresponding to reduced-out Tensor dimensions. The allreduce operation is necessary because the local reduction only sums across a subset of the split tensor-dimension. Bandwidth-efficient implementations of allreduce exist when the processors for each group are connected in any type of tree. [3] [4]
Einstein Summation (matrix multiplication, etc.)
Einstein-summation (einsum) notation (as defined in numpy, TensorFlow, etc.) is a way of expressing a class of operations including (batch) matrix multiplication, reductions and broadcasts, where the operation is defined by the names of the dimensions of the input and output tensors. Mesh-TensorFlow's use of named dimensions makes using einsum particularly convenient. Einsum can be defined as broadcasting all inputs to a shape consisting the union of all their dimensions, multiplying them component-wise, then reducing out all dimensions not in the specified output shape. Einsum is implemented by parallel einsum operations on each processor of that processor's input slices, followed by MPI-allreduce across any mesh dimensions corresponding to reduced-out Tensor dimensions.
While reshape is simple in the non-distributed case, Mesh-TensorFlow reshape can require network communication, since the layout of the output tensor may differ from that of the input tensor. Even keeping the same dimension sizes, changing the dimension names (and hence the layout) can result in several different communication patterns: If a dimension is split in the input but not in the output, the implementation involves MPI-allgather communication across the corresponding mesh-dimension. If a dimension is split in the output but not in the input, the implementation involves no communication, just slicing on each processor. MPI-alltoall is used in the case where different dimensions in the input and the output are split across the same mesh dimension, as might be the case when switching between data-parallelism and model-parallelism for different layers of the same model, as in [5].
Section Summary: Mesh-TensorFlow is a programming language very similar to TensorFlow, featuring familiar elements like computational graphs, data structures called tensors, basic operations, variables, device networks known as meshes, and built-in tools for calculating gradients. The main difference is that it gives each dimension of a tensor both a name and a size, creating fixed shapes that the system figures out automatically whenever possible, with operations like addition adjusting sizes on the fly if they partially match. Built as a Python library, it lets users construct these graphs in Python code, which then converts into TensorFlow for running on specialized hardware like TPUs or multiple CPUs and GPUs.
The Mesh-TensorFlow language is nearly identical to TensorFlow [6], with the familiar notions of graphs, tensors, operations, variables, devices (called meshes), and automatic gradient computation. The principal difference is that in Mesh-TensorFlow, tensor-dimensions have a name as well as a size. The shape of each tensor is a statically-known tuple of such dimensions. Shapes are inferred automatically when possible, as they are in TensorFlow. Binary component-wise operations like addition employ implicit broadcasting in the case where the shape of one operand is a subset of the shape of the other.
The initial implementation of Mesh-TensorFlow is a Python library. The user builds a Mesh-TensorFlow graph in python, which the library "lowers" to generate part of a TensorFlow graph. As of the writing of this paper, implementations exist for generating SPMD TensorFlow code for TPUs, or MIMD code (using device placement) for multi-CPU/GPU configurations.
Section Summary: This section explains a basic neural network setup with two fully connected layers, where input data passes through a hidden layer activated by ReLU before reaching the output, implemented using Mesh-TensorFlow code that handles batches of data. It demonstrates different ways to distribute the computation across multiple processors for efficiency, such as splitting the data batch for parallel processing, dividing the hidden layer's units, or combining both methods on multi-dimensional processor grids, which vary in communication overhead and speed. The discussion highlights efficient layouts that minimize replication and communication, warns against wasteful or invalid setups, and includes a table comparing computational and memory costs.
We consider a simple example of two fully-connected layers in the middle of a neural network. The input layer $x$ and the output layer $y$ each have $d_{io}$ units, and the hidden layer $h$ has $d_h$ units. The hidden layer also has a bias and $Relu$ activation.
$ y = Relu(xw + bias)v $
This Mesh-TensorFlow code fragment runs these layers on a batch $x$ of $batch_size=b$ inputs.
...
batch = mtf.Dimension("batch", b)
io = mtf.Dimension("io", $d_io)$
hidden = mtf.Dimension("hidden", $d_h)$
# x.shape == [batch, io]
w = mtf.get_variable("w", shape=[io, hidden])
bias = mtf.get_variable("bias", shape=[hidden])
v = mtf.get_variable("v", shape=[hidden, io])
h = mtf.relu(mtf.einsum(x, w, output_shape=[batch, hidden]) + bias)
y = mtf.einsum(h, v, output_shape=[batch, io])
...
The code above defines only the mathematical model. We now discuss several different computation layouts. Each will produce identical results, but will have different performance characteristics. We also provide illustrations of the layouts in Appendix A.
To train the above model in data-parallel mode on a mesh of $n$ processors, we would define:
mesh_shape = [("all", n)]
computation_layout = [("batch", "all")]
When the Mesh-TensorFlow graph is compiled with this layout, the parameter tensors $w$, $v$, and $bias$ are replicated on all processors, but the activation matrices $x$, $h$, $y$, etc. are split across the batch dimension. For example, each processor keeps a slice of $x$ with shape $[\frac{b}{n}, d_{io}]$.
There is no inter-processor communication in the forward pass. However, the gradient computations for the parameters are mtf.einsum operations which reduce out the batch dimension, and hence produce $Allreduce$ operations when they are compiled. The number of values allreduced per processor is equal to the number of parameters, approximately $2d_{io}d_h$.
Rather than splitting the batch, we can split the units in the hidden layer:
mesh_shape = [("all", n)]
computation_layout = [("hidden", "all")]
When the Mesh-TensorFlow graph is compiled with this layout, the input and output layers $x$, and $y$ are replicated on all processors, but the hidden activations $h$ and the parameter tensors $w$, $v$ and $bias$ are all split across the hidden dimension. For example, each processor keeps a slice of $w$ with shape $[d_{io}, \frac{d_h}{n}]$ and a slice of $v$ with shape $[\frac{d_h}{n}, d_{io}]$.
When computing $y$, the split hidden dimension is reduced out. Consequently, the results of that computation get allreduced across all processors. A similar allreduce happens in computing the gradients on $x$. In all, the number of values allreduced per processor is $2bd_{io}$.
On a two-dimensional mesh of $r \times c$ processors, we can employ both data-parallelism and model-parallelism:
mesh_shape = [("rows", r), ("cols", c)]
computation_layout = [("batch", "rows"), ("hidden", "cols")]
In this layout, each row of processors handles a fraction of the batch, while each column of processors handles a fraction of the hidden units. Each processor keeps a slice of x with shape $[\frac{b}{r}, d_{io}]$, with processors in the same row having identical slices. The hidden activation tensor $h$ is tiled in two dimensions, with each processor keeping a slice with shape $[\frac{b}{r}, \frac{d_h}{c}]$.
This layout causes partitioned-allreduce operations in several places. For example, in computing $y$, we reduce out the hidden dimension, which is split over the cols dimension of the mesh, so the results of the operation need to be summed up by processor-column, as opposed to over the entire mesh. In all, the number of values allreduced per processor is $\frac{2bd_{io}}{r} + \frac{2d_{io}d_h}{c}$
If we have a three-dimensional mesh of processors, we can even split the computation in three dimensions:
mesh_shape = [("rows", r), ("cols", c), ("planes", p)]
computation_layout = [
("batch", "rows"), ("hidden", "cols"), ("io", "planes"])
In this case, every matrix in the computation is tiled across two mesh dimensions and replicated in the third, and every einsum requires an allreduce across one mesh dimension.
For a computation layout to be efficient, all expensive operations need to be split (as opposed to replicated) across all mesh dimensions. For example, the empty layout below produces correct results, but since it replicates all computation on every processor, it saves no time or memory. A general rule is that any expensive einsum operation should have one input dimension that is split across each batch dimension.
mesh_shape = [("all", n)]
computation_layout = []
The computation layout below is illegal, because it causes the tensor $h$ to have two dimensions which are split across the same dimension of the mesh.
mesh_shape = [("all", n)]
computation_layout = [("batch", "all"), ("hidden", "all")]
::: {caption="Table 1: Computation, communication and memory costs for different layouts of the computation in Algorithm 1. Constant factors and lower-order terms are dropped."}

:::
Table 1 shows the computational costs associated with our example computation layouts. The computation time is dominated by that of einsum operations. The communication time comes from the Allreduce operations, which are necessary whenever the inner dimension of einsum is split. Assuming that the mesh has physical links between all pairs of logically adjacent processors, each Allreduce operations can be done in time proportional to the size of one slice divided by the per-link network bandwidth [4].
The network-boundedness of the computation is proportional to the value shown in the table column marked $\frac{communication}{computation}$, with the constant of proportionality depending on the ratio of communication and computation speeds on the given hardware. In the data-parallel layout, the value is $\frac{n}{b}$, the inverse of the per-processor batch size. Performance suffers if the per-processor batch is too small. In the model-parallel layout, the value is $\frac{n}{d_h}$, the inverse of the number of hidden units per processor. Performance suffers if the hidden layer is sliced too finely. For good performance, batch size is irrelevant, but we need the hidden layer to get larger as we increase the number of processors. In the first data-parallel, model-parallel layout, the value is $\frac{c}{d_h} + \frac{r}{b}$. In this layout, we can quadratically increase the number of processors while only linearly increasing the batch size and hidden layer sizes necessary to maintain good efficiency. The final layout lets us cubically increase the number of processors in a 3-dimensional mesh, while only linearly increasing the batch size and the layer sizes.
Section Summary: Researchers created a version of the Transformer model that spreads its key components—like vocabulary size, hidden layer depth, and attention heads—across multiple processors to handle larger models efficiently on TPU clusters. By scaling these elements alongside the number of processors and combining it with data splitting for batches, they trained massive models up to 5 billion parameters while maintaining high computational speed on up to 512 cores. Tests on language modeling and machine translation tasks showed that bigger models performed better, achieving top scores like a perplexity of 23.5 on a billion-word dataset and a BLEU score of 43.9 on English-to-French translation.
We implemented a model-parallel layout of the Transformer attention-based sequence-to-sequence model described in [1]. The complete implementation is available in the tensor2tensor library on github. The layout is given by:
mesh_shape = [("all", n)]
computation_layout = [
("vocab", "all"), ("d_ff", "all"), ("heads", "all")]
That is, the dimensions representing the vocabulary size, the size of the feed-forward hidden layer, and the number of attention heads are each split across all processors. This layout works because every expensive operation in the model has exactly one of these dimensions, and no tensor in the model has more than one. Similarly to the model-parallel layout for our example network (Section 8.2), network-boundedness and memory usage per processor remain constant if we scale all of these dimensions proportionally to the number of processors. We did just this, training transformer models with ever larger hidden layers and numbers of attention heads on ever larger TPU clusters (we did not increase the vocabulary size). As expected, we saw very similar performance characteristics between the models. This scaling turns out to be highly beneficial to model quality (Section 9.1).
To use even more processors, we combined this model-parallelism with data parallelism, splitting the batch across one dimension of a 2-dimensional TPU mesh and the dimensions described above across the other dimension of the mesh:
mesh_shape = [("rows", r), ("cols", c")]
computation_layout = [("batch", "rows"), ("vocab", "cols"),
("d_ff", "cols"), ("heads", "cols")]
This layout maintains constant performance if the batch size is scaled proportionally to r and the mentioned model dimensions are scaled proportionally to c. Using this layout, we trained Transformer models with feed-forward hidden dimensions up to 262144 and up to 256 attention heads on 2-dimensional TPUv2 meshes of up to 16x32=512 cores, maintaining computational efficiency of over 50% (6 PFLOP/s out of a maximum 11.5 PFLOP/s) on the largest models.
To examine the benefit of scaling the Transformer model in the manner suggested by the previous section, we trained such models on machine translation and language modeling tasks. Results are given in Table 2 and Table 3.
For the billion-word language modeling benchmark, we trained the models for 10 epochs. The largest model (4.9B parameters) took 13 hours to train on a 512-core TPUv2 cluster. Batch size for all models was 256 sequences of 256 tokens each (each sequence was the concatenation of multiple training sentences). The batch was split along the mesh dimension of size 16 and the model dimensions were split along the mesh dimension of size 32. Per-word dev-perplexity for the largest model was 24.0, but dropped to 23.5 when the model was evaluated with the logits multiplied by 0.9 (likely due to overfitting). This represents the best published result on this dataset. As expected, perplexity was lower for larger models. We have included random samples from these models in Appendix C. On the languagemodel_wiki_noref_v128k_l1k dataset from the Tensor2Tensor library[^1], consisting of over 5 billion tokens of text from Wikipedia, perplexity continued to improve significantly with a model size of 5 billion parameters.
[^1]: No published results exist for this dataset.
On the WMT14 En-Fr translation tasks (Table 3), we trained the models for 3 epochs. The largest model (2.9B parameters) was trained for 22 hours on a 128-core TPUv2 cluster. Quality improved with model size, with the largest model achieved BLEU score 43.9 (evaluated using sacrebleu), the best published result to date. For the WMT14 En-De dataset, gains from model size were smaller, presumably due to the small size of the training data.
Additional details about the configurations for these experiments are available as part of the tensor2tensor library on github.
::: {caption="Table 2: Transformer-Decoder Language Models: $d_{model}=1024$, $d_k = d_v = 256$ "}

:::
:Table 3: Transformer Machine-Translation Results. $d_{model}=1024$, $d_k = d_v = 128$
| $d_ff$ | $heads$ | $d_k, d_v$ | Parameters | WMT14 EN-DE | WMT14 EN-FR |
|---|---|---|---|---|---|
| (Billions) | BLEU | BLEU | |||
| 2048 | 4 | 128 | 0.15 | 25.5 | 41.8 |
| 4096 | 8 | 128 | 0.24 | 26.5 | 42.5 |
| 8192 | 16 | 128 | 0.42 | 27.1 | 43.3 |
| 16384 | 32 | 128 | 0.77 | 27.5 | 43.5 |
| 32768 | 64 | 128 | 1.48 | 27.5 | 43.8 |
| 65536 | 128 | 128 | 2.89 | 26.7 | 43.9 |
| 4096 | 16 | 64 | 0.21 | 28.4 | 41.8 |
Section Summary: Deep learning relies heavily on matrix multiplications and tensor operations, and researchers in high-performance computing have long studied ways to distribute these tasks across machines to reduce communication delays, using techniques like partitioning the computation space rather than just the output. Mesh-TensorFlow builds on these ideas by supporting a variety of efficient partitioning strategies from established algorithms, while simplifying the process for users by letting them specify dimensions to split, unlike older methods that require detailed layouts for each step; it shares traits with specialized tools like the Cyclops Tensor Framework but applies them more broadly. Recent deep learning research has explored blending data and model parallelism for better efficiency, such as theoretical work by Gholami and others or Jia's cost-based framework, but Mesh-TensorFlow improves on their approaches by enabling more optimal communication and easier experimentation.
A large part of deep learning computations is a series of matrix multiplications and tensor contractions (Einsums). Distributed matrix multiplication is a well-studied problem in high performance computing. Efficient algorithms partition the computational space, instead of partitioning work by the output matrix/tensor (owners compute), to minimize communication. This technique is sometimes called iteration space tiling [8], replication [9], or task parallelism [10]. Mesh-TensorFlow can express a wide range of uniform partitionings of the iteration space and therefore can adopt many best known mappings, e.g., 3D [11, 12] and 2.5D [9] algorithms for square matrices, CARMA [13] for rectangular matrices, 1.5D [14] algorithm for matrices with different sparsities, best tile sizes for direct convolutions [15], etc., although sometimes with higher memory requirements. Furthermore, in most existing work, when multiple multiplications are composed together, the user has to specify the data layout for each matrix separately [16]. Mesh-TensorFlow lets the user name the dimension to split, simplifying the process and allowing for much easier mapping explorations. Feature-wise, Mesh-TensorFlow shares many similarities with the Cyclops Tensor Framework [17], a distributed tensor contraction library originally developed for quantum chemistry applications, which also supports replication and arbitrary mappings.
In the context of deep learning, partitioning the iteration space, e.g., interpolating between data and model parallelism, is relatively new. Gholami et al. [18] analytically showed that using both data and model parallelism at the same time can be more beneficial than using just one of them. Building on top of 1.5D matrix multiplication algorithms, their algorithm can support replication and arbitrary processor grid shapes. However, they only explored the parallelization of AlexNet [19] and they have not implemented the algorithm. Jia et al. [20, 21] implemented a framework that uses cost modeling to pick the best parallelization strategy, including how to partition work for each operation. Their parallelizable dimensions are defined as the set of all divisible dimensions in the output tensor (owners compute), and therefore their mapping can be suboptimal in terms of communication. We expand on this in Appendix B.
Section Summary: The Mesh-TensorFlow library, which helps distribute machine learning computations across multiple devices, is openly available on GitHub and continues to be actively developed. Future improvements could include tools to automatically find the best ways to arrange computations, as well as support for more types of models and operations, such as handling image-processing tasks that require sharing border data between devices. Additionally, the library aims to expand single-program execution across clusters of CPUs and GPUs for broader compatibility.
The Mesh-TensorFlow library is available at https://github.com/tensorflow/mesh and is under active development. Some potential areas for development are:
Section Summary: This paper presents Mesh-TensorFlow, a new programming tool that makes it easier to spread out complex mathematical computations across multiple computers for AI tasks. By using this tool with the Transformer AI model, the researchers trained massive systems with 5 billion parameters on clusters of up to 512 processors. This approach set new records in English-to-French translation accuracy and in predicting words from a huge language dataset.
In this paper, we introduce the Mesh-TensorFlow language, facilitating a broad class of SPMD distributed tensor computations. Applying Mesh-TensorFlow to the Transformer model, we are able to train models with 5 billion parameters on up to 512-core clusters, establishing new state-of-the-art results for WMT14 En-Fr translation task and the One Billion Word language modeling benchmark.
Section Summary: This appendix provides visual illustrations for different ways to parallelize computations in a two-layer neural network example, showing data-parallel, model-parallel, and mixed layouts across small numbers of processors like 2 to 8, with diagrams explaining how data matrices are split and replicated. It then explores parallelization techniques for operations like matrix multiplication, highlighting how communication between processors often creates bottlenecks, and introduces concepts such as iteration space—the full set of calculations needed—and strategies like 1D, 2D, or 3D partitioning to divide work. These include owner-compute approaches, where each processor handles slices of the data it owns, though they may not always minimize communication efficiently.
This section provides the illustrations of the four layouts mentioned in the Two Fully-Connected Layers example in Section 8. The overall computation is shown in Figure 1. We draw a matrix multiplication $C=AB$ by putting $A$ to the left of $C$ and putting $B$ above $C$. For each matrix, we put its name inside, its number of rows on the left or right side, and its number of columns above or below it. We omit the numbers of rows or columns that can be implied from adjacent matrices, i.e., knowing that the multiplication dimensions must match.

Figure 2 presents the purely data-parallel layout with $n=2$ processors. We number the processors 0 and 1, respectively. The ranks of the processors that store each matrix part are written in blue on the matrix part. The matrix names are moved to the bottom-left corners. The whole $w$ and $v$ are labeled with both 0 and 1 because both of them are fully replicated between the two processors. The purely model-parallel layout are drawn similarly in Figure 3.
Figure 4 and Figure 5 show the mixed data-and-model-parallel layout with a 2-by-2 and a 2-by-2-by-2 processor meshes, respectively. We give each processor a serialized rank as shown in the figure, and use the serialized rank to label matrix slices.




Communication is much more expensive than computation and is usually the bottleneck in a parallel program, especially in distributed setting. The section shows how the more common owner-compute parallelization strategies can be communication-suboptimal. We start with a simplified overview of the parallelization schemes used in distributed matrix multiplication and tensor contractions (Einsums), focusing on their communication bandwidth costs. (See [23, 24, 25] for rigorous communication lower bounds analyses.) We only discuss distributed matrix multiplication here since the concept is trivially generalizable to its tensor counterpart.
Iteration Space.
The iteration space is the set of all index tuples required to compute a problem. For example, the matrix multiplication problem $C=AB$ computes $c_{ij} = \sum_k a_{ik}b_{kj}$. Its iteration space consists of all possible tuples $(i, j, k) \in \mathbb{Z}^3, 0 \le i < u, 0 \le j < v, 0 \le k < w$, where $C$ is $u$-by- $v$, $A$ is $u$-by- $w$, and $B$ is $w$-by- $v$, as shown in Figure 6.

Parallelization.
Let $n$ be the number of processors. Parallelization corresponds to partitioning the set of voxels into $n$ (not necessarily) equal subsets for each processor to compute. For matrix multiplication, the most widely-used partitionings are grouped into three categories [26]: 1D, 2D, and 3D, based on the number of dimensions of the iteration space that are split. Figure 7 shows an example for each category. The left image is 1D partitioning ($n=8$) because only the $j$ axis is split. The middle image splits axes $i$ and $j$ so it is 2D partitioning ($n=64$). The right image is 3D partitioning ($n=64$) because all three axes are split.
:::: {cols="1"}


Figure 7: Example partitionings of the iteration space for the matrix multiplication problem. The names 1D, 2D, and 3D comes from the number of dimensions that are split. ::::
Owner computes.
Owner-compute strategies split a matrix (or matrices) equally among processors and each processor is responsible for all computations related to the matrix chunk it owns. The 1D and 2D partitionings in Figure 7 are owner-compute strategies. In 1D case, each processor owns a slice of matrices $B$ and $C$ each, and computes the whole slab requires for its slice of $C$. In 2D case, each processor owns a patch of $C$ and computes a pencil corresponding to it. The 3D partitioning goes beyond owner-compute rule, since no processor is responsible for all computations associated with the data it has. Owner-compute schemes are more common because they are the most intuitive as we often view the output data as the unit of work. We will show why its communication costs are usually suboptimal in the next paragraph.
Communication.
Here, we focus on the number of elements that have to be transferred by a processor. Let $V$ be the voxel subset assigned to a processor, and $V_A, V_B$, and $V_C$ be the projections of $V$ onto the $A$, $B$, and $C$ planes, respectively. The total number of elements a processor has to access to complete its computation is simply $|V_A| + |V_B| + |V_C|$, where $|\cdot|$ denotes set cardinality. Since a processor can only hold a limited amount of data in memory, the rest of the elements must come through communication. The volume of the subset designates the computational workload. As mentioned in the paper, we would like to maximize the computation-to-communication ratio, therefore we want $V$ to have as low surface-to-volume ratio as possible. Assuming $V$ only takes cuboid shapes, then the best shape is a cube.
Owner-compute methods fall short when it cannot partition the space into cubes. To illustrate, we compare 2D and 3D partitionings for $p=64$ processors in Figure 7. When $u=v=w$, each pencil in the 2D partitioning has a computation-to-communication ratio,
$ r_{\text{2D}} = \dfrac{2u^3/64}{u^2/64 + 2u^2/8} = 2u/17 \approx 0.12u. $
Each cube in the 3D partitioning has a higher computation-to-communication ratio,
$ r_{\text{3D}} = \dfrac{2u^3/64}{3u^2/16} = u/6 \approx 0.17u. $
Mesh-TensorFlow.
Mesh-TensorFlow can express more fine-grained parallelism than owner-compute, even though all we have to do is just specifying the data layout for each tensor. That is because our layout allows the tensor to be replicated. This, combines with multiple layouts from multiple tensors involved in an operation, can split the iteration space in as many dimensions as necessary (up to the rank of the iteration space).
The following samples were randomly generated from the Transformer language models described in the paper. All sentences were seeded with the initial words "According to Ray Kurzweil", and continued randomly by the model. While all the models produce mostly grammatical sentences, the larger models exhibit more world knowledge.
According to Ray Kurzweil ...
According to Ray Kurzweil ...
According to Ray Kurzweil ...
According to Ray Kurzweil ...
[1] Vaswani et al. (2017). Attention Is All You Need. CoRR. abs/1706.03762. http://arxiv.org/abs/1706.03762. arXiv:1706.03762.
[2] Dean et al. (2012). Large Scale Distributed Deep Networks. In Proceedings of the 25th International Conference on Neural Information Processing Systems - Volume 1. pp. 1223– 1231. http://dl.acm.org/citation.cfm?id=2999134.2999271.
[3] Pitch Patarasuk and Xin Yuan (2009). Bandwidth optimal all-reduce algorithms for clusters of workstations. 69. pp. 117– 124.
[4] Nikhil Jain and Yogish Sabharwal (2010). Y.: Optimal bucket algorithms for large mpi collectives on torus interconnects. In In: Proceedings of the 24th ACM International Conference on Supercomputing. pp. 27– 36.
[5] Shazeer et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. CoRR. abs/1701.06538. http://arxiv.org/abs/1701.06538. arXiv:1701.06538.
[6] Yangqing Martin Abadi (2015). TensorFlow: Large-Scale Machine Learning on Heterogeneous Systems. Software available from tensorflow.org. https://www.tensorflow.org/.
[7] Jozefowicz et al. (2016). Exploring the Limits of Language Modeling. CoRR. abs/1602.02410. http://arxiv.org/abs/1602.02410. arXiv:1602.02410.
[8] M. Wolfe (1989). More Iteration Space Tiling. In Proceedings of the 1989 ACM/IEEE Conference on Supercomputing. pp. 655– 664. doi:10.1145/76263.76337. http://doi.acm.org/10.1145/76263.76337.
[9] Edgar Solomonik and James Demmel (2011). Communication-Optimal Parallel 2.5D Matrix Multiplication and LU Factorization Algorithms.
[10] Calvin et al. (2015). Scalable Task-based Algorithm for Multiplication of Block-rank-sparse Matrices. In Proceedings of the 5th Workshop on Irregular Applications: Architectures and Algorithms. pp. 4:1– 4:8. doi:10.1145/2833179.2833186. http://doi.acm.org/10.1145/2833179.2833186.
[11] Aggarwal et al. (1990). Communication Complexity of PRAMs. Theor. Comput. Sci.. 71(1). pp. 3– 28. doi:10.1016/0304-3975(90)90188-N. http://dx.doi.org/10.1016/0304-3975(90)90188-N.
[12] Jarle Berntsen (1989). Communication efficient matrix multiplication on hypercubes. Parallel computing. 12(3). pp. 335– 342.
[13] Demmel et al. (2013). Communication-optimal parallel recursive rectangular matrix multiplication. In Parallel & Distributed Processing (IPDPS), 2013 IEEE 27th International Symposium on. pp. 261– 272.
[14] Koanantakool et al. (2016). Communication-Avoiding Parallel Sparse-Dense Matrix-Matrix Multiplication. In 2016 IEEE International Parallel and Distributed Processing Symposium (IPDPS). pp. 842– 853. doi:10.1109/IPDPS.2016.117.
[15] James Demmel and Grace Dinh (2018). Communication-Optimal Convolutional Neural Nets. arXiv preprint arXiv:1802.06905.
[16] Koanantakool et al. (2018). Communication-Avoiding Optimization Methods for Distributed Massive-Scale Sparse Inverse Covariance Estimation. In Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics. pp. 1376– 1386. http://proceedings.mlr.press/v84/koanantakool18a.html.
[17] Solomonik et al. (2014). A massively parallel tensor contraction framework for coupled-cluster computations. Journal of Parallel and Distributed Computing. 74(12). pp. 3176– 3190.
[18] Gholami et al. (2017). Integrated Model and Data Parallelism in Training Neural Networks. arXiv preprint arXiv:1712.04432.
[19] Krizhevsky et al. (2012). Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems. pp. 1097– 1105.
[20] Jia et al. (2018). Exploring Hidden Dimensions in Parallelizing Convolutional Neural Networks. arXiv preprint arXiv:1802.04924.
[21] Jia et al. (2018). Beyond Data and Model Parallelism for Deep Neural Networks. arXiv preprint arXiv:1807.05358.
[22] Jin et al. (2018). Spatially Parallel Convolutions. https://openreview.net/forum?id=S1Yt0d1vG.
[23] Irony et al. (2004). Communication lower bounds for distributed-memory matrix multiplication. Journal of Parallel and Distributed Computing. 64(9). pp. 1017– 1026.
[24] Ballard et al. (2011). Minimizing communication in numerical linear algebra. SIAM Journal on Matrix Analysis and Applications. 32(3). pp. 866– 901.
[25] Christ et al. (2013). Communication Lower Bounds and Optimal Algorithms for Programs That Reference Arrays - Part 1. http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-61.html.
[26] Ballard et al. (2013). Communication optimal parallel multiplication of sparse random matrices. In Proceedings of the twenty-fifth annual ACM symposium on Parallelism in algorithms and architectures. pp. 222– 231.