Perhatian Pertanyaan Berkumpulan ialah kaedah perhatian berbilang pertanyaan dalam model bahasa besar Matlamatnya adalah untuk mencapai kualiti MHA sambil mengekalkan kelajuan MQA. Pertanyaan Berkumpulan Pertanyaan kumpulan perhatian supaya pertanyaan dalam setiap kumpulan berkongsi berat perhatian yang sama, yang membantu mengurangkan kerumitan pengiraan dan meningkatkan kelajuan inferens.
Dalam artikel ini, kami akan menerangkan idea GQA dan cara menterjemahkannya ke dalam kod.
GQA telah dicadangkan dalam kertas GQA: Latihan Model Transformer Berbilang Pertanyaan Umum daripada kertas Pusat Pemeriksaan Berbilang Kepala Ia adalah idea yang agak mudah dan bersih, dan dibina berdasarkan perhatian berbilang kepala.
Lapisan perhatian berbilang kepala standard (MHA) terdiri daripada pengepala pertanyaan H, pengepala utama dan pengepala nilai. Setiap kepala mempunyai dimensi D. Kod Pytorch adalah seperti berikut:
from torch.nn.functional import scaled_dot_product_attention # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 8, 64) value = torch.randn(1, 256, 8, 64) output = scaled_dot_product_attention(query, key, value) print(output.shape) # torch.Size([1, 256, 8, 64])
Untuk setiap pengepala pertanyaan, terdapat kunci yang sepadan. Proses ini ditunjukkan dalam rajah di bawah:
Dan GQA membahagikan pengepala pertanyaan kepada kumpulan G, setiap kumpulan berkongsi kunci dan nilai. Ia boleh dinyatakan sebagai:
Menggunakan ekspresi visual, anda boleh memahami dengan jelas prinsip kerja GQA, seperti yang kami katakan di atas. GQA ialah idea yang agak mudah dan bersih.
Mari tulis kod untuk membahagikan pengepala pertanyaan kepada kumpulan G, setiap kumpulan berkongsi kunci dan nilai. Kita boleh menggunakan perpustakaan einops untuk melaksanakan operasi kompleks pada tensor dengan cekap.
Pertama, tentukan pertanyaan, kunci dan nilai. Kemudian tetapkan bilangan kepala perhatian Nombor itu adalah sewenang-wenangnya, tetapi ia mesti dipastikan bahawa num_heads_for_query % num_heads_for_key = 0, yang bermaksud ia mesti boleh dibahagikan. Definisi kami adalah seperti berikut:
import torch # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 2, 64) value = torch.randn(1, 256, 2, 64) num_head_groups = query.shape[2] // key.shape[2] print(num_head_groups) # each group is of size 4 since there are 2 kv_heads
Untuk meningkatkan kecekapan, tukar dimensi seq_len dan num_heads, einops boleh dilengkapkan dengan mudah seperti berikut:
from einops import rearrange query = rearrange(query, "b n h d -> b h n d") key = rearrange(key, "b s h d -> b h s d") value = rearrange(value, "b s h d -> b h s d")
kita perlu memperkenalkan "kumpulan itu" konsep matriks pertanyaan.
from einops import rearrange query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) print(query.shape) # torch.Size([1, 4, 2, 256, 64])
Dengan kod di atas kita membentuk semula 2D menjadi 2D: Untuk tensor yang kami takrifkan, dimensi asal 8 (bilangan kepala dalam pertanyaan) kini dibahagikan kepada dua kumpulan (untuk memadankan kepala dalam kekunci dan nombor nilai), setiap saiz kumpulan ialah 4.
Bahagian terakhir dan paling sukar ialah mengira markah perhatian. Tetapi sebenarnya, ia boleh dilakukan dalam satu baris melalui operasi insum. Mari lihat cara ia berfungsi
einsum melakukan dua perkara untuk kita:1. Pertanyaan dan pendaraban matriks kunci. Dalam kes kita, bentuk tensor ini ialah (1,4,2,256,64) dan (1,2,256,64), jadi pendaraban matriks sepanjang dua dimensi terakhir memberi kita (1,4,2,256,256).
2 Jumlahkan elemen dalam dimensi kedua (dimensi g) - jika dimensi ditinggalkan dalam bentuk keluaran yang ditentukan, einsum akan melengkapkan kerja ini secara automatik dan penjumlahan tersebut digunakan untuk memadankan kekunci dan bilangan kepala dalam. nilai.
Akhir sekali, perhatikan pendaraban piawai pecahan dan nilai:
from einops import einsum, rearrange # g stands for the number of groups # h stands for the hidden dim # n and s are equal and stands for sequence length scores = einsum(query, key, "b g h n d, b h s d -> b h n s") print(scores.shape) # torch.Size([1, 2, 256, 256])
Pelaksanaan GQA yang paling mudah kini selesai, memerlukan kurang daripada 16 baris kod python:
Akhir sekali, sebutan ringkas tentang MQA: Perhatian Berbilang Pertanyaan (MQA) ialah satu lagi kaedah popular untuk memudahkan MHA. Semua pertanyaan akan berkongsi kunci dan nilai yang sama. Gambarajah skematik adalah seperti berikut:
Seperti yang anda lihat, kedua-dua MQA dan MHA boleh diperolehi daripada GQA. GQA dengan satu kunci dan nilai adalah bersamaan dengan MQA, manakala GQA dengan kumpulan yang sama dengan bilangan pengepala adalah bersamaan dengan MHA.
GQA ialah pertukaran yang baik antara prestasi terbaik (MQA) dan kualiti model terbaik (MHA).
Rajah di bawah menunjukkan bahawa menggunakan GQA, anda boleh mendapatkan kualiti model yang hampir sama dengan MHA, sambil meningkatkan masa pemprosesan sebanyak 3 kali ganda, mencapai prestasi MQA. Ini mungkin penting untuk sistem beban tinggi.
Tiada pelaksanaan rasmi GQA dalam pytorch. Jadi saya temui pelaksanaan tidak rasmi yang lebih baik. Jika anda berminat, anda boleh mencubanya:
https://www.php.cn/link/5b52e27a9d5bf294f5b593c4c071500e
kertas: https ://www.php.cn/link/e4ba31fba036a999321d5460f7f2d1d1
Atas ialah kandungan terperinci Penjelasan terperinci tentang GQA, mekanisme perhatian yang biasa digunakan dalam model besar, dan pelaksanaan kod Pytorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!