FSDP-Zero123分片简介

随着模型越来越大,单卡显存已经无法满足训练需求,人们使用多种手段来扩大训练规模,如使用混合精度减少参数内存占用,或是使用多卡通信的方法进行并行训练。ZeRO 是常见的数据并行模型状态分片方案,思路为将模型权重、梯度、优化器状态等训练状态分片到多张卡上。

这里需要先区分 ZeRO 分片和 tensor parallel。ZeRO-3 也会分片模型权重,但它分片的是模型状态的常驻存储:每张卡平时只保存一部分权重,真正计算某一层时再临时 AllGather 出这一层需要的完整参数。tensor parallel 则是把一层里的矩阵乘、attention head、MLP hidden dim 等计算本身切开,多张卡共同完成同一个 batch 的同一层计算。ZeRO 的核心问题是“状态放不下,怎么存”,tensor parallel 的核心问题是“单层算不动,怎么切计算”。

全量微调一个7.5B模型时,显存中的参数组成,以及zero并行方案示意图
全量微调一个7.5B模型时,显存中的参数组成,以及zero并行方案示意图

参数组成

假设使用AdamW + BF16 进行混合精度训练,则在训练过程中,需要存储的参数组成如下:

组件精度每参数字节数用途
模型权重(计算用)BF162B前向/反向计算
梯度BF162B反向传播产出,供优化器更新参数
FP32 主权重(优化器)FP324B保存更高精度的参数副本,用于稳定更新
一阶动量 m_tFP324BAdamW 对梯度均值的指数滑动平均
二阶动量 v_tFP324BAdamW 对梯度平方的指数滑动平均

BF16 模型权重主要服务于矩阵乘等实际计算,可以显著降低显存和带宽压力;梯度是反向传播阶段产生的临时训练信号,优化器会读取它来决定参数更新方向;FP32 主权重、m_tv_t 则属于 AdamW 的优化器状态,通常用更高精度保存,以减少长时间训练中的数值误差。

可见训练情况下,参数占用的大头在优化器状态上

不分片时:每张卡都持有完整的一份,即 N × (2+2+4+4+4) = 16N 字节(不算激活值)。

训练流程梳理

先不考虑 ZeRO 分片,只看普通多卡 DDP 的训练流程。DDP 中每张 GPU 都持有一份完整的模型权重和优化器状态,但每张卡读取的数据不同。例如 8 卡训练时,一个 global batch 会被切成 8 份 micro-batch,每张卡只负责其中一份。前向传播时,每张卡用自己本地的完整 BF16 模型权重计算输出,并保存反向传播需要的激活值。

反向传播时,每张卡根据自己的 micro-batch 计算本地梯度。此时不同 GPU 上的梯度并不相同,因为它们看到的数据不同。如果直接用各自的梯度更新参数,8 张卡的模型副本会从这一步开始分叉,后续训练就不再等价于单个大 batch 的训练。因此 DDP 会在反向传播过程中对梯度做 AllReduce,把所有卡的梯度求和或平均,并把同步后的梯度写回每张卡本地。

梯度同步完成后,每张卡都拥有相同的模型权重、相同的平均梯度,以及相同的 AdamW 优化器状态。随后每张卡各自在本地执行 optimizer step,读取 BF16 梯度、FP32 主权重、m_tv_t,计算出完全相同的参数更新。因为更新前的状态相同、梯度也已经同步,所以更新后的模型副本仍然保持一致。

这个流程的好处是实现简单、计算并行度高,每张卡都可以独立完成完整模型的前向和反向。代价是显存冗余很大:模型权重、梯度和优化器状态在每张卡上都完整复制了一份。ZeRO 后续要优化的正是这部分冗余存储。

ZeRO Stage 1:分片优化器状态

ZeRO-1 先处理最大的一块冗余:优化器状态。在普通 DDP 里,每张卡不仅保存完整的 BF16 计算权重和 BF16 梯度,还保存完整的 FP32 主权重、m_tv_t。如果使用 AdamW,这三项合起来是 12N 字节,比 BF16 模型权重本身大很多。ZeRO-1 的做法是:BF16 模型权重和 BF16 梯度仍然完整复制在每张卡上,但把 FP32 主权重、m_tv_t 按参数范围切成 8 份,每张卡只负责其中 1 份。

text
┌──────────────────────────────────────────────────┐
│              8 张 GPU 各自持有                      │
│                                                    │
│  完整 BF16 模型权重           N × 2B              │
│  完整 BF16 梯度               N × 2B              │
│                                                    │
│  分片优化器状态                                    │
│     • 1/8 的 FP32 主权重       N/8 × 4B            │
│     • 1/8 的 m_t              N/8 × 4B             │
│     • 1/8 的 v_t              N/8 × 4B             │
└──────────────────────────────────────────────────┘

训练时,ZeRO-1 的前向和反向基本仍然像 DDP 一样执行。每张卡都有完整 BF16 权重,所以可以独立完成自己 micro-batch 的 forward/backward;反向传播得到的梯度也会像普通 DDP 一样做 AllReduce,使每张卡都拿到完整的平均梯度。

区别发生在 optimizer step。由于 GPU0 只保存第 0 片参数的 FP32 主权重和 AdamW 动量,GPU1 只保存第 1 片,以此类推,所以每张卡只更新自己负责的那 1/8 参数。更新完成后,各卡会把自己更新好的参数分片同步给其他卡,所有卡再把这些分片拼回本地完整的 BF16 模型权重。这样下一轮 forward 时,每张卡仍然能看到完整且一致的模型。

从通信角度看,ZeRO-1 仍然保留 DDP 的梯度 AllReduce;额外需要在参数更新后同步各自更新的参数分片。它的优点是对计算流程改动较小,因为 forward/backward 期间每张卡仍有完整模型;缺点是梯度和 BF16 权重还没有省,显存优化只作用在优化器状态上。

每卡显存:2N + 2N + (4N+4N+4N)/8 = 5.5N 字节(vs 原来 16N)

节省的是:FP32 主权重、m_tv_t 这 12N 优化器状态从完整保存变成按卡分片。

ZeRO Stage 2:分片优化器状态 + 梯度

ZeRO-2 在 ZeRO-1 的基础上继续观察一个事实:既然每张卡只负责更新 1/8 的参数,那么它其实也只需要这 1/8 参数对应的平均梯度。Stage 1 中每张卡保留完整梯度,是为了和普通 DDP 的流程保持一致;Stage 2 则把这部分也分片掉。

text
┌──────────────────────────────────────────────────┐
│              8 张 GPU 各自持有                      │
│                                                    │
│  完整 BF16 模型权重           N × 2B              │
│                                                    │
│  分片梯度                                          │
│     • 1/8 的 BF16 梯度        N/8 × 2B             │
│                                                    │
│  分片优化器状态                                    │
│     • 1/8 的 FP32 主权重       N/8 × 4B            │
│     • 1/8 的 m_t              N/8 × 4B             │
│     • 1/8 的 v_t              N/8 × 4B             │
└──────────────────────────────────────────────────┘

普通 DDP 的梯度同步可以理解为“先把所有卡的梯度求和,再把完整结果发回每张卡”。ZeRO-2 把这个过程改成 Reduce-Scatter:仍然对所有卡的梯度求和或平均,但最终结果不是完整复制给每张卡,而是按参数范围 scatter 出去。GPU0 只拿到第 0 片平均梯度,GPU1 只拿到第 1 片平均梯度,刚好和各自保存的优化器状态对应。

这样做在数学上仍然等价于 DDP 的全局平均梯度更新,只是每张卡不再保存自己不负责的梯度部分。随后 optimizer step 和 ZeRO-1 类似:每张卡读取自己的梯度分片、FP32 主权重分片、m_t 分片和 v_t 分片,更新自己负责的参数。因为 BF16 模型权重在 Stage 2 里仍然完整复制在每张卡上,所以更新后的参数分片仍需要同步给所有卡,让每张卡本地的完整 BF16 权重保持一致。

实际实现通常会按 bucket 或按 layer 边反向、边 Reduce-Scatter、边释放梯度缓存,而不是等整个模型的梯度都算完再统一处理。因此下面的公式更接近稳定状态下的模型状态占用;真实峰值还会受到 bucket 大小、梯度累积、通信重叠策略影响。

每卡显存:2N + (2N+4N+4N+4N)/8 = 3.75N 字节

节省的是:在 Stage 1 基础上,梯度也从 2N → 0.25N。

ZeRO Stage 3:分片一切(优化器状态 + 梯度 + 模型权重)

ZeRO-3 再进一步,把计算用的 BF16 模型权重也分片保存。到这一步,每张卡常驻的模型状态只剩自己负责的 1/8:BF16 参数分片、BF16 梯度分片、FP32 主权重分片、m_t 分片和 v_t 分片。它和 tensor parallel 的关键区别是:ZeRO-3 并不是把某个矩阵乘的计算永久切给不同 GPU,而是把参数平时分片存放;真正执行某个 module 的计算前,再临时收集出这个 module 需要的完整参数。

text
┌──────────────────────────────────────────────────┐
│              8 张 GPU 各自持有                      │
│                                                    │
│  分片 BF16 模型权重           N/8 × 2B            │
│  分片 BF16 梯度               N/8 × 2B            │
│  分片 FP32 主权重             N/8 × 4B            │
│  分片 m_t                     N/8 × 4B            │
│  分片 v_t                     N/8 × 4B            │
└──────────────────────────────────────────────────┘

前向传播时,框架会在进入某一层或某个 FSDP unit 前,对这一段参数做 AllGather。AllGather 之后,每张卡临时拥有这一层的完整 BF16 权重,于是可以像普通数据并行一样,用自己的 micro-batch 完成这一层计算。计算结束后,非本卡负责的参数分片可以释放,只保留本卡原本负责的 shard。

反向传播时也类似。计算某一层的 backward 需要对应的权重,因此可能再次 AllGather 这一层参数;得到梯度后,再用 Reduce-Scatter 把梯度规约并切回各个 rank。最终每张卡只保留自己负责的梯度分片,并用它更新本地的 FP32 主权重、m_tv_t 和 BF16 参数分片。

和 ZeRO-1/2 不同,ZeRO-3 更新完参数后不需要把完整模型权重常驻同步到每张卡,因为每张卡本来就不保存完整权重。下一次 forward 需要某一层参数时,再从各卡当前持有的参数分片 AllGather 出来即可。这个设计把常驻显存压到最低,但把参数通信放进了 forward/backward 的关键路径。

每卡显存:(2N+2N+4N+4N+4N)/8 = 2N 字节

这个 2N 指的是理想情况下每卡常驻的模型状态。真实训练峰值会更高,因为某些时刻会临时持有当前层 AllGather 出来的完整参数、通信 bucket、prefetch 参数以及激活值。ZeRO-3 的核心取舍就是:用更多、更频繁的参数 AllGather,换取模型状态显存近似按数据并行卡数线性下降。

通信代价的直觉

  • ZeRO-1:梯度同步仍然接近普通 DDP,额外多了参数更新后的同步,整体通信压力通常可控。
  • ZeRO-2:用 Reduce-Scatter 替代 AllReduce 梯度同步,通信总量和 DDP 同阶,但每张卡只保留梯度分片。
  • ZeRO-3:前向和反向会按层 AllGather 参数,反向后再 Reduce-Scatter 梯度。它的显存收益最大,但速度更依赖网络带宽、分片粒度、通信计算重叠和框架实现。

Zero分片与张量并行组合使用

假设有 8 张 GPU,设置 tp_size=2dp_size=4,那么可以把相邻两张卡组成一个 TP group,共同切分同一个模型副本中的张量计算:

text
TP groups:
  [GPU0, GPU1], [GPU2, GPU3], [GPU4, GPU5], [GPU6, GPU7]

这 4 个 TP group 之间再构成数据并行关系,每个 group 处理不同的 mini-batch。为了让同一个 tensor-parallel 分片位置之间同步梯度或分片优化器状态,通常会按 TP rank 位置组成 DP group:

text
DP groups:
  [GPU0, GPU2, GPU4, GPU6]  # 每个模型副本里的第 0 个 TP 分片
  [GPU1, GPU3, GPU5, GPU7]  # 每个模型副本里的第 1 个 TP 分片

这样,一层内部的计算由 TP group 内的两张卡协作完成;不同数据 batch 之间的梯度同步、优化器状态分片,则发生在 DP group 内。ZeRO/FSDP 通常作用在 DP group 维度上,而不是替代 tensor parallel 的算子切分。

FSDP 和 ZeRO 的关系

PyTorch FSDP 可以理解为 ZeRO 思路在 PyTorch 里的实现之一。常见对应关系是:

FSDP 策略近似对应说明
NO_SHARDDDP不分片参数、梯度和优化器状态
SHARD_GRAD_OPZeRO-2分片梯度和优化器状态,参数在计算时保持完整
FULL_SHARDZeRO-3分片参数、梯度和优化器状态
HYBRID_SHARDZeRO-3 + 层级分组组内分片,组间复制,常用于多机多卡

实际训练时还需要额外考虑激活值、通信 bucket、临时 AllGather 出来的完整参数、CUDA allocator 碎片以及 checkpoint 保存方式,所以这些公式更适合用来建立显存组成的直觉,而不是直接等同于最终峰值显存。

← 返回文章列表