We present an end-to-end differentiable training method for retrieval-augmented open-domain question answering systems that combine information from multiple retrieved documents when generating answers. We model retrieval decisions as latent variables over sets of relevant documents. Since marginalizing over sets of retrieved documents is computationally hard, we approximate this using an expectation-maximization algorithm. We iteratively estimate the value of our latent variable (the set of relevant documents for a given question) and then use this estimate to update the retriever and reader parameters. We hypothesize that such end-to-end training allows training signals to flow to the reader and then to the retriever better than staged-wise training. This results in a retriever that is able to select more relevant documents for a question and a reader that is trained on more accurate documents to generate an answer. Experiments on three benchmark datasets demonstrate that our proposed method outperforms all existing approaches of comparable size by 2-3% absolute exact match points, achieving new state-of-the-art results. Our results also demonstrate the feasibility of learning to retrieve to improve answer generation without explicit supervision of retrieval decisions.
Notes: "We run all of our experiments on a machine with 96 CPUs, 1.3TB physical memory, and 16 A100 GPUs. We use PyTorch (Paszke et al., 2019) to implement our proposed model. With this hardware setup, our experiments on NQ and TriviaQA took approximately 25 hours to complete, while experiments on WebQ took roughly 8 hours to complete. Before supervised training, we also perform a one-time unsupervised MSS pre-training for 82,000 steps that took roughly 1 week." 1 week + 25 hours * 16 A100s = ~193 * 16 A100-hours = 193 * 16 * 3600 * 312 trillion * 0.3 = 1.04e21 Additionally, the model uses BERT, ICT, and T5 models. These required: - BERT: 6 * 110M parameters * (1M * 256 * 256) inputs = 4.33e19 FLOP - ICT: 6 * 220M parameters * (100k * 4096 * 256) inputs = 1.38e20 FLOP - T5: 6 * 220M parameters * (1M * 2048 * 256) inputs = 6.92e20 FLOP Total: 1.04e21 + 4.33e19 + 1.38e20 + 6.92e20 = 1.91e21
Size Notes: At the time of publication there were about 4B words (5.3B tokens) on English Wikipedia: https://en.wikipedia.org/wiki/Wikipedia:Size_of_Wikipedia#Yearly_statistics BookCorpus has about 1B words (1.3B tokens), C4 has about 156B tokens, and OpenWebText has about 9B tokens. From Table 6, it looks like all datasets were trained on for over one epoch. BERT: 1M steps, batches of 256, sequence length 256 = 65.5B tokens vs 6.6B in Wikipedia + BookCorpus ICT: 100k steps, batches of 4096, sequence length 256 = 104.9B tokens vs 5.3B in Wikipedia T5: 1M steps, batches of 2048, sequence length 256 = 524.3B tokens vs 170.3B tokens in C4 + Wikipedia + OpenWebText Total tokens: 171.6 billion Some tokens were probably seen more times than others, but overall this corresponds to 4.05 epochs on the pre-training data.
Notes: Table 2