目錄
GQA
Pytorch程式碼實作
GQA的好處是什麼?
首頁 科技週邊 人工智慧 大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

Apr 03, 2024 pm 05:40 PM
python pytorch 大型語言模型 gqa

群組查詢注意力(Grouped Query Attention)是大型語言模型中的一種多查詢注意力力方法,它的目標是在保持 MQA 速度的同時實現 MHA 的品質。 Grouped Query Attention 將查詢分組,每個群組內的查詢共享相同的注意力權重,這有助於降低計算複雜度和提高推理速度。

這篇文章中,我們將解釋GQA的想法以及如何將其轉化為程式碼。

GQA是在論文GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints paper.中提出,這是一個相當簡單和乾淨的想法,並且建立在多頭注意力之上。

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

GQA

#標準多頭注意層(MHA)由H個查詢頭、鍵頭和值頭組成。每個頭都有D個維度。 Pytorch的程式碼如下:

from torch.nn.functional import scaled_dot_product_attention  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 8, 64) value = torch.randn(1, 256, 8, 64)  output = scaled_dot_product_attention(query, key, value) print(output.shape) # torch.Size([1, 256, 8, 64])
登入後複製

對於每個查詢頭,都有一個對應的鍵。這個過程如下圖所示:

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

而GQA將查詢頭分成G組,每組共享一個鍵和值。可以表示為:

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

使用視覺化的表達就能非常清楚地了解GQA的工作原理,就像我們上面說的。 GQA是一個相當簡單和乾淨的想法。

Pytorch程式碼實作

讓我們寫程式將這個將查詢頭分割成G組,每個組共用一個鍵和值。我們可以使用einops庫有效地執行對張量的複雜操作。

首先,定義查詢、鍵和值。然後設定注意力頭的數量,數量是隨意的,但是要確保num_heads_for_query % num_heads_for_key = 0,也就是說要能夠整除。我們的定義如下:

import torch  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 2, 64) value = torch.randn(1, 256, 2, 64)  num_head_groups = query.shape[2] // key.shape[2] print(num_head_groups) # each group is of size 4 since there are 2 kv_heads
登入後複製

為了提高效率,交換seq_len和num_heads維度,einops可以像下面這樣簡單地完成:

#
from einops import rearrange  query = rearrange(query, "b n h d -> b h n d") key = rearrange(key, "b s h d -> b h s d") value = rearrange(value, "b s h d -> b h s d")
登入後複製

然後就是需要在查詢矩陣中引入」分組「的概念。

from einops import rearrange query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) print(query.shape) # torch.Size([1, 4, 2, 256, 64])
登入後複製

上面的程式碼我們將二維重塑為二維:對於我們定義的張量,原始維度8(查詢的頭數)現在被分成兩組(以匹配鍵和值中的頭數),每組大小為4。

最後最難的部分是計算注意力的分數。但其實它可以在一行中透過insum操作完成的

from einops import einsum, rearrange # g stands for the number of groups # h stands for the hidden dim # n and s are equal and stands for sequence length scores = einsum(query, key, "b g h n d, b h s d -> b h n s") print(scores.shape) # torch.Size([1, 2, 256, 256])
登入後複製

scores張量和上面的value張量的形狀是一樣的。我們來看看到底是怎麼操作的

einsum幫我們做了兩件事:

1、一個查詢和鍵的矩陣乘法。在我們的例子中,這些張量的形狀是(1,4,2,256,64)和(1,2,256,64),所以沿著最後兩個維度的矩陣乘法得到(1,4,2,256,256)。

2、對第二個維度(維度g)上的元素求和-如果在指定的輸出形狀中省略了維度,einsum將自動完成這項工作,這樣的求和是用來匹配鍵和值中的頭的數量。

最後是注意分數與值的標準乘法:

import torch.nn.functional as F  scale = query.size(-1) ** 0.5 attention = F.softmax(similarity / scale, dim=-1)  # here we do just a standard matrix multiplication out = einsum(attention, value, "b h n s, b h s d -> b h n d")  # finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim) out = rearrange(out, "b h n d -> b n h d") print(out.shape) # torch.Size([1, 256, 2, 64])
登入後複製

這樣最簡單的GQA實作就完成了,只需要不到16行python程式碼:

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

最後再簡單提一句MQA:多查詢注意(MQA)是另一種簡化MHA的流行方法。所有查詢將共享相同的鍵和值。原理圖如下:

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

可以看到,MQA和MHA都可以從GQA推導出來。具有單一鍵和值的GQA相當於MQA,而具有與頭部數量相等的組的GQA相當於MHA。

GQA的好處是什麼?

#GQA是最佳性能(MQA)和最佳模型品質(MHA)之間的一個很好的權衡。

下圖顯示,使用GQA,可以獲得與MHA幾乎相同的模型質量,同時將處理時間提高3倍,達到MQA的效能。這對於高負載系統來說可能是必不可少的。

大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現

在pytorch中沒有GQA的官方實作。所以我找到了一個比較好的非官方實現,有興趣的可以試試:

#https://www.php.cn/link/5b52e27a9d5bf294f5b593c4c071500e

#GQA論文:

#https://www.php.cn/link/e4ba31fba036a99321d5460f7f2d1d1

#

以上是大模型中常用的注意力機制GQA詳解以及Pytorch程式碼實現的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

熱門話題

Java教學
1664
14
CakePHP 教程
1423
52
Laravel 教程
1317
25
PHP教程
1268
29
C# 教程
1243
24
PHP和Python:解釋了不同的範例 PHP和Python:解釋了不同的範例 Apr 18, 2025 am 12:26 AM

PHP主要是過程式編程,但也支持面向對象編程(OOP);Python支持多種範式,包括OOP、函數式和過程式編程。 PHP適合web開發,Python適用於多種應用,如數據分析和機器學習。

在PHP和Python之間進行選擇:指南 在PHP和Python之間進行選擇:指南 Apr 18, 2025 am 12:24 AM

PHP適合網頁開發和快速原型開發,Python適用於數據科學和機器學習。 1.PHP用於動態網頁開發,語法簡單,適合快速開發。 2.Python語法簡潔,適用於多領域,庫生態系統強大。

PHP和Python:深入了解他們的歷史 PHP和Python:深入了解他們的歷史 Apr 18, 2025 am 12:25 AM

PHP起源於1994年,由RasmusLerdorf開發,最初用於跟踪網站訪問者,逐漸演變為服務器端腳本語言,廣泛應用於網頁開發。 Python由GuidovanRossum於1980年代末開發,1991年首次發布,強調代碼可讀性和簡潔性,適用於科學計算、數據分析等領域。

Python vs. JavaScript:學習曲線和易用性 Python vs. JavaScript:學習曲線和易用性 Apr 16, 2025 am 12:12 AM

Python更適合初學者,學習曲線平緩,語法簡潔;JavaScript適合前端開發,學習曲線較陡,語法靈活。 1.Python語法直觀,適用於數據科學和後端開發。 2.JavaScript靈活,廣泛用於前端和服務器端編程。

sublime怎麼運行代碼python sublime怎麼運行代碼python Apr 16, 2025 am 08:48 AM

在 Sublime Text 中運行 Python 代碼,需先安裝 Python 插件,再創建 .py 文件並編寫代碼,最後按 Ctrl B 運行代碼,輸出會在控制台中顯示。

Golang vs. Python:性能和可伸縮性 Golang vs. Python:性能和可伸縮性 Apr 19, 2025 am 12:18 AM

Golang在性能和可擴展性方面優於Python。 1)Golang的編譯型特性和高效並發模型使其在高並發場景下表現出色。 2)Python作為解釋型語言,執行速度較慢,但通過工具如Cython可優化性能。

vscode在哪寫代碼 vscode在哪寫代碼 Apr 15, 2025 pm 09:54 PM

在 Visual Studio Code(VSCode)中編寫代碼簡單易行,只需安裝 VSCode、創建項目、選擇語言、創建文件、編寫代碼、保存並運行即可。 VSCode 的優點包括跨平台、免費開源、強大功能、擴展豐富,以及輕量快速。

notepad 怎麼運行python notepad 怎麼運行python Apr 16, 2025 pm 07:33 PM

在 Notepad 中運行 Python 代碼需要安裝 Python 可執行文件和 NppExec 插件。安裝 Python 並為其添加 PATH 後,在 NppExec 插件中配置命令為“python”、參數為“{CURRENT_DIRECTORY}{FILE_NAME}”,即可在 Notepad 中通過快捷鍵“F6”運行 Python 代碼。

See all articles