August 7, 2024
13 mins

Gemma 2: Improving Open Language Models at a Practical Size

Gemma 2 research paper by Google Deepmind detailing Knowledge Distillation, Group Query Attention, and ton of safety trainings
Paper Link
Header image

Key Takeaways

  • Knowledge distillation is an effective method for training smaller language models: By training on the output probabilities of a larger "teacher" model, smaller models can achieve performance levels comparable to much larger models trained on raw text data.
  • Gemma 2 models are competitive with much larger models: These models significantly advance state-of-the-art performance for open, lightweight language models, even challenging models more than twice their size.
  • Responsible development is a key priority: The researchers have taken extensive steps to mitigate safety and security risks, including data filtering, rigorous evaluation, and ongoing monitoring of model usage.
  • Open-weight models have the potential to democratize AI: Gemma 2's open-weight approach promotes accessibility and collaboration, potentially unlocking new research avenues and beneficial applications.

Introduction

Large language models (LLMs) have made significant progress that can be attributed to the scaling of models with new capabilities emerging as models grow larger. However, this trend poses a challenge for accessibility, as training and deploying these massive models require significant computational resources. Initiated by Google DeepMind, Gemma aims to create high-performing, open-weight language models at a more practical size.

Previous work has shown that smaller models can benefit from extended training lengths. However, these gains are limited by the logarithmic scaling of performance with dataset size. Gemma 2 explores alternative approaches to boost the performance of smaller models without solely relying on extended training. The paper states:

In this work, we explore alternatives to improve small model performance without solely increasing training length. One solution is to improve the quality of information received by the network at each training step by replacing the next token prediction task with a richer objective.

The researchers focused on knowledge distillation as a key technique for improving the information quality provided to smaller models. This utilizes a larger, pre-trained language model as a "teacher" to guide the training of a smaller "student" model. By training on the teacher model's output probabilities rather than just raw text data, the student model can learn more efficiently and achieve superior performance. Llama3 8B and 70B models have used the same method.

Background

Rotary Position Embeddings (ROPE) (Su et al., 2021)

GeGLU non-linearity (Shazeer, 2020)

Knowledge Distillation (Hinton et al., 2015)

Interleaving local-global attention (Beltagy et al., 2020a): This approach combines global attention with local attention to improve efficiency and capture long-range dependencies.

Grouped-Query Attention (GQA) (Ainslie et al., 2023): GQA reduces the computational cost of attention by grouping queries, leading to faster inference.

Model Architecture

Gemma 2 models are built using the decoder-only transformer architecture, with specific design choices and parameters outlined in Table 1 of the paper. This section will delve into the architectural elements that distinguish Gemma 2 from its predecessor and other comparable models.

RoPE and GeGLU Non-linearity

Gemma 2 models inherit Rotary Position Embeddings (ROPE) and the GeGLU non-linearity from the original gemma models. ROPE is a technique that effectively encodes positional information in the input sequence. Instead of directly adding positional embeddings, ROPE applies a rotation matrix to the word embeddings, enabling the model to learn relative positions between tokens.

GeGLU is an approximation of the Gaussian Error Linear Unit (GELU) activation function. GELU is known for its strong performance, but it can be computationally expensive. GeGLU offers a computationally efficient approximation of GELU, maintaining performance while reducing the computational burden.

Local Sliding Window and Global Attention

This hybrid approach aims to capture both local and global context effectively while maintaining computational efficiency.

The local sliding window attention mechanism restricts the attention scope to a fixed-size window of tokens around each target token. This reduces the number of pairwise attention computations, enhancing speed and efficiency. In Gemma 2, the sliding window size for local attention layers is set to 4096 tokens.

In contrast, global attention layers attend to all tokens in the input sequence, allowing the model to capture long-range dependencies and broader context. Gemma 2 utilizes global attention layers with a span of 8192 tokens, matching the model's maximum context length. Alternating between the two, Gemma 2 strikes a balance between efficiency and the ability to capture both short-range and long-range context.

Logit Soft-capping

Logits are the raw output values of the model before they are transformed into probabilities. Soft-capping involves applying a non-linear function, such as a tanh function, to constrain the logits within a specified range. The paper elaborates:

We cap logits (Bello et al., 2016) in each attention layer and the final layer such that the value of the logits stays between -soft_cap and +soft_cap.

In Gemma 2, the soft_cap parameter is set to 50.0 for self-attention layers and 30.0 for the final layer. This technique helps to prevent extreme logit values, stabilizing training and potentially leading to improved generalization.

Post-norm and Pre-norm with RMSNorm

Normalization techniques like RMSNorm help to prevent the internal covariate shift, ensuring consistent input distributions for each layer during training.

Post-norm applies normalization after the attention and feedforward layers, while pre-norm applies normalization before these layers. By using both, Gemma 2 can potentially benefit from the strengths of each approach, leading to more stable and efficient training.

Grouped-Query Attention (GQA)

GQA is a modified attention mechanism that groups queries before computing attention weights. This reduces the number of attention computations, especially beneficial during inference.

From the paper:

We use GQA with num_groups = 2, based on ablations showing increased speed at inference time while maintaining downstream performance.

By using two groups, Gemma 2 can reduce the computational cost of attention while maintaining performance comparable to models using full Multi-Head Attention (MHA). This trade-off between speed and performance is a critical consideration for deploying models on resource-constrained devices.

Pretraining

Training Data

Gemma 2 models are trained on different quantities of data, scaling with model size. The 27B model is trained on 13 trillion tokens, the 9B model on 8 trillion tokens, and the 2B model on 2 trillion tokens. This data is primarily English text sourced from a diverse range of sources, including web documents, code, and scientific articles.

Tokenizer

Gemma 2 utilizes the same SentencePiece tokenizer as its predecessors, Gemma 1 and Gemin with a vocabulary size of 256k entries.

Data filtering plays a crucial role in ensuring model safety and mitigating potential harms. Gemma 2 inherits the filtering techniques from Gemma 1, focusing on:

Reducing the risk of unwanted or unsafe utterances: This involves filtering out toxic language, hate speech, and other harmful content.

Protecting personal information and sensitive data: The dataset is filtered to remove or mask personally identifiable information and sensitive content.

Decontaminating evaluation sets: This involves ensuring that the evaluation data is not present in the pretraining data, preventing inflated performance estimates.

Minimizing recitation: This involves reducing the risk of the model simply reciting memorized text from the training data by limiting the proliferation of sensitive outputs.

Knowledge Distillation

From the paper:

Given a large model used as a teacher, we learn smaller models by distilling from the probability given by the teacher of each token x given its context xc, i.e., Pr(x | xc).

Instead of predicting the next token in a sequence, the student model is trained to match the probability distribution of the teacher model's output. This provides a richer training signal, as the student model learns not only the correct token but also the relative probabilities of other potential tokens.

In Gemma 2, the 2B and 9B models are trained with knowledge distillation, using a larger language model as the teacher. Notably, the training process simulates training beyond the computationally optimal number of tokens, exceeding the theoretical limit by more than 50x.

Compute Infrastructure

Gemma 2leverages a combination of TPUs (Tensor Processing Units), specifically TPUv4, TPUv5e, and TPUv5p, depending on model size. The models are trained using data parallelism and model parallelism techniques to distribute the computational load across multiple chips. The optimizer state is further sharded using techniques similar to ZeRO-3, an approach that reduces memory requirements by distributing the optimizer state across multiple devices. For scales beyond a single pod, the Pathways approach is employed to enable data-replica reduction over the data center network.

Post Training

Supervised Fine-Tuning (SFT)

From the paper: 

We fine-tune our pre-trainedmodels into instruction-tuned models. First, weapply supervised fine-tuning (SFT) on a mixof text-only, English-only synthetic and humangenerated prompt-response pairs.

Reinforcement Learning from Human Feedback (RLHF)

In RLHF, we train a reward model to score the quality of model-generated responses based on human feedback. The model is then fine-tuned to maximize its expected reward, leading to outputs that are more aligned with human preferences.

In Gemma 2, the reward model for RLHF is significantly larger than the policy model, allowing for more nuanced and accurate evaluation of response quality. The reward model is trained on labelled English-only preference data, and the policy is based on the same prompts used in the SFT phase.

Model Merging

Finally, models obtained from different runs of the post-training pipeline with varying hyperparameters are averaged together.

Model merging leverages the diversity of models trained with different configurations, combining their strengths and reducing the risk of overfitting to a specific set of hyperparameters. .

Synthetic Data and Data Filtering

Synthetic data refers to artificially generated examples that mimic the structure and content of real-world data.

From the paper:

When using synthetic data, we run several stages of filtering to remove examples that show certain personal information, unsafe or toxic model outputs, mistaken self-identification data, and duplicated examples.

Ablations

Distillation vs. Training from Scratch

A core finding of the paper is the significant impact of knowledge distillation on smaller language models. Table 6 showcases this impact by comparing a 2B model trained from scratch to a 2B model trained with distillation from a 7B teacher model. Both models were trained on 500B tokens, exceeding the computationally optimal number by a factor of 10.

Distillation leads to a substantial improvement in performance, highlighting its effectiveness in enhancing the learning efficiency of smaller models.

Impact of Distillation vs. Model Size

he results indicate that the performance gain from distillation remains consistent across different model sizes.

This finding suggests that knowledge distillation is a valuable technique for training smaller language models across a range of sizes, not just limited to extremely small models.

GQA vs. MHA

The results show minimal differences in performance between the two approaches, suggesting that GQA offers a viable alternative to MHA, achieving comparable performance while reducing computational requirements. This is particularly beneficial for deploying models on resource-constrained devices.

Impact of Formatting

The results indicate that Gemma 2B models exhibit slightly higher sensitivity to formatting compared to the larger models.

Evaluations

Pretraining Evaluations

The researchers evaluated the performance of both the 27B model and the 2B and 9B models on pretraining tasks.

Post-training Evaluations

The researchers evaluated Gemma 2 IT models on the LMSYS Chatbot Arena.

Notably, Gemma 2 27B outperforms Llama 3 70B and ranks similar to GPT-4-0314.

Human Evaluation Single Turn

Human Evaluation Multi Turn

Standard Benchmarks

Memorization and Privacy

Verbatim Memorization

Gemma 2 models exhibit significantly lower memorization rates compared to previous models, with rates below 0.1%.

Approximate Memorization

Approximate memorization rates are higher than verbatim rates, they are still considerably lower than those observed in previous models.

Personal Data

From the paper:

We found no instances of high-severity data being emitted, and found a very low rate of 0.00026% of memorized data to contain lower-severity personal information.

What caught my attention

Model Merging: This is one of the few papers which have talked about model merging as a way to reduce errors and produce a generally good model. I think more researchers should evaluate

Share this post

Why Clio AI?

Unlock the most obvious-yet-hidden-in-plain-sight growth hack - enable your employees to work on important things, and reduce their cognitive load and time to resolve blockers.

Fast, efficient, and in-context information to make every employee a super performer.

Spend time thinking not searching. Get a demo today.

By signing up for a demo, you agree to our Privacy Policy.
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.