主要工作
TpuLang提供了mlir对外的接口函数。用户通过Tpulang可以直接组建用户自己的网络,将模型转换为 Top 层(芯片无关层) mlir 模型
(不包含 Canonicalize 部分, 因此生成的文件名为“*_origin.mlir”)。这个过程会根据输入的接口函数逐
一创建并添加算子(Op), 最终生成 mlir 文件与保存权重的 npz 文件。
算子转换样例
本节以 Conv 算子为例, 将单 Conv 算子模型转换为 Top mlir, 原模型定义如图所示(单 Conv 模型)
转换流程为:
接口定义
conv_v2 接口定义如下:
def conv_v2(tensor_i,
weight,
bias = None,
stride = None,
dilation = None,
pad = None,
group = 1,
input_zp = None,
weight_zp = None,
out_dtype = None,
out_name = None):
# pass
参数说明
tensor_i:Tensor类型,表示输入Tensor,4维NCHW格式。
weight:Tensor类型,表示卷积核Tensor,4维[oc, ic, kh, kw]格式。其中oc表示输出Channel数,ic表示输入channel数,kh是kernel_h,kw是kernel_w。
bias:Tensor类型,表示偏置Tensor。为None时表示无偏置,反之则要求shape为[1, oc, 1, 1]。
dilation:List[int],表示空洞大小,取None则表示[1,1],不为None时要求长度为2。List中顺序为[长,宽]
pad:List[int],表示填充大小,取None则表示[0,0,0,0],不为None时要求长度为4。List中顺序为[上, 下, 左, 右]
stride:List[int],表示步长大小,取None则表示[1,1],不为None时要求长度为2。List中顺序为[长,宽]
groups:int型,表示卷积层的组数。若ic=oc=groups时,则卷积为depthwise conv
input_zp:List[int]型或int型,表示输入偏移。取None则表示0,取List时要求长度为ic。
weight_zp:List[int]型或int型,表示卷积核偏移。取None则表示0,取List时要求长度为ic,其中ic表示输入的Channel数。
out_dtype:string类型或None,表示输出Tensor的类型。输入tensor类型为float16/float32时,取None表示输出tensor类型与输入一致,否则取None表示为int32。取值范围:/int32/uint32/float32/float16
out_name:string类型或None,表示输出Tensor的名称,为None时内部会自动产生名称。
在 TopOps.td 中定义 Top.Conv 算子, 算子定义如图所示(Conv 算子定义)
构建 Graph
init_MLIRImporter:
根据 input_names 与 output_names 从 shapes 中获取了对应的 input_shape 与 output_shape, 加上model_name, 生成了初始的 mlir 文本 MLIRImporter.mlir_module, 如图所示(初始 mlir 文本)。
generate_mlir
build input op, 生成的 Top.inputOp 会被插入到 MLIRImporter.mlir_module 中。
调用Operation.create 来创建 Top.ConvOp, 而 create 函数需要的参数有:
输入 op: 从接口定义可知, Conv 算子的 inputs 一共包含了 input, weight 与 bias, inputOp 已被创建好, weight 与 bias 的 op 则通过 getWeightOp()创建。
output_shape: 利用 Operator 中存储的输出 tensor 中获取其 shape。
Attributes: 从 Operator 中获取 attributes,并将attributes转换为MLIRImporter识别的Attributes
Top.ConvOp 创建后会被插入到 mlir 文本中
根据 output_names 从 operands 中获取相应的 op, 创建 return_op 并插入到 mlir 文本中。到此为止, 生成的 mlir 文本如图所示(完整的 mlir 文本)。
输出
将 mlir 文本保存为 Conv_origin.mlir, tensors 中的权重保存为 Conv_TOP_F32_all_weight.npz。