Rumah > Peranti teknologi > AI > Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

WBOY
Lepaskan: 2024-04-03 17:40:09
ke hadapan
994 orang telah melayarinya

Perhatian Pertanyaan Berkumpulan ialah kaedah perhatian berbilang pertanyaan dalam model bahasa besar Matlamatnya adalah untuk mencapai kualiti MHA sambil mengekalkan kelajuan MQA. Pertanyaan Berkumpulan Pertanyaan kumpulan perhatian supaya pertanyaan dalam setiap kumpulan berkongsi berat perhatian yang sama, yang membantu mengurangkan kerumitan pengiraan dan meningkatkan kelajuan inferens.

Dalam artikel ini, kami akan menerangkan idea GQA dan cara menterjemahkannya ke dalam kod.

GQA telah dicadangkan dalam kertas GQA: Latihan Model Transformer Berbilang Pertanyaan Umum daripada kertas Pusat Pemeriksaan Berbilang Kepala Ia adalah idea yang agak mudah dan bersih, dan dibina berdasarkan perhatian berbilang kepala.

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

GQA

Lapisan perhatian berbilang kepala standard (MHA) terdiri daripada pengepala pertanyaan H, pengepala utama dan pengepala nilai. Setiap kepala mempunyai dimensi D. Kod Pytorch adalah seperti berikut:

from torch.nn.functional import scaled_dot_product_attention  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 8, 64) value = torch.randn(1, 256, 8, 64)  output = scaled_dot_product_attention(query, key, value) print(output.shape) # torch.Size([1, 256, 8, 64])
Salin selepas log masuk

Untuk setiap pengepala pertanyaan, terdapat kunci yang sepadan. Proses ini ditunjukkan dalam rajah di bawah:

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

Dan GQA membahagikan pengepala pertanyaan kepada kumpulan G, setiap kumpulan berkongsi kunci dan nilai. Ia boleh dinyatakan sebagai:

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

Menggunakan ekspresi visual, anda boleh memahami dengan jelas prinsip kerja GQA, seperti yang kami katakan di atas. GQA ialah idea yang agak mudah dan bersih.

Pelaksanaan kod pytorch

Mari tulis kod untuk membahagikan pengepala pertanyaan kepada kumpulan G, setiap kumpulan berkongsi kunci dan nilai. Kita boleh menggunakan perpustakaan einops untuk melaksanakan operasi kompleks pada tensor dengan cekap.

Pertama, tentukan pertanyaan, kunci dan nilai. Kemudian tetapkan bilangan kepala perhatian Nombor itu adalah sewenang-wenangnya, tetapi ia mesti dipastikan bahawa num_heads_for_query % num_heads_for_key = 0, yang bermaksud ia mesti boleh dibahagikan. Definisi kami adalah seperti berikut:

import torch  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 2, 64) value = torch.randn(1, 256, 2, 64)  num_head_groups = query.shape[2] // key.shape[2] print(num_head_groups) # each group is of size 4 since there are 2 kv_heads
Salin selepas log masuk

Untuk meningkatkan kecekapan, tukar dimensi seq_len dan num_heads, einops boleh dilengkapkan dengan mudah seperti berikut:

from einops import rearrange  query = rearrange(query, "b n h d -> b h n d") key = rearrange(key, "b s h d -> b h s d") value = rearrange(value, "b s h d -> b h s d")
Salin selepas log masuk

kita perlu memperkenalkan "kumpulan itu" konsep matriks pertanyaan.

from einops import rearrange query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) print(query.shape) # torch.Size([1, 4, 2, 256, 64])
Salin selepas log masuk

Dengan kod di atas kita membentuk semula 2D menjadi 2D: Untuk tensor yang kami takrifkan, dimensi asal 8 (bilangan kepala dalam pertanyaan) kini dibahagikan kepada dua kumpulan (untuk memadankan kepala dalam kekunci dan nombor nilai), setiap saiz kumpulan ialah 4.

Bahagian terakhir dan paling sukar ialah mengira markah perhatian. Tetapi sebenarnya, ia boleh dilakukan dalam satu baris melalui operasi insum. Mari lihat cara ia berfungsi

einsum melakukan dua perkara untuk kita:

1. Pertanyaan dan pendaraban matriks kunci. Dalam kes kita, bentuk tensor ini ialah (1,4,2,256,64) dan (1,2,256,64), jadi pendaraban matriks sepanjang dua dimensi terakhir memberi kita (1,4,2,256,256).

2 Jumlahkan elemen dalam dimensi kedua (dimensi g) - jika dimensi ditinggalkan dalam bentuk keluaran yang ditentukan, einsum akan melengkapkan kerja ini secara automatik dan penjumlahan tersebut digunakan untuk memadankan kekunci dan bilangan kepala dalam. nilai.

Akhir sekali, perhatikan pendaraban piawai pecahan dan nilai:

from einops import einsum, rearrange # g stands for the number of groups # h stands for the hidden dim # n and s are equal and stands for sequence length scores = einsum(query, key, "b g h n d, b h s d -> b h n s") print(scores.shape) # torch.Size([1, 2, 256, 256])
Salin selepas log masuk

Pelaksanaan GQA yang paling mudah kini selesai, memerlukan kurang daripada 16 baris kod python:

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

Akhir sekali, sebutan ringkas tentang MQA: Perhatian Berbilang Pertanyaan (MQA) ialah satu lagi kaedah popular untuk memudahkan MHA. Semua pertanyaan akan berkongsi kunci dan nilai yang sama. Gambarajah skematik adalah seperti berikut:

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

Seperti yang anda lihat, kedua-dua MQA dan MHA boleh diperolehi daripada GQA. GQA dengan satu kunci dan nilai adalah bersamaan dengan MQA, manakala GQA dengan kumpulan yang sama dengan bilangan pengepala adalah bersamaan dengan MHA.

Apakah faedah GQA

GQA ialah pertukaran yang baik antara prestasi terbaik (MQA) dan kualiti model terbaik (MHA).

Rajah di bawah menunjukkan bahawa menggunakan GQA, anda boleh mendapatkan kualiti model yang hampir sama dengan MHA, sambil meningkatkan masa pemprosesan sebanyak 3 kali ganda, mencapai prestasi MQA. Ini mungkin penting untuk sistem beban tinggi.

Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch

Tiada pelaksanaan rasmi GQA dalam pytorch. Jadi saya temui pelaksanaan tidak rasmi yang lebih baik. Jika anda berminat, anda boleh mencubanya:

https://www.php.cn/link/5b52e27a9d5bf294f5b593c4c071500e

kertas: https ://www.php.cn/link/e4ba31fba036a999321d5460f7f2d1d1

🎜🎜

Atas ialah kandungan terperinci Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

sumber:51cto.com
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan