Large language models (LLMs) have shown remarkable capabilities in various tasks, but their reasoning abilities are often limited by the way they process information. Traditional methods involve generating sequences of discrete tokens before responding, which incurs latency and optimization challenges. This work introduces a novel approach to enhance LLMs by augmenting their key-value (kv) cache with latent embeddings generated by an offline coprocessor.
Recent studies have explored enabling LLMs to "think more" by generating intermediate reasoning steps. Techniques such as Chain-of-Thought (CoT) prompting and Tree-of-Thoughts aim to improve reasoning by allowing models to explore intermediate steps before producing final answers. However, many of these approaches generate discrete intermediate outputs that are difficult to train end-to-end and perform the extra computation just in time, as part of the output generation. Our approach is inspired by techniques in kv-cache compression that seeks to refine a LLM's internal memory. Instead of compressing, researchers augment the cache to allow the LLM to refine internal memory before generating a final output.
This section details our approach that enhances a frozen LLM by training a coprocessor that inputs the kv-cache and augments it with a set of soft tokens.
Given an input x and a desired target output y, and a pre-trained, frozen LLM parameterized by θ, researchers seek to learn a coprocessor, denoted by f. This coprocessor takes the kv-cache (kθ,x, vθ,x) generated by the frozen LLM when processing the input x as input, and outputs a sequence of latent representations z:
f(k<sub>θ,x</sub>, v<sub>θ,x</sub>) → z
The objective of learning f is to produce latent embeddings z that, when combined with the input x, improve the frozen LLM’s ability to generate the correct target y. Specifically, researchers aim to maximize the expected log-likelihood of the target y given the input x and the learned latent embeddings z, as predicted by the frozen LLM:
max E<sub>x</sub>[log *p<sub>θ</sub>*(*y*|*x*, *z*)]
Our architecture introduces a dedicated coprocessor module that operates on the kv-cache of a frozen, pre-trained LLM. Figure 1 shows the interaction between these components which occurs in three stages:
Training focuses solely on optimizing the coprocessor and trainable embeddings' weights, with the coprocessor sharing the same model architecture as the pretrained LLM and the loss calculated on the final output. This enables efficient fine-tuning without altering the pretrained LLM and allows the coprocessor’s augmentation to be performed offline and asynchronously, in parallel with the LLM's decoding process.
Researchers use a pretraining strategy to encourage the coprocessor to learn augmentations that will be useful for predicting larger segments of text beyond the next token after the augmentation. Instead of training on a single split, they augment at multiple points within each sequence. As shown in Figure 2(a), given an input text sequence, a subset of positions is randomly selected. For each selected position, the coprocessor generates a configurable number of latent embeddings. The training objective is to predict a number of tokens beyond the placement of the augmentation in a teacher-forcing way. For instance, in figure 2(a), if 'b' is chosen, and two tokens ahead are predicted, the coprocessor uses the generated latent embeddings (b', b") and the kv-cache of the preceding text 'a' and 'b' to predict 'c' and 'd'. This can be viewed as latent space interpolation, which allows the coprocessor to bridge the gap between the known and future context.
Researchers implement an efficient training framework by modifying the input, attention mask, position index, and target, to enable training everything together in one forward pass. Figure 2(b) illustrates this modification.
Rather than capturing different aspects of a single token's context, these embeddings are trained to generate information useful for predicting future tokens, effectively enabling the model to perform a form of "latent thinking" before making predictions. By learning to anticipate future tokens in this manner, the coprocessor develops a stronger understanding of sequential dependencies within text, which proves valuable in downstream tasks.
The approach is validated using the frozen Gemma-2 2B model. The augmented Gemma-2 models are trained on the same 2 trillion token dataset used for Gemma-2 pretraining. The model is trained for 100,000 steps with 16 ahead tokens and 128 randomly sampled augmentation positions. Importantly, no task-specific training is performed.
The augmented Gemma model achieves lower perplexity on the validation dataset compared to the pre-trained Gemma model on many tokens ahead, even beyond the ahead token defined during training. Figure 3 presents perplexity curves for the baseline and augmented models with varying numbers of latent embeddings. Across all latent sizes, this approach consistently reduces perplexity, with improvement scaling with the number of latent embeddings. Table 1 quantifies the perplexity reduction achieved by this approach. The consistent reduction confirms that the benefit of this method extends to multiple subsequent token predictions, suggesting improved internal representations.
The method is evaluated on a range of public benchmarks spanning natural language understanding and reasoning tasks (Table 2). This method consistently improves performance compared to the baseline, with particularly substantial gains on reasoning-intensive benchmarks. Several tasks exhibit a strong correlation between the number of latent embeddings and performance improvement. For example, on GSM8K, accuracy steadily climbs from +1.29% with 4 latent embeddings to +10.05% with 64. Similarly, MMLU shows a jump from +0.45% with 4 to +4.70% with 64. This trend suggests that for challenging reasoning tasks, providing more latent embeddings allows the model to perform more extensive “thinking” in the latent space.
This work has significant implications for both business and research. The ability to augment LLMs with offline coprocessors opens the door to more efficient and effective deployment of these powerful models. Here are a few implications:
The ability of LLMs to “think more" using latent embeddings and deliberation can be huge for complex problem solving. This enables LLMs to have a greater impact on domains that rely on reasoning and critical thinking.
The coprocessor operates offline and asynchronously, allowing the main LLM to function efficiently, and provides a way to strategically allocate compute. This means that we can potentially refine the internal memory of LLMs over time, leading to improved efficiency and faster response times. The ability to augment LLMs also expands accessibility, and can be useful for deployment where resources are constrained.
The coprocessor operates offline and asynchronously, allowing the main LLM to function efficiently, and provides a way to strategically allocate compute. This means that it is potentially possible to refine the internal memory of LLMs over time, leading to improved efficiency and faster response times. The ability to augment LLMs also expands accessibility and can be useful for deployment where resources are constrained.
In conclusion, our work introduces differentiable cache augmentation, a novel method for enhancing frozen decoder-only language models. This approach consistently reduces perplexity and significantly improves performance on a variety of reasoning-intensive tasks, even in zero/few-shot settings. Future work will explore scaling the coprocessor to larger models, using many modular coprocessors, investigating different coprocessor architectures and applying this method to more diverse downstream tasks.
The unique insight of this work lies in the ability to perform "latent thinking" by augmenting the kv-cache with embeddings generated by the coprocessor. This allows the LLM to access a more enriched contextual representation, and enables a form of deliberation within the model's memory, without having to generate explicit intermediate tokens. The coprocessor’s ability to predict multiple tokens ahead also contributes towards the performance gains by enabling more effective generation of latent embeddings. Furthermore, the coprocessor can operate offline and asynchronously, leading to compute efficiencies. The end-to-end differentiability framework makes it scalable and easily trainable, and further improves the capabilities of large language models.