理解 PyTorch 分布式 Autograd 设计

Autograd 是一个反向自动微分系统(或梯度计算引擎),基于记录所有的操作来构建一个有向无环图——Autograd 计算图,其中叶子节点是输入 Tensor,根节点 root 是输出 Tensor,通过跟踪图中从根节点 root 到叶子节点的路径上的操作,能够自动地计算出梯度。
在 PyTorch 中,模型训练的每一轮迭代,都会创建对应的 Autograd 计算图:在前向传播阶段动态地创建 Autograd 计算图,在反向传播阶段根据 Autograd 计算图来进行梯度的计算。

构建分布式 Autograd 计算图

对于分布式模型训练环境下,需要在各个节点(主机)之间进行大量的 RPC 调用,统一协调各个过程来完成模型的训练。PyTorch 实现的分布式 Autograd,在前向传播过程中构建 Autograd 计算图,并且基于 Autograd 计算图在反向传播过程中计算梯度。在前向传播过程中,PyTorch 持续跟踪各个 RPC 调用的情况,必须确保在反向传播过程中计算是正确的,所以 PyTorch 在实现过程中使用了 send、recv 这一对函数来进行跟踪,当执行 RPC 调用时将 send 和 recv 绑定到 Autograd 计算图上。

  • send 函数被绑定到 RPC 调用的源节点(Source Node)端,send 函数的输出边指向 RPC 的输入 Tensor 变量;在反向传播阶段,send 函数会接收从目的节点(Destination Node)端与之对应的 recv 函数发送过来的结果,作为 send 函数的输入
  • recv 函数被绑定到 RPC 调用的目的节点端,通过在目的节点端查询对应前向计算得到的结果作为 recv 的输入;在反向传播阶段,recv 函数执行得到的梯度结果被发送到源节点端对应的 send 函数

为了说明这个过程,以 PyTorch 文档中下面的简单计算为例:

import torch
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)

# Compute some loss.
loss = t5.sum()

这个例子非常简单,计算 t5 = (t1 + t2) * t4 的结果,其中 t1 + t2 的计算是通过 RPC 调用在 worker1 节点上计算,计算结果 t3 返回到 worker0 节点上,继续后面的乘法计算 t5 = t3 * t4。这个分布式计算的例子,在执行前向传播过程中生成 Autograd 计算图,在反向传播阶段使用它来计算梯度,如下图所示:
send_recv_functions
图中存在两个 send-recv 调用对,其中,在 Worker 0 上有 2 个 Autograd 计算图,分别以 mul 和 send 函数为根的 root;在 Worker 1 有 1 个计算图,根 root 是 send 函数。

FAST mode 算法

目前,PyTorch 已经实现了 FAST mode 算法,该算法考虑了对性能要求比较敏感的应用场景,通过设置较强的假设越是来简化分布式梯度计算。而 SMART mode 算法是更通用意义上的算法,当前正在进行中,还没有完成实现。
FAST mode 算法的关键假设:
每一个 send 函数在反向传播阶段只存在一个依赖,也就是说,通过一个 RPC 调用 send 函数只需要从目的节点端接收一个梯度结果。下面是 FAST mode 算法的基本流程:

  1. 在一个 Worker 节点上开始执行反向传播计算,从起始的根 root 开始,这就要求所有的根 root 必须是本地的。
  2. 为当前的分布式 Autograd 上下文对象(通过 dist_autograd.context()获得)查询所有的 send 函数(一次训练迭代中,Autograd 计算图中所有的 send-recv 对都保存在分布式 Autograd 上下文对象中)。
  3. 根据已经确定好的根 root,在本地计算这些根 root 的依赖关系,也包括所有 send 函数的依赖关系,然后启动本地 Autograd 引擎开始计算梯度。
  4. 当 Autograd 引擎执行 recv 函数时,recv 函数基于 RPC 调用将梯度发送给与之对应的 send 函数所在的节点端,实际上 recv 函数只需要发送对应的两个 ID 即可:autograd_context_id(唯一对应于一次迭代的分布式 Autograd 上下文对象)和 autograd_message_id(唯一对应于一个 send-recv 对)。
  5. 远程节点端接收到对应的 autograd_context_id 和 autograd_message_id,查询找到对应的 send 函数,如果这是第一次在该节点上接收到 autograd_context_id,会在该节点本地计算与 autograd_context_id 对应的所有依赖关系。
  6. send 函数被放到执行队列中,等待调度执行,得到该 send-recv 对对应的 RPC 调用结果,并返回给调用端。
  7. 单独为每个分布式 Autograd 上下文对象计算累加梯度(Accumulated Gradient),计算结果保存在 Dict[Tensor, Tensor] 结构中,基于这个 Dict 可以通过一个给定的 Tensor 得到它对应的累加梯度。

为了说明 FAST mode 算法的流程,下面通过一个简单的例子来加深理解计算的过程:

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:

# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

  # Run the backward pass.
  dist_autograd.backward(context_id, [loss])

  # Retrieve the gradients from the context.
  dist_autograd.get_gradients(context_id)

通过前向传播计算,会得到 Autograd 计算图,下图描述了 Autograd 计算图以及在进行分布式 Autograd 计算过程中得到的依赖关系:
fast_mode_distributed_dependencies_computed
通过上图可以看到,在前向传播阶段,构建好了 Autograd 计算图,其中:Worker 0 上有两个子图,它们的根 root 分别是 mul1 和 send1;Worker 2 上有一个计算图,根 root 为 send2。
分布式 Autograd 梯度的详细计算过程,描述如下:

  1. 在 Worker 0 上,从 loss 和 send1 开始,计算依赖关系:send1 有 1 个依赖、mul1 有 1 个依赖。
  2. 在 Worker 0 上启动 Autograd 引擎,首先执行 mul1 函数,将结果在对应的分布式 Autograd 上下文对象中进行累加,对应着 t4;然后执行 recv2 函数,将计算结果梯度发送给 Worker 1 上。
  3. Worker 1 第一次得知在进行反向传播计算,所以首先计算本地依赖关系:send2 有 1 个依赖、add1 有 1 个依赖、recv1 有 1 个依赖。
  4. 然后在 Worker 1 上启动本地 Autograd 引擎,并将 send2 加入到执行队列等待调度,接着依次执行 add1、recv1 函数,当 recv1 计算完成后,会将梯度计算结果发送给 Worker 0。
  5. Worker 0 接收到 recv1 的梯度结果,在本地执行 send1 函数。
  6. 最后,t1、 t2、t4 的梯度都会在当前的分布式 Autograd 上下文对象中进行累加计算,这样就完成了一轮迭代的分布式梯度计算。

参考资源

Creative Commons License

本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系

发表评论

电子邮件地址不会被公开。 必填项已用*标注

您可以使用这些HTML标签和属性: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>