テキストを入力する前に、まず ChatGPT などの Transformer 言語モデル (LM) のプロンプトを検討してください:
毎日何百万ものユーザーとクエリが生成されるため、ChatGPT はセルフアテンション メカニズムを使用してプロンプトを繰り返しエンコードし、その時間とメモリの複雑さは入力長に応じて二次関数的に増大します。プロンプトのトランスフォーマーのアクティベーションをキャッシュすると、部分的な再計算が防止されますが、この戦略でも、キャッシュされたプロンプトの数が増加するにつれて、かなりのメモリとストレージのコストが発生します。大規模な場合には、プロンプトの長さがわずかに短縮されただけでも、計算量、メモリ、ストレージが節約されると同時に、ユーザーが LM の限られたコンテキスト ウィンドウにより多くのコンテンツを収めることができるようになります。 ############それで。プロンプトのコストを削減するにはどうすればよいですか?典型的なアプローチは、おそらくパラメーター効率の高い適応手法を使用して、プロンプトなしで元のモデルと同様に動作するようにモデルを微調整または抽出することです。ただし、このアプローチの根本的な欠点は、新しいプロンプトが表示されるたびにモデルを再トレーニングする必要があることです (以下の図 1 の中央に示されています)。
この記事では、スタンフォード大学の研究者が要点モデルを提案しました (図 1 の下)これは、プレフィックスの微調整と同様に、プロンプトをより小さな仮想「Gist」トークンのセットに圧縮します。ただし、プレフィックスの微調整には勾配降下法による各タスクのプレフィックスの学習が必要ですが、Gisting ではメタ学習手法を使用して、各タスクのプレフィックスを学習せずにプロンプトのみで Gist プレフィックスを予測します。これにより、タスクごとのプレフィックス学習のコストが償却され、追加のトレーニングなしで未知の命令への一般化が可能になります。
さらに、「Gist」トークンは完全なプロンプトよりもはるかに短いため、Gisting を使用するとプロンプトを圧縮、キャッシュし、再利用して計算効率を向上させることができます。
#論文アドレス: https://arxiv.org/pdf/2304.08467 v1.pdf
研究者は、命令が従う要点モデルを学習するための非常に簡単な方法を提案しました。単に命令を微調整し、プロンプトの後に gish トークンを挿入し、変更後アテンション マスクは、要点トークンの後のトークンが要点トークンの前のトークンを参照するのを防ぎます。これにより、モデルは追加のトレーニング コストをかけずに、即時圧縮と次の命令を同時に学習することができます。
デコーダー専用 (LLaMA-7B) およびエンコーダー/デコーダー (FLAN-T5-XXL) LM では、Gisting は元のモデルと同じパフォーマンスを維持しながら最大 26 倍の即時圧縮を達成します。同様の出力品質。これにより、従来のプロンプト キャッシュ方法と比較して、推論中の FLOP が 40% 削減され、レイテンシが 4.2% 加速され、ストレージ コストが大幅に削減されます。
ギスティング研究者らはまず、指導の微調整という文脈でのギスティングについて説明します。データセット
に続く命令の場合、t は自然言語プロンプトでエンコードされたタスク (例: これをフランス語に翻訳) を表し、x はタスクの (オプションの) 入力 (例: 猫) を表します。 y は、必要な出力 (例: Le chat) を表します。命令微調整の目的は、t と x を連結して分布 pLM(y | t,x) を学習し、通常は事前トレーニングされた言語モデルに y を自己回帰的に予測させることです。推論中、新しいタスク t と入力 x を使用して、予測結果を取得するためのプロンプトとモデルからのデコードを行うことができます。ただし、t と x を接続するこのパターンには欠点があります。Transformer ベースの LM のコンテキスト ウィンドウは限られており、アーキテクチャまたはコンピューティング能力によって制限されます。後者は、自己注意が入力長に応じて二次関数的に変化するため、解決するのが特に困難です。したがって、非常に長いプロンプト、特に繰り返し再利用されるプロンプトは、計算効率が低くなります。プロンプトのコストを削減するにはどのようなオプションが利用できますか?
簡単なアプローチは、特定のタスク t に対して LM 微調整を実行することです。つまり、タスク t の下でのみ入出力の例を含むデータセットが与えられた場合、具体的には次のことを学ぶことができます。 、tについて考える必要がないので高速です。
さらに優れた、プレフィックス/プロンプト微調整やアダプターなどのパラメーター効率の高い微調整方法を使用すると、本格的な微調整よりもはるかに低いコストで同じ目標を達成できます。ただし、問題は残ります。各タスクのモデルの重みの少なくとも一部を保存する必要があり、さらに重要なことに、各タスク t について、対応する入出力ペアのデータセット D^t を収集し、モデルを再トレーニングする必要があります。
Gisting は、2 つのコストを償却する別のアプローチです: (1) t で p_LM を条件付けする推論時間コスト、(2) 各 t の学習 新しい p^t_LM のトレーニング時間コスト。このアイデアは、微調整中に t G (t) の圧縮バージョンを学習し、p_G (y | G (t),x) からの推論が p_LM (y|t,x) からの推論よりも高速になるようにすることです。
LM の用語では、G (t) は「仮想」Gist トークンのセットになります。これは t のトークンより数が少ないですが、それでも LM で同様の問題を引き起こします。行動。 G (t) 上のトランスフォーマーのアクティベーション (キーと値の行列など) をキャッシュして再利用することで、計算効率を向上させることができます。重要なのは、研究者らは G が目に見えないタスクに一般化できることを期待していることです。つまり、新しいタスク t が与えられると、追加のトレーニングなしで、対応する Gist 活性化 G(t) を予測して使用できるようになります。
上記では Gisting の一般的なフレームワークについて説明しました。次に、そのようなモデルを学習する非常に簡単な方法を検討します。LM 自体を使用します。要点予測子 G として。これにより、LM の既存の知識を活用するだけでなく、標準的な命令の微調整を実行し、Transformer アテンション マスクを変更してプロンプト圧縮を強化するだけで、要点を学習することもできます。つまり、Gisting には追加のトレーニング費用は発生せず、標準的な指示に基づいて微調整するだけで済みます。
具体的には、このようなモデルで一般的な文の開始/終了トークンと同様に、特別な gist トークンをモデルの語彙と埋め込み行列に追加します。次に、指定された (タスク、入力) タプル (t, x) に対して、(t, g_1, ..., g_k, x) 内の k 個の連続する要点トークンのセットを使用して、t と x を連結します。例: 。このシーケンスは、Gist トークンに続く入力トークンが前のプロンプト トークンを参照できない (ただし、Gist トークンは参照できる) という制限付きでモデルに入力されます。これにより、入力 x (出力 y) はプロンプト t を処理できないため、モデルはプロンプト内の情報を gist トークンに強制的に圧縮します。
#下の図 2 は、必要な変更を示しています。 GPT-3 や LLaMA などのデコーダ専用 LM の場合、通常は自己回帰因果的注意マスクを使用するため、図 2a に示す三角形の左下隅をマスクするだけで済みます。双方向エンコーダと自己回帰デコーダを備えたエンコーダ/デコーダ LM の場合、2 つの修正が必要です (図 2b を参照)。
まず、エンコーダー内のプロンプト トークン t を参照して入力トークン x をブロックします。エンコーダーには通常マスクがありません。ただし、プロンプト t と要点トークン g_i が入力トークン x を参照しないようにすることも必要です。そうしないと、エンコーダーは入力に応じて異なる要点表現を学習することになります。最後に、デコーダは、デコーダがプロンプト トークン t を参照するのを防ぐ必要があるクロスアテンション期間を除いて、通常どおり動作します。
実験結果
Gist トークンの数が異なる場合、LLaMA- 7B と FLAN-T5-XXL の ROUGE-L と ChatGPT の評価結果を以下の図 3 に示します。
# モデルは通常、gist トークンの数 k の影響を受けません。プロンプトを単一のトークンに圧縮しても、パフォーマンスが大幅に低下することはありません。実際、場合によっては、Gist トークンが多すぎるとパフォーマンスが低下することがあります (LLaMA-7B、Gist トークン 10 個など)。これはおそらく、容量の増加がトレーニング分布にオーバーフィットするためです。したがって、研究者らは、単一トークン モデルの具体的な値を以下の表 1 に示し、残りの実験では単一の要点モデルを使用します。
表示された手順では、Gist モデルは対応する肯定的な結果とほぼ同じ結果を取得しました。コントロールモデル ROUGE と ChatGPT のパフォーマンスが同じ場合、LLaMA-7B FLANT5-XXL の勝率はそれぞれ 48.6% と 50.8% です。ここで研究者が最も興味を持っているのは、目に見えないタスクに対する一般化能力であり、これは他の 2 つのデータセットを通じて測定する必要があります。 Alpaca トレーニング データセットの目に見えないプロンプトでは、Gist モデルが目に見えないプロンプトに対して強力な汎化能力を持っていることがわかります。対照グループと比較して、49.7% (LLaMA) )、勝率は 46.2% (FLAN-T5) でした。最も困難な OOD Human スプリットでは、Gist モデルの勝率はわずかに低下し、45.8% (LLaMA) と 42.5% (FLANT5) になります。 この記事の目的は、Gist モデルに元のモデルの機能を厳密に模倣させることです。そのため、Gist モデルがコントロール グループと正確に区別できなくなるのはいつなのかと疑問に思う人もいるかもしれません。以下の図 4 は、これがどのくらいの頻度で起こるかを示しています。目に見えるタスク (ただし目に見えない入力) については、要点モデルはほぼ半分の時間で対照グループと同等です。目に見えないタスクの場合、この数値は 20 ~ 25% に低下します。 OOD Human タスクの場合、この数値は 10% に戻ります。いずれにせよ、Gist モデルの出力の品質は非常に高いです。
全体的に、これらの結果は、Gist モデルがプロンプトを確実に圧縮できることを示しています。トレーニング配布外の特定のプロンプト、特に LLaMA のようなデコーダーのみの因果 LM に対しても実行されます。 FLAN-T5 などのエンコーダ-デコーダ モデルのパフォーマンスはわずかに劣ります。考えられる理由の 1 つは、要点マスクがエンコーダの双方向のアテンション フローを抑制するためであり、これは自己回帰デコーダで履歴の一部を単にマスクするよりも困難です。今後この仮説を調査するにはさらなる研究が必要です。 最後に、この作業の核となる動機の 1 つに戻ります。Gisting はどのような効率向上をもたらすのでしょうか? 以下の表 2 は、PyTorch 2.0 アナライザーを使用したモデルの単一の前方パス (つまり、単一の入力トークンを使用した自己回帰デコードの 1 ステップ) と Human eval の結果を示しています。分割された 252 個の命令が平均化されます。 Gist キャッシュにより、最適化されていないモデルと比較して効率が大幅に向上します。両方のモデルで、FLOP の 40% の節約とクロック時間の 4 ~ 7% の削減が達成されました。
ただし、より重要なのは、命令キャッシュと比較して、Gist キャッシュにはレイテンシーがあることです。その他の主な利点: 26 個のトークンを 1 つに圧縮すると、絶対位置の埋め込みや GPU VRAM によって制限される入力コンテキスト ウィンドウのスペースをさらに解放できます。特に LLaMA-7B の場合、KV キャッシュ内の各トークンには 1.05MB の記憶域が必要です。 KV キャッシュは、テストされたプロンプトの長さでの LLaMA-7B 推論に必要な総メモリに比べればほとんど寄与しませんが、開発者が多数のユーザーにわたって多くのプロンプトをキャッシュするシナリオがますます一般的になり、ストレージ コストが急速に増加する可能性があります。同じ記憶領域で、要点キャッシュは完全な命令キャッシュよりも 26 倍多くのプロンプトを処理できます。 コンピューティング、メモリ、およびストレージの効率
以上が26 個のトークンを 1 つの新しいメソッドに圧縮して、ChatGPT 入力ボックスのスペースを節約しますの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。