最近Google又發布了全新的文字-圖像生成Muse模型,沒有採用當下大火的擴散(diffusion)模型,而是採用了經典的Transformer模型就實現了最先進的圖像生成性能,相比擴散或自回歸(autoregressive)模型,Muse模型的效率也提升非常多。
論文連結:https://arxiv.org/pdf/2301.00704.pdf
#專案連結:https://muse-model.github.io/
Muse以masked modeling任務在離散token空間上進行訓練:給定從預先訓練的大型語言模型(LLM)中提取的文字嵌入,Muse的訓練過程就是預測隨機masked掉的圖像token。
與像素空間的擴散模型(如Imagen和DALL-E 2)相比,由於Muse使用了離散的token,只需要較少的採樣迭代,所以效率得到了明顯提高;
與自回歸模型(如Parti)相比,由於Muse使用了並行解碼,所以效率更高。
使用預先訓練好的LLM可以實現細粒度的語言理解,從而轉化為高保真的圖像生成和對視覺概念的理解,如物體、空間關係、姿態、cardinality等。
在實驗結果中,只有900M參數的Muse模型在CC3M上實現了新的SOTA效能,FID分數為6.06。
Muse 3B參數模型在zero-shot COCO評估中實作了7.88的FID,同時還有0.32的CLIP分數。
Muse還可以在不對模型進行微調或反轉(invert)直接實現一些圖像編輯應用程式:修復(inpainting)、擴充(outpainting )和無遮罩編輯(mask-free editing)。
Muse模型的框架包含多個元件,訓練pipeline由T5-XXL預訓練文字編碼器,基礎模型(base model)和超分辨率模型組成。
1. 預訓練文字編碼器
與先前研究中得出的結論類似,研究人員發現利用預先訓練的大型語言模型(LLM)有利於提升高品質影像的生成結果。
例如從語言模型T5-XXL中提取的嵌入(embedding)帶有關於物體(名詞)、行動(動詞)、視覺屬性(形容詞)、空間關係(介詞)以及其他屬性(如卡片性和組成)的豐富資訊。
所以研究者提出假設(hypothesis):Muse模型學會將LLM嵌入中的這些豐富的視覺和語義概念映射到生成的圖像上。
最近也有一些工作已經證明了,由LLM學習到的概念表徵與由視覺任務訓練的模型學習的概念表徵大致上是可以「線性映射」的。
給定一個輸入的文字標題,將其傳遞給凍結參數的T5-XXL編碼器,可以得到一個4096維的語言嵌入向量,然後將這些向量線性地投射到Transformer模型(base和超解析度)的hidden size維度上。
2. 使用VQGAN進行Semantic Tokenization
#VQGAN模型由一個編碼器和一個解碼器組成,其中的量化層(quantization layer)將輸入影像映射成來自一個學習過的codebook的token序列。
然後完全用卷積層建立編碼器和解碼器,以支援對不同解析度的影像進行編碼。
編碼器中包含幾個下取樣區塊來減少輸入的空間維度,而解碼器中則是有對應數量的上取樣區塊來將latents映射回原始影像大小。
研究人員訓練了兩個VQGAN模型:一個是下取樣率f=16,模型在256×256像素的影像上獲得基本模型的標記,從而得到空間尺寸為16×16的標記;另一個是下取樣率f=8,在512×512的影像上獲得超解析度模型的token,對應的空間尺寸為64×64。
編碼後得到的離散token可以捕捉影像的高層次語義,同時也可以消除低層次的噪聲,並且根據token的離散性可以在輸出端使用交叉熵損失來預測下一階段的masked token
3. Base Model
Muse的基礎模型是一個masked Transformer,其中輸入是映射的T5嵌入和圖像token.
研究人員將所有的文本嵌入設置為unmasked,隨機mask掉一部分不同的圖像token後,用一個特殊的[MASK]標記來代替原token.
然後將圖像token線性地映射到所需的Transformer輸入或hidden size維度的圖像輸入embedding中,並同時學習2D position embedding
和原始的Transformer架構一樣,包括幾個transformer層,使用自註意塊、交叉注意力塊和MLP塊來提取特徵。
在輸出層,使用一個MLP將每個masked映像嵌入轉換為一組logits(對應於VQGAN codebook的大小),並以ground truth的token為目標使用交叉熵損失。
在訓練階段,基礎模型的訓練目標為預測每一步的所有msked tokens;但在推理階段,mask預測是以迭代的方式進行的,這種方式可以極大提高品質。
4. 超解析度模型
#研究人員發現,直接預測512× 512解析度的圖像會導致模型專注於低層次的細節而非高層次的語義。
使用級聯模型(cascade of models)則可以改善這種情況:
首先使用一個產生16×16 latent map(對應256×256的影像)的基礎模型;然後是一個超解析度模型,將基礎latent map上取樣為64×64(對應512×512的影像)。其中超解析度模型是在基礎模型訓練完成後再進行訓練的。
如前所述,研究人員總共訓練了兩個VQGAN模型,一個是16×16潛分辨率和256×256空間分辨率,另一個是64×64潛伏解析度和512×512空間解析度。
由於基礎模型輸出對應於16×16 latent map的token,所以超解析度模組學會了將低解析度的latent map 「翻譯」成高解析度的latent map ,然後透過高解析度的VQGAN解碼,得到最終的高解析度圖像;此翻譯模型也是以類似於基礎模型的方式進行text conditioning和交叉注意力的訓練。
5. 解碼器微調
#為了進一步提高模型產生細節的能力,研究人員選擇透過增加VQGAN解碼器的容量,增加更多的殘差層(residual layer)和通道的同時保持編碼器的容量不變。
然後對新的解碼器進行微調,同時保持VQGAN編碼器的權重、codebook和Transformers(即基礎模型和超解析度模型)不變。這種方式能夠提高生成影像的視覺質量,而不需要重新訓練任何其他的模型組件(因為視覺token保持固定)。
可以看到,經過微調的解碼器以重建更多更清晰的細節。
6. 可變遮罩率(Masking Rate)
研究者使用基於Csoine scheduling的可變遮罩率來訓練模型:對於每個訓練例子,從截斷的arccos分佈中抽出一個掩碼率r∈[0,1],其密度函數如下.
#遮罩率的期望值為0.64,也就是說更偏向選擇更高的遮罩率,使得預測問題更加困難。
隨機的遮罩率不僅對平行取樣方案至關重要,而且還能實現一些零散的、開箱即用的編輯功能。
7. Classifier Free Guidance(CFG)
研究者採用無分類指導(CFG)來提高影像的生成品質和文字-影像對齊。
在訓練時,在隨機選擇的10%的樣本上去除文字條件,注意力機制降為圖像token本身的自註意力。
在推理階段,為每個被mask的token計算一個條件logit lc和一個無條件logit lu,然後透過從無條件logit中移出一個量t作為指導尺度,形成最終的logit lg:
直覺來看,CFG是以多樣性換取保真度,但與先前方法不同的是,Muse透過取樣過程線性地增加指導尺度t來減少多樣性的損失,使得early token可以在低引導或無引導的情況下更自由地被取樣,不過也增加了對later tokens條件提示的影響。
研究人員也利用這個機制,透過將無條件的logit lu替換為以negative prompt為條件的logit,促進了生成影像具有與postive prompt相關的特徵。
8. 推理時迭代並行解碼
#在提升模型推理時間效率的關鍵部分是使用並行解碼來預測單一前向通道中的多個輸出token,其中一個關鍵假設是馬爾科夫屬性,即許多token是有條件地獨立於給定的其他token的。
其中解碼是根據cosine schedule進行的,選擇固定比例中最高置信度的掩碼進行預測,其中token在剩餘的步驟中被設定為unmasked,並且適當減少masked tokens。
根據上述過程,就可以在基本模型中只用24個解碼步(step)實現對256個token的推理,在超解析度模型中用8個解碼步對4096個token進行推理,相較之下,自迴歸模型需要256或4096步,擴散模型需要數百步。
雖然最近的一些研究包括progressive distillation、better ODE solver大大減少了擴散模型的採樣步驟,但這些方法還沒有在大規模的文本到圖像生成中得到廣泛驗證。
研究人員以不同的參數量(從600M到3B),基於T5-XXL訓練了一系列基礎Transformer模型。
產生圖像的品質
#實驗中測試了Muse模型對於不同屬性的文字提示的能力,包括對cardinality的基本理解,對於非單數的物體,Muse並沒有多次產生相同的物體像素,而是增加了上下文的變化,使整個影像更加真實。
例如,大象的大小和方向、酒瓶包裝紙的顏色以及網球的旋轉等等。
定量比較
#研究人員在CC3M和COCO資料集上與其他研究方法進行了實驗對比,指標包括衡量樣本品質和多樣性的Frechet Inception Distance(FID),以及衡量圖像/文字對齊的CLIP分數。
實驗結果證明了632M的Muse模型在CC3M上取得了SOTA結果,在FID分數方面得到了改善,同時也取得了最先進的CLIP得分。
在MS-COCO資料集上,3B模型取得了7.88分的FID得分,略好於相似參數量的Parti-3B模型取得的8.1分。
以上是Transformer再勝Diffusion!谷歌發布新世代文字-圖像生成模型Muse:生成效率提升十倍的詳細內容。更多資訊請關注PHP中文網其他相關文章!