Heim > Technologie-Peripheriegeräte > KI > Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

WBOY
Freigeben: 2024-04-03 17:40:09
nach vorne
1019 Leute haben es durchsucht

Grouped Query Attention ist eine Multi-Query-Attention-Methode in großen Sprachmodellen. Ihr Ziel ist es, die Qualität von MHA zu erreichen und gleichzeitig die Geschwindigkeit von MQA beizubehalten. Grouped Query Attention gruppiert Abfragen so, dass Abfragen innerhalb jeder Gruppe das gleiche Aufmerksamkeitsgewicht haben, was dazu beiträgt, die Rechenkomplexität zu reduzieren und die Inferenzgeschwindigkeit zu erhöhen.

In diesem Artikel erklären wir die Idee der GQA und wie man sie in Code übersetzt.

GQA wurde im Artikel „GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints“ vorgeschlagen. Es handelt sich um eine ziemlich einfache und saubere Idee, die auf der Aufmerksamkeit mehrerer Köpfe basiert.

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

GQA

Der standardmäßige Multi-Head Attention Layer (MHA) besteht aus H-Abfrageheadern, Schlüsselheadern und Wertheadern. Jeder Kopf hat D-Maße. Der Pytorch-Code lautet wie folgt:

from torch.nn.functional import scaled_dot_product_attention  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 8, 64) value = torch.randn(1, 256, 8, 64)  output = scaled_dot_product_attention(query, key, value) print(output.shape) # torch.Size([1, 256, 8, 64])
Nach dem Login kopieren

Für jeden Abfrageheader gibt es einen entsprechenden Schlüssel. Dieser Prozess ist in der folgenden Abbildung dargestellt:

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

Und GQA unterteilt den Abfrageheader in G-Gruppen, jede Gruppe teilt einen Schlüssel und einen Wert. Es kann wie folgt ausgedrückt werden:

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

Anhand visueller Ausdrücke können Sie das Funktionsprinzip der GQA klar verstehen, genau wie wir oben gesagt haben. GQA ist eine ziemlich einfache und saubere Idee.

Pytorch-Code-Implementierung

Schreiben wir den Code, um den Abfrageheader in G-Gruppen zu unterteilen, wobei jede Gruppe einen Schlüssel und einen Wert teilt. Wir können die einops-Bibliothek verwenden, um komplexe Operationen an Tensoren effizient durchzuführen.

Zuerst definieren Sie die Abfrage, Schlüssel und Werte. Legen Sie dann die Anzahl der Aufmerksamkeitsköpfe fest. Die Anzahl ist beliebig, es muss jedoch sichergestellt sein, dass num_heads_for_query % num_heads_for_key = 0 ist, was bedeutet, dass sie teilbar sein muss. Unsere Definition lautet wie folgt:

import torch  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 2, 64) value = torch.randn(1, 256, 2, 64)  num_head_groups = query.shape[2] // key.shape[2] print(num_head_groups) # each group is of size 4 since there are 2 kv_heads
Nach dem Login kopieren

Um die Effizienz zu verbessern, tauschen Sie die Dimensionen seq_len und num_heads aus. Einops können einfach wie folgt vervollständigt werden:

from einops import rearrange  query = rearrange(query, "b n h d -> b h n d") key = rearrange(key, "b s h d -> b h s d") value = rearrange(value, "b s h d -> b h s d")
Nach dem Login kopieren

Dann müssen wir „Gruppierung“ einführen das Abfragematrixkonzept.

from einops import rearrange query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) print(query.shape) # torch.Size([1, 4, 2, 256, 64])
Nach dem Login kopieren

Mit dem obigen Code formen wir 2D in 2D um: Für den von uns definierten Tensor wird die ursprüngliche Dimension 8 (die Anzahl der Köpfe in der Abfrage) jetzt in zwei Gruppen aufgeteilt (um mit den Köpfen in den Schlüsseln übereinzustimmen). und Werteanzahl), jede Gruppengröße beträgt 4.

Der letzte und schwierigste Teil ist die Berechnung des Aufmerksamkeitswerts. Tatsächlich kann dies jedoch in einer Zeile durch die Insum-Operation erfolgen. Der

from einops import einsum, rearrange # g stands for the number of groups # h stands for the hidden dim # n and s are equal and stands for sequence length scores = einsum(query, key, "b g h n d, b h s d -> b h n s") print(scores.shape) # torch.Size([1, 2, 256, 256])
Nach dem Login kopieren

Scores-Tensor hat die gleiche Form wie der Wertetensor oben. Mal sehen, wie es funktioniert

einsum erledigt zwei Dinge für uns:

1 Eine Abfrage und Matrixmultiplikation von Schlüsseln. In unserem Fall sind die Formen dieser Tensoren (1,4,2,256,64) und (1,2,256,64), sodass die Matrixmultiplikation entlang der letzten beiden Dimensionen (1,4,2,256,256) ergibt.

2. Summieren Sie die Elemente in der zweiten Dimension (Dimension g) – wenn die Dimension in der angegebenen Ausgabeform weggelassen wird, schließt einsum diese Arbeit automatisch ab und diese Summierung wird verwendet, um Schlüssel und die Anzahl der Köpfe abzugleichen der Wert.

Beachten Sie abschließend die Standardmultiplikation von Brüchen und Werten:

import torch.nn.functional as F  scale = query.size(-1) ** 0.5 attention = F.softmax(similarity / scale, dim=-1)  # here we do just a standard matrix multiplication out = einsum(attention, value, "b h n s, b h s d -> b h n d")  # finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim) out = rearrange(out, "b h n d -> b n h d") print(out.shape) # torch.Size([1, 256, 2, 64])
Nach dem Login kopieren

Die einfachste GQA-Implementierung ist jetzt abgeschlossen und erfordert weniger als 16 Zeilen Python-Code:

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

Abschließend noch eine kurze Erwähnung von MQA: Multi-Query Attention (MQA) ist eine weitere beliebte Methode zur Vereinfachung von MHA. Alle Abfragen verwenden dieselben Schlüssel und Werte. Das schematische Diagramm sieht wie folgt aus:

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

Wie Sie sehen, können sowohl MQA als auch MHA aus GQA abgeleitet werden. GQA mit einem einzelnen Schlüssel und Wert entspricht MQA, während GQA mit Gruppen, die der Anzahl der Header entsprechen, MHA entspricht.

Was sind die Vorteile von GQA?

GQA ist ein guter Kompromiss zwischen bester Leistung (MQA) und bester Modellqualität (MHA).

Die folgende Abbildung zeigt, dass Sie mit GQA fast die gleiche Modellqualität wie mit MHA erzielen und gleichzeitig die Verarbeitungszeit um das Dreifache verlängern können, wodurch Sie die Leistung von MQA erreichen. Dies kann für Hochlastsysteme unerlässlich sein.

Detaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung

Es gibt keine offizielle Implementierung von GQA in Pytorch. Also habe ich eine bessere inoffizielle Implementierung gefunden. Wenn Sie interessiert sind, können Sie es versuchen:

https://www.php.cn/link/5b52e27a9d5bf294f5b593c4c071500e

GQA-Papier:

https ://www.php.cn/link/e4ba31fba036a999321d5460f7f2d1d1

Das obige ist der detaillierte Inhalt vonDetaillierte Erläuterung der GQA, des in großen Modellen häufig verwendeten Aufmerksamkeitsmechanismus und der Pytorch-Codeimplementierung. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:51cto.com
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage