尽管取得了很多显著的成就,但训练深度神经网络(DNN)的实践进展在很大程度上独立于理论依据。大多数成功的现代 DNN 依赖残差连接和归一化层的特定排列,但如何在新架构中使用这些组件的一般原则仍然未知,并且它们在现有架构中的作用也依然未能完全搞清楚。
残差架构是最流行和成功的,最初是在卷积神经网络(CNN)的背景下开发的,后来自注意力网络中产生了无处不在的 transformer 架构。残差架构之所以取得成功,一种原因是与普通 DNN 相比具有更好的信号传播能力,其中信号传播指的是几何信息通过 DNN 层的传输,并由内核函数表示。
最近,使用信号传播原则来训练更深度的 DNN 并且残差架构中没有残差连接和 / 或归一化层的参与,成为了社区感兴趣的领域。原因有两个:首先验证了残差架构有效性的信号传播假设,从而阐明对 DNN 可解释性的理解;其次这可能会实现超越残差范式的 DNN 可训练性的一般原则和方法。
对于 CNN,Xiao et al. (2018)的工作表明,通过更好初始化提升的信号传播能够高效地训练普通深度网络,尽管与残差网络比速度显著降低。Martens et al. (2021) 的工作提出了 Deep Kernel Shaping (DKS),使用激活函数转换来控制信号传播,使用 K-FAC 等强二阶优化器在 ImageNet 上实现了普通网络和残差网络的训练速度相等。Zhang et al. (2022) 的工作将 DKS 扩展到了更大类的激活函数,在泛化方面也实现了接近相等。
信号传播中需要分析的关键量是 DNN 的初始化时间内核,或者更准确地说,是无限宽度限制下的近似内核。对于多层感知机(MLP)以及使用 Delta 初始化的 CNN,该内核可以编写为仅包含 2D 函数的简单层递归,以便于进行直接分析。跨层 transformer 的内核演化更加复杂,因此 DKS 等现有方法不适用 transformer 或实际上任何包含自注意力层的架构。
在 MLP 中,信号传播是通过查看(一维)内核的行为来判断的,而 transformer 中的信号传播可以通过查看(高维)内核矩阵在网络层中的演化来判断。
该研究必须避免一种情况:对角线元素随深度增加快速增长或收缩,这与不受控制的激活范数有关,可能导致饱和损失或数值问题。避免秩崩溃(rank collapse)对于深度 transformer 的可训练性是必要的,而是否可以训练深度无残差 transformer 仍是一个悬而未决的问题。
ICLR 2023 盲审阶段的这篇论文解决了这个问题,首次证明了无需残差连接或归一化层时也可能成功训练深度 transformer。为此,他们研究了深度无残差 transformer 中的信号传播和秩崩溃问题,并推导出三种方法来阻止它们。具体而言,方法中使用了以下组合:参数初始化、偏置矩阵和位置相关的重缩放,并强调了 transformer 中信号传播特有的几种复杂性,包括与位置编码和因果掩蔽的交互。研究者实证证明了他们的方法可以生成可训练的深度无残差 transformer。
在实验部分,在 WikiText-103 和 C4 数据集上,研究者展示了使用他们主要的方法——指数信号保持注意力(Exponential Signal Preserving Attention, E-SPA),可以通过延长大约五倍的训练时间使得标准 transformer 与文中无残差 transformer 的训练损失相当。此外通过将这一方法与残差连接结合,研究者还表明无归一化层的 transformer 能够实现与标准 transformer 相当的训练速度。
论文地址:https://openreview.net/pdf?id=NPrsUQgMjKK
对于这篇论文,Google AI 首席工程师 Rohan Anil 认为是 Transformer 架构向前迈出的一大步,还是一个基础性的改进。
#迄今為止,修正Transformer 秩崩潰(rank collapse)的唯一策略依賴於殘差連接,該方式跳過了自註意力層固有的可訓練性問題。與此相反,該研究直接解決這個問題。首先透過注意力層更能理解訊號傳播,然後根據見解(insights)進行修改,以在深度 transformer 中實現對忠實訊號的傳輸,無論是否使用殘差連接,都可以對訊號進行訓練。
具體而言,首先,該研究對僅存在註意力的深度vanilla transformer 進行了一下簡單設置,之後他們假設該transformer 具有單一頭(h = 1)設置或具有多頭設置,其中註意力矩陣A 在不同頭之間不會變化。如果區塊l≤L 初始化時有註意力矩陣A_l,則最終區塊的表示形式為X_L:
對於上式而言,如果和採用正交初始化,那麼就可以在初始化時正交。
在上述假設下,如果採用表示跨位置輸入核矩陣,經過一些簡化處理後,可以得到以下公式:
從這個簡化公式(深度僅注意力transformer 中的核矩陣)中,可以確定(A_l)_l 的三個要求:
在接下來的3.1 和3.2 節中,研究專注於尋找滿足上述需求的注意力矩陣,他們提出了3 種方法E-SPA、U- SPA 和Value-Skipinit,每種方法都用來控制transformer 的注意力矩陣,即使在很深的深度也能實現忠實的訊號傳播。此外,3.3 節示範如何修改 softmax 注意力以實現這些注意力矩陣。
下圖中,該研究對提出的兩個SPA 方案進行了驗證,U-SPA 和E-SPA,結果顯示即使在網路較深時也能成功地避免僅注意vanilla transformers 中的秩崩潰現象。
#WikiText-103 基線:首先,該研究驗證了沒有殘差連接的標準深度transformer 是不可訓練的,即使它們有歸一化層(LN) 和transformed 激活,但本文的方法可以解決這個問題。如圖 2 所示,可以清楚地看到,從標準 transformer 中移除殘差連接使其不可訓練,訓練損失穩定在 7.5 左右。如圖 1 所示,標準 transformer 遭受了秩崩潰。
另一方面,研究提出的 E-SPA 方法優於 U-SPA 和 Value-Skipinit。然而,與本文無殘差方法相比,具有殘差和 LN 的預設 transformer 仍然保持訓練速度優勢。
在表 1 中,研究使用提出的方法評估了 MLP 區塊中不同活化函數的影響,以及 LN 在無殘差 transformer 的使用。可以看到在深度為 36 處,本文方法針對一系列激活實現了良好的訓練性能:DKS-transformed GeLU、TAT-transformed Leaky ReLU 以及 untransformed GeLU ,但不是 untransformed Sigmoid。透過實驗也看到,層歸一化對於訓練速度而言相對不重要,甚至在使用 SPA 時對 transformed activation 的活化有害,因為 SPA 已經具有控制激活規範的內建機制。
在圖3 中,我們看到一個不需要更多迭代就能匹配預設transformer 訓練損失的方法是使用歸一化殘差連接。
表 2 顯示有歸一化殘差和 LN 的 E-SPA 優於預設的 PreLN transformer。
下圖4(a)顯示E-SPA 再次優於其他方法;4(b)顯示訓練損失差距可以透過簡單地增加訓練時間來消除。
#以上是ICLR盲審階段就被評審讚不絕口的論文:會是Transformer架構的一大創新嗎?的詳細內容。更多資訊請關注PHP中文網其他相關文章!