TVM Learning (7)-Integration with Machine Learning Frameworks
Build an IRModule Through a Builder
下面用一个矩阵乘法回顾一下如何从张量表达式创建IRModule. 先创建 placeholder
对象表示 T.prim_func
函数的输入。
1 |
|
同样我们可以使用 *
运算符对tuple解引用来实现对不同维度大小的输入进行ReLU.
1 |
|
我们仅仅对传入输入和输出参数来创建T.prim_func,这样可以使得中间结果仅仅被分配临时内存(在Schedule.compute_at已介绍过)。可以看到矩阵乘法的中间结果 matmul
被 T.alloc_buffer
分配。
1 |
|
Use BlockBuilder to Create an IRModule
在chapter3_exercise中也介绍了使用 relax.BlockBuilder
来创建IRModule. BlockBuilder
自带的作用域与 relax 函数中的作用域相对应。例如,bb.dataflow()
会创建一个数据流代码块,其中所有 BlockBuilder
的方法调用的作用域都属于数据流作用域。每个中间结果都是一个 relax.Var
,对应于一个存储计算结果的变量。DataflowVar
表示该变量是数据流块(计算图)中的一个中间步骤。在底层,bb.emit_te
会执行以下操作:
- 为 A 和 B 创建输入
te.placeholder
- 调用
te_matmul
函数运行它们。 - 调用
te.create_prim_func
创建一个 TensorIR 函数。 - 通过
call_tir
生成对该函数的调用。
最后,函数输出由 bb.emit_func_output
标记。在每个函数作用域中,我们只能调用一次 emit_func_output
。
1 |
|
值得注意,我们可以在 emit_func_output
指定函数的输入参数列表,这样做有助于我们随时获取参数列表。我们也可以在最开始的函数作用域里面声明。
1 |
|
Import Model From PyTorch
我们了解了以编程方式构建 IRModule 的工具。也可以用它们将 PyTorch 中的模型转换成 IRModule. 用Pytorch实现矩阵乘法+ReLU的网络如下
1 |
|
TorchFX是用来变换 nn.Module
实例的工具包。FX 由三个主要组件组成:symbolic_trace、中间表示和 Python 代码生成。
Symbolic Trace
symbolic_trace
函数用于对一个 PyTorch 模型进行符号追踪,它会执行模型的 forward 函数,并记录所有操作(如卷积、线性层、激活函数等)以及它们之间的依赖关系。返回一个包含了模型的计算图表示的 GraphModule
对象。
1 |
|
- 在 FX 中,方法输入通过特殊的
placeholder
节点指定。在本例中,我们有一个placeholder
节点,其target
为x
,这意味着我们有一个名为 x 的(非自身)参数。 get_attr
、call_function
、call_module
和call_method
节点表示方法中的操作。所有这些语义的完整处理可以在Node
文档中找到。Graph
中的返回值由特殊的output
节点指定。
Graph IR
symbolic_traced.graph
属性是一个 torch.fx.Graph
对象,代表了模型的计算图的 IR 表示。
graph():
定义了一个名为graph
的函数,它代表整个计算图。%x : [num_users=1] = placeholder[target=x]
定义了一个名为%x
的占位符节点,它代表模型的输入数据。[num_users=1]
表示这个节点在计算图中被使用了一次。target=x
表示这个占位符节点对应于模型的x
输入参数。
%weight : [num_users=1] = get_attr[target=weight]
定义了一个名为%weight
的节点,它代表模型的权重参数。target=weight
表示这个节点对应于模型的weight
属性。
%matmul : [num_users=1] = call_function[target=torch.matmul](args = (%x, %weight), kwargs = {})
定义了一个名为%matmul
的节点,它代表对输入数据%x
和权重参数%weight
进行矩阵乘法操作。target=torch.matmul
表示这个节点对应于 PyTorch 的torch.matmul
函数。args = (%x, %weight)
表示该操作的输入参数是%x
和%weight
。kwargs = {}
表示该操作没有额外的关键字参数。
%relu : [num_users=1] = call_function[target=torch.relu](args = (%matmul,), kwargs = {})
定义了一个名为%relu
的节点,它代表对矩阵乘法的结果%matmul
应用 ReLU 激活函数。target=torch.relu
表示这个节点对应于 PyTorch 的torch.relu
函数。args = (%matmul,)
表示该操作的输入参数是%matmul
。
return relu
表示计算图的输出是%relu
节点,即经过 ReLU 激活后的结果。
1 |
|
Graph Code
symbolic_traced.code
属性是一个字符串,它包含了模型计算图的 Python 代码表示。对于每个计算图 IR,创建与图语义匹配的有效 Python 代码。
1 |
|
Create Map Function
整个翻译逻辑的主要流程如下:
- 创建一个
node_map
,将fx.Node
映射到相应的relax.Var
以表示 IRModule 中的节点。 - 按拓扑顺序遍历 fx 图中的节点。
- 根据映射输入计算节点的映射输出。
Map Parameter
map_param(param: nn.Parameter)
函数将 PyTorch 的 nn.Parameter
对象转换为 TVM Relax 的常量节点。它首先获取参数的形状和数据类型,然后使用 relax.const
函数创建一个常量节点,并将参数数据转换为 NumPy 数组。
1 |
|
Fetch Attribution
fetch_attr(fx_mod, target: str)
函数用于从 fx_mod
对象中获取指定属性值。它将 target
字符串拆分为多个部分,并依次访问 fx_mod
对象的属性,直到找到目标属性。
1 |
|
Translate from TorchFX
from_fx(fx_mod, input_shapes, call_function_map, call_module_map)
函数是核心转换函数,它将 fx_mod
对象转换为 TVM Relax 的 IRModule
对象。
它首先定义了几个变量:
input_index
: 用于跟踪输入节点的索引。node_map
: 用于存储fx_mod
中每个节点对应的 Relax 节点。named_modules
: 用于存储fx_mod
中所有模块的名称和对象。bb
: 一个relax.BlockBuilder
对象,用于构建 Relax 函数。fn_inputs
: 用于存储函数的输入参数。fn_output
: 用于存储函数的输出节点。
然后使用 bb.function
创建一个名为 “main” 的 Relax 函数。在函数中,遍历 fx_mod
的所有节点,并根据节点类型进行不同的处理:
placeholder
: 创建一个输入占位符节点。get_attr
: 使用map_param
函数将参数转换为常量节点。call_function
: 使用call_function_map
字典中指定的函数来处理函数调用。call_module
: 使用call_module_map
字典中指定的函数来处理模块调用。output
: 设置函数的输出节点。
最后,使用 bb.get()
获取生成的 IRModule
对象。
1 |
|
创建的IRModule如下
1 |
|
Translate by reusing pre-defined TE libraries
TOPI (TVM OPeration Inventory) 提供了 numpy 风格的通用操作和 schedule,其抽象程度高于 TVM. 使用它里面已有的模块可以省去自己定义张量表达式的工作。
topi.nn.dense(x, w)
执行转置矩阵乘法x @ w.T
topi.add
执行广播加法。
我们可以将下面的Pytorch MLP网络翻译成IRModule
1 |
|
Translating into High-level Operators
在大多数机器学习框架中,首先翻译成高级内置原语算子有时会很有帮助,因为这些算子已经被很大程度上优化过。我们通过调用内置算子将模型导入 IRModule. 这些内置运算符是比TensorIR函数高级的抽象。我们可以利用不同的方法,将这些原语算子进一步转化为库函数或TensorIR函数。
可以看见relax函数里面都是调用的原始算子而不是使用call_tir
1 |
|