January 21, 2025
4 mins

Deliberation in Latent Space via Differentiable Cache Augmentation

This paper introduces a novel approach to enhance LLMs by augmenting the key-value cache with latent embeddings generated by an offline coprocessor. The method is differentiable, efficient, and improves reasoning performance on a variety of tasks.
Paper Link
Header image
Weekly newsletter
No spam. Just the latest researches, and exclusive interviews in your inbox every week.
Read about our privacy policy.
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Key Takeaways

  • Introduces a novel method to enhance LLMs using differentiable cache augmentation.
  • Employs an offline coprocessor to generate latent embeddings for kv-cache augmentation.
  • Enables end-to-end backpropagation for efficient coprocessor training.
  • Demonstrates consistent reduction in perplexity and improved reasoning.
  • Achieves performance gains on various reasoning-intensive tasks such as MMLU and GSM8K.
  • Opens new avenues for computationally intensive and strategic deliberation in LLMs.

Introduction

Large language models (LLMs) have shown remarkable capabilities in various tasks, but their reasoning abilities are often limited by the way they process information.  Traditional methods involve generating sequences of discrete tokens before responding, which incurs latency and optimization challenges. This work introduces a novel approach to enhance LLMs by augmenting their key-value (kv) cache with latent embeddings generated by an offline coprocessor.

Relevant Background Work

Recent studies have explored enabling LLMs to "think more" by generating intermediate reasoning steps. Techniques such as Chain-of-Thought (CoT) prompting and Tree-of-Thoughts aim to improve reasoning by allowing models to explore intermediate steps before producing final answers. However, many of these approaches generate discrete intermediate outputs that are difficult to train end-to-end and perform the extra computation just in time, as part of the output generation. Our approach is inspired by techniques in kv-cache compression that seeks to refine a LLM's internal memory.  Instead of compressing, researchers augment the cache to allow the LLM to refine internal memory before generating a final output.

Methodology

This section details our approach that enhances a frozen LLM by training a coprocessor that inputs the kv-cache and augments it with a set of soft tokens.

Problem Statement

Given an input x and a desired target output y, and a pre-trained, frozen LLM parameterized by θ, researchers seek to learn a coprocessor, denoted by f. This coprocessor takes the kv-cache (kθ,x, vθ,x) generated by the frozen LLM when processing the input x as input, and outputs a sequence of latent representations z:

f(k<sub>θ,x</sub>, v<sub>θ,x</sub>) → z

The objective of learning f is to produce latent embeddings z that, when combined with the input x, improve the frozen LLM’s ability to generate the correct target y. Specifically, researchers aim to maximize the expected log-likelihood of the target y given the input x and the learned latent embeddings z, as predicted by the frozen LLM:

max  E<sub>x</sub>[log *p<sub>θ</sub>*(*y*|*x*, *z*)]

Model Architecture

Our architecture introduces a dedicated coprocessor module that operates on the kv-cache of a frozen, pre-trained LLM. Figure 1 shows the interaction between these components which occurs in three stages:

  • KV-cache Generation: The input sequence, x, is first processed by the frozen LLM to generate the corresponding kv-cache (kθ,x, vθ,x). The LLM's weights remain frozen throughout the entire process.
  • Augmentation: The kv-cache is then passed to the coprocessor module, which also receives a sequence of distinct extra soft tokens with trainable embeddings. The coprocessor ingests the kv-cache and these tokens to produce a sequence of latent embeddings, z.
  • LLM Generation with Augmented Context: Finally, z is appended to the original kv-cache. This augmented cache is then fed back into the frozen LLM, providing it with enriched contextual information derived from the coprocessor. The LLM proceeds to generate the output sequence, y, conditioned on both the original input x and the coprocessor’s output z.

Training focuses solely on optimizing the coprocessor and trainable embeddings' weights, with the coprocessor sharing the same model architecture as the pretrained LLM and the loss calculated on the final output. This enables efficient fine-tuning without altering the pretrained LLM and allows the coprocessor’s augmentation to be performed offline and asynchronously, in parallel with the LLM's decoding process.

Pretraining Setup

Researchers use a pretraining strategy to encourage the coprocessor to learn augmentations that will be useful for predicting larger segments of text beyond the next token after the augmentation. Instead of training on a single split, they augment at multiple points within each sequence. As shown in Figure 2(a), given an input text sequence, a subset of positions is randomly selected. For each selected position, the coprocessor generates a configurable number of latent embeddings. The training objective is to predict a number of tokens beyond the placement of the augmentation in a teacher-forcing way. For instance, in figure 2(a), if 'b' is chosen, and two tokens ahead are predicted, the coprocessor uses the generated latent embeddings (b', b") and the kv-cache of the preceding text 'a' and 'b' to predict 'c' and 'd'. This can be viewed as latent space interpolation, which allows the coprocessor to bridge the gap between the known and future context.

Researchers implement an efficient training framework by modifying the input, attention mask, position index, and target, to enable training everything together in one forward pass. Figure 2(b) illustrates this modification.

Rather than capturing different aspects of a single token's context, these embeddings are trained to generate information useful for predicting future tokens, effectively enabling the model to perform a form of "latent thinking" before making predictions. By learning to anticipate future tokens in this manner, the coprocessor develops a stronger understanding of sequential dependencies within text, which proves valuable in downstream tasks.

Experiments

The approach is validated using the frozen Gemma-2 2B model. The augmented Gemma-2 models are trained on the same 2 trillion token dataset used for Gemma-2 pretraining. The model is trained for 100,000 steps with 16 ahead tokens and 128 randomly sampled augmentation positions. Importantly, no task-specific training is performed.

Perplexity Evaluation

The augmented Gemma model achieves lower perplexity on the validation dataset compared to the pre-trained Gemma model on many tokens ahead, even beyond the ahead token defined during training. Figure 3 presents perplexity curves for the baseline and augmented models with varying numbers of latent embeddings. Across all latent sizes, this approach consistently reduces perplexity, with improvement scaling with the number of latent embeddings. Table 1 quantifies the perplexity reduction achieved by this approach. The consistent reduction confirms that the benefit of this method extends to multiple subsequent token predictions, suggesting improved internal representations.

Public Benchmark Evaluation

The method is evaluated on a range of public benchmarks spanning natural language understanding and reasoning tasks (Table 2). This method consistently improves performance compared to the baseline, with particularly substantial gains on reasoning-intensive benchmarks. Several tasks exhibit a strong correlation between the number of latent embeddings and performance improvement. For example, on GSM8K, accuracy steadily climbs from +1.29% with 4 latent embeddings to +10.05% with 64. Similarly, MMLU shows a jump from +0.45% with 4 to +4.70% with 64. This trend suggests that for challenging reasoning tasks, providing more latent embeddings allows the model to perform more extensive “thinking” in the latent space.

Business and Research Implications

This work has significant implications for both business and research. The ability to augment LLMs with offline coprocessors opens the door to more efficient and effective deployment of these powerful models.  Here are a few implications:

Improved Reasoning

The ability of LLMs to “think more" using latent embeddings and deliberation can be huge for complex problem solving.  This enables LLMs to have a greater impact on domains that rely on reasoning and critical thinking.

Efficiency

The coprocessor operates offline and asynchronously, allowing the main LLM to function efficiently, and provides a way to strategically allocate compute. This means that we can potentially refine the internal memory of LLMs over time, leading to improved efficiency and faster response times. The ability to augment LLMs also expands accessibility, and can be useful for deployment where resources are constrained.

Scalability

The coprocessor operates offline and asynchronously, allowing the main LLM to function efficiently, and provides a way to strategically allocate compute. This means that it is potentially possible to refine the internal memory of LLMs over time, leading to improved efficiency and faster response times. The ability to augment LLMs also expands accessibility and can be useful for deployment where resources are constrained.

Conclusion

In conclusion, our work introduces differentiable cache augmentation, a novel method for enhancing frozen decoder-only language models. This approach consistently reduces perplexity and significantly improves performance on a variety of reasoning-intensive tasks, even in zero/few-shot settings. Future work will explore scaling the coprocessor to larger models, using many modular coprocessors, investigating different coprocessor architectures and applying this method to more diverse downstream tasks.

What Makes This Work?

The unique insight of this work lies in the ability to perform "latent thinking" by augmenting the kv-cache with embeddings generated by the coprocessor. This allows the LLM to access a more enriched contextual representation, and enables a form of deliberation within the model's memory, without having to generate explicit intermediate tokens. The coprocessor’s ability to predict multiple tokens ahead also contributes towards the performance gains by enabling more effective generation of latent embeddings.  Furthermore, the coprocessor can operate offline and asynchronously, leading to compute efficiencies. The end-to-end differentiability framework makes it scalable and easily trainable, and further improves the capabilities of large language models.

Share this post

Why Clio AI?

Unlock the most obvious-yet-hidden-in-plain-sight growth hack - enable your employees to work on important things, and reduce their cognitive load and time to resolve blockers.

Fast, efficient, and in-context information to make every employee a super performer.

Spend time thinking not searching. Get a demo today.

By signing up for a demo, you agree to our Privacy Policy.
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.