MLIR 之 Dialect Conversion

在之前的 MLIR 语法介绍视频中有介绍,MLIR 的一大特点就是它提出了 Dialect 的概念,深度学习模型通过在多层 Dialect 的 IR 间进行转换、优化,最终形成能够在硬件上运行的二进制文件,从一个 Dialect 转换到另一个 Dialect 的过程就是 Conversion

MLIR 中提供了一个 Dialect Conversion 框架来支持这个转换过程,本质上其实就是将一系列在目标 Dialect 上非法的算子转为合法算子

不同于传统编译器只能进行全转换,MLIR 提供了较为灵活的转换模式,既可以只进行全转换,也可以进行部分转换。

其中 partialConversion 是指保留没有被明确标记为非法的 op,然后尽可能多地对其余 op 进行合法化,这样即使原 IR 中存在某些未知的 op,也能进行转换工作;AnalysisConversion 则会通过部分转换并记录 op 转换成功与否,来分析哪些 op 是可合法化的;FullConversion 顾名思义,会合法化所有的 op,所以转换过后的 IR 中只存在已知的 op。

为了使用这些转换模式,我们需要先准备好 Dialect Conversion 所需的组件。使用 Dialect Conversion 主要需要三个组件:

首先是 Target,主要用于明确在转换过程中哪些算子和 Dialect 是合法的,算子和 Dialect 可以被标记为合法,动态与非法三种 actions,其中动态是指某些算子只有在部分实例中是合法的。

既可以继承 **ConversionTarget **父类创建一个自定义的 target,也可以在 ConversionTarget 的实例中直接通过 add 系列的 function 添加合法与非法 Op。另外,还可以通过 markOpRecursivelyLegal 来将某个区域,即某个 Op 中嵌套的所有 Op 都定义为合法。

定义完合法与非法算子后,我们就需要合法化 patterns 来将非法算子转为合法算子,所以 rewrite pattern 是用于实现非法算子转换为合法算子的转换逻辑。

Dialect Conversion 框架会自动根据所提供的的 patterns 生成一个转换图用于合法化,从而简化整个改写的流程,例如我们的 patterns 中只提到 Dialect A 的 op0 可以合法化为 B 中的 op0,B 中的 op0 可以合法化为 C 中的 op0,Conversion 框架就会自动检测 DialectA 的 op0 可以合法化为 DialectC 中的 op0,而不用经过中间的算子转换。

RewritePattern 中还有个特例,ConversionPattern,相对于传统的 rewritepattern,它多了一个 operands 的输入参数,这个 operands 数组用于记录那些被重映射的操作数,主要用于处理那些存在 type 转换的算子,算子会对 type 转换后的 values 进行操作,但仍需要与原操作数相匹配。

当存在 type 转换时,则需要 Type Converter 来定义 type 在与 pattern 交互时的转换方式,重映射后的操作数类型需要与 type converter 规定的一致,如果没有提供 type converter,则这些操作数的 type 需要与原操作数相匹配,否则 pattern 的应用就会在调用 matchAndRewrite 之前失败,以此来保证 type 的合法性,所以 patterns 并不需要担心 type 的安全性问题。

TypeConverter 主要由两个方面组成,Conversion 和 Materialization,两者都是用来定义 Type 的转换方式,不过后者能够生成 IR,并且可以根据转换过程中的不同需求,实现从 type 的来回转换。在 Dialect 转换的 pass 中准备好以上所有所需组件后,便可以通过转换接口实现 Dialect 间的 IR 转换。