MLIR-Defining a New Dialect

Sketching Out a Dseign

TableGen 也可以用来定义 dialect. 本文将定义一个单未知数多项式运算的 dialect,系数用 uint32_t 类型表示。,并提供通过从标准 MLIR 类型指定多项式系数来定义多项式的操作,提取关于多项式的数据以将结果存储在标准MLIR类型中,以及对多项式进行算术运算。

An Empty Dialect

我们首先用 TableGen 定义一个空的 dialect. 它和上一章定义Pass没什么不同,只不过 include 的是 DialectBase.td 文件。同时也定义了命名空间为 ::mlir::tutorial::poly.

1
2
3
4
5
6
7
8
9
10
11
12
include "mlir/IR/DialectBase.td"

def Poly_Dialect : Dialect {
let name = "poly";
let summary = "A dialect for polynomial math";
let description = [{
The poly dialect defines types and operations for single-variable
polynomials over integers.
}];

let cppNamespace = "::mlir::tutorial::poly";
}

我们需要在 include 目录下的 CMakeLists.txt 文件中添加

1
2
3
4
5
set(TARGET_NAME "${PROJECT_TARGET_PREFIX}-Dialect-PolyDialect-IncGen")
set(LLVM_TARGET_DEFINITIONS mlir-learning/Dialect/Poly/PolyDialect.td)
mlir_tablegen(mlir-learning/Dialect/Poly/PolyDialect.hpp.inc --gen-dialect-decls)
mlir_tablegen(mlir-learning/Dialect/Poly/PolyDialect.cpp.inc --gen-dialect-defs)
add_public_tablegen_target(${TARGET_NAME})

然后在 tutorial-opt.cpp 中仅仅注册所有 mlir 自带的所有 dialect 后进行构建,我们可以查看生成的 .hpp.inc 和.cpp.inc 文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
namespace mlir {
namespace tutorial {

class PolyDialect : public ::mlir::Dialect {
explicit PolyDialect(::mlir::MLIRContext *context);

void initialize();
friend class ::mlir::MLIRContext;
public:
~PolyDialect() override;
static constexpr ::llvm::StringLiteral getDialectNamespace() {
return ::llvm::StringLiteral("poly");
}
};
} // namespace tutorial
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::PolyDialect)

编译器会报错,因为没有包含 Dialect 等类所在的头文件。这需要我们自己在 PolyDialect.h 文件中进行 include,这样 当重新构建的时候该文件注入变不会报错。这也是我觉得很匪夷所思的一点,明明生成这个代码的工具都已经使用这个库了帮我们完成了,为什么还要我们手动 include 呢?

1
2
3
4
5
6
7
8
9
// include/mlir-learning/Dialect/Poly/PolyDialect.h
#ifndef LIB_DIALECT_POLY_POLYDIALECT_H
#define LIB_DIALECT_POLY_POLYDIALECT_H

#include "mlir/IR/DialectImplementation.h" // include mannually

#include "mlir-learning/Dialect/Poly/PolyDialect.hpp.inc"

#endif

生成的 .cpp.inc 如下,他只包含了该类基本的构造函数和析构函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolyDialect)
namespace mlir {
namespace tutorial {
namespace poly {

PolyDialect::PolyDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<PolyDialect>())

{

initialize();
}

PolyDialect::~PolyDialect() = default;

} // namespace poly
} // namespace tutorial
} // namespace mlir

然后我们可以在 tutorial-opt.cpp 中注册该 dialect.

1
2
3
4
5
6
7
8
9
10
11
12
/* other includes */
#include "mlir-learning/Dialect/Poly/PolyDialect.h"

int main(int argc, char** argv) {
// Register all built-in MLIR dialects
mlir::DialectRegistry registry;
// Register our Dialect
registry.insert<mlir::tutorial::poly::PolyDialect>();
mlir::registerAllDialects(registry);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry));
}

Adding a Trival Type

下面我们需要定义自己的 poly.poly 类型.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#ifndef LIB_DIALECT_POLY_POLYTYPES_TD_
#define LIB_DIALECT_POLY_POLYTYPES_TD_

include "mlir-learning/Dialect/Poly/PolyDialect.td"
include "mlir/IR/AttrTypeBase.td"

// a base class for all types in the dialect
class Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> {
let mnemonic = typeMnemonic;
}

def Polynomial: Poly_Type<"Polynomial", "poly"> {
let summary = "A polynomial with u32 coefficients";

let description = [{
A type for polynomials with integer coefficients in a single-variable polynomial ring.
}];
}
#endif

在 MLIR 的 TableGen 文件中,class 和 def 的用法和含义有所不同

  • class 用于定义一个模板或基类,可以被其他类型或定义继承和重用。它本身不会创建实际的对象或具体类型,它只是一种结构,可以包含参数和默认属性。其他定义可以通过继承该类来获得其功能。
  • def 用于创建一个具体的实例,比如一个类型、操作或属性。它会将所定义的内容应用到 TableGen 中,使其成为可用的具体类型或功能。

这里我们定义了一个名为 Poly_Type 的类,参数为 name(类型的名称)和 typeMnemonic(类型的简写或助记符)。这个类继承自 TypeDef<Poly_Dialect, name>. 然后 def 特定的多项式类型 Polynomial,继承自 Poly_Type.

在 MLIR 的 TableGen 中,TypeDef 本身也是一个类,它接受模板参数,用于指定该类型所属的 dialect 和名称字段。其作用包括将生成的C++类与该 dialect 的命名空间相关联。

生成的 .hpp.inc 文件如下。生成的类 PolynomialType 就是在我们的 TableGen 文件中定义的 Polynomial 类型后面加上了 Type.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#ifdef GET_TYPEDEF_CLASSES
#undef GET_TYPEDEF_CLASSES


namespace mlir {
class AsmParser;
class AsmPrinter;
} // namespace mlir
namespace mlir {
namespace tutorial {
namespace poly {
class PolynomialType;
class PolynomialType : public ::mlir::Type::TypeBase<PolynomialType, ::mlir::Type, ::mlir::TypeStorage> {
public:
using Base::Base;
static constexpr ::llvm::StringLiteral name = "poly.poly";
static constexpr ::llvm::StringLiteral dialectName = "poly";
static constexpr ::llvm::StringLiteral getMnemonic() {
return {"poly"};
}

};
} // namespace poly
} // namespace tutorial
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolynomialType)

#endif // GET_TYPEDEF_CLASSES

生成的 .cpp.inc 文件如下。TableGen 试图为 dialect 中的 PolynomialType 自动生成一个 类型解析器 (type parser) 和类型打印器 (type printer). 不过此时这些功能还不可用,构建项目时会看到一些编译警告。

代码中使用了 头文件保护 (header guards) 来将 cpp 文件分隔为两个受保护的部分。这样可以分别管理类型声明和函数实现。

GET_TYPEDEF_LIST 只包含类名的逗号分隔列表。原因在于 PolyDialect.cpp 文件需要负责将类型注册到 dialect 中,而该注册过程通过在方言初始化函数中将这些 C++ 类名作为模板参数来实现。换句话说,GET_TYPEDEF_LIST 提供了一种简化机制,使得 PolyDialect.cpp 可以自动获取所有类名称列表,便于统一注册,而不需要手动添加每一个类型。

  • generatedTypeParser 函数是为 PolynomialType 定义的解析器。当解析器遇到 PolynomialType 的助记符(poly)时,会将 PolynomialType 类型实例化。KeywordSwitch 使用 getMnemonic() 来匹配 PolynomialType 的助记符(poly)。如果匹配成功,则调用 PolynomialType::get() 来获取类型实例。Default 子句在助记符不匹配时执行,记录未知的助记符,并返回 std::nullopt 表示解析失败。
  • generatedTypePrinter 函数为 PolynomialType 提供了打印功能。当类型为 PolynomialType 时,打印其助记符(poly),否则返回失败。TypeSwitch 用于检查 def 类型是否是 PolynomialType。如果是,打印助记符;否则返回失败,表示该类型不属于此方言。
  • PolyDialect::parseTypePolyDialect::printType 作为方言接口调用这两个函数,从而实现类型的解析和打印功能。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#ifdef GET_TYPEDEF_LIST
#undef GET_TYPEDEF_LIST

::mlir::tutorial::poly::PolynomialType

#endif // GET_TYPEDEF_LIST

#ifdef GET_TYPEDEF_CLASSES
#undef GET_TYPEDEF_CLASSES

static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
.Case(::mlir::tutorial::poly::PolynomialType::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
value = ::mlir::tutorial::poly::PolynomialType::get(parser.getContext());
return ::mlir::success(!!value);
})
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {
*mnemonic = keyword;
return std::nullopt;
});
}

static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def) .Case<::mlir::tutorial::poly::PolynomialType>([&](auto t) {
printer << ::mlir::tutorial::poly::PolynomialType::getMnemonic();
return ::mlir::success();
})
.Default([](auto) { return ::mlir::failure(); });
}

namespace mlir {
namespace tutorial {
namespace poly {
} // namespace poly
} // namespace tutorial
} // namespace mlir
MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolynomialType)
namespace mlir {
namespace tutorial {
namespace poly {

/// Parse a type registered to this dialect.
::mlir::Type PolyDialect::parseType(::mlir::DialectAsmParser &parser) const {
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef mnemonic;
::mlir::Type genType;
auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
if (parseResult.has_value())
return genType;

parser.emitError(typeLoc) << "unknown type `"
<< mnemonic << "` in dialect `" << getNamespace() << "`";
return {};
}
/// Print a type registered to this dialect.
void PolyDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const {
if (::mlir::succeeded(generatedTypePrinter(type, printer)))
return;

}
} // namespace poly
} // namespace tutorial
} // namespace mlir

#endif // GET_TYPEDEF_CLASSES

在设置 C++ 接口以使用 TableGen 文件时,通常会按照以下步骤来组织代码文件和包含关系。

  • PolyTypes.h 是唯一被允许包含 PolyTypes.h.inc 的文件。
  • PolyTypes.cpp.inc 文件包含了 TableGen 为 PolyDialect 中的类型生成的实现。我们需要在 PolyDialect.cpp 中将其包含进去,以确保所有实现都能在该方言的主文件中使用。
  • PolyTypes.cpp 文件应该包含 PolyTypes.h,以便访问类型声明,并在该文件中实现所有需要的额外功能。

为了让类型解析器和打印器能够正确编译和运行,需要最后在方言的 TableGen 文件中添加 let useDefaultTypePrinterParser = 1;,这个指令告诉 TableGen 使用默认的类型解析和打印器。当这个选项启用后,TableGen 会生成相应的解析和打印代码,并将这些实现作为 PolyDialect 类的成员函数。

1
2
3
4
5
6
/// Parse a type registered to this dialect.
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;

/// Print a type registered to this dialect.
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &os) const override;

我们可以写一个 .mlir 来测试属性是是否获取正确。在 MLIR 中自定义的 dialect 前都需要加上 !.

1
2
3
4
5
// CHECK-LABEL: test_type_syntax
func.func @test_type_syntax(%arg0: !poly.poly<10>) -> !poly.poly<10> {
// CHECK: poly.poly
return %arg0: !poly.poly<10>
}

Add a Poly Type Parameter

我们需要为多项式类型添加一个属性,表示它的次数上限。

1
2
3
// include/mlir-learning/Dialect/Poly/PolyTypes.td
let parameters = (ins "int":$degreeBound);
let assemblyFormat = "`<` $degreeBound `>`";

第一行定义了类型的一个参数 degreeBound,类型为 int. 表示在实例化该类型时,用户可以指定一个整数值作为类型的参数。parameters 中的 (ins "int":$degreeBound) 指定了输入参数的类型和名称,其中 int 是数据类型,$degreeBound 是参数的占位符。assemblyFormat 用于定义该类型在 MLIR 文本格式中的打印和解析格式。"<" $degreeBound ">" 表示该类型的参数会用尖括号包裹。第二行是必需的,因为现在一个 Poly 类型有了这个关联的数据,我们需要能够将它打印出来并从文本 IR 表示中解析它。

  • 加上这两行代码后进行 build 会发现多了一些新的内容。
    PolynomialType 有一个新的 int getDegreeBound() 方法,以及一个静态 get 工厂方法。
  • parseprint 升级为新格式。
  • 有一个名为 typestorage 的新类,它包含 int 形参,并隐藏在内部细节名称空间中。

MLIR会自动生成简单类型的 storage 类,因为它们不需要复杂的内存管理。如果参数更复杂,就需要开发者手动编写 storage 类来定义构造、析构和其他语义。复杂的 storage 类需要实现更多细节,以确保类型能够在 MLIR 的 dialect 系统中顺利运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// include/mlir-learning/Dialect/Poly/PolyTypes.hpp.inc
static ::mlir::Type parse(::mlir::AsmParser &odsParser);
void print(::mlir::AsmPrinter &odsPrinter) const;
int getDegreeBound() const;

// include/mlir-learning/Dialect/Poly/PolyTypes.cpp.inc

struct PolynomialTypeStorage : public ::mlir::TypeStorage {
/* lots of code */
};

PolynomialType PolynomialType::get(::mlir::MLIRContext *context, int degreeBound) {
return Base::get(context, std::move(degreeBound));
}

::mlir::Type PolynomialType::parse(::mlir::AsmParser &odsParser) {
/* code to parse the type */
}

void PolynomialType::print(::mlir::AsmPrinter &odsPrinter) const {
::mlir::Builder odsBuilder(getContext());
odsPrinter << "<";
odsPrinter.printStrippedAttrOrType(getDegreeBound());
odsPrinter << ">";
}

int PolynomialType::getDegreeBound() const {
return getImpl()->degreeBound;
}

Adding Some Simple Operations

下面我们定义一个简单的多项式加法操作

1
2
3
4
5
6
7
8
9
10
// include/mlir-learning/Dialect/Poly/PolyOps.td
include "PolyDialect.td"
include "PolyTypes.td"

def Poly_AddOp : Op<Poly_Dialect, "add"> {
let summary = "Addition operation between polynomials.";
let arguments = (ins Polynomial:$lhs, Polynomial:$rhs);
let results = (outs Polynomial:$output);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)";
}

看起来非常类似于类型,但基类是 Op,arguments 对应于操作的输入,assemblyFormat 更复杂。生成的 .hpp.inc 和 .cpp.inc 非常复杂。我们可以编写一个 .mlir 来测试。

1
2
3
4
5
6
// CHECK-LABEL: test_add_syntax
func.func @test_add_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> {
// CHECK: poly.add
%0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
return %0 : !poly.poly<10>
}

MLIR-Defining a New Dialect
https://darkenstar.github.io/2024/11/07/MLIR-Defining a New Dialect/
Author
ANNIHILATE_RAY
Posted on
November 7, 2024
Licensed under