Autograd 是一个反向自动微分系统(或梯度计算引擎),基于记录所有的操作来构建一个有向无环图——Autograd 计算图,其中叶子节点是输入 Tensor,根节点 root 是输出 Tensor,通过跟踪图中从根节点 root 到叶子节点的路径上的操作,能够自动地计算出梯度。
在 PyTorch 中,模型训练的每一轮迭代,都会创建对应的 Autograd 计算图:在前向传播阶段动态地创建 Autograd 计算图,在反向传播阶段根据 Autograd 计算图来进行梯度的计算。
构建分布式 Autograd 计算图
对于分布式模型训练环境下,需要在各个节点(主机)之间进行大量的 RPC 调用,统一协调各个过程来完成模型的训练。PyTorch 实现的分布式 Autograd,在前向传播过程中构建 Autograd 计算图,并且基于 Autograd 计算图在反向传播过程中计算梯度。在前向传播过程中,PyTorch 持续跟踪各个 RPC 调用的情况,必须确保在反向传播过程中计算是正确的,所以 PyTorch 在实现过程中使用了 send、recv 这一对函数来进行跟踪,当执行 RPC 调用时将 send 和 recv 绑定到 Autograd 计算图上。
- send 函数被绑定到 RPC 调用的源节点(Source Node)端,send 函数的输出边指向 RPC 的输入 Tensor 变量;在反向传播阶段,send 函数会接收从目的节点(Destination Node)端与之对应的 recv 函数发送过来的结果,作为 send 函数的输入
- recv 函数被绑定到 RPC 调用的目的节点端,通过在目的节点端查询对应前向计算得到的结果作为 recv 的输入;在反向传播阶段,recv 函数执行得到的梯度结果被发送到源节点端对应的 send 函数
为了说明这个过程,以 PyTorch 文档中下面的简单计算为例:
import torch import torch.distributed.rpc as rpc def my_add(t1, t2): return torch.add(t1, t2) # On worker 0: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) # Perform some computation remotely. t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) # Perform some computation locally based on remote result. t4 = torch.rand((3, 3), requires_grad=True) t5 = torch.mul(t3, t4) # Compute some loss. loss = t5.sum()
这个例子非常简单,计算 t5 = (t1 + t2) * t4 的结果,其中 t1 + t2 的计算是通过 RPC 调用在 worker1 节点上计算,计算结果 t3 返回到 worker0 节点上,继续后面的乘法计算 t5 = t3 * t4。这个分布式计算的例子,在执行前向传播过程中生成 Autograd 计算图,在反向传播阶段使用它来计算梯度,如下图所示:
图中存在两个 send-recv 调用对,其中,在 Worker 0 上有 2 个 Autograd 计算图,分别以 mul 和 send 函数为根的 root;在 Worker 1 有 1 个计算图,根 root 是 send 函数。
FAST mode 算法
目前,PyTorch 已经实现了 FAST mode 算法,该算法考虑了对性能要求比较敏感的应用场景,通过设置较强的假设越是来简化分布式梯度计算。而 SMART mode 算法是更通用意义上的算法,当前正在进行中,还没有完成实现。
FAST mode 算法的关键假设:
每一个 send 函数在反向传播阶段只存在一个依赖,也就是说,通过一个 RPC 调用 send 函数只需要从目的节点端接收一个梯度结果。下面是 FAST mode 算法的基本流程:
- 在一个 Worker 节点上开始执行反向传播计算,从起始的根 root 开始,这就要求所有的根 root 必须是本地的。
- 为当前的分布式 Autograd 上下文对象(通过 dist_autograd.context()获得)查询所有的 send 函数(一次训练迭代中,Autograd 计算图中所有的 send-recv 对都保存在分布式 Autograd 上下文对象中)。
- 根据已经确定好的根 root,在本地计算这些根 root 的依赖关系,也包括所有 send 函数的依赖关系,然后启动本地 Autograd 引擎开始计算梯度。
- 当 Autograd 引擎执行 recv 函数时,recv 函数基于 RPC 调用将梯度发送给与之对应的 send 函数所在的节点端,实际上 recv 函数只需要发送对应的两个 ID 即可:autograd_context_id(唯一对应于一次迭代的分布式 Autograd 上下文对象)和 autograd_message_id(唯一对应于一个 send-recv 对)。
- 远程节点端接收到对应的 autograd_context_id 和 autograd_message_id,查询找到对应的 send 函数,如果这是第一次在该节点上接收到 autograd_context_id,会在该节点本地计算与 autograd_context_id 对应的所有依赖关系。
- send 函数被放到执行队列中,等待调度执行,得到该 send-recv 对对应的 RPC 调用结果,并返回给调用端。
- 单独为每个分布式 Autograd 上下文对象计算累加梯度(Accumulated Gradient),计算结果保存在 Dict[Tensor, Tensor] 结构中,基于这个 Dict 可以通过一个给定的 Tensor 得到它对应的累加梯度。
为了说明 FAST mode 算法的流程,下面通过一个简单的例子来加深理解计算的过程:
import torch import torch.distributed.autograd as dist_autograd import torch.distributed.rpc as rpc def my_add(t1, t2): return torch.add(t1, t2) # On worker 0: # Setup the autograd context. Computations that take # part in the distributed backward pass must be within # the distributed autograd context manager. with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) # Perform some computation remotely. t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) # Perform some computation locally based on remote result. t4 = torch.rand((3, 3), requires_grad=True) t5 = torch.mul(t3, t4) # Compute some loss. loss = t5.sum() # Run the backward pass. dist_autograd.backward(context_id, [loss]) # Retrieve the gradients from the context. dist_autograd.get_gradients(context_id)
通过前向传播计算,会得到 Autograd 计算图,下图描述了 Autograd 计算图以及在进行分布式 Autograd 计算过程中得到的依赖关系:
通过上图可以看到,在前向传播阶段,构建好了 Autograd 计算图,其中:Worker 0 上有两个子图,它们的根 root 分别是 mul1 和 send1;Worker 2 上有一个计算图,根 root 为 send2。
分布式 Autograd 梯度的详细计算过程,描述如下:
- 在 Worker 0 上,从 loss 和 send1 开始,计算依赖关系:send1 有 1 个依赖、mul1 有 1 个依赖。
- 在 Worker 0 上启动 Autograd 引擎,首先执行 mul1 函数,将结果在对应的分布式 Autograd 上下文对象中进行累加,对应着 t4;然后执行 recv2 函数,将计算结果梯度发送给 Worker 1 上。
- Worker 1 第一次得知在进行反向传播计算,所以首先计算本地依赖关系:send2 有 1 个依赖、add1 有 1 个依赖、recv1 有 1 个依赖。
- 然后在 Worker 1 上启动本地 Autograd 引擎,并将 send2 加入到执行队列等待调度,接着依次执行 add1、recv1 函数,当 recv1 计算完成后,会将梯度计算结果发送给 Worker 0。
- Worker 0 接收到 recv1 的梯度结果,在本地执行 send1 函数。
- 最后,t1、 t2、t4 的梯度都会在当前的分布式 Autograd 上下文对象中进行累加计算,这样就完成了一轮迭代的分布式梯度计算。
参考资源
- https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework
- https://pytorch.org/docs/stable/notes/autograd.html
- https://pytorch.org/docs/stable/rpc/distributed_autograd.html
- https://pytorch.org/blog/overview-of-pytorch-autograd-engine/
本文基于署名-非商业性使用-相同方式共享 4.0许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系。