Heim > Technologie-Peripheriegeräte > KI > Neues Werk von Yan Shuicheng/Cheng Mingming! DiT-Training, die Kernkomponente von Sora, wird um das Zehnfache beschleunigt und Masked Diffusion Transformer V2 ist Open Source

Neues Werk von Yan Shuicheng/Cheng Mingming! DiT-Training, die Kernkomponente von Sora, wird um das Zehnfache beschleunigt und Masked Diffusion Transformer V2 ist Open Source

王林
Freigeben: 2024-03-13 17:58:18
nach vorne
445 Leute haben es durchsucht

Als eine der überzeugenden Kerntechnologien von Sora nutzt DiT Diffusion Transformer, um das generative Modell auf einen größeren Maßstab zu skalieren und so herausragende Bilderzeugungseffekte zu erzielen.

Größere Modellgrößen führen jedoch dazu, dass die Schulungskosten in die Höhe schnellen.

Das Forschungsteam von Yan Shuicheng und Cheng Mingming vom Sea AI Lab der Nankai University und dem Kunlun Wanwei 2050 Research Institute schlug auf der ICCV 2023-Konferenz ein neues Modell namens Masked Diffusion Transformer vor. Dieses Modell verwendet die Maskenmodellierungstechnologie, um das Training des Diffusion Transformers durch das Erlernen semantischer Darstellungsinformationen zu beschleunigen und SoTA-Ergebnisse im Bereich der Bilderzeugung zu erzielen. Diese Innovation bringt neue Durchbrüche bei der Entwicklung von Bilderzeugungsmodellen und bietet Forschern eine effizientere Trainingsmethode. Durch die Kombination von Fachwissen und Technologie aus verschiedenen Bereichen schlug das Forschungsteam erfolgreich eine Lösung vor, die die Trainingsgeschwindigkeit erhöht und die Generierungsergebnisse verbessert. Ihre Arbeit hat wichtige innovative Ideen zur Entwicklung des Bereichs der künstlichen Intelligenz beigetragen und nützliche Inspiration für zukünftige Forschung und Praxis geliefert , Masked Diffusion Transformer V2 hat SoTA erneut aktualisiert, die Trainingsgeschwindigkeit im Vergleich zu DiT um mehr als das Zehnfache erhöht und den ImageNet-Benchmark-Score von 1,58 erreicht.

Die neueste Version des Papiers und des Codes sind Open Source. 颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

Hintergrund

Obwohl die von DiT dargestellten Diffusionsmodelle im Bereich der Bilderzeugung erhebliche Erfolge erzielt haben, haben Forscher herausgefunden, dass es mit Diffusionsmodellen häufig schwierig ist, die semantischen Beziehungen zwischen Teilen von Objekten in Bildern effizient zu erlernen Die Einschränkung führt zu einer geringen Konvergenzeffizienz des Trainingsprozesses.

Bilder

Zum Beispiel hat DiT, wie im Bild oben gezeigt, gelernt, beim 50. k-ten Trainingsschritt die Haartextur eines Hundes zu erzeugen, und dann beim 200. k-ten Trainingsschritt gelernt, eines der Augen des Hundes zu erzeugen Trainingsschritt und Mund, aber ein weiteres Auge fehlte.

Selbst beim 300-km-Trainingsschritt ist die durch DiT erzeugte relative Position der beiden Ohren des Hundes nicht sehr genau.

Dieser Trainings- und Lernprozess zeigt, dass das Diffusionsmodell die semantische Beziehung zwischen verschiedenen Teilen des Objekts im Bild nicht effizient lernen kann, sondern nur die semantischen Informationen jedes Objekts unabhängig lernt. 颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

Die Forscher vermuten, dass der Grund für dieses Phänomen darin liegt, dass das Diffusionsmodell die Verteilung realer Bilddaten lernt, indem es den Vorhersageverlust jedes Pixels minimiert. Bei diesem Prozess wird die semantische relative Beziehung zwischen den verschiedenen Teilen des Objekts ignoriert Bild, was dazu führt, dass das Modell langsam konvergiert.

Methode: Masked Diffusion Transformer

Inspiriert durch die obigen Beobachtungen schlug der Forscher Masked Diffusion Transformer (MDT) vor, um die Trainingseffizienz und Generierungsqualität des Diffusionsmodells zu verbessern.

MDT schlägt eine für Diffusion Transformer entwickelte Lernstrategie für die Darstellung von Maskenmodellen vor, um die Lernfähigkeit des Diffusion Transformers für kontextuelle semantische Informationen explizit zu verbessern und das Assoziationslernen semantischer Informationen zwischen Objekten im Bild zu verbessern.

Bild

Wie in der Abbildung oben gezeigt, führt MDT eine Lernstrategie für die Maskenmodellierung ein und behält gleichzeitig den Diffusionstrainingsprozess bei. Durch die Maskierung des verrauschten Bildtokens verwendet MDT eine asymmetrische Diffusionstransformator-Architektur (Asymmetric Diffusion Transformer), um das maskierte Bildtoken aus dem nicht maskierten verrauschten Bildtoken vorherzusagen und so gleichzeitig die Prozesse der Maskenmodellierung und des Diffusionstrainings zu erreichen.

Während des Inferenzprozesses behält MDT weiterhin den Standardprozess der Diffusionsgenerierung bei. Das Design von MDT hilft Diffusion Transformer dabei, sowohl die Fähigkeit zum Ausdruck semantischer Informationen zu nutzen, die durch das Lernen der Maskenmodellierungsdarstellung entsteht, als auch die Fähigkeit des Diffusionsmodells, Bilddetails zu generieren.

Konkret ordnet MDT Bilder über den VAE-Encoder dem latenten Raum zu und verarbeitet sie im latenten Raum, um Rechenkosten zu sparen.

Während des Trainingsprozesses maskiert MDT zunächst einen Teil der Bild-Tokens, nachdem Rauschen hinzugefügt wurde, und sendet die verbleibenden Tokens an den Asymmetric Diffusion Transformer, um nach dem Entrauschen alle Bild-Tokens vorherzusagen.

Asymmetric Diffusion Transformer-Architektur

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源Bild

Wie in der Abbildung oben gezeigt, umfasst die Asymmetric Diffusion Transformer-Architektur einen Encoder, einen Seiteninterpolator (Hilfsinterpolator) und einen Decoder.

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源Bilder

Während des Trainingsprozesses verarbeitet der Encoder nur Token, die nicht maskiert sind; während des Inferenzprozesses verarbeitet er alle Token, da es keinen Maskierungsschritt gibt.

Um sicherzustellen, dass der Decoder während der Trainings- oder Inferenzphase immer alle Token verarbeiten kann, schlugen die Forscher daher eine Lösung vor: während des Trainingsprozesses durch einen Hilfsinterpolator, der aus DiT-Blöcken besteht (wie in der Abbildung dargestellt). oben), interpolieren und prognostizieren das maskierte Token aus der Ausgabe des Encoders und entfernen es während der Inferenzphase, ohne einen Inferenz-Overhead hinzuzufügen.

MDTs Encoder und Decoder fügen globale und lokale Positionscodierungsinformationen in den Standard-DiT-Block ein, um die Vorhersage des Tokens im Maskenteil zu erleichtern. Asymmetric Diffusion Transformer V2 Prozess der Modellierung.

Dazu gehört die Integration einer langen Verknüpfung im U-Net-Stil in den Encoder und einer dichten Eingabeverknüpfung in den Decoder.

Unter diesen sendet die dichte Eingabeverknüpfung das maskierte Token, nachdem Rauschen an den Decoder hinzugefügt wurde, wobei die dem maskierten Token entsprechenden Rauschinformationen beibehalten werden, wodurch das Training des Diffusionsprozesses erleichtert wird. 颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

Darüber hinaus hat MDT auch bessere Trainingsstrategien eingeführt, einschließlich der Verwendung eines schnelleren Adan-Optimierers, zeitschrittbezogener Verlustgewichte und erweiterter Maskenverhältnisse, um den Trainingsprozess des Masked Diffusion-Modells weiter zu beschleunigen.

Experimentelle Ergebnisse

Qualitätsvergleich der ImageNet 256-Benchmark-Generierung

Bilder

Die obige Tabelle vergleicht die Leistung von MDT und DiT unter dem ImageNet 256-Benchmark bei verschiedenen Modellgrößen.

Es ist offensichtlich, dass MDT bei allen Modellgrößen höhere FID-Werte mit weniger Schulungskosten erzielt.

Die Parameter und Inferenzkosten von MDT sind grundsätzlich die gleichen wie bei DiT, da, wie oben erwähnt, der mit DiT konsistente Standarddiffusionsprozess im MDT-Inferenzprozess weiterhin beibehalten wird. 颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源

Für das größte XL-Modell übertrifft MDTv2-XL/2, trainiert mit 400.000 Schritten, DiT-XL/2, trainiert mit 7.000.000 Schritten, deutlich mit einer FID-Score-Verbesserung von 1,92. Unter dieser Einstellung zeigen die Ergebnisse, dass MDT etwa 18-mal schneller trainiert als DiT.

Bei kleinen Modellen erreicht MDTv2-S/2 immer noch eine deutlich bessere Leistung als DiT-S/2 mit deutlich weniger Trainingsschritten. Beispielsweise hat MDTv2 bei gleichem Training mit 400.000 Schritten einen FID-Index von 39,50, was deutlich über dem FID-Index von DiT von 68,40 liegt.

Noch wichtiger ist, dass dieses Ergebnis auch die Leistung des größeren Modells DiT-B/2 bei 400.000 Trainingsschritten übertrifft (39,50 vs. 43,47).

ImageNet 256 Benchmark-CFG-Generierungsqualitätsvergleich

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源Bilder

In der obigen Tabelle haben wir auch die Bildgenerierungsleistung von MDT mit vorhandenen Methoden unter klassifikatorfreier Anleitung verglichen.

MDT übertrifft frühere SOTA DiT und andere Methoden mit einem FID-Score von 1,79. MDTv2 verbessert die Leistung weiter und treibt den SOTA-FID-Score für die Bilderzeugung mit weniger Trainingsschritten auf einen neuen Tiefstwert von 1,58.

Ähnlich wie bei DiT beobachteten wir während des Trainings keine Sättigung des FID-Scores des Modells, während wir mit dem Training fortfuhren.

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源MDT aktualisiert SoTA auf der Rangliste von PaperWithCode

Konvergenzgeschwindigkeitsvergleich

颜水成/程明明新作!Sora核心组件DiT训练提速10倍,Masked Diffusion Transformer V2开源Bild

Das obige Bild vergleicht 8×A100 unter dem ImageNet 6 Benchmark DiT-S/ auf GPU 2 FID Leistung von Baseline, MDT-S/2 und MDTv2-S/2 unter verschiedenen Trainingsschritten/Trainingszeiten.

Dank besserer kontextbezogener Lernfähigkeiten übertrifft MDT DiT sowohl in der Leistung als auch in der Generierungsgeschwindigkeit. Die Trainingskonvergenzgeschwindigkeit von MDTv2 ist mehr als zehnmal höher als die von DiT.

MDT ist in Bezug auf Trainingsschritte und Trainingszeit etwa dreimal schneller als DiT. MDTv2 verbessert die Trainingsgeschwindigkeit im Vergleich zu MDT um etwa das Fünffache.

Zum Beispiel zeigt MDTv2-S/2 in nur 13 Stunden (15.000 Schritte) eine bessere Leistung als DiT-S/2, dessen Training etwa 100 Stunden (1.500.000 Schritte) dauert, was zeigt, dass das Lernen der kontextuellen Darstellung wichtig ist Ein schnelleres generatives Lernen von Diffusionsmodellen ist von entscheidender Bedeutung.

Zusammenfassung und Diskussion

MDT führt im Diffusionstrainingsprozess ein MAE-ähnliches Lernschema für die Maskenmodellierungsdarstellung ein, das die Kontextinformationen von Bildobjekten verwenden kann, um die vollständigen Informationen des unvollständigen Eingabebilds zu rekonstruieren und so zu lernen Die Semantik im Bild Die Korrelation zwischen Teilen, wodurch die Qualität der Bilderzeugung und die Lerngeschwindigkeit verbessert werden.

Forscher glauben, dass die Verbesserung des semantischen Verständnisses der physischen Welt durch das Lernen visueller Repräsentationen den Simulationseffekt des generativen Modells auf die physische Welt verbessern kann. Dies deckt sich mit Soras Vision, durch generative Modelle einen physischen Weltsimulator zu bauen. Hoffentlich wird diese Arbeit weitere Arbeiten zur Vereinheitlichung von Repräsentationslernen und generativem Lernen inspirieren.

Referenz:

https://arxiv.org/abs/2303.14389

Das obige ist der detaillierte Inhalt vonNeues Werk von Yan Shuicheng/Cheng Mingming! DiT-Training, die Kernkomponente von Sora, wird um das Zehnfache beschleunigt und Masked Diffusion Transformer V2 ist Open Source. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:51cto.com
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage