目錄
01 知識蒸餾誕生的背景
02 Knowledge Distillation
簡介 
方法詳解 
03 FitNet 
04 總結
作者簡介 
首頁 科技週邊 人工智慧 大模型精準反哺小模型,知識蒸餾可協助提升 AI 演算法效能

大模型精準反哺小模型,知識蒸餾可協助提升 AI 演算法效能

Apr 08, 2023 pm 08:01 PM
電腦 ai 神經網路

01 知識蒸餾誕生的背景

來,深度神經網路(DNN)在工業界和學術界都取得了巨大成功,尤其是在 電腦視覺任務 方面。深度學習的成功很大程度上歸功於其具有數十億參數的用於編碼資料的可擴展性架構,其訓練目標是在已有的訓練資料集上建模輸入和輸出之間的關係,其效能高度依賴網路的複雜程度及有標註訓練資料的數量和品質。

相較於電腦視覺領域的傳統演算法,大多數基於DNN 的模型都因為 過參數化 而具備強大的 泛化能力 ,這種泛化能力體現在對於某個問題輸入的所有數據上,模型能給出較好的預測結果,無論是訓練數據、測試數據,或是屬於該問題的未知數據。

在當前深度學習的背景下,演算法工程師為了提升業務演算法的預測效果,常常會有兩種方案:

使用過參數化的更複雜的網絡,這類網絡學習能力非常強,但需要大量的運算資源來訓練,而且推理速度較慢。

整合模型,將許多效果較弱的模型整合起來,通常包括參數的整合和結果的整合。

這兩種方案能顯著提升現有演算法的效果,但都提升了模型的規模,產生了較大的運算負擔,所需的運算和儲存資源很大。

在工作中,各種演算法模型的最終目的都是要 服務某個應用 。就像在買賣中我們需要控制收入和支出一樣。在工業應用中,除了要求模型要有好的預測以外, 計算資源的使用也要嚴格控制,不能只考慮結果不考慮效率。在輸入資料編碼量高的電腦視覺領域,運算資源較顯有限,控制演算法的資源佔用就更為重要。

通常來說,規模較大的模型預測效果更好,但訓練時間長、推理速度慢的問題使得模型難以即時部署。尤其是在視訊監控、自動駕駛汽車和高吞吐量雲端環境等運算資源有限的設備上,響應速度顯然不夠用。規模較小的模型雖然推理速度較快,但是因為參數量不足,推理效果和泛化表現可能就沒那麼好。如何權衡大規模模型和小規模模型一直是個熱門話題,目前的解決方法大多是 根據部署環境的終端設備效能選擇合適規模的 DNN 模型。

如果我們希望有一個規模較小的模型,能在保持較快推理速度的前提下,達到和大模型相當或接近的效果該如何做到呢?

在機器學習中,我們常常假定輸入到輸出有一個潛在的映射函數關係,從頭學習一個新模型就是輸入資料和對應標籤中一個 近似 未知的映射函數。在輸入資料不變的前提下,從頭訓練一個小模型,從經驗上來看很難接近大模型的效果。為了提升小模型演算法的效能,一般來說最有效的方式是標註更多的輸入數據,也就是提供更多的監督信息,這可以讓學習到的映射函數更魯棒,性能更好。舉兩個例子,在電腦視覺領域中,實例分割任務透過額外提供掩膜訊息,可以提高目標包圍框檢測的效果;遷移學習任務透過提供在更大資料集上的預訓練模型,顯著提升新任務的預測效果。因此 提供更多的監督資訊 ,可能是縮短小規模模型和大規模模式差距的關鍵。

依照先前的說法,想要獲取更多的監督資訊意味著標註更多的訓練數據,這往往需要龐大的成本,那麼有沒有一種低成本又高效的監督資訊取得方法呢? 2006 年的文獻[1]中指出,可以讓新模型近似(approximate)原始模型(模型即函數)。因為原模型的函數是已知的,新模型訓練時等於天然地增加了更多的監督訊息,這顯然要更可行。

進一步思考,原模型帶來的監督資訊可能蘊含著不同維度的知識,這些與眾不同的資訊可能是新模型自己無法捕捉到的,在某種程度上來說,這對於新模型也是一種「跨域」的學習。

2015年Hinton在論文《Distilling the Knowledge in a Neural Network》[2] 中沿用近似的思想,率先提出「 知識蒸餾 (Knowledge Distillation, KD)」的概念:可以先訓練出一個大而強的模型,然後將其包含的知識轉移給小的模型,就實現了“保持小模型較快推理速度的同時,達到和大模型相當或接近的效果”的目的。這其中先訓練的大模型可以稱為教師模型,後訓練的小模型則稱為學生模型,整個訓練過程可以形像地比喻為「師生學習」。隨後幾年,湧現了大量的知識蒸餾與師生學習的工作,為工業界提供了更多新的解決想法。目前,KD 已廣泛應用於兩個不同的領域:模型壓縮和知識遷移[3]。

大模型精準反哺小模型,知識蒸餾可協助提升 AI 演算法效能


02 Knowledge Distillation

簡介 

Knowledge Distillation 是一種基於「教師-學生網路」思想的模型壓縮方法,由於簡單有效,在工業界被廣泛應用。其目的是將已經訓練好的大模型所包含的知識-蒸餾(Distill),提取到另一個小的模型中去。那怎麼讓大模型的知識,或者說泛化能力轉移到小模型身上呢? KD 論文把大模型對樣本輸出的機率向量當作軟目標(soft targets)提供給小模型,讓小模型的輸出盡量去向這個軟目標靠(原來是往one-hot 編碼上靠),去近似學習大模型的行為。

在傳統的硬標籤訓練過程中,所有負標籤都被統一對待,但這種方式把類別間的關係割裂開了。比如說辨識手寫數字,同是標籤為“3”的圖片,可能有的比較像“8”,有的比較像“2”,硬標籤區分不出來這個信息,但是一個訓練良好的大模型可以給出。大模型 softmax 層的輸出,除了正例之外,負標籤也帶有大量的訊息,例如某些負標籤對應的機率遠大於其他負標籤。近似學習此行為使得每個樣本帶給學生網路的資訊量大於傳統的訓練方式。

因此,作者在訓練學生網路時修改了一下損失函數,讓小模型在擬合訓練資料的真值(ground truth)標籤的同時,也要擬合大模型輸出的機率分佈。這個方法叫做知識 蒸餾訓練 (Knowledge Distillation Training, KD Training)。知識蒸餾過程所使用的訓練樣本可以和訓練大模型用的訓練樣本一樣,或是另找一個獨立的 Transfer set。

大模型精準反哺小模型,知識蒸餾可協助提升 AI 演算法效能

方法詳解 

具體來說,知識蒸餾使用的是Teacher—Student 模型,其中teacher 是「知識」的輸出者,student 是「知識」的接受者。知識蒸餾的過程分為2 個階段:

  • 教師模型訓練:訓練」Teacher 模型「, 簡稱為Net-T,它的特徵是模型相對複雜,也可以由多個分別訓練的模型整合而成。對「Teacher模型」不作任何關於模型架構、參數量、是否集成方面的限制,因為該模型不需要部署,唯一的要求就是,對於輸入X, 其都能輸出Y,其中Y 經過softmax 的映射,輸出值對應相應類別的機率值。
  • 學生模型訓練:訓練「Student 模型」, 簡稱為 Net-S,它是參數量較小、模型結構相對簡單的單一模型。同樣的,對於輸入 X,其都能輸出 Y,Y 經過 softmax 映射後同樣能輸出對應對應類別的機率值。

由於使用softmax 的網路的結果很容易走向極端,即某一類的置信度超高,其他類別的置信度都很低,此時學生模型關注到的正類信息可能還是僅屬於某一類。除此之外,因為不同類別的負類資訊也有相對的重要性,所有負類分數都差不多也不好,達不到知識蒸餾的目的。為了解決這個問題,引入溫度(Temperature)的概念,使用高溫將小機率值所攜帶的資訊蒸餾出來。具體來說,在 logits 過 softmax 函數前除以溫度 T。

訓練時先將教師模型學習到的知識蒸餾給小模型,具體來說對樣本X,大模型的倒數第二層先除以一個溫度T,然後透過softmax 預測一個軟目標Soft target,小模型也是一樣,倒數第二層除以同樣的溫度T,然後透過softmax 預測一個結果,再把這個結果和軟目標的交叉熵當作訓練的total loss 的一部分。然後再將小模型正常的輸出和真值標籤(hard target)的交叉熵作為訓練的 total loss 的另一部分。 Total loss 把這兩個損失加權合起來作為訓練小模型的最終的 loss。

在小模型訓練好了要預測時,就不需要再有溫度 T 了,直接按照常規的 softmax 輸出就可以了。

03 FitNet 

簡介 

FitNet 論文在蒸餾時引入了中間層隱藏映射(intermediate-level hints)來指導學生模型的訓練。使用一個寬而淺的教師模型來訓練一個窄而深的學生模型。在進行 hint 引導時,提出使用一個層來匹配 hint 層和 guided 層的輸出 shape,在後人的工作裡面常被稱為 adaptation layer。

總的來說,相當於在做知識蒸餾時,不僅用到了教師模型的 logit 輸出,還用到了教師模型的中間層特徵圖作為監督資訊。可以想到的是,直接讓小模型在輸出端模仿大模型,這個對於小模型來說太難了(模型越深越難訓,最後一層的監督信號要傳到前面去還是挺累的),不如在中間加一些監督訊號,使得模型在訓練時可以從逐層接受學習更難的映射函數,而不是直接學習最難的映射函數;除此之外,hint 引導加速了學生模型的收斂,在在一個非凸問題上找到更好的局部最小值,使得學生網路能更深的同時,還能訓練得更快。這感覺就好像是,我們的目的是讓學生做高考題,那麼就先把國中的題目給他教會了(先讓小模型用前半個模型學會提取圖像底層特徵),然後再回到本來的目的、去學高考題(用KD 調整小模型的全部參數)。

這篇文章是提出蒸餾中間特徵圖的始祖,提出的演算法很簡單,但思路具有開創性。

大模型精準反哺小模型,知識蒸餾可協助提升 AI 演算法效能

方法詳解 

FitNets 的具體做法是:

  • 確定教師網路,並訓練成熟,將教師網路的中間特徵層hint 提取出來。
  • 設定學生網絡,一般較教師網絡較窄且較深。訓練學生網路使得學生網路的中間特徵層與教師模型的 hint 相符。由於學生網路的中間特徵層和與教師 hint 尺寸不同,因此需要在學生網路中間特徵層後添加回歸器用於特徵升維,以匹配 hint 層尺寸。其中匹配教師網路的 hint 層與回歸器轉換後的學生網路的中間特徵層的損失函數為均方差損失函數。

實際訓練的時候往往和上一節的KD Training 聯合使用,用兩階段法訓練:先用hint training 去pretrain 小模型前半部分的參數,再用KD Training 去訓練全體參數。由於蒸餾過程中使用了更多的監督信息, 基於中間特徵圖的蒸餾方法比基於結果 logits 的蒸餾方法效果要好 ,但是訓練時間更久。

04 總結

知識蒸餾對於將知識從整合或從高度正規化的大型模型轉移到較小的模型中非常有效。即使在用於訓練蒸餾模型的遷移資料集中缺少任何一個或多個類別的資料時,蒸餾的效果也非常好。在經典之作 KD 和 FitNet 提出之後,各種各樣的蒸餾方法如雨後春筍般湧現。未來我們也希望能在模型壓縮和知識遷移領域中做出更進一步的探索。

作者簡介 

馬佳良,網易易盾資深電腦視覺演算法工程師,主要負責電腦視覺演算法在內容安全領域的研發、最佳化與創新。

以上是大模型精準反哺小模型,知識蒸餾可協助提升 AI 演算法效能的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

<🎜>:泡泡膠模擬器無窮大 - 如何獲取和使用皇家鑰匙
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
北端:融合系統,解釋
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
Mandragora:巫婆樹的耳語 - 如何解鎖抓鉤
3 週前 By 尊渡假赌尊渡假赌尊渡假赌

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

熱門話題

Java教學
1666
14
CakePHP 教程
1426
52
Laravel 教程
1328
25
PHP教程
1273
29
C# 教程
1255
24
C  中的chrono庫如何使用? C 中的chrono庫如何使用? Apr 28, 2025 pm 10:18 PM

使用C 中的chrono庫可以讓你更加精確地控制時間和時間間隔,讓我們來探討一下這個庫的魅力所在吧。 C 的chrono庫是標準庫的一部分,它提供了一種現代化的方式來處理時間和時間間隔。對於那些曾經飽受time.h和ctime折磨的程序員來說,chrono無疑是一個福音。它不僅提高了代碼的可讀性和可維護性,還提供了更高的精度和靈活性。讓我們從基礎開始,chrono庫主要包括以下幾個關鍵組件:std::chrono::system_clock:表示系統時鐘,用於獲取當前時間。 std::chron

怎樣在C  中處理高DPI顯示? 怎樣在C 中處理高DPI顯示? Apr 28, 2025 pm 09:57 PM

在C 中處理高DPI顯示可以通過以下步驟實現:1)理解DPI和縮放,使用操作系統API獲取DPI信息並調整圖形輸出;2)處理跨平台兼容性,使用如SDL或Qt的跨平台圖形庫;3)進行性能優化,通過緩存、硬件加速和動態調整細節級別來提升性能;4)解決常見問題,如模糊文本和界面元素過小,通過正確應用DPI縮放來解決。

如何理解C  中的DMA操作? 如何理解C 中的DMA操作? Apr 28, 2025 pm 10:09 PM

DMA在C 中是指DirectMemoryAccess,直接內存訪問技術,允許硬件設備直接與內存進行數據傳輸,不需要CPU干預。 1)DMA操作高度依賴於硬件設備和驅動程序,實現方式因係統而異。 2)直接訪問內存可能帶來安全風險,需確保代碼的正確性和安全性。 3)DMA可提高性能,但使用不當可能導致系統性能下降。通過實踐和學習,可以掌握DMA的使用技巧,在高速數據傳輸和實時信號處理等場景中發揮其最大效能。

C  中的實時操作系統編程是什麼? C 中的實時操作系統編程是什麼? Apr 28, 2025 pm 10:15 PM

C 在實時操作系統(RTOS)編程中表現出色,提供了高效的執行效率和精確的時間管理。 1)C 通過直接操作硬件資源和高效的內存管理滿足RTOS的需求。 2)利用面向對象特性,C 可以設計靈活的任務調度系統。 3)C 支持高效的中斷處理,但需避免動態內存分配和異常處理以保證實時性。 4)模板編程和內聯函數有助於性能優化。 5)實際應用中,C 可用於實現高效的日誌系統。

給MySQL表添加和刪除字段的操作步驟 給MySQL表添加和刪除字段的操作步驟 Apr 29, 2025 pm 04:15 PM

在MySQL中,添加字段使用ALTERTABLEtable_nameADDCOLUMNnew_columnVARCHAR(255)AFTERexisting_column,刪除字段使用ALTERTABLEtable_nameDROPCOLUMNcolumn_to_drop。添加字段時,需指定位置以優化查詢性能和數據結構;刪除字段前需確認操作不可逆;使用在線DDL、備份數據、測試環境和低負載時間段修改表結構是性能優化和最佳實踐。

怎樣在C  中測量線程性能? 怎樣在C 中測量線程性能? Apr 28, 2025 pm 10:21 PM

在C 中測量線程性能可以使用標準庫中的計時工具、性能分析工具和自定義計時器。 1.使用庫測量執行時間。 2.使用gprof進行性能分析,步驟包括編譯時添加-pg選項、運行程序生成gmon.out文件、生成性能報告。 3.使用Valgrind的Callgrind模塊進行更詳細的分析,步驟包括運行程序生成callgrind.out文件、使用kcachegrind查看結果。 4.自定義計時器可靈活測量特定代碼段的執行時間。這些方法幫助全面了解線程性能,並優化代碼。

量化交易所排行榜2025 數字貨幣量化交易APP前十名推薦 量化交易所排行榜2025 數字貨幣量化交易APP前十名推薦 Apr 30, 2025 pm 07:24 PM

交易所內置量化工具包括:1. Binance(幣安):提供Binance Futures量化模塊,低手續費,支持AI輔助交易。 2. OKX(歐易):支持多賬戶管理和智能訂單路由,提供機構級風控。獨立量化策略平台有:3. 3Commas:拖拽式策略生成器,適用於多平台對沖套利。 4. Quadency:專業級算法策略庫,支持自定義風險閾值。 5. Pionex:內置16 預設策略,低交易手續費。垂直領域工具包括:6. Cryptohopper:雲端量化平台,支持150 技術指標。 7. Bitsgap:

數字虛擬幣交易平台top10 安全可靠的十大數字貨幣交易所 數字虛擬幣交易平台top10 安全可靠的十大數字貨幣交易所 Apr 30, 2025 pm 04:30 PM

數字虛擬幣交易平台top10分別是:1. Binance,2. OKX,3. Coinbase,4. Kraken,5. Huobi Global,6. Bitfinex,7. KuCoin,8. Gemini,9. Bitstamp,10. Bittrex,這些平台均提供高安全性和多種交易選項,適用於不同用戶需求。

See all articles