从U-Net到DiT:Transformer技术在统治扩散模型中的应用
近几年,在 Transformer 的推动下,机器学习正在经历复兴。过去五年中,用于自然语言处理、计算机视觉以及其他领域的神经架构在很大程度上已被 transformer 所占据。
不过还有许多图像级生成模型仍然不受这一趋势的影响,例如过去一年扩散模型在图像生成方面取得了惊人的成果,几乎所有这些模型都使用卷积 U-Net 作为主干。这有点令人惊讶!在过去的几年中,深度学习的大事件一直是跨领域的 Transformer 的主导地位。U-Net 或卷积是否有什么特别之处使它们在扩散模型中表现得如此出色?
将 U-Net 主干网络首次引入扩散模型的研究可追溯到 Ho 等人,这种设计模式继承了自回归生成模型 PixelCNN++,只是稍微进行了一些改动。而 PixelCNN++ 由卷积层组成,其包含许多的 ResNet 块。其与标准的 U-Net 相比,PixelCNN++ 附加的空间自注意力块成为 transformer 中的基本组件。不同于其他人的研究,Dhariwal 和 Nichol 等人消除了 U-Net 的几种架构选择,例如使用自适应归一化层为卷积层注入条件信息和通道计数。
本文中来自 UC 伯克利的 William Peebles 以及纽约大学的谢赛宁撰文《 Scalable Diffusion Models with Transformers 》,目标是揭开扩散模型中架构选择的意义,并为未来的生成模型研究提供经验基线。该研究表明,U-Net 归纳偏置对扩散模型的性能不是至关重要的,并且可以很容易地用标准设计(如 transformer)取代。
这一发现表明,扩散模型可以从架构统一趋势中受益,例如,扩散模型可以继承其他领域的最佳实践和训练方法,保留这些模型的可扩展性、鲁棒性和效率等有利特性。标准化架构也将为跨领域研究开辟新的可能性。
- 论文地址:https://arxiv.org/pdf/2212.09748.pdf
- 项目地址:https://github.com/facebookresearch/DiT
- 论文主页:https://www.wpeebles.com/DiT
该研究专注于一类新的基于 Transformer 的扩散模型:Diffusion Transformers(简称 DiTs)。DiTs 遵循 Vision Transformers (ViTs) 的最佳实践,有一些小但重要的调整。DiT 已被证明比传统的卷积网络(例如 ResNet )具有更有效地扩展性。
具体而言,本文研究了 Transformer 在网络复杂度与样本质量方面的扩展行为。研究表明,通过在潜在扩散模型 (LDM) 框架下构建 DiT 设计空间并对其进行基准测试,其中扩散模型在 VAE 的潜在空间内进行训练,可以成功地用 transformer 替换 U-Net 主干。本文进一步表明 DiT 是扩散模型的可扩展架构:网络复杂性(由 Gflops 测量)与样本质量(由 FID 测量)之间存在很强的相关性。通过简单地扩展 DiT 并训练具有高容量主干(118.6 Gflops)的 LDM,可以在类条件 256 × 256 ImageNet 生成基准上实现 2.27 FID 的最新结果。
Diffusion Transformers
DiTs 是一种用于扩散模型的新架构,目标是尽可能忠实于标准 transformer 架构,以保留其可扩展性。DiT 保留了 ViT 的许多最佳实践,图 3 显示了完整 DiT 体系架构。
DiT 的输入为空间表示 z(对于 256 × 256 × 3 图像,z 的形状为 32 × 32 × 4)。DiT 的第一层是 patchify,该层通过将每个 patch 线性嵌入到输入中,以此将空间输入转换为一个 T token 序列。patchify 之后,本文将标准的基于 ViT 频率的位置嵌入应用于所有输入 token。
patchify 创建的 token T 的数量由 patch 大小超参数 p 决定。如图 4 所示,将 p 减半将使 T 翻四倍,因此至少能使 transformer Gflops 翻四倍。本文将 p = 2,4,8 添加到 DiT 设计空间。
DiT 块设计:在 patchify 之后,输入 token 由一系列 transformer 块处理。除了噪声图像输入之外,扩散模型有时还会处理额外的条件信息,例如噪声时间步长 t、类标签 c、自然语言等。本文探索了四种以不同方式处理条件输入的 transformer 块变体。这些设计对标准 ViT 块设计进行了微小但重要的修改。所有模块的设计如图 3 所示。
本文尝试了四种因模型深度和宽度而异的配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。这些模型配置范围从 33M 到 675M 参数,Gflops 从 0.4 到 119 。
实验
研究者训练了四个最高 Gflop 的 DiT-XL/2 模型,每个模型使用不同的 block 设计 ——in-context(119.4Gflops)、cross-attention(137.6Gflops)、adaptive layer norm(adaLN,118.6Gflops)或 adaLN-zero(118.6Gflops)。然后在训练过程中测量 FID,图 5 为结果。
扩展模型大小和 patch 大小。图 2(左)给出了每个模型的 Gflops 和它们在 400K 训练迭代时的 FID 概况。可以发现,增加模型大小和减少 patch 大小会对扩散模型产生相当大的改进。
图 6(顶部)展示了 FID 是如何随着模型大小的增加和 patch 大小保持不变而变化的。在四种设置中,通过使 Transformer 更深、更宽,训练的所有阶段都获得了 FID 的明显提升。同样,图 6(底部)展示了 patch 大小减少和模型大小保持不变时的 FID。研究者再次观察到,在整个训练过程中,通过简单地扩大 DiT 处理的 token 数量,并保持参数的大致固定,FID 会得到相当大的改善。
图 8 中展示了 FID-50K 在 400K 训练步数下与模型 Gflops 的对比:
SOTA 扩散模型 256×256 ImageNet。在对扩展分析之后,研究者继续训练最高 Gflop 模型 DiT-XL/2,步数为 7M。图 1 展示了该模型的样本,并与类别条件生成 SOTA 模型进行比较,表 2 中展示了结果。
当使用无分类器指导时,DiT-XL/2 优于之前所有的扩散模型,将之前由 LDM 实现的 3.60 的最佳 FID-50K 降至 2.27。如图 2(右)所示,相对于 LDM-4(103.6 Gflops)这样的潜在空间 U-Net 模型来说,DiT-XL/2(118.6 Gflops)计算效率高得多,也比 ADM(1120 Gflops)或 ADM-U(742 Gflops)这样的像素空间 U-Net 模型效率高很多。
表 3 展示了与 SOTA 方法的比较。XL/2 在这一分辨率下再次胜过之前的所有扩散模型,将 ADM 之前取得的 3.85 的最佳 FID 提高到 3.04。
更多研究细节,可参考原论文。
以上是从U-Net到DiT:Transformer技术在统治扩散模型中的应用的详细内容。更多信息请关注PHP中文网其他相关文章!

热AI工具

Undresser.AI Undress
人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover
用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool
免费脱衣服图片

Clothoff.io
AI脱衣机

Video Face Swap
使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热门文章

热工具

记事本++7.3.1
好用且免费的代码编辑器

SublimeText3汉化版
中文版,非常好用

禅工作室 13.0.1
功能强大的PHP集成开发环境

Dreamweaver CS6
视觉化网页开发工具

SublimeText3 Mac版
神级代码编辑软件(SublimeText3)

使用C 中的chrono库可以让你更加精确地控制时间和时间间隔,让我们来探讨一下这个库的魅力所在吧。C 的chrono库是标准库的一部分,它提供了一种现代化的方式来处理时间和时间间隔。对于那些曾经饱受time.h和ctime折磨的程序员来说,chrono无疑是一个福音。它不仅提高了代码的可读性和可维护性,还提供了更高的精度和灵活性。让我们从基础开始,chrono库主要包括以下几个关键组件:std::chrono::system_clock:表示系统时钟,用于获取当前时间。std::chron

DMA在C 中是指DirectMemoryAccess,直接内存访问技术,允许硬件设备直接与内存进行数据传输,不需要CPU干预。1)DMA操作高度依赖于硬件设备和驱动程序,实现方式因系统而异。2)直接访问内存可能带来安全风险,需确保代码的正确性和安全性。3)DMA可提高性能,但使用不当可能导致系统性能下降。通过实践和学习,可以掌握DMA的使用技巧,在高速数据传输和实时信号处理等场景中发挥其最大效能。

在C 中处理高DPI显示可以通过以下步骤实现:1)理解DPI和缩放,使用操作系统API获取DPI信息并调整图形输出;2)处理跨平台兼容性,使用如SDL或Qt的跨平台图形库;3)进行性能优化,通过缓存、硬件加速和动态调整细节级别来提升性能;4)解决常见问题,如模糊文本和界面元素过小,通过正确应用DPI缩放来解决。

C 在实时操作系统(RTOS)编程中表现出色,提供了高效的执行效率和精确的时间管理。1)C 通过直接操作硬件资源和高效的内存管理满足RTOS的需求。2)利用面向对象特性,C 可以设计灵活的任务调度系统。3)C 支持高效的中断处理,但需避免动态内存分配和异常处理以保证实时性。4)模板编程和内联函数有助于性能优化。5)实际应用中,C 可用于实现高效的日志系统。

在C 中测量线程性能可以使用标准库中的计时工具、性能分析工具和自定义计时器。1.使用库测量执行时间。2.使用gprof进行性能分析,步骤包括编译时添加-pg选项、运行程序生成gmon.out文件、生成性能报告。3.使用Valgrind的Callgrind模块进行更详细的分析,步骤包括运行程序生成callgrind.out文件、使用kcachegrind查看结果。4.自定义计时器可灵活测量特定代码段的执行时间。这些方法帮助全面了解线程性能,并优化代码。

在MySQL中,添加字段使用ALTERTABLEtable_nameADDCOLUMNnew_columnVARCHAR(255)AFTERexisting_column,删除字段使用ALTERTABLEtable_nameDROPCOLUMNcolumn_to_drop。添加字段时,需指定位置以优化查询性能和数据结构;删除字段前需确认操作不可逆;使用在线DDL、备份数据、测试环境和低负载时间段修改表结构是性能优化和最佳实践。

交易所内置量化工具包括:1. Binance(币安):提供Binance Futures量化模块,低手续费,支持AI辅助交易。2. OKX(欧易):支持多账户管理和智能订单路由,提供机构级风控。独立量化策略平台有:3. 3Commas:拖拽式策略生成器,适用于多平台对冲套利。4. Quadency:专业级算法策略库,支持自定义风险阈值。5. Pionex:内置16 预设策略,低交易手续费。垂直领域工具包括:6. Cryptohopper:云端量化平台,支持150 技术指标。7. Bitsgap:

C 中使用字符串流的主要步骤和注意事项如下:1.创建输出字符串流并转换数据,如将整数转换为字符串。2.应用于复杂数据结构的序列化,如将vector转换为字符串。3.注意性能问题,避免在处理大量数据时频繁使用字符串流,可考虑使用std::string的append方法。4.注意内存管理,避免频繁创建和销毁字符串流对象,可以重用或使用std::stringstream。
