June 16, 2024
8 mins

Transformers meet Neural Algorithmic Reasoners

This paper from Google Deepmind combines Transformers with Neural Algorithmic Reaasoning resulting in an architecture where LLMs are good at reasoning tasks.
Paper Link
Header image
This is a potentially groundbreaking paper from Google Deepmind, enabling LLMs to think and generalize reasoning to everyday tasks.

Key takeaways

  • This paper proposes a novel approach to combine the strengths of Transformers and Neural Algorithmic Reasoners (NARs) to create a hybrid model called TransNAR.
  • TransNAR significantly outperforms Transformer-only models on the CLRS-Text benchmark, a text-based version of the CLRS-30 benchmark for algorithmic reasoning.
  • TransNAR demonstrates impressive generalization capabilities, particularly in out-of-distribution scenarios, showing significant improvements for various algorithmic tasks.
  • TransNAR provides insights into the limitations of Transformers for reasoning and suggests potential avenues for future research.

Introduction

Imagine a computer that can understand and solve complex algorithms, just like a human programmer. That's the goal of Neural Algorithmic Reasoning (NAR). NAR uses artificial neural networks (ANNs) to learn how to perform algorithmic computations. One of the key approaches to NAR is using Graph Neural Networks (GNNs). GNNs are a type of ANN specifically designed to work with data that is structured like a graph, with nodes (like variables) connected by edges (like relationships). These GNNs can be trained to solve various algorithmic problems, like sorting, searching, or even planning.

However, there are limitations to NARs. From the paper:

NARs are, however, still relatively narrow forms of AI, as they require rigidly structured formatting of inputs, and they hence cannot be directly applied to problems posed in more noisy forms—such as in natural language—even when the underlying problem is still algorithmic in nature.

On the other hand, Transformers are a powerful architecture for natural language processing. They excel at understanding text and generating human-like language, as evidenced by their success in language models like GPT-4. But Transformers are notoriously brittle when it comes to tasks that involve precise reasoning.

In spite of their unrivalled natural language understanding properties, they are also notoriously brittle when faced with even the simplest algorithmic tasks [9]—especially if out-of-distribution generalisation is required.

This paper aims to bridge this gap by combining the best of both worlds. The authors propose TransNAR, a hybrid architecture that leverages the language understanding capabilities of Transformers with the robust reasoning power of GNN-based NARs.

The impact of this work could be significant. It might lead to the development of language models that are much better at solving real-world problems involving complex algorithms, such as those found in finance, healthcare, or engineering.

Background of the paper

The paper builds upon previous research in both NAR and Transformers.

  • NARs have been shown to be effective for solving algorithmic tasks, even when the input size is much larger than what they were trained on. The authors cite work showing that NARs can generalize to inputs 6x larger than the training set.
  • Transformers have been successful at various language tasks, but they struggle with tasks that require reasoning, especially in out-of-distribution scenarios.

Related Work

The authors discuss three main areas of related research:

Neural algorithmic reasoning

Previous work has demonstrated that NARs can be effective in a variety of scenarios:

Multi-task NARs

Some researchers have created NARs that can solve multiple algorithms simultaneously. The paper cites work on the Triplet-GMPNN, a NAR capable of solving 30 algorithms from the CLRS benchmark.

Applications of NAR

NARs have been applied to various domains, including reinforcement learning, self-supervised learning, combinatorial optimization, computational biology, and neuroscience.

Length generalisation in LLMs

While NARs can generalize well to larger inputs, Transformers have faced challenges in doing so. This is because Transformers are often trained with an autoregressive objective, which means they predict outputs in a specific order. This order doesn't always match the logical flow of algorithmic reasoning.

Several approaches have been proposed to improve length generalization in LLMs, including:

Careful prompting

Using specific prompts can help Transformers reason about larger inputs.

Randomized positional encoding

Randomizing the positional encoding of tokens can improve generalization.

Curricula

Training models on tasks with increasing difficulty can enhance their ability to reason about longer sequences.

Scratchpads

Providing a separate space for the model to store intermediate results can facilitate longer computations.

Tool use and multimodality

Another way to improve reasoning in LLMs is to teach them to use external tools, like algorithms or APIs. This approach has led to some success in reasoning tasks.

Arguably, most of the major successes of reasoning with LLMs [18, 25, 29] can primarily be attributed to an LLM’s clever usage of a tool rather than the LLM itself, as a tool will by definition not have issues in generalising to diverse inputs.

However, the authors are interested in understanding the reasoning abilities of LLMs themselves, so they do not allow tool use in their baselines. They envision the NAR as an “internal tool,” helping the Transformer reason by providing it with more robust embeddings.

TransNAR

TransNAR is a hybrid architecture that combines Transformers and NARs. It's like giving a Transformer a "superpower" to access a specialized reasoning engine. Here's a breakdown:

TransNAR (Transformer + NAR)

The paper explains TransNAR using a series of equations, but let's break it down in plain English:

Input

The model takes two inputs:

  1. Textual input: The algorithmic problem described in natural language.
  2. Graph input: The same problem represented as a graph, which the NAR can understand.

Transformer Processing

The Transformer processes the textual input, creating token embeddings that represent the meaning of the words.

NAR Processing

The NAR processes the graph input, creating node and edge embeddings that capture the structure and relationships of the problem.

Cross-Attention

The Transformer can access the NAR's embeddings through a cross-attention mechanism. This allows the Transformer to learn from the NAR's reasoning capabilities.

Output

The Transformer uses the combined information from its own processing and the NAR to generate a textual response to the algorithmic problem.

Experiments

The authors conducted a series of experiments to evaluate TransNAR's performance. Here's a breakdown of their methodology:

Transformer architecture and initialisation

The experiments used a decoder-only Transformer model from the Chinchilla family, pre-trained on a massive text dataset. They evaluated the model with both pre-trained and randomly initialized weights.

Randomized positional encoding

Previous research has shown that randomized positional encoding can improve generalization in Transformers. The authors also used this technique in their experiments.

Higher res image

Pre-training the NAR

They pre-trained a multi-task MPNN-based NAR on the CLRS-30 benchmark, using graph inputs up to a size of 16. This pre-training ensured the NAR could generalize to larger graphs.

Combining cross-attention contributions from nodes and edges

The NAR generated both node and edge embeddings, which were combined through concatenation and a linear layer before being used in cross-attention.

Higher res image

Datasets

The experiments used the CLRS-Text dataset, which is a text-based version of the CLRS-30 benchmark. It consists of 2.4 million data points across various algorithmic tasks.

Training details

The models were trained for seven epochs with a batch size of 256, using the Adam optimizer and a learning rate of 10−4. Randomized positional encoding with a maximum length of 8,192 was used in addition to the base Chinchilla Transformer's Rotary Positional Encoding.

Evaluation metrics

Instead of relying solely on exact string matching, the authors used three metrics to capture different aspects of model performance:

Shape score

Checks if the output has the correct shape (e.g., a list of the correct length, a matrix with the correct dimensions).

Parse score

Checks if the output is syntactically valid (e.g., contains only numbers when a numerical result is expected).

CLRS score

The percentage of elements in the output that match the ground truth answer.

Results

TransNAR consistently outperformed the baseline Transformer across various algorithmic tasks, particularly in out-of-distribution scenarios.

  • The CLRS score showed a significant improvement for TransNAR compared to the baseline. This improvement was consistent across both pre-trained and untrained Transformer settings.
  • The shape score demonstrated that TransNAR was much better at producing outputs with the correct shape, indicating it addressed a common failure mode of Transformers.
  • The parse score showed that TransNAR generally parsed the output correctly, indicating it could avoid simple syntax errors.

However, the authors identified a few algorithms where TransNAR did not outperform the baseline: Binary Search, Find Maximum Subarray, Minimum, and Quickselect. These algorithms involve searching for a specific index within an input list.

A closer look at the results indicates that such tasks (Binary Search, Find Maximum Subarray, Minimum, and Quickselect) all involve an element of searching for a particular index in an input list. This hints at a unified failure mode: as these failures persist both when interpolating and extrapolating, the model as implemented is not able to generalise to novel index boundaries unseen in the training data. We therefore suspect that the use of index hints— as already demonstrated by Zhou et al. [40]—is a promising avenue for ameliorating this behaviour.

The authors suggested two potential reasons for these failures:

Difficulty decoding NAR outputs

The final hidden states from the NAR might be difficult for the cross-attention layers to decode in a generalizable way. This could be addressed by increasing the capacity of cross-attention or using a more progressive decoding scheme.

Lack of index hints

The model might struggle to generalize to unseen index boundaries. Using index hints as a way to guide the model's reasoning could potentially improve performance.

Limitations

While promising, TransNAR has some limitations:

Dual input requirement

TransNAR requires both textual and graph inputs to be trained and used effectively. This limits its applicability to cases where a graph-based representation of the problem is available.

Higher res image

Business implications

TransNAR's ability to perform robust algorithmic reasoning in natural language could have significant business implications, especially for industries that rely heavily on complex algorithms and data analysis. Here are some potential applications:

Automated code generation

TransNAR could be used to generate code from natural language descriptions of algorithms, reducing the time and effort required for software development.

Financial modeling

TransNAR could assist financial analysts in building complex financial models, analyzing market data, and making investment decisions.

Healthcare diagnostics

TransNAR could be used to analyze medical images, patient data, and clinical notes to assist in diagnosis and treatment planning.

Personalized learning

TransNAR could be used to create personalized learning experiences, tailoring the difficulty and content of educational materials to individual students' needs.

Conclusion

TransNAR represents a significant step forward in the field of neural algorithmic reasoning. The model demonstrates that combining Transformers with NARs can lead to improved performance on algorithmic tasks, particularly in out-of-distribution scenarios. The results highlight the importance of combining different types of representations (textual and graph) to enable robust reasoning in language models.

How could it change LLMs?

If implemented successfully, TransNAR could significantly change the capabilities of large language models (LLMs). Here's how:

Improved reasoning skills

TransNAR could equip LLMs with the ability to understand and execute complex algorithms, unlocking new possibilities for problem-solving and decision-making.

Enhanced generalizability

By leveraging the generalization capabilities of NARs, TransNAR could help LLMs overcome their inherent brittleness in out-of-distribution scenarios, leading to more robust and reliable performance.

Unlocking new applications

LLMs equipped with TransNAR's reasoning abilities could be applied to a broader range of tasks, including code generation, financial modeling, healthcare, and education, leading to new innovations and advancements in these industries.

While TransNAR is still under development and faces some challenges, it holds enormous potential for shaping the future of LLMs and their impact on our world.

Share this post

Why Clio AI?

Unlock the most obvious-yet-hidden-in-plain-sight growth hack - enable your employees to work on important things, and reduce their cognitive load and time to resolve blockers.

Fast, efficient, and in-context information to make every employee a super performer.

Spend time thinking not searching. Get a demo today.

By signing up for a demo, you agree to our Privacy Policy.
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.