Transformer based language models think in terms of tokens. By design, the architecture allocates and spreads FLOPs uniformly across the input tokens. This is what makes language modelling so computationally expensive. If we use the analogy of time and effort to solve problems, compute is that for language models. And, a language model is expending same amount of compute per token in a forward pass. That is an overkill as not all problems need the same kind of effort. This paper about Mixture of Depths (MoD) approach addresses this issue by dynamically allocating compute based on token-level decisions, allowing the model to focus on most relevant parts of the input.
Conditional computation is a technique that tries to reduce total compute by expending it only when needed. Various algos offer solutions but fall short because of hardware constraints as they are based on dynamic computation graphs, where the hardware dictates that computation graphs be static and tensor sizes be known to maximize hardware utilization.
Say we have a static compute budget for an LLM. The network must learn how to dynamically allocate the available compute by making decisions per-token, in each layer, about where to spend compute from the available budget. Since we know the total compute, the hardware gains can be exploited ahead of time without sacrificing performance.
This paper uses an approach similar to Mixture of Experts. In MoE, dynamic token-level routing decisions are made across the network depth. In a slight change, the model chooses to either apply a computation to a token (as would be the case for a standard transformer), or pass it through a residual connection (remaining unchanged and saving compute). Also in contrast to MoE, this routing is applied to both forward MLPs(multilayer perceptrons)and multi-head attention. The routing makes decisions not only about which tokens to update, but also which tokens are made available to attend to. This is the basic idea behind mixture of depths.
This approach also allows for trade offs for performance with speed. MoD transformers achieves better log prob (1.5% higher) if trained for same time as a dense transformer. MoD transformer can achieve training loss parity with an isoFLOP optimal vanilla transformer, but which uses a fraction of the FLOPs (upwards of 50%) per forward pass, and hence is faster to step.
A wide variety of recent work has developed conditional computation methods for transformers. Some of this work focuses on "early exiting", that is, learning to decide when to end computation on a given token, allowing the token to skip any remaining transformer layers after the exit decision is made. In MoD, unlike in early-exit methods, a token can skip middle layers, then be updated via self-attention with tokens that that have gone through all the middle layers. One successful formulation of conditional computation is the the "mixture-of-experts" layer (MoE). Unlike other conditional computation approaches that try to conserve or expend additional compute, MoE transformers use conditional logic to route tokens to one of many expert MLPs while keeping total compute expenditure constant. Our mixture-of-depths method can be thought of as using the routing logic from MoE transformers, but rather than having multiple experts, MoD deploys a single expert which can be dynamically skipped.
It's the token capacity that determines the total FLOPs for transformers that use conditional computation, rather than the outcomes of any routing decisions. This is because static-graph implementations account for the worst-case scenarios decisions; e.g., a computation’s inputs will be padded to its capacity amount even if relatively few tokens actually end up routing to it, and/or tokens will be dropped from the computation if the capacity is exceeded.
The core hypothesis is that certain tokens may not require as much compute and that can be identified through training and learning. Therefore, if the network learns to choose the right tokens to fill up its capacities, then it may preserve its performance.
Learned Routing
Employs a router network to assign weights based on token embeddings, allowing the model to learn which tokens require more processing. There are two kinds of learned routing - expert choice and token choice. In token-choice routing, a router produces per-token probability distributions across computational paths (e.g., across expert identities in MoE Transformers). Tokens are then shuttled to the path they prefer—i.e., that with the highest probability—and auxiliary losses ensure that all tokens don’t converge to the same path. Token-choice routing can have load balancing problems since there isn’t a guarantee that tokens divide themselves appropriately between the possible paths. “Expert choice routing” flips this recipe on its head: rather than having tokens choose the path they prefer, each path instead chooses the top-k tokens based on the tokens’ preferences. This ensures a perfect load balance since k tokens are guaranteed to be shuttled to each path. However, it could result in over- or under-processing of some tokens, since some tokens may be among the top-k for multiple paths, or for none of them.
Expert choice is preferable because there is no need for auxiliary balancing loss, and that most critical tokens are in top-k by design which is not possible in token-choice.
Stochastic Routing
Randomly selects tokens for processing, serving as a baseline for comparison.
Expert-choice routing is employed, where each block selects the top-k tokens based on their router weights. This ensures a balanced distribution of tokens across computational paths and allows the model to prioritize important tokens.
While expert-choice routing has a number of advantages, it has one distinct problem: the top-k operation is non-causal. This means that whether a given token’s routing weight is among the top-k for the sequence depends on the values of the routing weights for tokens that come after it, which we don’t have access to when autoregressively sampling.
To address the non-causal nature of top-k routing during inference, an auxiliary predictor network is trained to predict whether a token will be among the top-k, enabling efficient autoregressive sampling.
MoD models are trained using standard language modeling objectives, with the addition of the auxiliary predictor loss for autoregressive sampling.
MoD transformers achieve equal or better performance compared to isoFLOP optimal baseline models while requiring fewer FLOPs per forward pass and maintaining similar training times.This translates to faster inference and potential for larger model sizes or longer training durations within the same computational budget.Compared to MoE, MoD offers additional efficiency gains by allowing tokens to bypass computations altogether, leading to a more sparse allocation of resources.
Autoregressive evaluation demonstrates minimal performance degradation when switching from top-k routing to the predictor-based approach during inference.
The auxiliary predictor achieves high accuracy, ensuring efficient and accurate decoding.
MoD can be effectively integrated with MoE. The performance and efficiency improvements from MoD compounds with those of MoE. There are two variants - staged and integrated. Both show promising results.
in staged MoDE, which routes tokens around or towards blocks prior to the self-attention step. In integrated MoDE, which implements MoD routing by integrating “no-op” experts among the conventional MLP experts. The former is advantageous because it allows for tokens to skip the self-attention step, while the latter is advantageous because it simplifies the routing machinery.
Implementing MoDE in the integrated manner was distinctly better than simply reducing the capacity of experts in conventional MoE models.
MoD offers several potential benefits for businesses utilizing LLMs:
Fewer FLOPs per forward pass translates to faster and more cost-effective inference, particularly beneficial for real-time applications like chatbots or machine translation.
The saved compute resources can be used to train larger models with improved performance, leading to better accuracy and generalization.
Dynamic allocation ensures that resources are focused on the most critical parts of the input, leading to a more efficient use of computational power.
Today, many models are reaching their limit because of implementation and architecture issues with Transformers. MoD and significant gains mean that we can have large models with parameters > 3T which can retain more knowledge from the training data and can be more effective at problem solving.
MoE gave us step function changes to the capabilities. Next step function with MoD might be a significant step in the evolution of LLMs. Perhaps closer to AGI?
In a way, this is the pathway to having a powerful language model in your smartphone or on your computer given the compute requirements are very low compared to existing transformer models.
MoD raises interesting questions for future research:
Analyzing how the model learns to prioritize tokens for processing can provide insights into the inner workings of the model and its understanding of language.
Integrating MoD with long-term memory mechanisms could further enhance context awareness and improve performance on tasks requiring long-range dependencies.
Exploring the routing of tokens to different types of computations beyond self-attention and MLP can expand the capabilities of the model and tailor it to specific tasks.
MoD identifies tokens requiring more processing and allocates additional compute to them, while allowing less complex tokens to bypass certain computations, leading to overall efficiency gains. It's picking and choosing where to expend it's tokens and that is more efficient than allocating them equally.
The model learns to make intelligent routing decisions based on token embeddings, ensuring that resources are directed where they have the most impact on prediction accuracy.
By maintaining a fixed compute budget, MoD ensures hardware efficiency and avoids the challenges associated with dynamic computation graphs, making it suitable for real-world deployments.
MoD presents a significant advancement in efficient language modeling, offering a compelling alternative to traditional transformer architectures. Its ability to dynamically allocate compute resources leads to improved performance, faster inference, and greater resource efficiency, opening up new possibilities for LLM applications in various domains. Further exploration of this technique and its integration with other efficiency methods promises exciting advancements in the field of natural language processing.
Conditional computation, refers to the ability of a model to dynamically adjust the amount of computation it performs based on the specific input it receives.
Think of it like this: when reading a book, you don't spend the same amount of time and effort on every word of a sentence. You might skim through easy parts, but slow down and reread complex passages. Similarly, conditional computation allows a language model to identify the most important tokens and skim through other tokens.