November 22, 2024
6 mins

CaLM - Composition of LLMs by augmentation

CaLM provides composition for LLMs similar to how libraries would in a programming language. It's a powerful method to enable combining skills of multiple LLMs depending on the use case.
Paper Link
Header image
Weekly newsletter
No spam. Just the latest researches, and exclusive interviews in your inbox every week.
Read about our privacy policy.
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Key Takeaways

  • CaLM introduces a cross attention layer that can combine two LLMs such that they preserve their individual skills and combine seamlessly to produce output from the combination of two individual LLMs. Eg: if one LLM is good at code generation, and another at problem solving, a CaLM based composition could produce a code output with minimal finetuning.
  • CaLM could augment models to enable enterprises make the best of both worlds - train an AI on their private data, and use it in conjunction with LLMs trained on public datasets.

Introduction:

Over the last one year, one of the most prominent use cases of Generative AI applications that emerged are augmenting the LLMs knowledge with organizational knowledge. Whether it's chatting with a PDF, chat over your own documents, code generation, or natural language queries, most AI apps are interested in helping employees navigate through the wall of unstructured data. Many of these apps use what is called RAG based approach - retrieving the most similar subset of indexed data at runtime and passing along a prompt - which works well to a degree, but has its limitations.

LLMs in their current form are monolithic in structure. They are trained on a large corpus of data, demonstrated non trivial skills in different domains, but it's expensive to impart them new knowledge or skills that were not in original dataset. A typical approach to solve this is to pretrain a new model from scratch, spend much time on continual learning and pray to gods that it works in any iteration and avoid catastrophic forgetting, or in specific cases, efficiently finetune a large foundational model for a specific skill. Eg: Code generation.

Rise of domain specialized models

When open source models came out, the community spent so much time on finetuning these for domain specific skillsets. Today, you have specialized models for code generation, sales, but one model exceling at code generation might be poor at reasoning.
Moreover, these models are finetuned meaning they are still limited by training data of the foundational LLM and fail for tasks outside of that. Further, finetuning models is computationally expensive, processing new data may not be feasible due to privacy concerns and org boundaries.

A modular approach to LLMs

In an ideal world:

  • We may want to keep and reuse specialized models just like we may reuse libraries. That is, we do not want to alter the weights of specific individual models.
  • We want to combine the capabilities of two or more models to give us a skillset we desire without having to fine tune a model on both datasets.
  • We do not want to deal with issues like catastrophic forgetting, lack of control etc.

Introducing CaLM

CaLM stands for 'Composition to Augment Language Models'. It introduces a compositional technique to combine two LLMs such that:

  • Individual models multiple times in combinations with other models. Just like programs.
  • Weights of any of the models are not modified, and we still get the functionality we would get by combining two LLMs.
  • It's done with a fraction of examples and a fraction of resources that would be needed to finetune.

Sounds too good to be true? Stay with me on this one.

From the paper:

In this work, we propose a novel Composition to Augment Language Models (CALM) framework to address the general model composition setting mentioned above. Rather than a shallow combination of the augmenting and anchor LMs (Wortsman et al., 2022; Ilharco et al., 2022), CALM introduces a small number of trainable parameters over both augmenting and anchor models’ intermediate layer representations. CALM finds an effective combination of the given models to perform new challenging tasks more accurately than either of the models alone, while preserving the capabilities of individual models. Figure 1 highlights few motivating scenarios for CALM.

That is, you pick an anchor Language model with augmenting model, and then train on the combination's intermediate layers, to get the functionality of both models.

In practical sense, say you have an anchor model (model A or mA) which is great at problem solving, but poor at everything else (fine-tuned only on reasoning examples). Another model, our augmenting model B (or mB), is good at code generation but poor at most other stuff. By combining the two models using CaLM, you can have a composite model (mAB)which can solve a complex coding problem and write an efficient code for it. You can also combine the model A with another model mC which could be great at physics knowledge, and now you have a combination that can theoretically solve physics related questions.

Key Assumptions

For this paper, the authors worked under the following assumptions given in the paper:

  1. We can access weights, run forward and backward pass, and access intermediate representations of both mB and mA,
  2. We are not allowed to change weights of any of the models
  3. We do not have access to the training data, hyperparameters, training states of both the base models
  4. We are provided a few examples from the target composition domain.

Cross Attention Layer

This is where the magic happens. This cross attention layer is what enables the composition of capabilities in the anchor model (mA) and the augmenting model (mB). Cross attention layers for the anchor model to attend to the representations of the augmenting model. The algo is quite neat actually:

  • Select a set of layers from mA and mB. Let's denote these sets as LA and LB respectively.
  • For each layer in LB, project its representation to match the dimensionality of mA using a learned projection function.
  • For each pair of layers (i, j) from (LA, LB), compute the cross-attention between the projected representation fproj(Hi) from mA and the representation Hj from mB.
  • The cross-attention output is added as a residual connection to the layer representation of mB, effectively combining the information from both models.

Composition Training Data (DC)

Since our target of mA union mB involves composition over the models A and B, we construct a set of training examples DC to depict a combined skill. The goal of this training run is to enable a model to recognize the individual skillsets (so to speak) of the two composition models, and how to use them to generate output. Going back to our code + reasoning example, a set of training data would be examples of how to reason your way from a given problem statement and then the code for the output. This way our composition would know when to utilize the reasoning model and code generation model.

The best part here is that only a fraction of examples would suffice for cross attention layer to generalize to the full set. This saves on compute, time, and costs and can help get to market quickly.

From the paper:

In other real world settings, a clear distinction in specializing tasks for each model might be difficult to formulate and hence defining a task that captures the combined skills can be challenging. We find that using a set of examples that capture certain capabilities of the two models suffices, i.e., some rough notion of tA∪B. For our language inclusivity task, we use a mixture of examples containing a small amount of low-resource language and high-resource language data.


Business Implications for enterprises

While the paper does not mention it, slight modification to the technique can be a real gamechanger for many enterprises. Today, most of the larger companies with datasets > 1B tokens are struggling to effectively implement AI in their organizations. Reasons are multiple. RAG is less effective at that scale. They are not willing to share their data. Deploying a model on premises for basic RAG based apps is overkill both cost and functionality wise. (Deploying a mistral 7B model on premise costs about $10,000 a year for inference. Without even considering the training or fine-tuning costs. )
What fascinated the executives at these companies was the promise of ChatGPT and GPT4. An AI trained on internet could simplify the world's knowledge for everyone. Unfortunately, it set the benchmark and gave a wow factor that most AI apps failed to replicate. Because most of an org's data is private, and GPT-4 is not trained on it. RAG can only go so far when a model has no context about what it's being asked. LLM scaling laws dictate that any model's efficacy is proportional to the knowledge it's trained on, training runs, and number of parameters. So, you have a model not trained on data at all, trying to guess what an answer looks like based on context window and zero/few shot learning. Not going to match expectations no matter how realistic they are.

CaLM helps address all of these issues effectively.

  • Enterprises can train a small 1B-5B parameter model on their own data. This model can use the same architecture as the open source models.
  • With a cross attention mechanism, we can use these models for composition with the open source models and pass high quality examples through the cross attention layer.
  • This combination enables companies to get the best of both worlds - their data is kept private, they can access a SOTA model for instructions on public data, and can use the same model for any questions about their own datasets.

A typical hypothetical example would be an anchor model trained on all the effective marketing strategies in the last five years. Another model is trained on all information about an org's product. With right training, you can get a composition which understands the company's product, business lines, and suggest the right kind of marketing strategies for them.

Of course, training even a smaller LLM is non trivial. That is where we come in at Clio AI. We specialize in training models privately to enable specific use cases or implementations as described in the paper. You can get in touch with us here.

Share this post

Why Clio AI?

Unlock the most obvious-yet-hidden-in-plain-sight growth hack - enable your employees to work on important things, and reduce their cognitive load and time to resolve blockers.

Fast, efficient, and in-context information to make every employee a super performer.

Spend time thinking not searching. Get a demo today.

By signing up for a demo, you agree to our Privacy Policy.
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.