TVM Learning (11)-Add Model Architeture in MLC LLM
IRModule: The key concept in TVM Unity
IRModule 是张量函数的集合,代表我们需要在模型中执行的计算子集。例如,在 MLC-LLM 中,它可以是一个 Transformer 模块。
机器学习编译框架中的 IRModule 就像深度学习框架中的张量,是一切的基础。在整个编译流程中,模型将以 IRModule 的形式导入,然后以 IRModule 到 IRModule 的方式进行转换和优化,然后我们就可以在任何支持的平台上将 IRModule 转化为可运行的模块。IRModule 可以用 python 方式访问,例如,我们可以用 python AST 的形式显示它,以便检查、调整和调试。unity 的主要设计目标之一是实现单一抽象,将所有主要元素封装在同一模块中。这样,我们就能在此基础上进行有机的增量转换。
TVMScript 是 IRModule 的 python AST 格式,用于在整套转换过程中检查 IRModules 并与之交互。与 IRModule 的交互都可以使用 TVMScript 在 python 中进行。用户将 TVMScript 解析为 IRModule 内部结构,使用 python API 操作 IRModule,并将 IRModule 打印为 TVMScript 格式。
TVMScript Examples
用 Pytorch 框架实现矩阵乘法一般调用 torch.matmul
或者使用 @
算子。
1 |
|
在 Relax 中可以用 IRModule 实现相同的功能。
1 |
|
通过上述 TVMScript 创建的 IRModule 是一个完全图级别的抽象,只包含一个 R.function (Relax 函数: IRModule 中计算图的表示形式)
上述示例包含 Relax 函数中的两个重要概念:高级 Relax 算子和数据流块。
- Relax 函数包含高级 Relax 算子
R.matmul
,它描述计算图中的节点,不包含其底层实现的信息。一个高级 Relax 算子可以映射到不同的底层实现,TVM Unity 的编译流程会生成性能良好的实现。 R.dataflow()
是数据流块的一个重要作用域注解。具体来说,在数据流块内,所有操作都必须是 side-effect free. 而在数据流块之外,操作可能包含副作用。
A more complex TVMScript example: 2-layer MLP
下面我们以一个更复杂的两层 MLP 为例,模型结构如下。
其对应的 Pytoch 实现如下
1 |
|
对应的 IRModule 的 TVMScript 表示如下
1 |
|
上述 Relax 函数只包含高级 Relax 算子。在 pytorch 中,torch.nn.Linear
计算 在 relax 中,转置由 permute_dims 实现,其次是 矩阵乘法和加法分别由 R.matmul
和 R.add
实现。
Compilation Flow in TVM Unity
- 将模型导入 IRModule. 对于静态模型,我们可以使用 pytorch dynamo 将 pytorch 程序跟踪为 fx 图,然后转换为 IRModule。然而,LLM 通常是动态的,因为序列长度和 kv cache 长度都是可变的。在这种情况下,我们需要直接在 IRModule 中建立模型。第一步可以抽象为 LLM -> IRModule 转换。
- 优化模型。与传统编译器一样,我们可以在 IRModule 上应用 pass (IRModule 到 IRModule 的变换,改变计算但保留了原始 IRModule 的语义)。在这一步中,我们的目标是加速模型计算。在消费类设备上以适当速度运行 LLM 的大多数关键技术,如量化、算子融合和张量函数调度,都是在这一步实现的。
- 在设备上部署 IRModule。对于每个 IRM 模块,我们都能将其转化为可运行模块,并在 tvm 运行时支持的任何平台上运行。IRModule 上的每个函数都将成为环境中的本地可运行函数。
以下是 2 层 MLP 模型的编译流程
1 |
|
Build IRModule in Pytorch Style
构建 IRModule 最直接的方法是手动编写 TVMScript。这种方法适用于小型模型,但 LLM 的 IRModule 非常庞大和复杂,手工编写并不现实。TVM Unity 提供了另一个类 nn.Module,可以像 pytorch 模块一样轻松构建 IRModule.
用 Pytorch 手动编写的一个 Linear 层如下
1 |
|
在 Relax 中的实现如下
1 |
|
与 Pytorch 的结构非常相似,只是前向函数实际上并不执行计算。它使用作为输入传递的占位符跟踪算子的计算图。nn.emit(relax.op.linear(input, self.weight, self.bias))
表示在构建的 IRModule 中添加高级 linear 算子。
通过堆叠 1 个线性层、1 个 relu 层和 1 个线性层,就可以构建例子中的 MLP.
1 |
|
直接调用 nn.Module 的前向函数就可以代替原先在 with bb.dataflow():
下的操作,将 nn.Module
构建成 IRModule 的步骤如下
1 |
|
Custom Operator Support
在某些情况下,我们要表示的模型包含一些自定义运算符,而这些运算符没有被提供的 Relax 运算符覆盖(如 LLaMA 中的 Rotary Embedding),或者我们要进行底层优化以加速单个内核。下面介绍如何在 IRModule 中编写自定义算子。
TensorIR: Low-level tensor function
TVM Unity 在 IRModule TensorIR 中提供了底层张量函数的表示方法,用户可以在其中定义自定义操作符并执行细粒度调度。
下面对比了一个矩阵乘法生成的 TVMScript TensorIR 代码和 low-level Pytorch 代码。@T.prim_func
装饰器表示下面的函数是一个原始的张量函数,包含运算符实现的底层细节。
T.prim_func
采用 destination-passing 约定,即在函数外部明确分配输入和输出空间,并将其作为参数传入。destination-passing 约定可以对内存分配进行精细调度,例如合并两个实时间隔不相交的变量的内存分配,这是在内存有限的设备上运行大型模型的关键。
1 |
|
Interaction between Relax function and TensorIR
为了支持 T.prim_func
(底层部分)和 R.function
(高层部分)之间的交互,TVM 引入了 call_tir
, Relax 中的一个特殊运算符,用于描述计算图中的节点及其张量函数的实现。torch_call_tir
是一个参考实现,用来说明 call_tir 的含义。实际上,可以有不同的底层方法来优化执行。例如,我们可能会选择提前分配所有输出内存,然后再运行执行。
1 |
|
下面是 2 层 MLP 的 IRModule,我们使用 call_tir
和张量原语函数 matmul
来替换 Relax 运算符 R.matmul
1 |
|
Implement Custom TensorIR Function
nn.Module
不仅支持高级 Relax 运算符,还支持自定义 TensorIR 函数。
要构建 TensorIR 函数并在 Relax 图中调用它,我们需要使用 nn.emit_te(f_te_expr,*args)
。
f_te_expr
是一个返回张量表达式(Tensor Expression,TE)的函数,是描述张量计算的 DSL.args
是f_te_expr
的参数。
创建 TE 表达式的方法如下
1 |
|
它描述如下的计算模式
在 Python 的 itertools 模块中,product
函数用于生成可迭代对象的笛卡尔积。
product
函数接受一个或多个可迭代对象作为参数,并返回一个迭代器,该迭代器生成所有可能的组合,其中每个组合包含来自每个输入可迭代对象的单个元素。
1 |
|
product
函数还支持重复元素,可以使用 repeat 参数指定每个可迭代对象需要重复的次数。
1 |
|
product
应用场景
- 组合生成: 生成所有可能的组合,例如密码生成、彩票号码生成等。
- 多维数组遍历: 遍历多维数组的所有元素。
- 测试用例生成: 生成测试用例,覆盖所有可能的输入组合。
1 |
|
用 emit_te
实现 Linear 层来构建 IRModule 的代码如下
1 |
|