Table of Contents
GQA
Pytorch code implementation
What are the benefits of GQA?
Home Technology peripherals AI Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Apr 03, 2024 pm 05:40 PM
python pytorch Large language model gqa

Grouped Query Attention (Grouped Query Attention) is a multi-query attention method in large language models. Its goal is to achieve the quality of MHA while maintaining the speed of MQA. Grouped Query Attention groups queries so that queries within each group share the same attention weight, which helps reduce computational complexity and increase inference speed.

In this article, we will explain the idea of ​​GQA and how to translate it into code.

GQA is proposed in the paper GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints paper. It is a fairly simple and clean idea, and is built on multi-head attention. above strength.

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

GQA

The standard multi-head attention layer (MHA) consists of H query heads, key heads and Value header composition. Each head has D dimensions. The Pytorch code is as follows:

from torch.nn.functional import scaled_dot_product_attention  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 8, 64) value = torch.randn(1, 256, 8, 64)  output = scaled_dot_product_attention(query, key, value) print(output.shape) # torch.Size([1, 256, 8, 64])
Copy after login

For each query header, there is a corresponding key. This process is shown in the figure below:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

And GQA divides the query header into G groups, and each group shares a key and value. It can be expressed as:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Using visual expression, you can clearly understand the working principle of GQA, just like what we said above. GQA is a fairly simple and clean idea.

Pytorch code implementation

Let us write code to divide the query header into G groups, each group sharing a key and value . We can use the einops library to efficiently perform complex operations on tensors.

First, define the query, keys, and values. Then set the number of attention heads. The number is arbitrary, but it must be ensured that num_heads_for_query % num_heads_for_key = 0, which means it must be divisible. Our definition is as follows:

import torch  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 2, 64) value = torch.randn(1, 256, 2, 64)  num_head_groups = query.shape[2] // key.shape[2] print(num_head_groups) # each group is of size 4 since there are 2 kv_heads
Copy after login

To improve efficiency, swapping the seq_len and num_heads dimensions, einops can be done simply as follows:

from einops import rearrange  query = rearrange(query, "b n h d -> b h n d") key = rearrange(key, "b s h d -> b h s d") value = rearrange(value, "b s h d -> b h s d")
Copy after login

Then the concept of "grouping" needs to be introduced into the query matrix.

from einops import rearrange query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) print(query.shape) # torch.Size([1, 4, 2, 256, 64])
Copy after login

With the code above we reshape 2D into 2D: for the tensor we defined, the original dimension 8 (the number of heads in the query) is now Split into two groups (to match the number of heads in keys and values), each group size 4.

The last and hardest part is calculating the attention score. But in fact, it can be done in one line through the insum operation. The

from einops import einsum, rearrange # g stands for the number of groups # h stands for the hidden dim # n and s are equal and stands for sequence length scores = einsum(query, key, "b g h n d, b h s d -> b h n s") print(scores.shape) # torch.Size([1, 2, 256, 256])
Copy after login

scores tensor has the same shape as the value tensor above. Let’s see how it works

einsum does two things for us:

1. A query and matrix multiplication of keys . In our case, the shapes of these tensors are (1,4,2,256,64) and (1,2,256,64), so matrix multiplication along the last two dimensions gives us (1,4,2,256,256).

2. Sum the elements in the second dimension (dimension g) - if the dimension is omitted in the specified output shape, einsum will automatically complete this work, so The summation is used to match the number of keys and values ​​in the header.

Finally, note the standard multiplication of fractions and values:

import torch.nn.functional as F  scale = query.size(-1) ** 0.5 attention = F.softmax(similarity / scale, dim=-1)  # here we do just a standard matrix multiplication out = einsum(attention, value, "b h n s, b h s d -> b h n d")  # finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim) out = rearrange(out, "b h n d -> b n h d") print(out.shape) # torch.Size([1, 256, 2, 64])
Copy after login

The simplest GQA implementation is now complete, requiring less than 16 lines of python code:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Finally, I will briefly mention A word about MQA: Multiple Query Attention (MQA) is another popular method to simplify MHA. All queries will share the same keys and values. The schematic diagram is as follows:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

As you can see, both MQA and MHA can be derived from GQA. GQA with a single key and value is equivalent to MQA, while GQA with groups equal to the number of headers is equivalent to MHA.

What are the benefits of GQA?

GQA is one between the best performance (MQA) and the best model quality (MHA) Very good trade-off.

The following figure shows that using GQA, you can obtain almost the same model quality as MHA, while increasing the processing time by 3 times, reaching the performance of MQA. This may be essential for high load systems.

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

There is no official implementation of GQA in pytorch. So I found a better unofficial implementation. If you are interested, you can try it:

https://www.php.cn/link/5b52e27a9d5bf294f5b593c4c071500e

GQA paper:

https://www.php.cn/link/e4ba31fba036a999321d5460f7f2d1d1

The above is the detailed content of Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation. For more information, please follow other related articles on the PHP Chinese website!

Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

Video Face Swap

Video Face Swap

Swap faces in any video effortlessly with our completely free AI face swap tool!

Hot Tools

Notepad++7.3.1

Notepad++7.3.1

Easy-to-use and free code editor

SublimeText3 Chinese version

SublimeText3 Chinese version

Chinese version, very easy to use

Zend Studio 13.0.1

Zend Studio 13.0.1

Powerful PHP integrated development environment

Dreamweaver CS6

Dreamweaver CS6

Visual web development tools

SublimeText3 Mac version

SublimeText3 Mac version

God-level code editing software (SublimeText3)

Can vs code run in Windows 8 Can vs code run in Windows 8 Apr 15, 2025 pm 07:24 PM

VS Code can run on Windows 8, but the experience may not be great. First make sure the system has been updated to the latest patch, then download the VS Code installation package that matches the system architecture and install it as prompted. After installation, be aware that some extensions may be incompatible with Windows 8 and need to look for alternative extensions or use newer Windows systems in a virtual machine. Install the necessary extensions to check whether they work properly. Although VS Code is feasible on Windows 8, it is recommended to upgrade to a newer Windows system for a better development experience and security.

How to run programs in terminal vscode How to run programs in terminal vscode Apr 15, 2025 pm 06:42 PM

In VS Code, you can run the program in the terminal through the following steps: Prepare the code and open the integrated terminal to ensure that the code directory is consistent with the terminal working directory. Select the run command according to the programming language (such as Python's python your_file_name.py) to check whether it runs successfully and resolve errors. Use the debugger to improve debugging efficiency.

Can visual studio code be used in python Can visual studio code be used in python Apr 15, 2025 pm 08:18 PM

VS Code can be used to write Python and provides many features that make it an ideal tool for developing Python applications. It allows users to: install Python extensions to get functions such as code completion, syntax highlighting, and debugging. Use the debugger to track code step by step, find and fix errors. Integrate Git for version control. Use code formatting tools to maintain code consistency. Use the Linting tool to spot potential problems ahead of time.

Is the vscode extension malicious? Is the vscode extension malicious? Apr 15, 2025 pm 07:57 PM

VS Code extensions pose malicious risks, such as hiding malicious code, exploiting vulnerabilities, and masturbating as legitimate extensions. Methods to identify malicious extensions include: checking publishers, reading comments, checking code, and installing with caution. Security measures also include: security awareness, good habits, regular updates and antivirus software.

Python: Automation, Scripting, and Task Management Python: Automation, Scripting, and Task Management Apr 16, 2025 am 12:14 AM

Python excels in automation, scripting, and task management. 1) Automation: File backup is realized through standard libraries such as os and shutil. 2) Script writing: Use the psutil library to monitor system resources. 3) Task management: Use the schedule library to schedule tasks. Python's ease of use and rich library support makes it the preferred tool in these areas.

What is vscode What is vscode for? What is vscode What is vscode for? Apr 15, 2025 pm 06:45 PM

VS Code is the full name Visual Studio Code, which is a free and open source cross-platform code editor and development environment developed by Microsoft. It supports a wide range of programming languages ​​and provides syntax highlighting, code automatic completion, code snippets and smart prompts to improve development efficiency. Through a rich extension ecosystem, users can add extensions to specific needs and languages, such as debuggers, code formatting tools, and Git integrations. VS Code also includes an intuitive debugger that helps quickly find and resolve bugs in your code.

Python vs. JavaScript: The Learning Curve and Ease of Use Python vs. JavaScript: The Learning Curve and Ease of Use Apr 16, 2025 am 12:12 AM

Python is more suitable for beginners, with a smooth learning curve and concise syntax; JavaScript is suitable for front-end development, with a steep learning curve and flexible syntax. 1. Python syntax is intuitive and suitable for data science and back-end development. 2. JavaScript is flexible and widely used in front-end and server-side programming.

Golang vs. Python: Concurrency and Multithreading Golang vs. Python: Concurrency and Multithreading Apr 17, 2025 am 12:20 AM

Golang is more suitable for high concurrency tasks, while Python has more advantages in flexibility. 1.Golang efficiently handles concurrency through goroutine and channel. 2. Python relies on threading and asyncio, which is affected by GIL, but provides multiple concurrency methods. The choice should be based on specific needs.

See all articles