TVM Learning (10)-Computational Graph Optimization
Pattern Match and Rewriting
下面代码中 MyModule
包含一个带有两个高级算子 relax.opmultiply
和 relax.op.add
的 relax 函数。我们的目标是找到这两个算子,并将其替换为对 relax.ewise_fma
算子的调用。
1 |
|
每个 IRModule 都包含一组函数,函数体由一组称为抽象语法树(AST)的数据结构组成。
抽象语法树(Abstract Syntax Tree,AST)是一种广泛用于编程语言处理的树状数据结构。它是一种对源代码语法结构的抽象表示,去掉了编程语言的具体语法细节,但保留了代码的结构和语义信息。
AST 是一棵树状结构,其节点表示源代码中的语法结构。例如,变量声明、操作符、函数调用、控制结构(如条件语句、循环)等。每个节点包含与相应语法结构相关的信息,如操作符的类型、变量的名称、常量的值等。
1 |
|
这个代码可以转换为如下形式的 AST:
1 |
|
每个函数都由一个 relax.expr.Function
节点表示。
1 |
|
该函数包含一系列参数
1 |
|
该函数包含一个返回值表达式,和函数中的一组 binding blocks.
1 |
|
函数主体 SeqExpr 包含一系列 binding.
1 |
|
在 DataflowBlock 中,我们可以访问各个 binding ,包括 value 和 var.
1 |
|
改写程序可以通过递归遍历 MyModule 的 AST ,并生成转换后的 AST 来实现。但是我们可以使用额外的工具支持来简化流程。下面的代码遵循一种称为 visitor pattern 的设计模式,允许我们访问每个 AST 节点并将它们重写为转换后的版本。主要目的是将形如 a * b + c
的表达式转换为 ewise_fma(a, b, c)
的形式。
EwiseFMARewriter
继承自 relax.PyExprMutator
,这是 TVM 中的一个基类,用于遍历和修改表达式树中的节点。visit_call_
方法被重载来处理 relax.Call
节点,被重载来处理 relax.Call
节点。
如果当前节点不是加法操作,直接返回该节点,表示对该节点不进行任何修改。如果加法的第一个操作数不是乘法操作,或者第一个操作数的绑定值不是一个 relax.Call
节点,直接返回该加法节点。如果匹配成功,构造一个新的 ewise_fma
操作节点,将乘法的两个操作数和加法的第二个操作数作为参数传入。
1 |
|
使用 remove_all_unused
来删除代码中没有用到的 DataflowBlocks 和 VarBindings.
1 |
|
Fuse Linear and ReLU
下面在端到端模型上进行计算图的改写。采用的还是之前使用的 FashionMNIST MLP 模型。为了简化过程,直接使用高级运算符构建模型。
1 |
|
我们的目标是对 matmul 和 add 进行算子融合。具体实现步骤与 FMA 相似:
- 识别 matmul 和 add 算子。
- 生成另一个调用 matmul 和 add 算子的子函数。
- 将 matmul 和 add 替换为融合后的子函数。
下面代码定义了一个名为 DenseAddFusor
的类,用于在 TVM 的 Relax 框架中将特定的矩阵乘法和加法操作模式融合成一个高效的原语函数。
transform
方法遍历模块中的每个函数。如果函数已经被标记为 primitive(即已经被融合过),则跳过。对每个函数应用visit_expr
以进行模式匹配和潜在的融合操作,然后删除未使用的变量,并更新函数。最后,返回更新后的IRModule
。visit_call_
方法用于访问relax.Call
节点(表示操作符调用)。它首先递归处理子表达式,然后尝试匹配特定模式。match_call
是一个内部函数,用于检查某个节点是否是特定操作符的调用。如果当前节点不是add
操作,或者add
操作的第一个参数不是matmul
(矩阵乘法)操作,则直接返回当前节点,不进行修改。如果匹配成功,则提取matmul
的两个操作数x
和w
以及add
的第二个操作数b
,准备进行融合。- 通过
relax.BlockBuilder
定义一个名为fused_dense_addX
新的融合函数,其中X
是一个递增的计数器。该函数接收x
、w
、b
作为参数,首先进行矩阵乘法,然后将结果与b
相加,最终输出结果。 - 给新生成的融合函数添加一个属性 Primitive,标记为已经融合的原语函数。通过
builder_
更新全局模块,将融合函数添加到模块中 (GlobalVar 用于指代存储在 IRModule 中的全局函数)。返回一个新的relax.Call
节点,该节点调用生成的融合函数,并传递原始的输入参数x
、w
、b
。
TVM 中的 VisitExpr
流程是一种递归遍历 IR 节点的机制,它是实现各种 IR 转换和优化的基础。具体流程如下:
- 首先创建一个
ExprVisitor
或ExprMutator
的子类实例,这个子类会实现各种具体的访问逻辑。 - 调用
visit_expr
方法,传入根 IR 节点。这个方法会触发整个遍历过程的启动。 visit_expr
方法会首先调用visit_expr_post_order
方法,这个方法会以深度优先的方式遍历所有子节点。- 对于每个子节点,
visit_expr_post_order
会根据节点的具体类型,调用相应的visit_XXX_
方法。这些visit_XXX_
方法是由访问器子类实现的,包含了具体的访问逻辑。 - 在
visit_XXX_
方法中,如果遇到子节点,会递归调用visit_expr_post_order
方法继续遍历。 - 当遍历完整个 IR 树后,
visit_expr
方法会返回最终的结果,即经过转换和修改的 IR 节点。
1 |
|
融合后的 MLPFused 对应的 TensorIR 如下
TVM 框架中使用 module_pass 来管理各种优化操作。这种机制允许将不同的优化操作(如图优化、代码生成、算子融合等)组织成一个流水线(pipeline),按顺序对模块进行处理。将 DenseAddFusor 封装为一个 module_pass,使得它能够轻松集成到 TVM 的 Pass 流水线中,与其他 Pass 一起工作,从而保证优化过程的整体性和一致性。
1 |
|
上面的例子中,我们创建了两个前缀为 fuse_matmul_add 的子函数。 这些子函数包含有融合后算子的计算信息。 这种重写的替代方法是简单地为融合算子创建一个单独的原语算子(如ewise_fma)。 但是,当我们尝试融合更多算子时,可能存在指数级数量的组合。 将融合操作分组在一起的子函数为后续的 pass 保留了原始信息,进而便于分析,无需为每个融合 pattern 引入专用的高级算子。
Map to TensorIR Calls
为了进一步进行底层优化和代码生成,我们需要将这些高级原语运算转换为相应的 TensorIR 函数。下面代码主要功能是将 Relax 表达式树中的高层次算子( matmul
、add
、relu
)转换为对应的 TensorIR 表示,从而使得这些算子能够映射到底层的张量操作(tensor operations)。这种转换使得编译器可以生成更接近硬件的高效代码,并为后续的代码优化和生成做好准备。
- 调用
transform
方法会遍历mod_
中的所有函数:- 对于每个函数,首先调用
visit_expr
方法,这会触发VisitExpr
流程 visit_expr
方法会调用visit_expr_post_order
方法进行深度优先遍历- 在遍历过程中对于每个
relax.Call
节点,会调用visit_call_
方法 visit_call_
方法会检查op_map
字典,如果当前操作在字典中,则调用对应的转换函数(map_dense
,map_add
,map_relu
)- 这些转换函数会使用
bb.call_te
方法,将 Relax IR 操作转换为 TensorIR 操作
- 对于每个函数,首先调用
- 在
transform
方法的最后,会调用builder_.get()
方法,返回转换后的新 IR 模块。 - 最后
LowerToTensorIRPass
类将LowerToTensorIR
转换器包装成一个可注册到 TVM 优化 pipeline 的 pass.
module_pass
的 opt_level
参数决定了优化 pass 在优化 pipeline 中的执行顺序。 TVM 的优化 pipeline 是由多个 module_pass
组成的,每个 module_pass
都有一个 opt_level
属性来指定它的优化级别。
当 TVM 进行优化时,它会按照 opt_level
从低到高的顺序依次应用各个 module_pass
. opt_level=0
的 pass 会首先被执行。这些 pass 通常会执行一些基础的、必要的转换,为后续的优化奠定基础。 随后会执行 opt_level=1
的 pass,这些 pass 可能会执行一些更复杂的优化,比如循环优化、内存访问优化等。依此类推,opt_level
越高的 pass 会在优化 pipeline 的后期执行,它们执行的优化通常也越复杂和深入。
通过合理地设置 opt_level
,开发者可以控制各个优化 pass 的执行顺序,从而构建出针对性强、性能优秀的优化 pipeline 。这种灵活的优化管理机制是 TVM 的一大特点。
对于 LowerToTensorIRPass
,它的 opt_level
被设置为 0, 说明它是一个基础的 pass, 主要用于将高级的 Relax IR 操作转换为底层的 TensorIR 操作。
1 |
|
融合后的 TensorIR 如下
1 |
|
在上面的 IRModule 中 fused_matmul_add0
和 fused_matmul_add1
仍然是 relax 函数,它们调用相应的 TensorIR matmul
和 add
函数。 我们可以将它们变成一个单一的 TensorIR 函数。
1 |
|