Over the last one year, one of the most prominent use cases of Generative AI applications that emerged are augmenting the LLMs knowledge with organizational knowledge. Whether it's chatting with a PDF, chat over your own documents, code generation, or natural language queries, most AI apps are interested in helping employees navigate through the wall of unstructured data. Many of these apps use what is called RAG based approach - retrieving the most similar subset of indexed data at runtime and passing along a prompt - which works well to a degree, but has its limitations.
LLMs in their current form are monolithic in structure. They are trained on a large corpus of data, demonstrated non trivial skills in different domains, but it's expensive to impart them new knowledge or skills that were not in original dataset. A typical approach to solve this is to pretrain a new model from scratch, spend much time on continual learning and pray to gods that it works in any iteration and avoid catastrophic forgetting, or in specific cases, efficiently finetune a large foundational model for a specific skill. Eg: Code generation.
When open source models came out, the community spent so much time on finetuning these for domain specific skillsets. Today, you have specialized models for code generation, sales, but one model exceling at code generation might be poor at reasoning.
Moreover, these models are finetuned meaning they are still limited by training data of the foundational LLM and fail for tasks outside of that. Further, finetuning models is computationally expensive, processing new data may not be feasible due to privacy concerns and org boundaries.
In an ideal world:
CaLM stands for 'Composition to Augment Language Models'. It introduces a compositional technique to combine two LLMs such that:
Sounds too good to be true? Stay with me on this one.
From the paper:
In this work, we propose a novel Composition to Augment Language Models (CALM) framework to address the general model composition setting mentioned above. Rather than a shallow combination of the augmenting and anchor LMs (Wortsman et al., 2022; Ilharco et al., 2022), CALM introduces a small number of trainable parameters over both augmenting and anchor models’ intermediate layer representations. CALM finds an effective combination of the given models to perform new challenging tasks more accurately than either of the models alone, while preserving the capabilities of individual models. Figure 1 highlights few motivating scenarios for CALM.
That is, you pick an anchor Language model with augmenting model, and then train on the combination's intermediate layers, to get the functionality of both models.
In practical sense, say you have an anchor model (model A or mA) which is great at problem solving, but poor at everything else (fine-tuned only on reasoning examples). Another model, our augmenting model B (or mB), is good at code generation but poor at most other stuff. By combining the two models using CaLM, you can have a composite model (mAB)which can solve a complex coding problem and write an efficient code for it. You can also combine the model A with another model mC which could be great at physics knowledge, and now you have a combination that can theoretically solve physics related questions.
For this paper, the authors worked under the following assumptions given in the paper:
This is where the magic happens. This cross attention layer is what enables the composition of capabilities in the anchor model (mA) and the augmenting model (mB). Cross attention layers for the anchor model to attend to the representations of the augmenting model. The algo is quite neat actually:
Since our target of mA union mB involves composition over the models A and B, we construct a set of training examples DC to depict a combined skill. The goal of this training run is to enable a model to recognize the individual skillsets (so to speak) of the two composition models, and how to use them to generate output. Going back to our code + reasoning example, a set of training data would be examples of how to reason your way from a given problem statement and then the code for the output. This way our composition would know when to utilize the reasoning model and code generation model.
The best part here is that only a fraction of examples would suffice for cross attention layer to generalize to the full set. This saves on compute, time, and costs and can help get to market quickly.
From the paper:
In other real world settings, a clear distinction in specializing tasks for each model might be difficult to formulate and hence defining a task that captures the combined skills can be challenging. We find that using a set of examples that capture certain capabilities of the two models suffices, i.e., some rough notion of tA∪B. For our language inclusivity task, we use a mixture of examples containing a small amount of low-resource language and high-resource language data.
While the paper does not mention it, slight modification to the technique can be a real gamechanger for many enterprises. Today, most of the larger companies with datasets > 1B tokens are struggling to effectively implement AI in their organizations. Reasons are multiple. RAG is less effective at that scale. They are not willing to share their data. Deploying a model on premises for basic RAG based apps is overkill both cost and functionality wise. (Deploying a mistral 7B model on premise costs about $10,000 a year for inference. Without even considering the training or fine-tuning costs. )
What fascinated the executives at these companies was the promise of ChatGPT and GPT4. An AI trained on internet could simplify the world's knowledge for everyone. Unfortunately, it set the benchmark and gave a wow factor that most AI apps failed to replicate. Because most of an org's data is private, and GPT-4 is not trained on it. RAG can only go so far when a model has no context about what it's being asked. LLM scaling laws dictate that any model's efficacy is proportional to the knowledge it's trained on, training runs, and number of parameters. So, you have a model not trained on data at all, trying to guess what an answer looks like based on context window and zero/few shot learning. Not going to match expectations no matter how realistic they are.
CaLM helps address all of these issues effectively.
A typical hypothetical example would be an anchor model trained on all the effective marketing strategies in the last five years. Another model is trained on all information about an org's product. With right training, you can get a composition which understands the company's product, business lines, and suggest the right kind of marketing strategies for them.
Of course, training even a smaller LLM is non trivial. That is where we come in at Clio AI. We specialize in training models privately to enable specific use cases or implementations as described in the paper. You can get in touch with us here.