In the current AI field, the mainstream architecture used by large language models is Transformer. However, with the advent of architectures such as RWKV and Mamba, there is an obvious trend: cyclic large language models that compete with Transformer in terms of language modeling perplexity are quickly entering people's attention.
What’s exciting is that these architectures use a constant amount of memory during inference. However, due to limited memory, recurrent language models (LM) cannot remember and use all the information in long contexts, which leads to poor quality of context learning (in-context learning, ICL). Therefore, a key challenge in achieving efficient large language models is choosing which information to store or discard.
In the recent paper "Just read twice: closing the recall gap for recurrent language models", researchers from Stanford University and the University at Buffalo discovered through simple observation that data poured into the sorting pole of recurrent language models during inference. The land affects the difficulty of predicting what information to store in limited memory.
We assume that we ask based on document D (such as Galileo Galilei's detailed Wikipedia): When did Galileo move to Florence? At this time, if the prompts follow the ordering [Q, D], the model only needs to remember one fact in document D. In contrast, if the cues follow the ordering [D, Q], the model needs to remember all the facts. This is shown in Figure 1 (left) below.
Therefore, this paper first theoretically formalizes how data sorting affects memory requirements, and then proposes two methods to alleviate the dependence on data sorting, namely the Just-read-twice (JRT) prompting strategy and JRT loop architecture. This article is mainly divided into the following parts:
Understanding the role of data sorting. The first insight the researchers gained was that the hardness of the memory problem should be reduced to the same as set disjointness (SD), which is the most typical problem in communication complexity theory that has lasted for decades. SD requires a streaming algorithm (such as a recurrent model) to decide whether to peel off the input set provided in the context:
Theoretical analysis and experimental results show that the first set |A| dominates the memory required to solve SD. A causal model requires storing all elements in A to compare with elements in B. This shows that using the "correct data ordering" in the context (such as putting the smallest min (|A|, |B|) set first) will help memory-constrained models. Furthermore, it is observed that models with contextual acausal logic can solve for SD in the smallest space (|A|, |B|) without taking into account data ordering.
Second is utilizing the "correct" sorting. This article proposes a very simple JRT-Prompt strategy that repeats information multiple times in context before the model generates an answer (shown in Figure 1 above, right). In the second and subsequent rounds, the language model is conditional on the complete context when deciding what information to store, effectively avoiding the problem of "reforming" the data sorting.
The results show that JRT-Prompt achieves an average improvement of 11.0 ± 1.3 percentage points on 16 existing recurrent language models and 6 ICL tasks, while the throughput is 11.9 of FlashAttention-2 (length 32k, batch size 16) times. Although JRT-Prompt increases the context length, it is still asymptotically more computationally and memory efficient than attention.
Beyond Causal Model. This paper proposes JRT-RNN, which is inspired by the simple Prefix-LM encoder-decoder architecture design. Most contextual learning inputs contain two parts, namely input prompts (context, instructions) and model-generated text as output. In the Prefix-LM architecture, the LM does not process the cue region following causal logic, but causally decodes the output, where only the standard next token prediction loss is used on the causal region, and the loss on the non-causal region.
However, unfortunately, the previous training method of the Prefix-LM model has achieved limited success and used the inefficient Transformer backbone. Therefore, this article improves quality and efficiency through some simple changes, including improving the training loss and using a linear attention formula called "Prefix Linear Attention, PLA". The researchers found that using their IO-aware implementation, JRT-RNN can provide an average quality improvement of 13.7 and 6.9 percentage points at 360m and 1.3b parameter settings, respectively, with a throughput of 19.2 times that of FA2.
Paper address: https://arxiv.org/pdf/2407.05483
Project homepage: https://github.com/HazyResearch/prefix-linear-attention
JRT- Prompt method overview
上下文學習任務以 (C, Q, Y) 作為輸入,其中 C 為一些上下文來源(如文檔或程式碼儲存庫),Q 為給定上下文時對模型的一些問題或請求,Y 為答案。對於使用自迴歸 LM A 的標準情境學習,研究者輸入 C 和 Q,並根據正確的完成情況 Y 來評估產生的輸出 Yˆ = A (C, Q)。
JRT-Prompt 是一種極其簡單的方法,在提示模型輸出答案之前會在上下文中重複提示中的信息(如問題和文檔),例如下圖1 右的Yˆ = A (C, Q, C , Q)。因此,在上下文第二次出現時,模型根據完整的上下文來決定儲存哪些資訊。
此外,JRT-Prompt 可以與現成的 LLM 一起使用。研究者在零樣本提示下,在一系列記憶密集型上下文任務上評估了以下LM:
Based 預訓練LM,參數規模為1.3B,在Pile 的10 − 50B 個token 上進行訓練;
Mamba 預訓練的LM,參數規模為130M、370M、1.4B 和2.8B,在Pile 的300B 個token 上進行訓練;
Gated Linear Attention 預先訓練的, 2.7B,在 SlimPajama 資料集的100B 個token 上進行訓練;
Mamba-2 預訓練的LM,參數規模為130M、370M、1.3B 和2.7B,在Pile 的300B 個token 上進行訓練。
結果如下表1 所示,透過增加狀態(state)大小,研究者發現JRT-Prompt 方法在各個模型和任務上平均帶來了11.0 ± 1.3 百分點的性能提升,利用該方法的Based 模型平均優於利用標準提示的Transformer 模型。
他們也發現,JRT-Prompt 可以使 Transformer 模型受益,並且該方法在某些任務上(附錄 2)比少樣本學習更有效。值得注意的是,Springer 等人在論文《Repetition improves language model embeddings》中提出使用自回歸 Transformer 模型來重複上下文以實現生成嵌入的目的,本文的研究結果也類似。研究者專注於亞二次架構和情境學習任務。
JRT-Prompt 雖然由於重複而增加了上下文長度,但是其使用的亞二次循環架構仍比使用二次 Transformer 模型更有效率。研究者發現,在序列長度 N = 32768、批次大小為 16 時,使用 JRT-Prompt(序列長度 2N)在英偉達 H100 上提供的吞吐量是 FlashAttention-2(序列長度 N)的 11.9 倍。
JRT-RNN:編碼器 - 解碼器循環架構
JRT-RNN 的靈感來自於 Prefix-LMs,但側重於擴展質量 - 效率權衡空間的帕累托邊界權衡(Pareto frontier)。為了提高質量,JRT-RNN 在編碼器端使用了單獨的 k_e 和 v_e 映射,在解碼器端使用了 k_d 和 v_d 映射。雖然 Prefix LM 模型對編碼器和解碼器區域使用了共享映射權重,但研究者發現使用兩組映射可以提高品質。
為了提高效率,JRT-RNN 為編碼器使用了非因果線性注意力,而為解碼器使用標準因果線性注意力。研究者稱為 Prefix Linear Attention(PLA)(圖 1 右),公式如下:
JRT-RNN 訓練目標。 Prefix LMs 通常不會計算非因果區域的損失,而 JRT-RNN 將下一個 token 預測與掩碼語言建模(MLM)目標進行了結合。並且對於新增的 MLM 目標,研究者用一個 [MASK] token 取代了來自編碼器區域 {u_1, ..., u_M} 的比例為 P 的 tokens,並在預測原始 token 時測量了交叉熵損失。
損失如下:
實驗結果
在實驗中,研究者評估了JRT-RNN 在以下三個指標上的品質和效率:
如下表2 所示,研究者發現,JRT-RNN 在參數為360M(30B tokens)時比僅解碼器的基線(Based)平均高出13.7 個百分點,在參數為1.3B(50B tokens)時平均高出6.9 個百分點。
同時,JRT-RNN 在參數為 360M 和 1.3B 時與 Transformer++ 的差距分別縮小到了 0.5 個百分點和 1.9 個百分點之內。
在下表 3 中,研究者比較了當 prefill 長度 l 小於編碼器長度 M 時,JRT-RNN 與同類推理策略的表現。
整體自然語言理解
根據以往研究,研究者進一步將困惑度分為了兩組:聯想記憶“AR slice”包括了被稱為“AR hits”的tokens,它們需要模型按照模型順序執行記憶以正確預測下一個token;而「Other slice」包含剩餘的tokens(如記憶的知識)。
對於記憶頻率,JRT-RNN 在「AR slice」表現優異。對於訓練期間不常見的二元組(即不太可能在模型參數中被記住的),JRT-RNN 的困惑度相對於 Based 和 Mamba 這兩個強大的因果循環基線有所改善。
對於記憶距離,在「AR slice」中,JRT-RNN 與僅解碼器基線之間的差距隨著上下文中重複二元組的增加而擴大。這也進一步證明了 JRT-RNN 可以幫助完成更長的情境記憶任務。
非記憶頻率。對於訓練期間很少見到的二元組的非記憶“Other slice”,JRT-RNN 的困惑度比僅解碼器的 LM 更差。這是意料之中的結果,因為 JRT-RNN 計算了僅解碼器 LM 的 65% tokens 的損失。
我們預期這一差距會隨著規模和訓練時間的延長而縮小(隨著二元語法頻率的增加而增加)(圖 3,左上角)。
產生吞吐量
產生可以分解為提示「prefill 處理」和解碼「下一個 token 預測」兩步。相較於標準的僅解碼器循環模型,JRT-RNN 不會修改解碼步驟,因此討論重點在 prefill 階段。
使用Simran Arora 等人論文《Simple linear attention language models balance the recall-throughput tradeof》中提出的Based CUDAn 內核,JRT-Prompt 在處理prefill 時吞吐量分別是FlashAttention-2 和FL1倍,如下表5 所示。
當研究者將批次大小增加到 64 時,JRT-Prompt 吞吐量分別是 FlashAttention-2 和 FLA Triton 內核的 6.1 倍和 7.2 倍。
接下來他們擴展了 Based 內核以支援 JRT-RNN,並且證明了當將序列長度增加到 32768 時,吞吐量分別是 FlashAttention-2 和 FLA 的 19.2 倍和 22.0 倍。當批次大小增加到 64 時,JRT-RNN 分別又提供了 9.7 倍和 11.5 倍的吞吐量提升。 JRT-RNN 所需的時間是 Based prefill 的 1.24 倍,比 JRT-Prompt 更有效率。
更多技術細節和實驗結果請參考原論文。
The above is the detailed content of Small tricks with big effects, 'only read the prompt twice' allows the cyclic language model to surpass Transformer++. For more information, please follow other related articles on the PHP Chinese website!