目录
OpenDiT 方法介绍
安装与使用
DiT 复现结果
首页 科技周边 人工智能 想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

Feb 29, 2024 pm 04:34 PM
ai 模型

Sora 在 2024 年初的惊艳表现成为了新的标杆,激励着所有研究文生视频的人士争相追赶。每个研究者都怀着复现 Sora 成果的渴望,争分夺秒地努力着。

根据 OpenAI 披露的技术报告,Sora 的一个重要创新点是将视觉数据转换为 patch 的统一表示形式,并通过 Transformer 和扩散模型相结合,展现了出色的扩展性。随着报告的发布,Sora 的核心研发人员 William Peebles 和纽约大学计算机科学助理教授谢赛宁合作撰写的《Scalable Diffusion Models with Transformers》论文备受研究者关注。研究界希望通过论文中提出的 DiT 架构,探索再现 Sora 的可行性途径。

最近,新加坡国立大学尤洋团队开源的一个名为 OpenDiT 的项目为训练和部署 DiT 模型打开了新思路。

OpenDiT是一个专为提升DiT应用程序的训练和推理效率而设计的系统,它不仅易于操作,而且速度快且内存利用高效。该系统涵盖了文本到视频生成和文本到图像生成等功能,旨在为用户提供高效、便捷的体验。

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

项目地址:https://github.com/NUS-HPC-AI-Lab/OpenDiT

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

OpenDiT 方法介绍

OpenDiT 提供由 Colossal-AI 支持的 Diffusion Transformer (DiT) 的高性能实现。在训练时,视频和条件信息分别被输入到相应的编码器中,作为DiT模型的输入。随后,通过扩散方法进行训练和参数更新,最终将更新后的参数同步至EMA(Exponential Moving Average)模型。推理阶段则直接使用EMA模型,将条件信息作为输入,从而生成对应的结果。

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

图源:https://www.zhihu.com/people/berkeley-you-yang

OpenDiT 利用了 ZeRO 并行策略,将 DiT 模型参数分布到多台机器上,初步降低了显存压力。为了取得更好的性能与精度平衡,OpenDiT 还采用了混合精度的训练策略。具体而言,模型参数和优化器使用 float32 进行存储,以确保更新的准确性。在模型计算的过程中,研究团队为 DiT 模型设计了 float16 和 float32 的混合精度方法,以在维持模型精度的同时加速计算过程。

DiT 模型中使用的 EMA 方法是一种用于平滑模型参数更新的策略,可以有效提高模型的稳定性和泛化能力。但是会额外产生一份参数的拷贝,增加了显存的负担。为了进一步降低这部分显存,研究团队将 EMA 模型分片,并分别存储在不同的 GPU 上。在训练过程中,每个 GPU 只需计算和存储自己负责的部分 EMA 模型参数,并在每次 step 后等待 ZeRO 完成更新后进行同步更新。

FastSeq

在 DiT 等视觉生成模型领域,序列并行性对于有效的长序列训练和低延迟推理是必不可少的。

然而,DeepSpeed-Ulysses、Megatron-LM Sequence Parallelism 等现有方法在应用于此类任务时面临局限性 —— 要么是引入过多的序列通信,要么是在处理小规模序列并行时缺乏效率。

为此,研究团队提出了 FastSeq,一种适用于大序列和小规模并行的新型序列并行。FastSeq 通过为每个 transformer 层仅使用两个通信运算符来最小化序列通信,利用 AllGather 来提高通信效率,并策略性地采用异步 ring 将 AllGather 通信与 qkv 计算重叠,进一步优化性能。

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

算子优化

在 DiT 模型中引入 adaLN 模块将条件信息融入视觉内容,虽然这一操作对模型的性能提升至关重要,但也带来了大量的逐元素操作,并且在模型中被频繁调用,降低了整体的计算效率。为了解决这个问题,研究团队提出了高效的 Fused adaLN Kernel,将多次操作合并成一次,从而增加了计算效率,并且减少了视觉信息的 I/O 消耗。

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

图源:https://www.zhihu.com/people/berkeley-you-yang

简单来说,OpenDiT 具有以下性能优势:

1、在 GPU 上加速高达 80%,50%的内存节省

  • 设计了高效的算子,包括针对DiT设计的 Fused AdaLN,以及 FlashAttention、Fused Layernorm 和HybridAdam。
  • 采用混合并行方法,包括 ZeRO、Gemini 和 DDP。对 ema 模型进行分片也进一步降低了内存成本。

2、FastSeq:一种新颖的序列并行方法

  • 专为类似 DiT 的工作负载而设计,在这些应用中,序列通常较长,但参数相比于 LLM 较小。
  • 节点内序列并行可节省高达 48% 的通信量。
  • 打破单个 GPU 的内存限制,减少整体训练和推理时间。

3、易于使用

  • 只需几行代码的修改,即可获得巨大的性能提升。
  • 用户无需了解分布式训练的实现方式。

4、文本到图像和文本到视频生成完整 pipeline

  • 研究人员和工程师可以轻松使用 OpenDiT pipeline 并将其应用于实际应用,而无需修改并行部分。
  • 研究团队通过在 ImageNet 上进行文本到图像训练来验证 OpenDiT 的准确性,并发布了检查点(checkpoint)。

安装与使用

要使用 OpenDiT,首先要安装先决条件:

  • Python >= 3.10
  • PyTorch >= 1.13(建议使用 >2.0 版本)
  • CUDA >= 11.6

建议使用 Anaconda 创建一个新环境(Python >= 3.10)来运行示例:

conda create -n opendit pythnotallow=3.10 -yconda activate opendit
登录后复制

安装 ColossalAI:

git clone https://github.com/hpcaitech/ColossalAI.gitcd ColossalAIgit checkout adae123df3badfb15d044bd416f0cf29f250bc86pip install -e .
登录后复制

安装 OpenDiT:

git clone https://github.com/oahzxl/OpenDiTcd OpenDiTpip install -e .
登录后复制

(可选但推荐)安装库以加快训练和推理速度:

# Install Triton for fused adaln kernelpip install triton# Install FlashAttentionpip install flash-attn# Install apex for fused layernorm kernelgit clone https://github.com/NVIDIA/apex.gitcd apexgit checkout 741bdf50825a97664db08574981962d66436d16apip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-optinotallow=--cpp_ext" --config-settings "--build-optinotallow=--cuda_ext" ./--global-optinotallow="--cuda_ext" --global-optinotallow="--cpp_ext"
登录后复制

图像生成

你可以通过执行以下命令来训练 DiT 模型:

# Use scriptbash train_img.sh# Use command linetorchrun --standalone --nproc_per_node=2 train.py \--model DiT-XL/2 \--batch_size 2
登录后复制

默认禁用所有加速方法。以下是训练过程中一些关键要素的详细信息:

  • plugin: 支持 ColossalAI、zero2 和 ddp 使用的 booster 插件。默认是 zero2,建议启用 zero2。
  • mixed_ precision:混合精度训练的数据类型,默认是 fp16。
  • grad_checkpoint: 是否启用梯度检查点。这节省了训练过程的内存成本。默认值为 False。建议在内存足够的情况下禁用它。
  • enable_modulate_kernel: 是否启用 modulate 内核优化,以加快训练过程。默认值为 False,建议在 GPU
  • enable_layernorm_kernel: 是否启用 layernorm 内核优化,以加快训练过程。默认值为 False,建议启用它。
  • enable_flashattn: 是否启用 FlashAttention,以加快训练过程。默认值为 False,建议启用。
  • sequence_parallel_size:序列并行度大小。当设置值 > 1 时将启用序列并行。默认值为 1,如果内存足够,建议禁用它。

如果你想使用 DiT 模型进行推理,可以运行如下代码,需要将检查点路径替换为你自己训练的模型。

# Use scriptbash sample_img.sh# Use command linepython sample.py --model DiT-XL/2 --image_size 256 --ckpt ./model.pt
登录后复制

视频生成

你可以通过执行以下命令来训练视频 DiT 模型:

# train with sciptbash train_video.sh# train with command linetorchrun --standalone --nproc_per_node=2 train.py \--model vDiT-XL/222 \--use_video \--data_path ./videos/demo.csv \--batch_size 1 \--num_frames 16 \--image_size 256 \--frame_interval 3# preprocess# our code read video from csv as the demo shows# we provide a code to transfer ucf101 to csv formatpython preprocess.py
登录后复制

使用 DiT 模型执行视频推理的代码如下所示:

# Use scriptbash sample_video.sh# Use command linepython sample.py \--model vDiT-XL/222 \--use_video \--ckpt ckpt_path \--num_frames 16 \--image_size 256 \--frame_interval 3
登录后复制

DiT 复现结果

为了验证 OpenDiT 的准确性,研究团队使用 OpenDiT 的 origin 方法对 DiT 进行了训练,在 ImageNet 上从头开始训练模型,在 8xA100 上执行 80k step。以下是经过训练的 DiT 生成的一些结果:

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

损失也与 DiT 论文中列出的结果一致:

想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速

要复现上述结果,需要更改 train_img.py 中的数据集并执行以下命令:

torchrun --standalone --nproc_per_node=8 train.py \--model DiT-XL/2 \--batch_size 180 \--enable_layernorm_kernel \--enable_flashattn \--mixed_precision fp16
登录后复制

感兴趣的读者可以查看项目主页,了解更多研究内容。

以上是想训练类Sora模型吗?尤洋团队OpenDiT实现80%加速的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您听不到任何人,如何修复音频
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解锁Myrise中的所有内容
4 周前 By 尊渡假赌尊渡假赌尊渡假赌

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

Laravel的地理空间:互动图和大量数据的优化 Laravel的地理空间:互动图和大量数据的优化 Apr 08, 2025 pm 12:24 PM

利用地理空间技术高效处理700万条记录并创建交互式地图本文探讨如何使用Laravel和MySQL高效处理超过700万条记录,并将其转换为可交互的地图可视化。初始挑战项目需求:利用MySQL数据库中700万条记录,提取有价值的见解。许多人首先考虑编程语言,却忽略了数据库本身:它能否满足需求?是否需要数据迁移或结构调整?MySQL能否承受如此大的数据负载?初步分析:需要确定关键过滤器和属性。经过分析,发现仅少数属性与解决方案相关。我们验证了过滤器的可行性,并设置了一些限制来优化搜索。地图搜索基于城

mysql 无法启动怎么解决 mysql 无法启动怎么解决 Apr 08, 2025 pm 02:21 PM

MySQL启动失败的原因有多种,可以通过检查错误日志进行诊断。常见原因包括端口冲突(检查端口占用情况并修改配置)、权限问题(检查服务运行用户权限)、配置文件错误(检查参数设置)、数据目录损坏(恢复数据或重建表空间)、InnoDB表空间问题(检查ibdata1文件)、插件加载失败(检查错误日志)。解决问题时应根据错误日志进行分析,找到问题的根源,并养成定期备份数据的习惯,以预防和解决问题。

mysql安装后怎么使用 mysql安装后怎么使用 Apr 08, 2025 am 11:48 AM

文章介绍了MySQL数据库的上手操作。首先,需安装MySQL客户端,如MySQLWorkbench或命令行客户端。1.使用mysql-uroot-p命令连接服务器,并使用root账户密码登录;2.使用CREATEDATABASE创建数据库,USE选择数据库;3.使用CREATETABLE创建表,定义字段及数据类型;4.使用INSERTINTO插入数据,SELECT查询数据,UPDATE更新数据,DELETE删除数据。熟练掌握这些步骤,并学习处理常见问题和优化数据库性能,才能高效使用MySQL。

了解 ACID 属性:可靠数据库的支柱 了解 ACID 属性:可靠数据库的支柱 Apr 08, 2025 pm 06:33 PM

数据库ACID属性详解ACID属性是确保数据库事务可靠性和一致性的一组规则。它们规定了数据库系统处理事务的方式,即使在系统崩溃、电源中断或多用户并发访问的情况下,也能保证数据的完整性和准确性。ACID属性概述原子性(Atomicity):事务被视为一个不可分割的单元。任何部分失败,整个事务回滚,数据库不保留任何更改。例如,银行转账,如果从一个账户扣款但未向另一个账户加款,则整个操作撤销。begintransaction;updateaccountssetbalance=balance-100wh

偏远的高级后端工程师(平台)需要圈子 偏远的高级后端工程师(平台)需要圈子 Apr 08, 2025 pm 12:27 PM

远程高级后端工程师职位空缺公司:Circle地点:远程办公职位类型:全职薪资:$130,000-$140,000美元职位描述参与Circle移动应用和公共API相关功能的研究和开发,涵盖整个软件开发生命周期。主要职责独立完成基于RubyonRails的开发工作,并与React/Redux/Relay前端团队协作。为Web应用构建核心功能和改进,并在整个功能设计过程中与设计师和领导层紧密合作。推动积极的开发流程,并确定迭代速度的优先级。要求6年以上复杂Web应用后端

mysql 能返回 json 吗 mysql 能返回 json 吗 Apr 08, 2025 pm 03:09 PM

MySQL 可返回 JSON 数据。JSON_EXTRACT 函数可提取字段值。对于复杂查询,可考虑使用 WHERE 子句过滤 JSON 数据,但需注意其性能影响。MySQL 对 JSON 的支持在不断增强,建议关注最新版本及功能。

Bangla 部分模型检索中的 Laravel Eloquent ORM) Bangla 部分模型检索中的 Laravel Eloquent ORM) Apr 08, 2025 pm 02:06 PM

LaravelEloquent模型检索:轻松获取数据库数据EloquentORM提供了简洁易懂的方式来操作数据库。本文将详细介绍各种Eloquent模型检索技巧,助您高效地从数据库中获取数据。1.获取所有记录使用all()方法可以获取数据库表中的所有记录:useApp\Models\Post;$posts=Post::all();这将返回一个集合(Collection)。您可以使用foreach循环或其他集合方法访问数据:foreach($postsas$post){echo$post->

掌握SQL LIMIT子句:控制查询中的行数 掌握SQL LIMIT子句:控制查询中的行数 Apr 08, 2025 pm 07:00 PM

SQLLIMIT子句:控制查询结果行数SQL中的LIMIT子句用于限制查询返回的行数,这在处理大型数据集、分页显示和测试数据时非常有用,能有效提升查询效率。语法基本语法:SELECTcolumn1,column2,...FROMtable_nameLIMITnumber_of_rows;number_of_rows:指定返回的行数。带偏移量的语法:SELECTcolumn1,column2,...FROMtable_nameLIMIToffset,number_of_rows;offset:跳过

See all articles