Large Language Models (LLMs) have achieved impressive results in various natural language tasks. However, when faced with complex problems, they often fall short of human-like reasoning ability. Imagine asking an LLM a tricky math problem. If it gets the answer wrong, can we give it more time to "think" and improve its response? This paper explores exactly that.
Scaling up test-time compute, meaning giving the LLM more computational resources during inference, helps enhance its performance on challenging prompts. The authors aim to answer the crucial question:
if an LLM is allowed to use a fixed but non-trivial amount of inference-time compute, how much can it improve its performance on a challenging prompt?
The answer to this question has significant implications for:
Achievable performance of LLMs: Can we unlock a higher level of accuracy by allowing models to "think harder"?
Future of LLM pretraining: Should we focus on training larger models or shift our attention to more sophisticated inference-time strategies?
Tradeoff between inference-time and pre-training compute: Can we achieve similar results with a smaller model and more test-time compute compared to a larger pretrained model?
This research is particularly crucial as it challenges the prevailing focus on scaling up model size and pretraining compute, potentially opening new avenues for LLM development and deployment.
The paper builds on existing research exploring the use of test-time compute to improve LLM outputs. Previous studies have presented mixed results, with some showing promise in specific tasks while others highlighting limitations in more complex reasoning scenarios.
One common approach is best-of-N sampling: sampling N outputs in “parallel” from a base LLM and selecting the one that scores the highest per a learned verifier or a reward model While simple, this approach might not be the most efficient way to utilize test-time compute. The paper argues that by modifying how LLMs generate responses and evaluate their correctness, we can significantly enhance the effectiveness of scaling test-time compute.
The authors propose a unified perspective on test-time computation by framing it as a process of modifying the LLM's predicted distribution adaptively at test-time, conditioned on a given prompt. This modification can be achieved through two primary mechanisms:
The proposer aims to improve the proposal distribution, that is the distribution from which the LLM samples its responses. Modifying the proposal distribution allows for a more targeted search for the correct answer. This can be achieved through techniques like sequential revisions i.e. instructing it to critique and revise its own outputs in an iterative fashion.
The verifier evaluates the quality of the proposals generated by the proposer. Instead of simply using a reward model, the authors explore the use of process-based verifiers, also known as Process Reward Models (PRMs) which produces a prediction of the correctness of each intermediate step in an solution, rather than just the final answer By evaluating individual steps, PRMs can guide the search process more effectively, leading to more efficient utilization of test-time compute.
Effectiveness of test-time compute strategies is not uniform across all problems. Some problems might be "easier" for the base LLM to solve, while others require a more extensive search process. This leads to the concept of compute-optimal scaling - the strategy that chooses hyper-parameters corresponding to a given test-time strategy for maximal performance benefits on a given prompt at test time.
To implement a compute-optimal scaling strategy, we need to estimate the difficulty of a given prompt. To achieve this, we bin questions into five difficulty levels based on the model's pass@1 rate. Specifically, we bin the model's pass@1 rate estimated from 2048 samples – on each question in the test set into five quantiles, each corresponding to increasing difficulty levels. This notion of question difficulty acts as a crucial factor in determining the optimal allocation of test-time compute.
The paper focused the analysis on the challenging MATH benchmark, which consists of high-school competition level math problems. PaLM 2-S* (Codey) model is used as the base LLM. This model provides a good test-bed due to its non-trivial performance on MATH and its potential for further improvement.
The authors trained PRMs using a method that bypasses the need for expensive human labels. They used Monte Carlo rollouts to estimate per-step correctness, similar to recent work in the field.
Three search approaches were compared:
Best-of-N Weighted: Sampling N answers and selecting the best based on the PRM's final answer judgment.
Beam Search: Optimizing the PRM by searching over its per-step predictions, effectively exploring different solution paths.
Lookahead Search: An extension of beam search that utilizes lookahead rollouts to improve the accuracy of the PRM's value estimations during the search.
The analysis revealed that:
- Beam search outperforms best-of-N at low generation budgets, but the improvement diminishes as the budget increases, highlighting a potential for over-optimization.
- Lookahead search generally underperforms other methods due to the additional compute cost of simulating lookahead rollouts.
- The effectiveness of search methods varies significantly across difficulty levels. Beam search is more effective on harder questions, while best-of-N performs better on easier questions and at higher generation budgets.
These findings led to the implementation of a compute-optimal search strategy based on question difficulty, which significantly outperformed best-of-N while using 4x less test-time compute.
We can improve the proposal distribution by enabling the LLM to revise its own answers iteratively, thereby refining its response over time.
Revision models are trained using a modified version of an existing approach that involves finetuning on trajectories of incorrect and correct answers. They introduced a character edit distance metric to ensure correlation between incorrect and correct answers, enhancing the learning process.
The key findings include:
Sequential revisions (proposing answers sequentially based on previous attempts) outperform parallel sampling (generating N independent answers) when selecting the best answer using either a verifier or majority voting.
An ideal ratio exists between sequential and parallel compute, balancing the benefits of local refinement and global search. This optimal ratio varies depending on the difficulty of the question.
By implementing a compute-optimal revision strategy based on question difficulty, the authors achieved significant performance improvements, surpassing the parallel best-of-N baseline while using significantly less compute.
Suppose a model was pre-trained with X FLOPs. Assume that we plan to run Y FLOPs of inference with this model. If we want to improve performance by increasing the total FLOPs budget by a factor of M (i.e., M(X + Y) total FLOPs across both pretraining and inference), should we spend our FLOPs on increased pretraining compute or on additional test-time compute?
The analysis reveals that the answer depends on the type of question and the ratio of inference tokens to pretraining tokens:
This finding suggests a potential shift in focus from solely scaling up pretraining compute towards a more balanced approach that leverages the power of test-time compute, especially for tasks within a model's capabilities.
This paper's findings have profound implications for businesses deploying and developing LLMs:
With high test time compute, models have more tokens to "think", and hence they can reason complex problems that were deemed beyond their scope in previous iterations.
The research suggests a shift in focus from solely increasing model size towards developing more sophisticated test-time compute strategies. This necessitates investing in research and development of techniques like sequential revisions, advanced search methods against PRMs, and robust methods for assessing question difficulty.
Recognizing the interplay between question difficulty, pretraining, and test-time compute allows for developing tailored solutions for specific business needs. For tasks with predictable difficulty levels, we can optimize model selection and resource allocation accordingly, maximizing efficiency and performance.
This paper presents a thorough and insightful analysis of scaling LLM test-time compute, highlighting its potential to significantly enhance model performance on reasoning tasks. The introduction of compute-optimal scaling strategies based on question difficulty represents a significant step towards more efficient utilization of test-time compute. While the research demonstrates the promising potential of test-time compute as a substitute for pretraining compute in certain scenarios, it also acknowledges its limitations, particularly with challenging problems. This necessitates further research and development of more advanced test-time strategies to fully unlock the potential of LLMs for complex reasoning tasks.
Test-time compute allows for exploring a wider range of potential solutions, going beyond the limitations of a single forward pass through the LLM. Methods like beam search and sequential revisions enable the model to explore different solution paths and iteratively refine its response.
Process-based verifiers like PRMs provide a more granular evaluation of the reasoning process, guiding the search towards more promising solutions. By evaluating individual steps, PRMs can identify errors early on and prevent the model from going down incorrect paths.
Recognizing that the efficacy of test-time compute strategies varies with question difficulty allows for a more targeted allocation of resources. By implementing compute-optimal strategies based on difficulty, we can maximize the benefits of test-time compute while minimizing unnecessary computation.
Essentially, test-time compute scaling provides the LLM with the opportunity to "think harder" about a problem, iteratively refine its approach, and leverage more sophisticated evaluation mechanisms to arrive at a more accurate solution. As test-time strategies continue to evolve, we can expect even greater improvements in LLM reasoning capabilities.