Sketching Out a Dseign TableGen 也可以用来定义 dialect. 本文将定义一个单未知数多项式运算的 dialect,系数用 uint32_t 类型表示。,并提供通过从标准 MLIR 类型指定多项式系数来定义多项式的操作,提取关于多项式的数据以将结果存储在标准MLIR类型中,以及对多项式进行算术运算。
An Empty Dialect 我们首先用 TableGen 定义一个空的 dialect. 它和上一章定义Pass没什么不同,只不过 include 的是 DialectBase.td 文件。同时也定义了命名空间为 ::mlir::tutorial::poly
.
1 2 3 4 5 6 7 8 9 10 11 12 include "mlir/IR/DialectBase.td" def Poly_Dialect : Dialect { let name = "poly" ; let summary = "A dialect for polynomial math" ; let description = [{ The poly dialect defines types and operations for single-variable polynomials over integers. }]; let cppNamespace = "::mlir::tutorial::poly" ; }
我们需要在 include 目录下的 CMakeLists.txt 文件中添加
1 2 3 4 5 set (TARGET_NAME "${PROJECT_TARGET_PREFIX}-Dialect-PolyDialect-IncGen" )set (LLVM_TARGET_DEFINITIONS mlir-learning/Dialect/Poly/PolyDialect.td) mlir_tablegen(mlir-learning/Dialect/Poly/PolyDialect.hpp.inc --gen-dialect-decls) mlir_tablegen(mlir-learning/Dialect/Poly/PolyDialect.cpp.inc --gen-dialect-defs) add_public_tablegen_target(${TARGET_NAME} )
然后在 tutorial-opt.cpp 中仅仅注册所有 mlir 自带的所有 dialect 后进行构建,我们可以查看生成的 .hpp.inc 和.cpp.inc 文件。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 namespace mlir { namespace tutorial {class PolyDialect : public ::mlir::Dialect { explicit PolyDialect (::mlir::MLIRContext *context) ; void initialize () ; friend class : :mlir::MLIRContext; public: ~PolyDialect() override; static constexpr ::llvm::StringLiteral getDialectNamespace () { return ::llvm::StringLiteral("poly" ); } }; } } MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::PolyDialect)
编译器会报错,因为没有包含 Dialect 等类所在的头文件。这需要我们自己在 PolyDialect.h 文件中进行 include,这样 当重新构建的时候该文件注入变不会报错。这也是我觉得很匪夷所思的一点,明明生成这个代码的工具都已经使用这个库了帮我们完成了,为什么还要我们手动 include 呢?
1 2 3 4 5 6 7 8 9 #ifndef LIB_DIALECT_POLY_POLYDIALECT_H #define LIB_DIALECT_POLY_POLYDIALECT_H #include "mlir/IR/DialectImplementation.h" #include "mlir-learning/Dialect/Poly/PolyDialect.hpp.inc" #endif
生成的 .cpp.inc 如下,他只包含了该类基本的构造函数和析构函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 MLIR_DEFINE_EXPLICIT_TYPE_ID (::mlir::tutorial::poly::PolyDialect)namespace mlir {namespace tutorial {namespace poly { PolyDialect::PolyDialect (::mlir::MLIRContext *context) : ::mlir::Dialect (getDialectNamespace (), context, ::mlir::TypeID::get <PolyDialect>()) { initialize (); } PolyDialect::~PolyDialect () = default ; } } }
然后我们可以在 tutorial-opt.cpp 中注册该 dialect.
1 2 3 4 5 6 7 8 9 10 11 12 #include "mlir-learning/Dialect/Poly/PolyDialect.h" int main (int argc, char ** argv) { mlir::DialectRegistry registry; registry.insert <mlir::tutorial::poly::PolyDialect>(); mlir::registerAllDialects (registry); return mlir::asMainReturnCode ( mlir::MlirOptMain (argc, argv, "Tutorial Pass Driver" , registry)); }
Adding a Trival Type 下面我们需要定义自己的 poly.poly 类型.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 #ifndef LIB_DIALECT_POLY_POLYTYPES_TD_ #define LIB_DIALECT_POLY_POLYTYPES_TD_ include "mlir-learning/Dialect/Poly/PolyDialect.td" include "mlir/IR/AttrTypeBase.td" class Poly_Type<string name , string typeMnemonic> : TypeDef <Poly_Dialect , name> { let mnemonic = typeMnemonic; }def Polynomial : Poly_Type <"Polynomial" , "poly" > { let summary = "A polynomial with u32 coefficients" ; let description = [{ A type for polynomials with integer coefficients in a single-variable polynomial ring . }]; } #endif
在 MLIR 的 TableGen 文件中,class 和 def 的用法和含义有所不同
class
用于定义一个模板或基类,可以被其他类型或定义继承和重用。它本身不会创建实际的对象或具体类型,它只是一种结构,可以包含参数和默认属性。其他定义可以通过继承该类来获得其功能。
def
用于创建一个具体的实例,比如一个类型、操作或属性。它会将所定义的内容应用到 TableGen 中,使其成为可用的具体类型或功能。
这里我们定义了一个名为 Poly_Type
的类,参数为 name
(类型的名称)和 typeMnemonic
(类型的简写或助记符)。这个类继承自 TypeDef<Poly_Dialect, name>
. 然后 def
特定的多项式类型 Polynomial
,继承自 Poly_Type
.
在 MLIR 的 TableGen 中,TypeDef
本身也是一个类,它接受模板参数,用于指定该类型所属的 dialect 和名称字段。其作用包括将生成的C++类与该 dialect 的命名空间相关联。
生成的 .hpp.inc 文件如下。生成的类 PolynomialType
就是在我们的 TableGen 文件中定义的 Polynomial
类型后面加上了 Type.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 #ifdef GET_TYPEDEF_CLASSES #undef GET_TYPEDEF_CLASSES namespace mlir {class AsmParser ;class AsmPrinter ; } namespace mlir {namespace tutorial {namespace poly {class PolynomialType ;class PolynomialType : public ::mlir::Type::TypeBase<PolynomialType, ::mlir::Type, ::mlir::TypeStorage> {public : using Base::Base; static constexpr ::llvm::StringLiteral name = "poly.poly" ; static constexpr ::llvm::StringLiteral dialectName = "poly" ; static constexpr ::llvm::StringLiteral getMnemonic () { return {"poly" }; } }; } } } MLIR_DECLARE_EXPLICIT_TYPE_ID (::mlir::tutorial::poly::PolynomialType)#endif
生成的 .cpp.inc 文件如下。TableGen 试图为 dialect 中的 PolynomialType
自动生成一个 类型解析器 (type parser) 和类型打印器 (type printer). 不过此时这些功能还不可用,构建项目时会看到一些编译警告。
代码中使用了 头文件保护 (header guards) 来将 cpp
文件分隔为两个受保护的部分。这样可以分别管理类型声明和函数实现。
GET_TYPEDEF_LIST
只包含类名的逗号分隔列表。原因在于 PolyDialect.cpp
文件需要负责将类型注册到 dialect 中,而该注册过程通过在方言初始化函数中将这些 C++ 类名作为模板参数来实现。换句话说,GET_TYPEDEF_LIST
提供了一种简化机制,使得 PolyDialect.cpp
可以自动获取所有类名称列表,便于统一注册,而不需要手动添加每一个类型。
generatedTypeParser
函数是为 PolynomialType
定义的解析器。当解析器遇到 PolynomialType
的助记符(poly
)时,会将 PolynomialType
类型实例化。KeywordSwitch
使用 getMnemonic()
来匹配 PolynomialType
的助记符(poly
)。如果匹配成功,则调用 PolynomialType::get()
来获取类型实例。Default
子句在助记符不匹配时执行,记录未知的助记符,并返回 std::nullopt
表示解析失败。
generatedTypePrinter
函数为 PolynomialType
提供了打印功能。当类型为 PolynomialType
时,打印其助记符(poly
),否则返回失败。TypeSwitch
用于检查 def
类型是否是 PolynomialType
。如果是,打印助记符;否则返回失败,表示该类型不属于此方言。
PolyDialect::parseType
和 PolyDialect::printType
作为方言接口调用这两个函数,从而实现类型的解析和打印功能。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 #ifdef GET_TYPEDEF_LIST #undef GET_TYPEDEF_LIST ::mlir::tutorial::poly::PolynomialType#endif #ifdef GET_TYPEDEF_CLASSES #undef GET_TYPEDEF_CLASSES static ::mlir::OptionalParseResult generatedTypeParser (::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { return ::mlir::AsmParser::KeywordSwitch <::mlir::OptionalParseResult>(parser) .Case (::mlir::tutorial::poly::PolynomialType::getMnemonic (), [&](llvm::StringRef, llvm::SMLoc) { value = ::mlir::tutorial::poly::PolynomialType::get (parser.getContext ()); return ::mlir::success (!!value); }) .Default ([&](llvm::StringRef keyword, llvm::SMLoc) { *mnemonic = keyword; return std::nullopt ; }); }static ::llvm::LogicalResult generatedTypePrinter (::mlir::Type def, ::mlir::AsmPrinter &printer) { return ::llvm::TypeSwitch <::mlir::Type, ::llvm::LogicalResult>(def) .Case <::mlir::tutorial::poly::PolynomialType>([&](auto t) { printer << ::mlir::tutorial::poly::PolynomialType::getMnemonic (); return ::mlir::success (); }) .Default ([](auto ) { return ::mlir::failure (); }); }namespace mlir {namespace tutorial {namespace poly { } } } MLIR_DEFINE_EXPLICIT_TYPE_ID (::mlir::tutorial::poly::PolynomialType)namespace mlir {namespace tutorial {namespace poly { ::mlir::Type PolyDialect::parseType (::mlir::DialectAsmParser &parser) const { ::llvm::SMLoc typeLoc = parser.getCurrentLocation (); ::llvm::StringRef mnemonic; ::mlir::Type genType; auto parseResult = generatedTypeParser (parser, &mnemonic, genType); if (parseResult.has_value ()) return genType; parser.emitError (typeLoc) << "unknown type `" << mnemonic << "` in dialect `" << getNamespace () << "`" ; return {}; }void PolyDialect::printType (::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const { if (::mlir::succeeded (generatedTypePrinter (type, printer))) return ; } } } } #endif
在设置 C++ 接口以使用 TableGen 文件时,通常会按照以下步骤来组织代码文件和包含关系。
PolyTypes.h
是唯一被允许包含 PolyTypes.h.inc
的文件。
PolyTypes.cpp.inc
文件包含了 TableGen 为 PolyDialect
中的类型生成的实现。我们需要在 PolyDialect.cpp
中将其包含进去,以确保所有实现都能在该方言的主文件中使用。
PolyTypes.cpp
文件应该包含 PolyTypes.h
,以便访问类型声明,并在该文件中实现所有需要的额外功能。
为了让类型解析器和打印器能够正确编译和运行,需要最后在方言的 TableGen 文件中添加 let useDefaultTypePrinterParser = 1
;,这个指令告诉 TableGen 使用默认的类型解析和打印器。当这个选项启用后,TableGen 会生成相应的解析和打印代码,并将这些实现作为 PolyDialect
类的成员函数。
1 2 3 4 5 6 ::mlir::Type parseType (::mlir::DialectAsmParser &parser) const override ; void printType (::mlir::Type type, ::mlir::DialectAsmPrinter &os) const override ;
我们可以写一个 .mlir 来测试属性是是否获取正确。在 MLIR 中自定义的 dialect 前都需要加上 !
.
1 2 3 4 5 // CHECK-LABEL: test_type_syntax func.func @test_type_syntax(%arg0: !poly.poly<10>) -> !poly.poly<10> { // CHECK: poly.poly return %arg0: !poly.poly<10> }
Add a Poly Type Parameter 我们需要为多项式类型添加一个属性,表示它的次数上限。
1 2 3 // include/mlir-learning/Dialect/Poly/PolyTypes.tdlet parameters = (ins "int" :$degreeBound );let assemblyFormat = "`<` $degreeBound `>`" ;
第一行定义了类型的一个参数 degreeBound
,类型为 int
. 表示在实例化该类型时,用户可以指定一个整数值作为类型的参数。parameters
中的 (ins "int":$degreeBound
) 指定了输入参数的类型和名称,其中 int 是数据类型,$degreeBound
是参数的占位符。assemblyFormat
用于定义该类型在 MLIR 文本格式中的打印和解析格式。"<" $degreeBound ">"
表示该类型的参数会用尖括号包裹。第二行是必需的,因为现在一个 Poly 类型有了这个关联的数据,我们需要能够将它打印出来并从文本 IR 表示中解析它。
加上这两行代码后进行 build 会发现多了一些新的内容。PolynomialType
有一个新的 int getDegreeBound()
方法,以及一个静态 get
工厂方法。
parse
和 print
升级为新格式。
有一个名为 typestorage
的新类,它包含 int 形参,并隐藏在内部细节名称空间中。
MLIR会自动生成简单类型的 storage 类,因为它们不需要复杂的内存管理。如果参数更复杂,就需要开发者手动编写 storage 类来定义构造、析构和其他语义。复杂的 storage 类需要实现更多细节,以确保类型能够在 MLIR 的 dialect 系统中顺利运行。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 static ::mlir::Type parse (::mlir::AsmParser &odsParser) ; void print (::mlir::AsmPrinter &odsPrinter) const ; int getDegreeBound () const ;struct PolynomialTypeStorage : public ::mlir::TypeStorage { };PolynomialType PolynomialType::get (::mlir::MLIRContext *context, int degreeBound) { return Base::get (context, std::move (degreeBound)); } ::mlir::Type PolynomialType::parse (::mlir::AsmParser &odsParser) { }void PolynomialType::print (::mlir::AsmPrinter &odsPrinter) const { ::mlir::Builder odsBuilder (getContext()) ; odsPrinter << "<" ; odsPrinter.printStrippedAttrOrType (getDegreeBound ()); odsPrinter << ">" ; }int PolynomialType::getDegreeBound () const { return getImpl ()->degreeBound; }
Adding Some Simple Operations 下面我们定义一个简单的多项式加法操作
1 2 3 4 5 6 7 8 9 10 // include/mlir-learning/Dialect/Poly/PolyOps.td include "PolyDialect.td" include "PolyTypes.td" def Poly_AddOp : Op<Poly_Dialect, "add" > { let summary = "Addition operation between polynomials." ; let arguments = (ins Polynomial:$lhs , Polynomial:$rhs ); let results = (outs Polynomial:$output ); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs ) `,` type($rhs ) `)` `->` type($output )" ; }
看起来非常类似于类型,但基类是 Op,arguments 对应于操作的输入,assemblyFormat 更复杂。生成的 .hpp.inc 和 .cpp.inc 非常复杂。我们可以编写一个 .mlir 来测试。
1 2 3 4 5 6 // CHECK-LABEL: test_add_syntax func.func @test_add_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> { // CHECK: poly.add %0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> return %0 : !poly.poly<10> }