Neural network scaling has been critical for improving the model quality in many real-world machine learning applications with vast amounts of training data and compute. Although this trend of scaling is affirmed to be a sure-fire approach for better model quality, there are challenges on the path such as the computation cost, ease of programming, and efficient implementation on parallel devices. GShard is a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler. It provides an elegant way to express a wide range of parallel computation patterns with minimal changes to the existing model code. GShard enabled us to scale up multilingual neural machine translation Transformer model with Sparsely-Gated Mixture-of-Experts beyond 600 billion parameters using automatic sharding. We demonstrate that such a giant model can efficiently be trained on 2048 TPU v3 accelerators in 4 days to achieve far superior quality for translation from 100 languages to English compared to the prior art.
Notes: Trained for a total of 235.5 TPU v3 core-years. Hardware estimate: 235.5 * 365.25 * 24 * 3600 * (1.23e14 / 2) * 0.3 = 1.371e23 Footnote 10 indicates 300k steps and 4M tokens/step -> 1.2T tokens Arithmetic estimate: 6 * 2.3B * 1.2T = 1.656e22 FLOPs Geometric mean: sqrt(1.371e23 * 1.656e22) = 4.765e22
Size Notes: "We focus on improving the translation quality (measured in terms of BLEU score [48]) from all 100 languages to English. This resulted in approximately 13 billion training examples to be used for model training" Each example is a sentence pair. Assuming 20 words per sentence and 4/3 tokens per word, that is 13*20*4/3 billion tokens
Notes: "Our best quality dense single Transformer model (2.3B parameters) achieving ∆BLEU of 6.1, was trained with GPipe [15] on 2048 TPU v3 cores for 6 weeks or total of 235.5 TPU v3 core-years."