利用TPU-MLIR在BM1684X上移植部署Stable Diffusion

背景介绍#

TPU-MLIR编译器可以将在gpu等平台上运行的深度学习模型转换成算能处理器上运行的bmodel模型。本文介绍如何使用TPU-MLIR将在gpu上运行的stable diffusion的模型移植到算能处理器,并介绍其部署的方法。 由于stable diffusion包含多个版本,我们这里以stable diffusion 1.5版本为例。

stable diffusion模型介绍#

stable diffusion 由3个部分组成,分别是text_encoder、unet和vae。其中3个部分方便使用的模型是:CLIP_ViT, Latent Diffusion Models和AutoencoderKL VAE。模型运行流程如下图所示:

移植stable diffusion需要将3个部分的模型分别移植到算能处理器上,并将3个部分的模型串起来,得到一个完整的stable diffusion。

模型移植#

获取模型#

可以通过hugging face来获取stable diffusion v1.5版本的模型,具体的方法如下:

1
2
3
4
5
from diffusers import StableDiffusionPipeline
import torch

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

参考: https://huggingface.co/runwayml/stable-diffusion-v1-5 hugging face官网链接。

移植text_encoder#

这里可以将text_encoder模型转为onnx模型,然后构造转换脚本,将onnx模型转换为bmodel模型。

转换为onnx模型#

1
2
3
4
5
6
7
8
9
def export_textencoder(pipe):
for para in pipe.text_encoder.parameters():
para.requires_grad = False
batch = 1
fake_input = torch.randint(0, 1000, (batch, 77))
onnx_model_path = "./encoder.onnx"
torch.onnx.export(pipe.text_encoder, fake_input, onnx_model_path, verbose=True, opset_version=14, input_names=["input_ids"], output_names=["output"])

export_textencoder(pipe)

通过构造一个batch为1的fake_input,然后将text_encoder模型使用torch.onnx.export转换为onnx模型。

构造转换脚本#

1
2
model_transform.py  --model_name encoder  --input_shape [[1,77]]  --model_def encoder.onnx  --mlir encoder.mlir
model_deploy.py --mlir encoder.mlir --quantize F32 --processor bm1684x --model text_encoder_1684x_f32.bmodel

执行这个转换脚本,可以将onnx模型转换为bmodel模型。

移植unet#

这里可以将unet模型转为pt模型,然后构造转换脚本,将pt模型转换为bmodel模型。 需要注意的是如果unet会接上controlnet作为插件的话,其转换方式有一定的差异。

转换为pt模型#

1
2
3
4
5
6
7
8
9
10
11
def myunet(self, latent, timestep, encoder_hidden_states):
return self.unet(latent, timestep, encoder_hidden_states=encoder_hidden_states)

def export_unet(pipe):
img_size = (512, 512)
latent_model_input = torch.rand(2, 4, img_size[0]//8, img_size[1]//8)
t = torch.tensor([999])
prompt_embeds = torch.rand(2 77, 768)
fake_input = (latent_model_input, t, prompt_embeds)
jit_model = torch.jit.trace(myunet, fake_input)
jit_model.save("unet.pt")

通过构造一个batch为2的fake_input,然后将unet模型使用torch.jit.trace转换为pt模型。这里batch为2是应该考虑到negative promptpositive prompt两种情况,所以模型采用了batch为2的方式。

构造转换脚本#

1
2
model_transform.py  --model_name unet  --input_shape [[2,4,64,64],[1],[2,77,768]]  --model_def unet.pt  --mlir unet.mlir
model_deploy.py --mlir unet.mlir --quantize F16 --processor bm1684x --model unet_1684x_f16.bmodel

移植vae#

vae在stable diffusion中是拆开的,分为encoder和decoder两个部分,这里可以将encoder和decoder分别转为pt模型,然后构造转换脚本,将pt模型转换为bmodel模型。

移植vae encoder#

转换为pt模型#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def export_vaencoder(pipe):
for para in pipe.vae.parameters():
para.requires_grad = False

vae = pipe.vae

class Encoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder1 = vae.encoder
self.encoder2 = vae.quant_conv

def forward(self, x):
x = self.encoder1(x)
x = self.encoder2(x)
return x
encoder = Encoder()
img_size = (512, 512)
img_size = (img_size[0]//8, img_size[1]//8)
latent = torch.rand(1,3, img_size[0]*8, img_size[1]*8)
jit_model = torch.jit.trace(encoder, latent)
encoder_model_path = './' + "vae_encoder" + '.pt'
jit_model.save(encoder_model_path)

构造转换脚本#

1
2
model_transform.py  --model_name vae_encoder  --input_shape [[1,3,512,512]]  --model_def vae_encoder.pt  --mlir vae_encoder.mlir
model_deploy.py --mlir vae_encoder.mlir --quantize F16 --processor bm1684x --model vae_encoder_1684x_f16.bmodel

移植vae decoder#

转换为pt模型#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def export_vaedecoder(pipe):
for para in pipe.vae.parameters():
para.requires_grad = False
vae = pipe.vae
class Decoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.decoder2 = vae.post_quant_conv
self.decoder1 = vae.decoder

def forward(self, x):
x = self.decoder2(x)
x = self.decoder1(x)
return x
decoder = Decoder()
img_size = (512, 512)
latent = torch.rand(1,4, img_size[0]//8, img_size[1]//8)
jit_model = torch.jit.trace(decoder, latent)
decoder_model_path = "./vae_decoder" + '.pt'
jit_model.save(decoder_model_path)

export_vaedecoder(pipe)

构造转换脚本#

1
2
model_transform.py  --model_name vae_decoder  --input_shape [[1,4,64,64]]  --model_def vae_decoder.pt  --mlir vae_decoder.mlir
model_deploy.py --mlir vae_decoder.mlir --quantize F16 --processor bm1684x --model vae_decoder_1684x_f16.bmodel

这里完成了所有的模型移植,都得到了相应的bmodel文件,接下来需要将3个部分的模型串起来,进行部署,得到一个完整的stable diffusion。

模型部署#

构造部署脚本#

一些版本依赖信息#

1
2
3
4
5
6
diffusers==0.2.4
streamlit==1.15.1
streamlit-drawable-canvas==0.9.2
tokenizers==0.13.2
tqdm==4.64.1
transformers==4.24.0

构建运行时#

使用sail构建模型的运行时,请参考https://doc.sophgo.com/sdk-docs/v23.05.01/docs_latest_release/docs/sophon-sail/docs/zh/html/安装sail,这里为了更好的调用,先封装了一层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np 
import time
import os
import sophon.sail as sail
class EngineOV:
def __init__(self, model_path="", device_id=0) :
self.model_path = model_path
self.device_id = device_id
try:
self.model = sail.Engine(model_path, device_id, sail.IOMode.SYSIO)
except Exception as e:
print("load model error; please check model path and device status;")
print(">>>> model_path: ",model_path)
print(">>>> device_id: ",device_id)
print(">>>> sail.Engine error: ",e)
raise e
sail.set_print_flag(True)
self.graph_name = self.model.get_graph_names()[0]
self.input_name = self.model.get_input_names(self.graph_name)
self.output_name= self.model.get_output_names(self.graph_name)

def __str__(self):
return "EngineOV: model_path={}, device_id={}".format(self.model_path,self.device_id)

def __call__(self, args):
if isinstance(args, list):
values = args
elif isinstance(args, dict):
values = list(args.values())
else:
raise TypeError("args is not list or dict")
args = {}
for i in range(len(values)):
args[self.input_name[i]] = values[i]
# import pdb; pdb.set_trace()
output = self.model.process(self.graph_name, args)
res = []

for name in self.output_name:
res.append(output[name])
return res

构造pipeline#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import inspect
import numpy as np
# tokenizer
from transformers import CLIPTokenizer
# utils
from tqdm import tqdm
from diffusers import LMSDiscreteScheduler
import cv2
# engine
from .npuengine import EngineOV

class StableDiffusionPipeline:
def __init__(
self,
scheduler,
model="",
tokenizer="openai/clip-vit-large-patch14",
device="NPU"
):
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
self.scheduler = scheduler
# models
# text features
self.text_encoder = EngineOV("./bmodel/text_encoder_fp32.bmodel")
# diffusion
self.unet = EngineOV("./bmodel/unet_dtype_fix.bmodel",device_id=0)
self.latent_shape = (4,64,64 )
# decoder
self.vae_decoder = EngineOV("./bmodel/vae_decoder_fp16.bmodel")
# encoder
self.vae_encoder = EngineOV("./bmodel/vae_encoder.bmodel")
# self.vae_encoder = engine_type(model, "vae_encoder", device)
self.init_image_shape = (512,512)

def _preprocess_mask(self, mask):
h, w = mask.shape
if h != self.init_image_shape[0] and w != self.init_image_shape[1]:
mask = cv2.resize(
mask,
(self.init_image_shape[1], self.init_image_shape[0]),
interpolation = cv2.INTER_NEAREST
)
mask = cv2.resize(
mask,
(self.init_image_shape[1] // 8, self.init_image_shape[0] // 8),
interpolation = cv2.INTER_NEAREST
)
mask = mask.astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = 1 - mask
return mask

def _preprocess_image(self, image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[1:]
if h != self.init_image_shape[0] and w != self.init_image_shape[1]:
image = cv2.resize(
image,
(self.init_image_shape[1], self.init_image_shape[0]),
interpolation=cv2.INTER_LANCZOS4
)
# normalize
image = image.astype(np.float32) / 255.0
image = 2.0 * image - 1.0
# to batch
image = image[None].transpose(0, 3, 1, 2)
return image

def _encode_image(self, init_image):
moments = self.vae_encoder({
"input.1": self._preprocess_image(init_image)
})
mean, logvar = np.split(moments, 2, axis=1)
std = np.exp(logvar * 0.5)
latent = (mean + std * np.random.randn(*mean.shape)) * 0.18215
return latent

def __call__(
self,
prompt,
init_image = None,
mask = None,
strength = 0.5,
num_inference_steps = 32,
guidance_scale = 7.5,
eta = 0.0
):
# extract condition
tokens = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True
).input_ids


# text_embedding use npu engine to inference
text_embeddings = self.text_encoder({"tokens": np.array([tokens])})
# do classifier free guidance
if guidance_scale > 1.0:
tokens_uncond = self.tokenizer(
"",
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True
).input_ids
#np.save("bins3/tokens_uncond.npy",np.array(tokens_uncond))
uncond_embeddings = self.text_encoder({"tokens": np.array([tokens_uncond])})
#np.save("bins3/uncond_embeddings.npy",uncond_embeddings)
text_embeddings = np.concatenate((uncond_embeddings, text_embeddings), axis=0)

# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

# initialize latent latent
if init_image is None:
latents = np.random.randn(*self.latent_shape)
# latents = np.load("bins2/latents_init.npy")
init_timestep = num_inference_steps
else:
init_latents = self._encode_image(init_image)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = np.array([[self.scheduler.timesteps[-init_timestep]]]).astype(np.long)
noise = np.random.randn(*self.latent_shape)
latents = self.scheduler.add_noise(init_latents, noise, timesteps)[0]

if init_image is not None and mask is not None:
mask = self._preprocess_mask(mask)
else:
mask = None

# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# #np.save("latents2.npy", latents)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.stack([latents, latents], 0) if guidance_scale > 1.0 else latents[None]
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# unet model we have modified, the `t` is batch sequence, so we need to expand the t
# t: (1) -> (batch_size)
batch_size = latent_model_input.shape[0]
newt = np.tile(t, (batch_size)).astype(np.long)

noise_pred = self.unet({
"input.1": latent_model_input,
"timesteps.1": newt,
"input0.1": text_embeddings
})

# perform guidance
if guidance_scale > 1.0:
noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# masking for inapinting
if mask is not None:
init_latents_proper = self.scheduler.add_noise(init_latents, noise, t)
latents = ((init_latents_proper * mask) + (latents * (1 - mask)))[0]

latents = 1 / 0.18215 * latents
image = self.vae_decoder({"input.1": np.expand_dims(latents, 0)})
#np.save("bins3/image.npy", image)
# convert tensor to opencv's image format
image = (image / 2 + 0.5).clip(0, 1)
image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
return image

构造运行脚本#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -- coding: utf-8 --`
import argparse
import os
import random
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import numpy as np
import cv2
from PIL import Image, ImageEnhance
import numpy as np
# engine
from stable_diffusion import StableDiffusionPipeline
# scheduler
from diffusers import PNDMScheduler
def run(engine):
st.header("Sophon AI ZOOM")
with st.form(key="request"):
with st.sidebar:
prompt = st.text_area(label='Enter prompt(please use English)')

with st.expander("Initial image"):
init_image = st.file_uploader("init_image", type=['jpg','png','jpeg'])
stroke_width = st.slider("stroke_width", 1, 100, 50)
stroke_color = st.color_picker("stroke_color", "#00FF00")
canvas_result = st_canvas(
fill_color="rgb(0, 0, 0)",
stroke_width = stroke_width,
stroke_color = stroke_color,
background_color = "#000000",
background_image = Image.open(init_image) if init_image else None,
height = 512,
width = 512,
drawing_mode = "freedraw",
key = "canvas"
)

if init_image is not None:
init_image = cv2.cvtColor(np.array(Image.open(init_image)), cv2.COLOR_RGB2BGR)

if canvas_result.image_data is not None:
mask = cv2.cvtColor(canvas_result.image_data, cv2.COLOR_BGRA2GRAY)
mask[mask > 0] = 255
else:
mask = None

num_inference_steps = st.select_slider(
label='num_inference_steps',
options=range(1, 150),
value=32
)

guidance_scale = st.select_slider(
label='guidance_scale',
options=range(1, 21),
value=7
)

strength = st.slider(
label='strength',
min_value = 0.0,
max_value = 1.0,
value = 0.5
)

seed = st.number_input(
label='seed',
min_value = 0,
max_value = 2 ** 31,
value = random.randint(0, 2 ** 31)
)

generate = st.form_submit_button(label = 'Generate')

if prompt:
np.random.seed(seed)
image = engine(
prompt = prompt,
init_image = init_image,
mask = mask,
strength = strength,
num_inference_steps = num_inference_steps,
guidance_scale = guidance_scale
)
st.image(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)), width=512)

@st.cache(allow_output_mutation=True)
def load_pipeline(args):
scheduler = PNDMScheduler(
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
skip_prk_steps = True,
tensor_format="np"
)
pipeline = StableDiffusionPipeline(
model = args.model,
scheduler = scheduler,
tokenizer = args.tokenizer
)
return pipeline


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# pipeline configure
parser.add_argument("--model", type=str, default="bes-dev/stable-diffusion-v1-4-openvino", help="model name")
# scheduler params
parser.add_argument("--beta-start", type=float, default=0.00085, help="LMSDiscreteScheduler::beta_start")
parser.add_argument("--beta-end", type=float, default=0.012, help="LMSDiscreteScheduler::beta_end")
parser.add_argument("--beta-schedule", type=str, default="scaled_linear", help="LMSDiscreteScheduler::beta_schedule")
# tokenizer
parser.add_argument("--tokenizer", type=str, default="openai/clip-vit-large-patch14", help="tokenizer")

try:
args = parser.parse_args()
except SystemExit as e:
# This exception will be raised if --help or invalid command line arguments
# are used. Currently streamlit prevents the program from exiting normally
# so we have to do a hard exit.
os._exit(e.code)

engine = load_pipeline(args)
run(engine)

这里可以开启一个streamlit的服务,通过浏览器访问。