Dans le domaine actuel de l'IA, l'architecture principale utilisée par les grands modèles de langage est Transformer. Cependant, avec l'avènement d'architectures telles que RWKV et Mamba, il existe une tendance évidente : les grands modèles de langage cycliques qui rivalisent avec Transformer en termes de perplexité en matière de modélisation du langage attirent rapidement l'attention des gens.
Ce qui est passionnant, c'est que ces architectures utilisent une quantité constante de mémoire lors de l'inférence. Cependant, en raison d’une mémoire limitée, les modèles de langage récurrents (LM) ne peuvent pas mémoriser et utiliser toutes les informations dans des contextes longs, ce qui conduit à une mauvaise qualité d’apprentissage en contexte (ICL). Par conséquent, un défi clé pour obtenir de grands modèles de langage efficaces consiste à choisir les informations à stocker ou à supprimer.
Dans l'article récent « Il suffit de lire deux fois : combler l'écart de rappel pour les modèles de langage récurrents », des chercheurs de l'Université de Stanford et de l'Université de Buffalo ont découvert par une simple observation que les données affluaient dans le pôle de tri des modèles de langage récurrents lors de l'inférence. affecte la difficulté de prédire quelles informations stocker dans une mémoire limitée.
Nous supposons que nous posons la question sur la base du document D (tel que le Wikipédia détaillé de Galileo Galilei) : Quand Galilée a-t-il déménagé à Florence ? À ce stade, si les invites suivent l'ordre [Q, D], le modèle n'a besoin de mémoriser qu'un seul fait dans le document D. En revanche, si les signaux suivent l’ordre [D, Q], le modèle doit mémoriser tous les faits. Ceci est illustré dans la figure 1 (à gauche) ci-dessous.
Par conséquent, cet article formalise d'abord théoriquement la façon dont le tri des données affecte les besoins en mémoire, puis propose deux méthodes pour atténuer la dépendance au tri des données, à savoir la Stratégie d'invite Just-read-twice (JRT) et la Boucle JRT architecture. Cet article est principalement divisé en parties suivantes :
Comprendre le rôle du tri des données. La première idée que les chercheurs ont acquise était que la difficulté du problème de la mémoire devrait être réduite au même niveau que la disjonction des ensembles (SD), qui est le problème le plus typique de la théorie de la complexité de la communication qui dure depuis des décennies. SD nécessite un algorithme de streaming (tel qu'un modèle récurrent) pour décider s'il faut retirer l'ensemble d'entrées fourni dans le contexte :
L'analyse théorique et les résultats expérimentaux montrent que le premier ensemble |A| domine la mémoire requise pour résoudre DAKOTA DU SUD. Un modèle causal nécessite de stocker tous les éléments de A pour les comparer avec les éléments de B. Cela montre que l'utilisation du « classement correct des données » dans le contexte (par exemple, en plaçant le plus petit min (|A|, |B|) en premier) aidera les modèles à mémoire limitée. De plus, on observe que les modèles avec une logique causale contextuelle peuvent résoudre le SD dans le plus petit espace (|A|, |B|) sans prendre en compte l'ordre des données.
La deuxième consiste à utiliser le tri "correct". Cet article propose une stratégie JRT-Prompt très simple qui répète les informations plusieurs fois en contexte avant que le modèle ne génère une réponse (illustré dans la figure 1 ci-dessus, à droite). Lors du deuxième cycle et des suivants, le modèle de langage est conditionné au contexte complet pour décider quelles informations stocker, évitant ainsi le problème de la « réforme » du tri des données.
Les résultats montrent que JRT-Prompt réalise une amélioration moyenne de 11,0 ± 1,3 points de pourcentage sur 16 modèles de langage récurrents existants et 6 tâches ICL, tandis que le débit est de 11,9 fois celui de FlashAttention-2 (longueur 32 Ko, taille du lot 16). Bien que JRT-Prompt augmente la longueur du contexte, il reste asymptotiquement plus efficace en termes de calcul et de mémoire que l'attention.
Au-delà du modèle causal. Cet article propose JRT-RNN, qui s'inspire de la conception simple de l'architecture codeur-décodeur Prefix-LM. La plupart des entrées d'apprentissage contextuel contiennent deux parties, à savoir les invites de saisie (contexte, instructions) et le texte généré par le modèle en sortie. Dans l'architecture Prefix-LM, le LM ne traite pas la région de repère selon une logique causale, mais décode de manière causale la sortie, où seule la perte de prédiction de jeton suivant standard est utilisée sur la région causale, et la perte sur la région non causale.
Cependant, malheureusement, la méthode de formation précédente du modèle Prefix-LM a obtenu un succès limité et a utilisé le squelette inefficace du Transformer. Par conséquent, cet article améliore la qualité et l'efficacité grâce à quelques changements simples, notamment l'amélioration de la perte d'entraînement et l'utilisation d'une formule d'attention linéaire appelée « Préfixe d'attention linéaire, PLA ». Les chercheurs ont découvert qu'en utilisant leur implémentation compatible IO, JRT-RNN peut fournir une amélioration moyenne de la qualité de 13,7 et 6,9 points de pourcentage aux paramètres de 360 m et 1,3b, respectivement, avec un débit de 19,2 fois supérieur à celui de FA2.
Adresse papier : https://arxiv.org/pdf/2407.05483
Page d'accueil du projet : https://github.com/HazyResearch/prefix-linear-attention
JRT- Aperçu de la méthode rapide
Les tâches d'apprentissage contextuel prennent (C, Q, Y) comme entrée, où C est une source de contexte (telle qu'un document ou un référentiel de code), Q est une question ou une demande adressée au modèle compte tenu du contexte et Y est la réponse. Pour l'apprentissage contextuel standard utilisant LM A autorégressif, le chercheur prend les entrées C et Q et évalue la sortie résultante Yˆ = A (C, Q) par rapport à l'achèvement correct Y .
JRT-Prompt est une méthode extrêmement simple qui répète les informations de l'invite (telles que des questions et des documents) dans leur contexte avant d'inviter le modèle à générer la réponse, telle que Yˆ = A (C, Q, C à droite dans Figure 1 ci-dessous), Q). Par conséquent, la deuxième fois que le contexte apparaît, le modèle décide quelles informations stocker en fonction du contexte complet.
De plus, JRT-Prompt peut être utilisé avec un LLM prêt à l'emploi. Les chercheurs ont évalué le LM suivant sur une série de tâches contextuelles gourmandes en mémoire sous des invites à échantillon nul :
Basé sur un LM pré-entraîné avec une taille de paramètre de 1,3 B, entraîné sur les 10 à 50 B de jetons de Pile ; LM pré-entraîné Mamba, la taille des paramètres est de 130M, 370M, 1,4B et 2,8B, formé sur les jetons 300B de Pile ;
LM pré-entraîné Gated Linear Attention, la taille des paramètres est de 1,3B et 2,7B, formé sur les jetons 100B de l'ensemble de données SlimPajama
Mamba-2 LM pré-entraîné, avec des tailles de paramètres de 130 M, 370 M, 1,3 B et 2,7 B, entraîné sur 300 B de jetons de Pile.
Les résultats sont présentés dans le tableau 1 ci-dessous. En augmentant la taille de l'état, le chercheur a constaté que la méthode JRT-Prompt apportait une amélioration moyenne des performances de 11,0 ± 1,3 points de pourcentage sur divers modèles et tâches utilisant ce modèle. La méthode surpasse en moyenne les modèles Transformer en utilisant des astuces standard.
JRT-RNN s'inspire des préfixes-LM, mais se concentre sur l'extension de la frontière de Pareto de l'espace de compromis qualité-efficacité. Pour améliorer la qualité, JRT-RNN utilise des mappages k_e et v_e distincts côté codeur et des mappages k_d et v_d côté décodeur. Bien que le modèle Prefix LM utilise des poids de mappage partagés pour les régions d'encodeur et de décodeur, nous avons constaté que l'utilisation de deux ensembles de mappages améliore la qualité. Pour améliorer l'efficacité, JRT-RNN utilise une attention linéaire non causale pour l'encodeur et une attention linéaire causale standard pour le décodeur. Les chercheurs l'appellent Prefix Linear Attention (PLA) (à droite sur la figure 1), et la formule est la suivante :
Objectif d'entraînement JRT-RNN. Les LM de préfixe ne calculent généralement pas les pertes dans les régions non causales, tandis que JRT-RNN combine la prédiction du prochain jeton avec un objectif de modélisation de langage masqué (MLM). Et pour l'objectif MLM supplémentaire, les chercheurs ont remplacé les jetons avec la proportion P de la région d'encodeur {u_1, ..., u_M} par un jeton [MASK] et ont mesuré la perte d'entropie croiséelors de la prédiction du jeton d'origine.
Les pertes sont les suivantes :Résultats expérimentaux
Dans l'expérience, les chercheurs ont évalué la qualité et l'efficacité du JRT-RNN sur les trois indicateurs suivants :
Qualité de l'apprentissage du contexte
Modélisation globale du langage
Génération
Comme le montre le tableau 2 ci-dessous, les chercheurs ont constaté que JRT-RNN est meilleur que la référence du décodeur uniquement (basée) lorsque les paramètres sont de 360 M. (30 milliards de jetons) La moyenne est de 13,7 points de pourcentage plus élevée et la moyenne est de 6,9 points de pourcentage plus élevée lorsque le paramètre est de 1,3 milliard (50 milliards de jetons). Dans le même temps, l'écart entre JRT-RNN et Transformer++ est réduit à 0,5 point de pourcentage et 1,9 point de pourcentage lorsque les paramètres sont respectivement de 360M et 1,3B. Dans le tableau 3 ci-dessous, les chercheurs comparent les performances de JRT-RNN avec des stratégies d'inférence similaires lorsque la longueur de pré-remplissage l est inférieure à la longueur de l'encodeur M. Compréhension globale du langage naturel Sur la base de recherches antérieures, les chercheurs ont divisé la perplexité en deux groupes : La mémoire associative "AR slice" comprend des jetons appelés "AR hits", qui nécessitent que le modèle suive La mémoire est effectuée séquentiellement pour prédire correctement le prochain jeton ; et "Autre tranche" contient les jetons restants (tels que les connaissances mémorisées). Pour la fréquence mémoire, JRT-RNN fonctionne bien en "AR slice". Pour les bigrammes qui sont rares pendant la formation (c'est-à-dire moins susceptibles d'être mémorisés dans les paramètres du modèle), la perplexité de JRT-RNN s'améliore par rapport à Based et Mamba, deux lignes de base de boucle causale fortes. Pour la distance de mémoire, dans la « tranche AR », l'écart entre JRT-RNN et la ligne de base du décodeur uniquement s'élargit à mesure que le nombre de bigrammes répétés dans le contexte augmente. Cela prouve en outre que JRT-RNN peut aider à accomplir des tâches de mémoire contextuelle plus longues. Fréquence sans mémoire. Pour les « autres tranches » de bigrammes sans mémoire qui sont rarement vues pendant l'entraînement, JRT-RNN a une perplexité pire que le LM avec décodeur uniquement. C'est un résultat attendu puisque JRT-RNN calcule la perte pour seulement 65% des tokens du décodeur LM. Nous nous attendons à ce que cet écart diminue avec l'échelle et le temps de formation (augmentant avec la fréquence des bigrammes) (Figure 3, en haut à gauche). Débit de génération La génération peut être décomposée en deux étapes : inviter le « traitement de pré-remplissage » et décoder la « prédiction du prochain jeton ». Par rapport au modèle de boucle standard avec décodeur uniquement, JRT-RNN ne modifie pas l'étape de décodage, la discussion se concentre donc sur l'étape de pré-remplissage. En utilisant le noyau CUDAn basé proposé dans l'article "Les modèles de langage d'attention linéaire simple équilibrent le compromis rappel-débit" par Simran Arora et al., le débit de JRT-Prompt lors du traitement du pré-remplissage est de 11,9 et 13,7 de FlashAttention-2 et noyaux FLA Triton respectivement, comme indiqué dans le tableau 5 ci-dessous. Lorsque les chercheurs ont augmenté la taille du lot à 64, le débit de JRT-Prompt était respectivement 6,1 fois et 7,2 fois supérieur à celui des noyaux FlashAttention-2 et FLA Triton. Ensuite, ils ont étendu le noyau basé pour prendre en charge JRT-RNN et ont démontré qu'en augmentant la longueur de la séquence à 32768, le débit était respectivement 19,2 fois et 22,0 fois celui de FlashAttention-2 et FLA. En augmentant la taille du lot à 64, JRT-RNN offre respectivement des améliorations de débit supplémentaires de 9,7x et 11,5x. Le temps requis par JRT-RNN est 1,24 fois supérieur à celui du pré-remplissage basé, ce qui est plus efficace que JRT-Prompt. Veuillez vous référer à l'article original pour plus de détails techniques et de résultats expérimentaux.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!