PyTorch 流水线并行模式设计分析

流水线并行(Pipeline Parallelism)最早是 Google 在 Gpipe 论文中提出的,这种并行训练模式能够充分利用多 GPU 的资源高效地训练评估大模型。目前 PyTorch 最新版本是 2.2,流水线并行的功能是基于 torchgpipe 论文中的设计来实现的,该功能当前还处于试验阶段。

问题背景

大模型无法直接放到单个 GPU 中进行训练,通过模型并行(Model Parallelism)的方法可以把模型进行分片,每一个分片放置到一个 GPU 上,这样能够很好实现模型并行且利用多 GPU 的资源。虽然使用这种较为初级的方式能够实现大模型的训练,但在训练的过程中并不能充分利用 GPU 资源,因为对顺序(Sequential)模型来说它每次只能激活一个 GPU 来进行训练,其它的 GPU 此时是闲置的,所以在底层设备上其实仍然是顺序执行。
例如,对一个有 4 层的顺序(Sequential)神经网络模型,经过模型分片后,训练过程中每一层(或 Subnetwork)放在一个 GPU 上,先进行前向传播计算得到 Loss,然后反向传播计算梯度,如下图所示:
4-Layer-EXample-NN
使用这种方式利用 GPU 训练,我们可以看到在训练过程中 GPU 完全没有被充分利用,如下图示:
NN-Trained-In-4-GPUs-Sequentially
训练过程中,在同一时刻基本上只有一个 GPU 在使用,其它的都处于闲置状态,只有在完成前向传播计算、反向传播计算之后,同时使用多个 GPU 更新梯度信息。
对于上述多 GPU 不能被充分利用的问题,使用流水线并行功能在一定程度上能够得到很好的解决。
流水线并行的思路是:

  • 为简化说明,我们假设模型的每一层(或者子网络粒度)放到不同 GPU 上,将分派到每个 GPU 上每一小批数据(Mini-Batch),再次分成更小的微批(Micro-Batch)。
  • 前向传播阶段,第一个 GPU 上可以顺序处理每一个微批数据,即将这些微批数据输入给模型第一层处理;只要第一个微批处理完成,结果发到第二个 GPU 上,作为模型第二层的输入,此时第二个 GPU 和第一个 GPU 是在并行计算的;以此类推下去,就实现了流水线并行。
  • 反向传播阶段,等所有前向传播计算都完成后即开始,计算的思路和前向传播类似,只不过按照与前向传播计算相反顺序计算。

上面描述的流水线并行处理过程,如下图所示:
NN-Pipelined-In-4-GPUs
可见,相对于完全顺序执行,使用流水线并行处理极大地充分利用了多 GPU 资源并行处理的能力,模型训练时间也减少了。上图描述的流程中,标识 “Bubble” 表示有一些时刻还是存在只有单个或少量几个 GPU 同时并行处理,所以优化的目标也很明确,尽量减少 “Bubble” 的大小,就能够实现更好地利用 GPU 资源的目标。

流水线并行的基本思想

对要进行训练的神经网络进行分割,可以得到 N 个子网络(Subnetwork),假设模型对应的神经网络为 f,则分割后可以形式化表达为:
NN-Splitable-Subnetworks
上面对神经网络划分的 N 个子网络,是按顺序依赖的。同样,我们把训练数据集划分的每一个小批(Mini-Batch),分成 M 个微批(Micro-Batc),则对 M 个微批进行处理也是有顺序依赖的,它们要保证和未进行划分微批的一个小批数据进行处理的效果是相同的。
在设计实现流水线并行计算过程中,只要在满足上述两类依赖的前提下进行设计,就可以方便地实现不同的流水线并行的设计,而且可以不断地进行优化以使 GPU 的利用最大化。
假设 M = 4,N = 3,即训练数据集的每一个小批分割为 4 个微批,模型对应的神经网络分割为 3 个子网络。这样可以得到一个对前向传播计算和反向传播计算过程中,面向任务(Task)的最小依赖图,如下图所示:
NN-Minimal-Task-Dependency-Graph
上图中 F 和 B 都是计算任务,F 表示前向计算的任务,B 表示反向计算的任务。其中,Fi,j 表示第 i 个微批数据输入第 j 个子网络进行前向计算的任务,Bi,j 表示对 Fi,j 的计算结果进行反向的计算梯度任务。所以,合理调度上述任务到指定的 GPU 上计算,就能够实现模型并行训练。
基于上面的例子我们知道,对于第一个微批的前向和反向计算过程:

  • 任务 B1,1 要使用任务 F1,1 的结果计算梯度
  • 任务 B1,2 要使用任务 F1,2 的结果计算梯度
  • 任务 B1,3 要使用任务 F1,3 的结果计算梯度

其它各个微批的计算也是类似的,所以为方便理解说明,我们可以设计一个简单的不完全的调度策略,使反向计算任务能够复用前向任务计算的结果,即按照上述条件,把 F1,1 和 B1,1 分配到 GPU1 上处理,F1,2 和 B1,2 分配到 GPU2 上处理,F1,3 和 B1,3 分配到 GPU3 上处理,如下图所示:
Assign-FB-Tasks-To-GPUs
可以看到上面情况的调度过程,以及每个 GPU 上并行的状况:

  • F1,1(GPU1)计算完成,可以同时启动计算任务 F1,2(GPU2)和 F2,1(GPU1)
  • F1,2(GPU2) 和 F2,1(GPU1)计算完成,可以同时启动计算任务 F1,3(GPU3)、F2,2(GPU2)、F3,1(GPU1)
  • F1,3(GPU3)、F2,2(GPU2)、F3,1(GPU1) 计算完成,可以同时启动计算任务 B1,3(GPU3)、F2,3(GPU2)、F4,1(GPU1),F3,2 在 GPU2 上排队等待
  • 其它情况类似,满足任务计算顺序依赖的前提下,尽量保证每个 GPU 上都有任务执行

使用 Checkpointing 机制优化内存使用

Checkpointing 机制的原理是,在前向传播计算过程中,在每个数据分区的边界处保存计算结果张量,从而为下次使用该结果的计算提供输入,而其它所有的中间计算结果都不保存,直接丢弃掉。
对于任务依赖,无论采取什么样的调度策略,任务顺序都必须严格满足下面条件:

  • 只有 Fi,j-1 计算完成以后,才能计算 Fi,j
  • 只有 Bi,j+1 计算完成以后,才能计算 Bi,j
  • 只有 Fi,j 计算完成以后,才能计算 Bi,j

同时,也考虑任务调度到所在的 GPU,GPipe 给出的流水线并行的一组前置条件如下:

  • 全部任务的集合为 {Fi,j} 和 {Bi,j} ,其中 i=1,2,3…M,j=1,2,3…N
  • 模型分割成 N 个子网络,存在 K 个 GPU 设备,满足 K>=N 才能把模型的所有分片都放到所有的 GPU 内训练
  • 约束 Fi,j 和 Bi,j 都在第 j 个 GPU 上计算,其中 j=1,2,3…K,这个条件非常关键

所以基于 Checkpointing 的机制,对于前面给出的 M = 4,N = 3 的例子,对应的具有 Checkpointing 机制的流水线并行的任务依赖图如下所示:
Pipeline-Parallelism-With-Checkpointing
图中,相同的颜色表示同一个 GPU,实线箭头表示任务之间必须满足的前后依赖关系,虚线箭头表示根据微批数据的顺序得到的任务执行顺序,F’i,j 和 B’i,j 分别表示对任务 Fi,j 和 Bi,j 进行了重新计算。另外,j 也标识了 GPU 设备的编号,具有相同的索引位置 j 的任务 Fi,j、Bi,j、Fi,j(其中 i=1,2,3,4)都在第 j 个 GPU 上进行处理。

实现基于 Checkpointing 机制的流水线并行

Checkpointing 机制的核心思想是:

  • 设置一定策略,定期对计算过程中的数据进行临时保存,如果后续计算中会使用到之前计算的结果,直接读取结果能够减少重新计算带来的资源消耗;但是如果内存资源不足的情况下,就会删除之前保存的一些结果数据,释放内存资源以供当前其它计算任务使用;而将来的任务计算时用到之前需要的计算结果,并且这些结果已经被清除了,则会重新计算之前任务得到结果以供这类任务使用。

在 PyTorch 中,通过定义了一个特殊的 Autograd 函数来实现 Checkpointing 功能,Autograd 函数的功能如下:

  • 前向传播计算过程中和正常计算一样,不需要存储中间结果 Tensor,但是需要保存输入 Tensor;
  • 在反向传播计算过程中,基于前向计算已经保存的输入 Tensor,重新计算 Autograd 函数,从而创建一个本地计算图(Local Computing Graph),并通过该本地计算图来计算反向传播的梯度结果。

需要注意的是,在计算任务 Bi,j 时,需要重新计算任务 F’i,j,并且也需要将在第 j+1 个 GPU 内计算任务 Bi,j+1 的结果复制到 第 j 个 GPU 内,这两个步骤其实是可以同时进行的,PyTorch 通过定义 Checkpoint 和 Recompute 两个 Autograd 函数,实现了更加细粒度的灵活控制。在运行任务 Fi,j 时,会生成一块供 Checkpoint 和 Recompute 共享的内存,这块内存被用来传输本地计算图,这样就能够使重新计算任务 F’i,j 和传输任务 Bi,j+1 的结果同步进行。
PyTorch 实现了流水线并行,对应的 API 如下所示:

class torch.distributed.pipeline.sync.Pipe(
  module, chunks=1, checkpoint='except_last', deferred_batch_norm=False
)

其中参数 checkpoint 表示内存优化的选项,PyTorch 提供了三个 Checkpointing 的取值:

  • always:对所有的微批(Micro-Batch)都进行 Checkpointing
  • except_last:除了一个小批(Mini-Batch)中最后一个微批,其它所有的微批都进行 Checkpointing
  • never:完全禁用 Checkpointing

使用 PyTorch 实现的 Pipe 类非常直观,例如,将两个全连接层跨两个 GPU 实现流水线并行,示例代码如下所示:

# Need to initialize RPC framework first.
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)

# Build pipe.
fc1 = nn.Linear(16, 8).cuda(0)
fc2 = nn.Linear(8, 4).cuda(1)
model = nn.Sequential(fc1, fc2)
model = Pipe(model, chunks=8)
input = torch.rand(16, 16).cuda(0)
output_rref = model(input)

Pipe 类只支持实现 nn.Sequential 的模型,即顺序搭建神经网络各层(Module)的模型,而对于不完全是 nn.Sequential 的模型 PyTorch 也提供了支持,通过 torch.distributed.pipeline.sync.skip.skippable.skippable(stash=(), pop=()) 可以跳过对应的非 nn.Sequential 层,能够实现对模型中部分支持 nn.Sequential 的层进行流水线并行处理以实现模型的训练。
下面是一个使用 skippable(stash=(), pop=()) 函数实现的例子:

@skippable(stash=['1to3'])
class Layer1(nn.Module):
    def forward(self, input):
        yield stash('1to3', input)
        return f1(input)

class Layer2(nn.Module):
    def forward(self, input):
        return f2(input)

@skippable(pop=['1to3'])
class Layer3(nn.Module):
    def forward(self, input):
        skip_1to3 = yield pop('1to3')
        return f3(input) + skip_1to3

model = nn.Sequential(Layer1(), Layer2(), Layer3())

想要跳过对应的 Layer,需要为该 Layer 创建一个 nn.Module 装饰器,使用 @skippable 来装饰 class,表示该 nn.Module 是可跳过的。每一个需要跳过的 Tensor 通过为其静态地标记一个名称来使用,如上面的 ’1to3′。上面代码在第一层 Layer1 中对 input 张量进行 stash(),将张量 input 与 ’1to3′ 绑定;在最后一层 Layer3 进行 pop(),解除绑定。可见,上面代码实现了对三个 Layer 的跳过处理。

参考资源

Creative Commons License

本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>