Ces dernières années, porté par Transformer, le machine learning connaît une renaissance. Au cours des cinq dernières années, les architectures neuronales pour le traitement du langage naturel, la vision par ordinateur et d’autres domaines ont été largement dominées par les transformateurs.
Cependant, il existe de nombreux modèles génératifs au niveau de l'image qui ne sont toujours pas affectés par cette tendance. Par exemple, les modèles de diffusion ont obtenu des résultats étonnants en matière de génération d'images au cours de l'année écoulée, et presque tous ces modèles utilisent U- convolutif. Net comme colonne vertébrale. C'est un peu surprenant ! La grande histoire de l’apprentissage profond au cours des dernières années a été la domination de Transformer dans tous les domaines. Y a-t-il quelque chose de spécial à propos d'U-Net ou des convolutions qui les rend si performants dans les modèles de diffusion ?
La recherche qui a introduit pour la première fois le réseau fédérateur U-Net dans le modèle de diffusion remonte à Ho et al. Ce modèle de conception hérite du modèle génératif autorégressif PixelCNN++ avec seulement de légers changements. PixelCNN++ se compose de couches convolutives, qui contiennent de nombreux blocs ResNet. Par rapport au U-Net standard, le bloc d’auto-attention spatiale supplémentaire de PixelCNN++ devient un composant de base du transformateur. Contrairement à d'autres études, Dhariwal et Nichol et al. éliminent plusieurs choix architecturaux d'U-Net, tels que l'utilisation de couches de normalisation adaptatives pour injecter des informations sur les conditions et le nombre de canaux dans les couches convolutives.
Dans cet article, William Peebles de l'UC Berkeley et Xie Saining de l'Université de New York ont écrit "Modèles de diffusion évolutifs avec transformateurs". L'objectif est de découvrir l'importance des choix architecturaux dans les modèles de diffusion et de fournir une base empirique pour les futures générations. recherche de modèles. Cette étude montre que la polarisation inductive U-Net n'est pas essentielle aux performances des modèles de diffusion et peut être facilement remplacée par des conceptions standards telles que des transformateurs.
Cette découverte montre que les modèles de diffusion peuvent bénéficier des tendances d'unification architecturale. Par exemple, les modèles de diffusion peuvent hériter des meilleures pratiques et méthodes de formation d'autres domaines, conservant l'évolutivité, la robustesse et l'efficacité de ces modèles. Une architecture standardisée ouvrira également de nouvelles possibilités pour la recherche inter-domaines.
Cette recherche se concentre sur un nouveau type de modèle de diffusion basé sur un transformateur : les transformateurs de diffusion (DiTs en abrégé). Les DiT suivent les meilleures pratiques des Vision Transformers (ViT), avec quelques ajustements mineurs mais importants. Il a été démontré que DiT évolue plus efficacement que les réseaux convolutifs traditionnels tels que ResNet.
Plus précisément, cet article étudie le comportement de mise à l'échelle de Transformer en termes de complexité du réseau et de qualité des échantillons. L'étude montre qu'en construisant et en évaluant l'espace de conception DiT dans le cadre du modèle de diffusion latente (LDM), où le modèle de diffusion est formé dans l'espace latent de VAE, il est possible de remplacer avec succès le squelette U-Net par un transformateur. Cet article montre en outre que DiT est une architecture évolutive pour les modèles de diffusion : il existe une forte corrélation entre la complexité du réseau (mesurée par Gflops) et la qualité des échantillons (mesurée par FID). En étendant simplement DiT et en entraînant un LDM avec un réseau fédérateur haute capacité (118,6 Gflops), des résultats de pointe de 2,27 FID sont obtenus sur le benchmark de génération ImageNet 256 × 256 conditionnel à la classe.
DiTs est une nouvelle architecture de modèles de diffusion qui vise à être la plus fidèle possible à l'architecture standard du transformateur afin de conserver son évolutivité. DiT conserve bon nombre des meilleures pratiques de ViT, et la figure 3 montre l'architecture complète de DiT. L'entrée de
DiT est la représentation spatiale z (pour une image 256 × 256 × 3, la forme de z est 32 × 32 × 4). La première couche de DiT est patchify, qui convertit l'entrée spatiale en une séquence de jetons T en intégrant linéairement chaque patch dans l'entrée. Après patchify, nous appliquons des intégrations positionnelles standard basées sur la fréquence ViT à tous les jetons d'entrée.
Le nombre de tokens T créés par patchify est déterminé par l'hyperparamètre de taille du patch p. Comme le montre la figure 4, la réduction de moitié de p quadruple T et donc au moins quadruple les Gflops du transformateur. Cet article ajoute p = 2,4,8 à l'espace de conception DiT.
Conception du bloc DiT : après patchify, le jeton d'entrée est traité par une série de blocs de transformateur. En plus de l'entrée d'image bruitée, les modèles de diffusion gèrent parfois des informations conditionnelles supplémentaires, telles que le pas de temps de bruit t, l'étiquette de classe c, le langage naturel, etc. Cet article explore quatre variantes de blocs de transformateur qui gèrent l'entrée conditionnelle de différentes manières. Ces conceptions comportent des modifications mineures mais significatives par rapport à la conception standard du bloc ViT. La conception de tous les modules est illustrée à la figure 3.
Cet article a essayé quatre configurations qui varient selon la profondeur et la largeur du modèle : DiT-S, DiT-B, DiT-L et DiT-XL. Ces configurations de modèles vont de 33 M à 675 M de paramètres et de Gflops de 0,4 à 119.
Les chercheurs ont formé quatre modèles DiT-XL/2 avec les Gflops les plus élevés, chacun utilisant une conception de bloc différente - en contexte (119,4 Gflops), attention croisée (137,6 Gflops), norme de couche adaptative (adaLN , 118,6 Gflops) ou adaLN-zéro (118,6 Gflops). Le FID a ensuite été mesuré pendant l'entraînement et la figure 5 montre les résultats.
Taille du modèle étendue et taille du patch. La figure 2 (à gauche) donne un aperçu des Gflops pour chaque modèle et de leur FID à 400 000 itérations de formation. On peut constater que l’augmentation de la taille du modèle et la réduction de la taille des patchs produisent des améliorations considérables du modèle de diffusion.
La figure 6 (en haut) montre comment le FID change à mesure que la taille du modèle augmente et que la taille du patch reste constante. Dans les quatre paramètres, des améliorations significatives du FID sont obtenues à toutes les étapes de la formation en rendant le Transformer plus profond et plus large. De même, la figure 6 (en bas) montre le FID lorsque la taille du patch est réduite et que la taille du modèle reste constante. Les chercheurs ont de nouveau observé que le FID s'est considérablement amélioré en augmentant simplement le nombre de jetons traités par DiT et en maintenant les paramètres à peu près fixes tout au long du processus de formation.
La figure 8 montre la comparaison du FID-50K avec le modèle Gflops à 400K étapes de formation :
Modèle de diffusion SOTA 256×256 ImageNet. Après l’analyse approfondie, les chercheurs ont continué à entraîner le modèle Gflop le plus élevé, DiT-XL/2, avec un nombre de pas de 7 M. La figure 1 montre un échantillon de ce modèle et le compare au modèle SOTA de génération conditionnelle de catégorie, et les résultats sont présentés dans le tableau 2.
Lors de l'utilisation d'un guidage sans classificateur, DiT-XL/2 surpasse tous les modèles de diffusion précédents, réduisant le précédent meilleur FID-50K de 3,60 obtenu par LDM à 2,27. Comme le montre la figure 2 (à droite), comparé aux modèles U-Net à espace latent tels que LDM-4 (103,6 Gflops), DiT-XL/2 (118,6 Gflops) est beaucoup plus efficace en termes de calcul que ADM (1 120 Gflops) ou). ADM-U (742 Gflops), les modèles U-Net à espace pixel sont beaucoup plus efficaces.
Le tableau 3 montre la comparaison avec les méthodes SOTA. XL/2 surpasse encore une fois tous les modèles de diffusion précédents à cette résolution, améliorant le précédent meilleur FID d'ADM de 3,85 à 3,04.
Pour plus de détails sur la recherche, veuillez vous référer à l'article original.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!