單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

王林
發布: 2024-06-13 14:06:09
原創
752 人瀏覽過

乘法和排序也有效。

自 2017 年被提出以來,Transformer 已成為 AI 大模型的主流架構,一直穩站 C 位元。

然而,雖然所有研究者都不得不承認的是,Transformer 在算數任務中表現異常糟糕,儘管是加法,這一缺陷在很大程度上源於Transformer 無法跟踪大範圍數字中每個數字的準確位置。

為了解決這個問題,來自馬裡蘭大學、CMU等機構的研究者們向這個問題發起了挑戰。他們透過在每個數字中添加一個嵌入來解決這個問題,該嵌入編碼數字相對於開頭的位置。研究發現,只花一天時間在單一GPU上訓練20位數,就可以達到最新的表現水平,100位數數字加法問題高達99%的準確率。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

論文網址:https://arxiv.org/pdf/2405.17399

專案網址:https://github.com/mcleish7/arithmetic

標題:Transformers Can Do Arithmetic with the Right Embeddings

具體而言,研究者建議對資料表顯示進行一個簡單的修改,就能解決這個缺點。他們提出了 Abacus 嵌入用於編碼每個數字符號 token 範圍內的位置。將Abacus 嵌入與標準位置嵌入結合使用後,研究觀察到Transformer 在算數任務上的準確率有顯著提高,以至於最多只訓練了20 位數操作數的模型可擴展到120 位數操作數的問題。這個數字代表了 6 倍的 SOTA 擴展因子,而前的最先進的擴展因子也只有 2.5 倍。據了解,這是迄今為止被證明的最長的學習加法序列。

除了研究優化Transformer在算術和泛化方面的表現之外,本文還探討了幾種其他方法來改善Transformer的表現。他們發現,透過在輸入註入(input injection)層和每個解碼器層之間插入跳躍連接,可以在Abacus嵌入基線上減少50%的泛化誤差。本文也發現,與嵌入結合使用的looped Transformer架構可以在加法問題上實現幾乎完美的泛化。

本文的貢獻可以總結如下:

  • 本文提出了一個新的位置嵌入,稱為Abacus 嵌入,以更好地捕捉每個數字的重要性,從而實現近乎完美的分佈內泛化;

  • 研究表明,當將Abacus 嵌入與輸入註入和looped transformer 相結合時,性能會進一步提高,分佈外準確率從92.9% 提高到99.1%,與單獨使用標準架構的嵌入相比,誤差降低了87%;

  • 研究者將這些發現擴展到更複雜的問題,包括乘法和排序,在這些領域也展現了長度泛化。

實現加法的長度泛化

作者研究了一系列方法,旨在提高從頭開始訓練的語言模型在算術能力上的表現。他們主要關注兩個假設:1)數字內各位數的位置資訊正在遺失;2)循環可以提高 Transformer 架構在多步驟算術推理問題上的推理能力。在詳細描述每項改進之前,作者簡要討論了訓練和評估設定。

實驗設定

作者訓練了僅包含解碼器的因果語言模型來解決加法問題。

他們考慮了兩種標準 transformer 架構。首先,他們使用標準的自回歸 transformer 模型,多個解碼器層以前饋方式堆疊。其次,他們透過輸入註入(input injection)增強了這個標準 transformer 模型,也就是把嵌入的輸入加入到每個解碼器層的輸入中。作者在圖 20 中直觀地描述了這些架構。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

Abacus 嵌入幫助對齊數字

透過先前的研究和初步實驗,作者發現,即使輸入的數字是先顯示最不重要的數字,訓練資料是分層的、豐富的(幾百萬個例子),標準transformer 也很難學習多位數加法。他們也觀察到,人類在進行長加法運算時,會先將數位相同的數字排列成列。因此,作者的第一個假設是,對於 transformer 來說,每個數字的數字並不容易表示,而且這個子問題比實際加法本身帶來的障礙更大。

為了解決 transformer 在表示位置資訊方面的局限性,作者設計了一種特殊的位置嵌入,它可以編碼每個數字相對於當前數位起始位置的位置。作者將其稱之為 Abacus 嵌入。他們將相同的位置嵌入應用於所有具有相同數字的數字,從而提供一個明確的訊號,供模型用於對齊數字,如圖 2 所示。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

Abacus 嵌入解決加法問題

對於標準 transformer 架構,Abacus 嵌入可將泛化效能提高到 100 位元及以上。在圖 3(左)中,作者強調了 Abacus 嵌入與標準 transformer 架構和嵌入相比,在進行加法運算時所具有的比較優勢,取三種模型在所有情況下的平均準確度。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

圖 1 也顯示了使用 FIRE 和 Abacus 訓練的標準 transformer 模型的準確度結果,這些模型經過了域內 (ID) 和域外 (OOD) 測試。 單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

Transformer 中的循環提高了效能

在解決位置嵌入問題後,接下來作者探討了循環架構能否進一步提高transformer 執行多位數加法的能力。他們使用「循環區塊(recurrent block)」一詞來指一組具有不同權重的解碼器層,而「循環(recurrence)」則指循環塊的重複次數。作者使用有效深度(effective depth)一詞來指稱 transformer 中使用的層數,無論其權重是否唯一。除非另有說明,否則他們使用的是最大循環架構,即只循環一個唯一層來達到有效深度。他們也採用了輸入註入、 殘差連接的方式,將輸入的副本傳播到網路中的每一層。

循環的優勢

在圖3(右)中,作者比較了使用FIRE 和NoPE 嵌入對操作數多達40 位的加法進行訓練的所有架構變體。儘管參數數量僅相當於其他模型的 1/10,但可以看到,looped transformer(循環的、有輸入註入和漸進損失)在使用任何一種位置嵌入時都取得了最佳的分佈外性能。在圖 8 中,作者展示了這項結果在多種訓練資料規模下的穩健性。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

對於循環模型,可以選擇在訓練時改變每次前向傳遞的循環次數。這往往會提高模型測試時對較難任務的泛化能力,這也被稱為漸進式損失計算(progressive loss computation)。這個損失函數是兩個前向傳遞的損失值的凸組合,一個使用字面上的循環數(1 × 16 模型為 16),另一個使用隨機的較小循環數。

接下來,作者探討了在保持有效深度固定的同時改變循環區塊大小的效果。他們將循環區塊中的層數減半,循環次數增加一倍,從區塊中有16 層、循環次數只有一次(16 × 1,即標準transformer)的模型,過渡到區塊中只有一層、循環次數有16 次(1 × 16)的模型。

透過圖 4 分析這些結果,作者發現在某些情況下,結合循環和 Abacus 嵌入可以進一步提高效能。具體來說,在OOD 問題上,有兩個循環的模型(8 × 2)產生的誤差是純非循環模型(16 × 1)的一半,而在100 + 的OOD 問題上,其準確率也有所提高。

最後,在附錄 A.7.3 中,作者改變了模型的有效深度,以分析參數數量對這項任務的影響,包括 Abacus、FIRE 和 NoPE 嵌入。雖然圖 4 中的實驗是對不同深度的公平比較,但純粹的標準 transformer 模型比相應的循環模型擁有更多的參數。在附錄的表 3 中,作者記錄了最接近百萬的參數量。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

實驗

#研究者不僅對加法問題進行了探討,也對乘法和排序進行了研究。

整數乘法

圖5 展示了Abacus 嵌入模型在15 位數乘法的分佈內準確率超過了先前的工作,且不需要用零將每個操作數填入相同長度。特別地,研究強調,與僅使用 FIRE 的基線相比,將 Abacus 嵌入與 FIRE 結合也提高了分佈問題中最難的分佈準確率 (右下)。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

陣列排序

表 1 展示了使用不同嵌入 ——FIRE、Abacus 及其組合 —— 訓練的標準 transformer(八層)的表現。結果顯示,組合嵌入法增強了模型的泛化能力。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

如表2 所示,研究者觀察到在將Abacus+FIRE 嵌入組合與不同的模型架構(有效深度為8)配對時,結果表現出混合性。

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

Abacus 和相關嵌入

#圖6 展示了將Abacus 嵌入整合到更通用系統中的真正潛力,顯示出Abacus 嵌入與FIRE 結合可以解鎖遠超FIRE 嵌入解決問題的能力。 

單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

更多研究細節,請參考原文。

以上是單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:jiqizhixin.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
最新問題
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板
關於我們 免責聲明 Sitemap
PHP中文網:公益線上PHP培訓,幫助PHP學習者快速成長!