理解 PyTorch DDP 分布式数据并行模式

PyTorch 使用 DDP(Distributed Data Parallel) 实现了真正的分布式数据并行,在下面的两个场景下都可以使用 DDP 实现模型的分布式训练:

  • 单机、多 GPU(单进程多线程的伪分布式)
  • 多机、多 GPU(多机多进程的真正分布式)

上面第一种方式,就是类似使用简单的 DP 数据并行模式,但是 DP 使用的单进程、多线程的范式来实现的;而 DDP 完全使用了多进程的方式,包括单机多进程、多机多进程,如果是多机的情形则对应着物理上的分布式多进程模式。为了获得更好的性能,最好是使用 DDP 模式来训练模型,即使是在单机、多 GPU 的情况下,也建议使用 DDP 模式来实现基于数据并行的模型训练,使用单机 DDP 模式训练模型的性能要比 DP 模式好很多。
DDP 基于集合通信(Collective Communications)来实现分布式训练过程中的梯度同步。在反向传播过程中,DDP 使用 AllReduce 来实现分布式梯度计算和同步。
下面,我们从与 DDP 相关的几个方面来理解 DDP 的设计与实现,包括:

  • 集合通信(Collective Communication)
  • 通信后端(Communication Backend)
  • DDP 内部实现概览
  • 分布式梯度 AllReduce
  • DDP API 实现

集合通信(Collective Communication)

DDP 模式是基于分布式数据并行的,所以在模型训练过程中要进行大量通信及数据传输,而且 DDP 是依赖 c10d 的 ProcessGroup 来实现集合通信(Collective Communication),要理解 DDP 对于 PyTorch 实现的集合通信是非常重要的一部分。集合通信是一种通信模式,它能够根据在配置好进程组中,实现方便地跨多进程间通信。
PyTorch 实现了 6 个集合通信函数,它们对应的功能,如下图所示:
Collective-Communications
对每个通信函数说明,如下表所示:

通信函数 功能说明
dist.scatter(tensor, scatter_list, src, group) 从 src 进程将 tensor scatter_list[i] 复制到第 i 个进程
dist.gather(tensor, gather_list, dst, group) 将所有进程的 tensor 收集到进程 dist
dist.reduce(tensor, dst, op, group) 将操作 op 应用到每个 tensor 并在 dist 进程保存结果
dist.all_reduce(tensor, op, group) 将操作 op 应用到每个 tensor 并在所有进程保存结果
dist.broadcast(tensor, src, group) 从进程 src 复制 tensor 到所有其它进程
dist.all_gather(tensor_list, tensor, group) 将所有进程的 tensor 复制到 tensor_list 中
dist.barrier(group) 阻塞所有进程,直到每个进程都进入该函数

通信后端(Communication Backend)

在 PyTorch 中实现分布式通信,是直接绑定到设置的通信后端(Communication Backends)上,可以根据使用场景选择合适的通信后端,而每个通信后端实现可以使用集合通信提供的各种通信函数。如果愿意,甚至可以直接使用点对点通信方式来实现更加灵活的通信方式。下面我们了解一下 PyTorch 支持的 3 个通信后端,可以直接在创建 DDP 的时候进行配置并创建。

  • Gloo Backend

Gloo 通信后端在开发时使用,它被预编译进入 PyTorch 库。Gloo 通信后端在 CPU 上支持所有的点对通信函数,在 GPU 上支持所有的集合通信函数。

  • MPI Backend

MPI 后端既支持点对点通信,也支持集合通信。PyTorch 二进制发行包并没有包括 MPI 实现,如果使用它需要我们自己手动编译。

  • NCCL Backend

NCCL 通信后端是直接预编译进 PyTorch 二进制发行包以支持 CUDA,它提供了基于 CUDA Tensor 的集合通信的优化实现。如果使用 CUDA Tensor 进行集合通信,NCCL 能够带来最好的性能。

关于每个通信后端,支持在 CPU 及 GPU 上使用的集合通信函数的情况,如下表所示:

通信后端 GLOO MPI NCCL
计算设备 CPU GPU CPU GPU CPU GPU
scatter × ×
gather × ? ×
reduce × ? ×
all_reduce ? ×
broadcast ? ×
all_gather × ? ×
barrier × ? ×
reduce_scatter × × × × ×

DDP 内部实现概览

在说明 DDP 内部实现之前,我们有必要从更 High-Level 的视角看一下,PyTorch 实现的 DDP API 所处的位置的,如下图所示:
DDP-Stacked-Code-Components
从代码文件的层面看,Python 语言实现的 DDP API 是如何依赖底层的 C++ 模块的实现的。DDP 通过使用 Python 语言实现面向模型开发用户友好的接口,使用用户专注于自己业务领域内的模型开发。在分布式训练模型过程中,需要进行频繁的通信,DDP 都是通过调用底层 C++ 提供的基本接口来实现的,所以使用 DDP 模式分布式训练模型具有非常好的性能。

  • 1.DDP 初始化

首先,在创建 DDP 的时候会在 rank0 进程上,将模型的副本同步到分布式训练集群中其它每个进程上,使每个进程持有的模型副本从开始具有相同的状态。
然后,每个 DDP 进程创建一个本地的 Reducer,Reducer 用来支持反向传播过程中的梯度同步。为了提高通信效率,Reducer 采用基于 Bucket 的方式来组织和管理模型的参数,亦即在多个 Bucket 中将模型参数(Model.parameters())以逆序的方式保存,这样可以方便在反向传播过程中直接对 Bucket 中的参数计算梯度,而不用在这个时候过多考虑和处理模型参数的顺序问题。
采用 Bucket 结构,会将模型参数映射到 Bucket 中,这样能够对 DDP 的速度产生重要的影响。在每一轮迭代中的反向传播阶段,会在所有模型参数与 Bucket 之间发生一次双向复制:先将模型参数复制到 Bucket 中,在执行 AllReduce 计算后再将平均梯度从 Bucket 中复制更新到模型参数上。为了加速复制操作,需要在模型参数保存的设备上创建与之对应的 Bucket。如果模型跨多个设备,这时就不再是单纯的 DDP 分布式数据并行了,模型无法放到单个设备中所以也必须进行分片,每个设备保存模型的一个分片,比如使用 PyTorch 的同时支持数据并行与模型并行的 FSDP(Fully Sharded Data Parallel) 模型,DDP 都会考虑模型设备亲和性(Model Device Affinity),也就是确保在同一个 Bucket 中的所有模型参数都在同一个设备上。
关于 Bucket 的具体结构,如下图所示:
DDP-Bucket-based-Gradient-Synchronization
图中示例了两个进程,其中模型参数 param0 和 param1 保存在 bucket1 中,param2 和 param3 保存在 bucket2 中。当反向传播时计算得到的 grad0 和 grad1 会保存在 bucket1 中,grad2 和 grad3 会保存在 bucket2 中。在这里还会为模型的每个参数注册一个 Autograd Hook,是为是为每个参数的梯度累加器(Gradient Accumulator)注册一个 Autograd Hook,当反向传播过程中参数的梯度更新完成后会触发 Autograd Hook 执行。

  • 2.前向传播过程

在前向传播过程中,DDP 将拿到的一批输入数据传给本地持有的模型,通过计算得到输出。

  • 3.反向传播过程

在反向传播过程中,通过调用 loss.backward() 执行反向传播计算,该调用实际上是不受 DDP 控制的,因为 DDP 在创建的时候为每个参数注册了一个 Autograd Hook,DDP 会等待 Autograd Hook 被触发从而判断是否需要进行梯度的同步。反向传播过程 Reducer 会遍历创建 DDP 时的多个存储模型参数的 Bucket,如果存在某个 Bucket 已经 ready 了,即 Bucket 中的所有参数对应的 Autograd Hook 都已经被触发,则会触发一个异步 AllReduce 信号。这种使用 Bucket 和注册 Autograd Hook 的机制,能够很好地实现的通信的 Overlapping,极大地降低了多次频繁通信带来的开销。
当所有的 Bucket 中还有没 ready 的,Reducer 会一直阻塞等待直到所有的 AllReduce 操作完成。当得到梯度均值结果以后,所有的进程中模型参数的梯度都要更新,对应的是每个参数的 .grad 字段,这样就保证训练集群中每个进程中的参数梯度是相同的。

  • 4.优化器执行优化

因为 DDP 模式中,所有的模型都被复制到每个进程中,所有对于优化器来说都是在本地进行模型的优化。多进程中的每个模型都保持状态同步,所以在每一轮训练迭代后,梯度的结果都是相同的。

分布式梯度 AllReduce

在整个模型训练过程中,分布式梯度的 AllReduce 具有非常重要的作用,在兼顾通信性能的前提下,AllReduce 能够使多个进程中的模型参数最终得到同步并保持状态一致。下面我们基于 DDP 的论文中给出的分布式 AllReduce 流程,来理解这个过程,如下图所示:
Distributed-Gradient-Reduction
图中示例了在两个 DDP 进程内维护着模型参数对应的 Bucket,当需要在多个进程之间同步梯度信息时,同步更新对应的 Bucket。然后,将 Bucket 内部保存的模型参数的梯度更新到 GPU 内。所以,使用 DDP 模式推荐的方法是:每个进程对应一个模型副本,而每个模型副本对应多个 GPU 设备。
前面我们提到的 Bucket、Reducer 以及 AllReduce,他们都是 PyTorch 最核心的内容,是在 reducer.cpp 中用 C++ 实现的。DDP 进程会在本地创建 Reducer,它会被用来管理整个反向传播阶段梯度的计算和同步。
下面,我们从模型参数变化的角度,来看一下在进行 AllReduce 阶段各个参数在计算过程是如何进行操作,整个过程可以分为 2 个大的阶段:

  • 1.Scatter-Reduce 阶段

在 Scatter-Reduce 阶段,在每个 GPU 中对一个 Mini-Batch 的数据执行模型训练,会经过一轮一轮迭代,以确定的顺序对各个 GPU 中参数梯度进行交换,最终的结果是每个 GPU 上都有模型一个参数的所有梯度,并且对其进行 Reduce 累加计算平均梯度,过程如下图所示:
Scatter-Reduce

  • 2.Allgather 阶段

经过上一个阶段,每个 GPU 上的一些参数并没有获取到所有进程上计算得到的梯度,所以并不能进行累加求平均,还需要进行 Allgather,把所有 GPU 每个参数的所有梯度都完整收集过来,从而才能进行累加并计算梯度的平均值,过程如下图所示:
Allgather

DDP API 实现

PyTorch DDP 提供的 API 如下所示:

class torch.nn.parallel.DistributedDataParallel(
  module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True,
  process_group=None, bucket_cap_mb=25, find_unused_parameters=False,
  check_reduction=False, gradient_as_bucket_view=False, static_graph=False,
  delay_all_reduce_named_params=None, param_to_hook_all_reduce=None,
  mixed_precision=None, device_mesh=None
)

参数 bucket_cap_mb 是配置 Bucket 的大小,默认是 25MB,它会影响 DDP 在初始化时创建 Bucket 的数量,一旦 DDP 初始化完成以后,Bucket 的数量就确定不变了,所以在实际使用时要根据具体的场景来确定 bucket_cap_mb 的大小。
参数 find_unused_parameters 设置为 True,在一定程度上会影响计算的速度,因为它会遍历 Autograd 计算图,如果不理解该参数的实际影响就使用默认值 False。
其它参数的含义和使用方法,可以参考官网文档。
创建一个 DDP 模型,代码示例:

 torch.distributed.init_process_group(
     backend='nccl', world_size=N, init_method='...'
 )
 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[0, 1, 2...], ...)

参考资源

Creative Commons License

本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>