Introduction

Recently, Retrieval Augmented Generation (RAG) has attracted wide attention in the Artificial Intelligence (AI) field, becoming a hot point of discussion among researchers and developers. As a technology that combines retrieval and generation, RAG has achieved remarkable results in various applications, such as dialogue generation, text summarisation, and fact verification. Its emergence offers a new perspective for solving complex problems, enabling AI models to understand and respond to user needs more accurately and efficiently. However, current RAG systems have the problem of lack of coordination among various sub-modules, and the overall performance is not as good as expected.

This article introduces the concept of Stochastic RAG, an end-to-end RAG framework optimised through expected utility maximisation. This approach aims to build a more efficient and powerful system by optimising individual components, providing state-of-the-art results.

Traditional RAG

While traditional RAG systems perform well in retrieving information and generating responses, they have several limitations:

  • Marginalisation Assumption: Traditional RAG systems often assume that each retrieved document is independent and the generated response can be achieved by marginalising (adding up information from multiple sources) the results from multiple independently retrieved documents. This can lead to a loss of contextual coherence as the relationships between documents are ignored. Independent documents might also contain conflicting information without recognising and reconciling them, resulting in inconsistent or contradictory responses.
  • Module Separation: The retrieval and generator modules in traditional RAG are usually independent. The lack of coordination between the retrieval model and the generation model can lead to suboptimal final results.
  • Use off-the-shelf retrieval models: Traditional RAG systems often rely on pre-existing retrieval methods such as commercial search engine APIs or term matching models. without further customisation for specific tasks. This can result in misalignment between retrieval output and downstream generation tasks.

Stochastic RAG

Stochastic RAG optimises retrieval-augmented generation models through expected utility maximisation, relaxing the simplifying assumptions of traditional RAG systems. This novel framework views the retrieval process as a stochastic sampling without replacement process, using the straight-through Gumbel-top-k method, which will be explained later, to provide a differentiable approximation for sampling without replacement, enabling end-to-end optimisation. The main improvements are:

  1. Stochastic Sampling Without Replacement:

Stochastic RAG casts the retrieval process as a stochastic sampling without replacement process, meaning that once a document is selected, it is not put back into the pool for subsequent sampling. This approach can avoid redundancy by ensuring that extracted documents are not selected again. Also, due to the sampling method of not putting back, the system considers previously selected documents before selecting new documents thus creating a dependency where each selection is based on the current set of selected documents. This breaks the assumption in traditional methods that each document is scored and processed independently. By accounting for each document’s relationship, it can better synthesise and utilise information from different documents. Sampling without retracting increases the diversity and coverage of information, leading to more comprehensive and richer generated results.

2. Straight-through Gumbel-top-k Approach:

In machine learning, end-to-end optimisation involves simultaneously optimising all parts of a model through backpropagation. For this to be achieved, all operations within the model need to be differentiable. However, the traditional sampling process (e.g., directly selecting the top k documents) is non-differentiable because the sorting and selection operations are discrete and cannot be directly used for gradient descent optimisation. This gradient descent is an optimisation algorithm for finding a local minimum of a differentiable function. While this technical method solves this issue and makes the sampling without replacement process differentiable, allowing the retrieval-enhanced generation (RAG) model to be optimised end-to-end. It involves:

  • Straight-through Mechanism: Using hard selection (e.g., argmax) during the forward pass but soft selection (e.g., softmax) during the backward pass, maintaining the realism of the selection process while enabling gradient descent optimisation.
  • Gumbel-top-k Method: Adding Gumbel (a particular case of the generalised extreme value distribution) noise to each document’s retrieval score, sorting and selecting the top k documents, and using softmax (convert a vector of real numbers into a probability distribution) for the backward pass to approximate the selection process, making it differentiable.

Experimental Results

To assess the Stochastic RAG model effectiveness, we compared our model against the best performing entries on the KILT leaderboard (a benchmark for knowledge intensive language tasks) as of February 1, 2024, using the official KILT-score metrics. The competing methods, detailed in the picture below, employ a variety of approaches including dense retrieval methods coupled with BART or T5 for generation, generative retrieval models, retrieval and reranking models, and large pre-trained language models without augmentation. FiD-Light, which extends the Fusion-in-Decoder architecture to include document identifier generation for re-ranking during inference, serves as the main baseline. It has demonstrated state-of-the-art performance on six of the seven datasets evaluated. The application of stochastic expected utility maximisation has enhanced performance across all datasets.

Table of comparison of models with top performing entries in the KILT leaderboard according to KILT-scores, as of February 1, 2024. The results are reported on the blind KILT test sets Source: https://doi.org/10.48550/arXiv.2405.02816

The last two rows compare the same model with different sizes of the downstream language model: T5-Base (220 million parameters) and T5-XL (3 billion parameters). Both model sizes benefited from stochastic expected utility maximisation, with the larger model generally performing better.

Conclusion

Stochastic RAG enhances the performance of retrieval-augmented generation (RAG) systems by optimising individual components and employing expected utility maximisation. It achieves state-of-the-art results on six out of seven datasets from the KILT benchmark, indicating that language models of varying sizes (220M parameters and 3B parameters) benefit from this end-to-end optimization. Moreover, the stochastic nature of Stochastic RAG can be used to increase the diversity of generated outputs in RAG systems, which is particularly important in scenarios where multiple outputs are generated to collect human feedback. Despite these advancements, the current work focuses solely on relatively short text generation and further research is needed to test the impact of Stochastic RAG on longer text generation and to explore various utility functions that can be defined in RAG optimisation.

References

Catch the latest version of this article over on Medium.com. Hit the button below to join our readers there.

Learn more on Medium