域適應是解決遷移學習的重要方法,目前域適應當法依賴原域和目標域資料進行同步訓練。當源域資料不可得,同時目標域資料不完全可見時,測試階段訓練(Test- Time Training)成為新的域適應方法。目前針對 Test-Time Training(TTT)的研究廣泛利用了自監督學習、對比學習、自訓練等方法,然而,如何定義真實環境下的 TTT 卻被經常忽略,以至於不同方法間缺乏可比性。
近日,華南理工、A*STAR 團隊和鵬城實驗室聯合提出了針對TTT 問題的系統性分類準則,透過區分方法是否具備順序推理能力(Sequential Inference)和是否需要修改源域訓練目標,對目前方法做了詳細分類。同時,提出了基於目標域資料定錨聚類(Anchored Clustering)的方法,在多種TTT 分類下取得了最高的分類準確率,本文對TTT 的後續研究指明了正確的方向,避免了實驗設定混淆帶來的結果不可比問題。研究論文已被 NeurIPS 2022 接收。
深度學習的成功主要歸功於大量的標註資料和訓練集與測試集獨立同分佈的假設。在一般情況下,需要在合成資料上訓練,然後在真實資料上測試時,以上假設就沒辦法滿足,這也稱為域偏移。為了緩解這個問題,域適應 (Domain Adaptation, DA) 誕生了。現有的 DA 工作要么需要在訓練期間存取源域和目標域的數據,要么同時在多個域進行訓練。前者需要模型在做適應 (Adaptation) 訓練期間總是能存取到源域數據,而後者需要更昂貴的計算量。為了降低對源域資料的依賴,由於隱私問題或儲存開銷無法存取源域數據,無需源域資料的域適應 (Source-Free Domain Adaptation, SFDA) 解決無法存取源域資料的域適應問題。作者發現 SFDA 需要在整個目標資料集上訓練多個輪次才能達到收斂,在面對串流資料需要及時做出推斷預測的時候 SFDA 無法解決此類問題。這種面對串流資料需要及時適應並做出推斷預測的更現實的設定,被稱為測試時訓練 (Test-Time Training, TTT) 或測試時適應(Test-Time Adaptation, TTA)。
作者註意到在社群裡對 TTT 的定義存在混亂從而導致比較的不公平。論文以兩個關鍵的因素對現有的TTT 方法進行分類:
這篇論文的目標是解決最現實和最具挑戰性的 TTT 協議,即單輪適應並無需修改訓練損失方程式。這個設定類似於 TENT[1]提出的 TTA,但不限於使用來自源域的輕量級訊息,如特徵的統計量。鑑於 TTT 在測試時高效適應的目標,該假設在計算上是高效的,並大大提高了 TTT 的性能。作者將這個新的 TTT 協議命名為順序測試時訓練(sequential Test Time Training, sTTT)。
除了上述不同 TTT 方法的分類外,論文還提出了兩種技術讓 sTTT 更有效、更準確:
論文分了四部分來闡述所提出的方法,分別是1)介紹測試時訓練(TTT) 的錨定聚類模組,如圖1 中的Anchored Clustering 部分;2)介紹用於過濾偽標籤的一些策略,如圖1 中的Pseudo Label Filter 部分;3)不同於TTT [2]中的使用L2 距離來衡量兩個分佈的距離,作者使用了KL 散度來測量兩個全局特徵分佈間的距離;4)介紹在測試時訓練(TTT) 過程的特徵統計量的有效更新迭代方法。最後第五小節給出了整個演算法的過程程式碼。
第一部分在錨定聚類裡,作者首先使用混合高斯對目標域的特徵進行建模,其中每個高斯分量代表一個被發現的聚集。然後,作者使用來源域中每個類別的分佈作為目標域分佈的錨點來進行配對。透過這種方式,測試資料特徵可以同時形成集群,並且集群與來源域類別相關聯,從而達到了對目標域的推廣。概述來說就是,將源域和目標域的特徵分別根據類別資訊建模成:
#然後透過KL 散度度量兩個混合高斯分佈的距離,並透過減少KL 散度來達到兩個域特徵的匹配。可是,在兩個混合高斯分佈上直接求解 KL 散度並沒有閉式解,這導致了無法使用有效的梯度最佳化方法。在這篇論文中,作者在源域和目標域中分配相同數量的集群,每個目標域集群被分配給一個源域集群,這樣就可以將整個混合高斯的KL 散度求解變成了各對高斯之間的KL 散度總和。如下式:
上式的閉式解形式為:
在公式2 中,源域群集的參數可以線下收集完,而且由於只用到了輕量化統計數據,所以不會導致隱私洩漏問題且只使用了少量的計算和存儲開銷。對於目標域的變量,涉及了偽標籤的使用,作者為此設計了一套有效的且輕量的偽標籤過濾策略。
第二部分偽標籤篩選的策略主要分為兩部分:
1)時序上一致性預測的篩選:
2)根據後驗機率的篩選:
最後,使用篩選後的樣本來求解目標域群集的統計量:
#第三部分 由於在錨定聚類中,部分被濾除的樣本並沒有參與目標域的估計。作者也對所有測試樣本進行全域特徵對齊,類似錨定聚類中對集群的做法,這裡將所有樣本看作一個整體的集群,在源域和目標域分別定義
#然後再次以最小化KL 散度為目標對齊全域特徵分佈:
第四部分以上三部分都在介紹一些域對齊的手段,但在TTT 過程中,想要估計一個目標域的分佈是不簡單的,因為我們無法觀測整個目標域的資料。在前沿的工作中,TTT [2]使用了一個特徵隊列來儲存過去的部分樣本,來計算一個局部分佈來估計整體分佈。但這樣不但帶來了記憶體開銷也導致了精度與記憶體之間的 trade off。在這篇論文中,作者提出了迭代更新統計量的方式來緩解記憶體開銷。具體的迭代更新式子如下:
總的來說,整個演算法如下演算法1 所示:
如同引言部分所說,這篇論文中作者非常注重不同TTT 策略下的不同方法的公平比較。作者將所有TTT 方法根據以下兩個關鍵因素來分類:1)是否單輪適應協議(One-Pass Adaptation) 和2)修改源域的訓練損失方程,分別記為Y/N 表示需要或不需要修改源域訓練方程,O/M 表示單輪適應或多輪適應。除此之外,作者在 6 個基準的資料集上進行了充分的比較實驗和一些進一步的分析。
如表一所示,TTT [2]同時出現在了N-O 和Y-O 的協定下,是因為TTT [2]擁有一個額外的自監督分支,我們在N-O協議下將不添加自監督分支的損失,而在Y-O 下可以正常使用此分子的損失。 TTAC 在 Y-O 下也是使用了跟 TTT [2]一樣的自監督分支。從表中可以看到,在所有的 TTT 協定下所有資料集下,TTAC 均取得到最優的結果;在 CIFAR10-C 和 CIFAR100-C 資料集上,TTAC 都取得了 3% 以上的提升。從表 2 - 表 5 分別是 ImageNet-C、CIFAR10.1、VisDA 上的數據,TTAC 均取到了最優的結果。
#此外,作者在多個TTT 協議下同時做了嚴格的消融實驗,清楚地看出了每個部件的作用,如表6 所示。首先從L2 Dist 和KLD 的對比中,可以看出使用KL 散度來衡量兩個分佈具有更優的效果;其次,發現如果單單使用Anchored Clustering 或單獨使用偽標籤監督提升只有14%,但如果結合了Anchored Cluster 和Pseudo Label Filter 就可以看到效能顯著提高29.15% -> 11.33%。這也可以看出每個部件的必要性和有效的結合。
Finally, the author fully analyzes TTAC from five dimensions at the end of the text, namely the cumulative performance under sTTT (N-O) and the TSNE visualization of TTAC features. , source domain independent TTT analysis, analysis of test sample queues and update rounds, computational overhead measured in wall-clock time. There are more interesting proofs and analyzes shown in the appendix of the article.
This article only briefly introduces the contribution points of this work of TTAC: classification and comparison of existing TTT methods, proposed methods, and various Experiments under TTT protocol classification. There will be more detailed discussion and analysis in the paper and appendix. We hope that this work can provide a fair benchmark for TTT methods and that future studies should compare within their respective protocols.
以上是如何正確定義測試階段訓練?順序推理和域適應聚類方法的詳細內容。更多資訊請關注PHP中文網其他相關文章!