DistriFusion
Abstract
DistriFusion 将模型输入分割成多个 patch 后分配给 GPU。但是直接实现这样的算法会破坏 patch 之间的交互并失去保真度,而同步 GPU 之间的激活将产生巨大的通信开销。为了克服这一困境,根据观察到的相邻扩散步输入之间的高度相似性提出了 displaced patch parallelism,该方法通过重用前一个时间步骤中预先计算的 feature map 来利用扩散过程的顺序性,为当前步提供 context. 该方法支持异步通信,可以通过计算实现流水线化。
Introduction
加速扩散模型推理主要集中在两种方法上:减少采样步骤和优化网络推理。随着计算资源的快速增长,利用多个 GPU 来加速推理是很有吸引力的。例如在 NLP 中, LLM 已经成功地利用了 GPU 之间的张量并行性,从而显著降低了延迟。然而,对于扩散模型,由于激活尺寸大,张量并行这样的技术不太适合扩散模型。多个 GPU 通常只用于 batch 推理,当生成单个图像时,通常只涉及一个GPU.
Techniques like tensor parallelism are less suitable for diffusion models due to the large activation size, as communication costs outweigh savings from distributed computation.
自然而然的一种方法是将图像分成几个 patch 后分配给不同的设备进行生成。由于各个 patch 之间缺乏相互作用,它在每个 patch 的边界处都有一个清晰可见的分界线。
DistriFusion 也是基于 patch parallelism. 关键在于扩散模型中相邻去噪步骤的输入是相似的,因此,只在第一步采用同步通信。后续步骤重用前一步中预先计算的激活,为当前步骤提供全局上下文和 patch 交互。通过异步通信有效地隐藏了计算中的通信开销。并且还稀疏地在指定的区域上进行卷积和注意力计算,从而按比例减少每个设备的计算量。
Method
Displaced Patch Parallelism.
在预测
我们对除第一层 (采用同步通信获得其他设备上的输入) 外的每一层重复这个过程。然后将最终输出 Gather 在一起以近似
Sparse Operations
对于每一层 l,如果原始算子 Fl 是一个卷积层、线性层或交叉注意层,调整使其专门作用于新激活的区域。这可以通过从 scatter 输出中提取最新部分并将其输入到 Fl 中来实现。对于 self-attention,将其转换为 cross-attention,仅在设备上保留来自新激活的 Q,而 KV 仍然包含整个特征图。
Corrected Asynchronous GroupNorm
仅对新 patch 进行归一化或重用旧特征都会降低图像质量。同步 AllGather 所有均值和方差将产生相当大的开销。为了解决这一困境,DistriFusion 在陈旧的统计数据中引入了一个校正项。计算公式如下
同样对二阶矩
Code Implementation
Distrifusion 中主要就是将 UNet2DConditionModel 中的 Conv2d, Attention 和 GroupNorm 替换成对应的 patch 实现的网络结构 DistriUNetPP. 这里继承的 BaseModel 类为集成了 PatchParallelismCommManager 类 (介绍见后文) 的网络。
1 |
|
PatchParallelismCommManager
PatchParallelismCommManager 类主要处理异步通信的部分。
1 |
|
成员函数功能介绍如下
register_tensor(self, shape: tuple[int, ...] or list[int], torch_dtype: torch.dtype, layer_type: str = None) -> int
: 用于注册张量的形状和数据类型,同时计算并记录张量在缓冲区中的起始位置和结束位置。- 如果尚未指定
torch_dtype
,则将传入的torch_dtype
设为类成员的默认数据类型。 - 计算传入张量形状的总元素数
numel
,并更新starts
、ends
和shapes
列表。 - 如果指定了
layer_type
,更新numel_dict
中该层类型对应的元素数目。
- 如果尚未指定
create_buffer(self)
: 每个设备上为所有注册的张量创建一个统一的缓冲区。- 为每个设备创建一个形状为
(numel,)
的张量,并将其放入buffer_list
中。 - 输出在各设备上创建的缓冲区总参数量。
- 为每个设备创建一个形状为
get_buffer_list(self, idx: int) -> list[torch.Tensor]
: 返回每个设备上对应于指定索引idx
的缓冲区张量。- 根据
starts
和ends
信息,从buffer_list
中提取指定索引idx
的张量片段并调整其形状。
- 根据
communicate(self)
: 调用dist.all_gather
将缓冲区中的张量在不同设备间进行广播。- 确定当前需要通信的张量范围 (根据
idx_queue
中的索引). - 调用
dist.all_gather
在设备组内进行异步广播通信,并将句柄存储在handles
中。
- 确定当前需要通信的张量范围 (根据
enqueue(self, idx: int, tensor: torch.Tensor)
: 将指定索引idx
处的张量数据复制到buffer_list
中,并将索引添加到通信队列idx_queue
。- 如果通信队列不为空且索引为 0,则先执行一次通信操作。
- 将张量数据复制到
buffer_list
中的对应位置。 - 当通信队列长度达到
distri_config
中设定的通信检查点值时,进行通信。
clear(self)
: 执行一次所有待通信张量的通信,并等待所有异步操作完成。- 如果通信队列不为空,则进行通信操作。
- 遍历所有句柄,等待所有异步操作完成后,将句柄设为
None
.
DistriConv2dPP
DistriConv2dPP 计算自己负责 patch 部分的卷积,需要通信其他设备需要自己负责 patch 的上下 padding 部分。
__init__
:构造函数,初始化成员变量,设置是否为第一层卷积。naive_forward
:执行标准的前向传播,不进行任何切片操作。这是单个设备处理时的普通卷积操作。sliced_forward
:处理输入张量的切片操作。根据当前设备索引 (split_idx
) 计算输入张量在高度方向的起始和结束位置,并在必要时为切片后的张量添加 padding 后进行卷积操作。
1 |
|
DistriSelfAttentionPP
DistriSelfAttentionPP 只负责计算自己 patch 的输出,需要完整的 KV,将 self attention 运算变成 cross-attention 计算。需要通信自己的 KV.
1 |
|
DistriGroupNorm
DistriGroupNorm 根据上一步全特征图的以及当前步 patch 的均值和二阶矩近似当前步的全特征图均值和方差。需要通信 patch 均值和二阶矩。
1 |
|