Transforming and Deploying Stable Diffusion with TPU-MLIR

The TPU-MLIR compiler is capable of transforming deep learning models running on platforms such as GPUs into bmodel models that can run on arithmetic capability chips. This document provides an overview of how to use TPU-MLIR to port the stable diffusion model running on GPUs to arithmetic capability chips, along with the deployment methods involved. As stable diffusion encompasses multiple versions, we will use stable diffusion version 1.5 as an example here.

Introduction to the Stable Diffusion Model#

The stable diffusion model consists of three components: text_encoder, unet, and vae. The three readily usable models for these components are: CLIP_ViT, Latent Diffusion Models, and AutoencoderKL VAE. The model execution process is illustrated in the following figure:

Transform stable diffusion involves transplanting each of the three model components onto the arithmetic capability processor, and then connecting these three components together to form a complete stable diffusion model.

Transform the Stable Diffusion Model#

Get the Stable Diffusion Model#

The stable diffusion model can be obtained from the hugging face model hub. The following code snippet shows how to get the stable diffusion model from the hugging face model hub:

1
2
3
4
5
from diffusers import StableDiffusionPipeline
import torch

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

Please refer https://huggingface.co/runwayml/stable-diffusion-v1-5.

Transform the Text Encoder#

The text encoder is a CLIP_ViT model, which can be transformed into a onnx model, and then the onnx model can be transformed into a bmodel model.

Transform into an onnx model#

1
2
3
4
5
6
7
8
9
def export_textencoder(pipe):
for para in pipe.text_encoder.parameters():
para.requires_grad = False
batch = 1
fake_input = torch.randint(0, 1000, (batch, 77))
onnx_model_path = "./encoder.onnx"
torch.onnx.export(pipe.text_encoder, fake_input, onnx_model_path, verbose=True, opset_version=14, input_names=["input_ids"], output_names=["output"])

export_textencoder(pipe)

Make a transformation script#

1
2
model_transform.py  --model_name encoder  --input_shape [[1,77]]  --model_def encoder.onnx  --mlir encoder.mlir
model_deploy.py --mlir encoder.mlir --quantize F32 --processor bm1684x --model text_encoder_1684x_f32.bmodel

Run this transformation script to transform the onnx model into a bmodel model.

Transform the Unet Model#

The unet model is a Latent Diffusion Models model, which can be transformed into a pt model, and then the pt model can be transformed into a bmodel model.

Transform into a pt model#

1
2
3
4
5
6
7
8
9
10
11
def myunet(self, latent, timestep, encoder_hidden_states):
return self.unet(latent, timestep, encoder_hidden_states=encoder_hidden_states)

def export_unet(pipe):
img_size = (512, 512)
latent_model_input = torch.rand(2, 4, img_size[0]//8, img_size[1]//8)
t = torch.tensor([999])
prompt_embeds = torch.rand(2 77, 768)
fake_input = (latent_model_input, t, prompt_embeds)
jit_model = torch.jit.trace(myunet, fake_input)
jit_model.save("unet.pt")

By constructing a fake_input with a batch of 2, and then converting the unet model into a pt model using torch.jit.trace. The batch of 2 is considered to be the negative prompt and positive prompt cases, so the model uses a batch of 2.

Make a transformation script#

1
2
model_transform.py  --model_name unet  --input_shape [[2,4,64,64],[1],[2,77,768]]  --model_def unet.pt  --mlir unet.mlir
model_deploy.py --mlir unet.mlir --quantize F16 --processor bm1684x --model unet_1684x_f16.bmodel

Transform the VAE Model#

VAE is split into two parts in stable diffusion, encoder and decoder. Here we can convert the encoder and decoder into pt models, and then construct a conversion script to convert the pt models into bmodel models.

Transform VAE Encoder#

Transform into a pt model#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def export_vaencoder(pipe):
for para in pipe.vae.parameters():
para.requires_grad = False

vae = pipe.vae

class Encoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder1 = vae.encoder
self.encoder2 = vae.quant_conv

def forward(self, x):
x = self.encoder1(x)
x = self.encoder2(x)
return x
encoder = Encoder()
img_size = (512, 512)
img_size = (img_size[0]//8, img_size[1]//8)
latent = torch.rand(1,3, img_size[0]*8, img_size[1]*8)
jit_model = torch.jit.trace(encoder, latent)
encoder_model_path = './' + "vae_encoder" + '.pt'
jit_model.save(encoder_model_path)

Make a transformation script#

1
2
model_transform.py  --model_name vae_encoder  --input_shape [[1,3,512,512]]  --model_def vae_encoder.pt  --mlir vae_encoder.mlir
model_deploy.py --mlir vae_encoder.mlir --quantize F16 --processor bm1684x --model vae_encoder_1684x_f16.bmodel

Transform VAE Decoder#

Transform into a pt model#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def export_vaedecoder(pipe):
for para in pipe.vae.parameters():
para.requires_grad = False
vae = pipe.vae
class Decoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.decoder2 = vae.post_quant_conv
self.decoder1 = vae.decoder

def forward(self, x):
x = self.decoder2(x)
x = self.decoder1(x)
return x
decoder = Decoder()
img_size = (512, 512)
latent = torch.rand(1,4, img_size[0]//8, img_size[1]//8)
jit_model = torch.jit.trace(decoder, latent)
decoder_model_path = "./vae_decoder" + '.pt'
jit_model.save(decoder_model_path)

export_vaedecoder(pipe)

Make a transformation script#

1
2
model_transform.py  --model_name vae_decoder  --input_shape [[1,4,64,64]]  --model_def vae_decoder.pt  --mlir vae_decoder.mlir
model_deploy.py --mlir vae_decoder.mlir --quantize F16 --processor bm1684x --model vae_decoder_1684x_f16.bmodel

Here we have completed all the model transformation, and obtained the corresponding bmodel files. Next, we need to string together the three parts of the model, deploy them, and get a complete stable diffusion.

Deploy the Stable Diffusion Model#

Make a deployment script#

Some version dependency information#

1
2
3
4
5
6
diffusers==0.2.4
streamlit==1.15.1
streamlit-drawable-canvas==0.9.2
tokenizers==0.13.2
tqdm==4.64.1
transformers==4.24.0

Build the runtime#

Use sail to build the runtime of the model. Please refer to https://doc.sophgo.com/sdk-docs/v23.05.01/docs_latest_release/docs/sophon-sail/docs/zh/html/安装sail. Here for better call, we create a class.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np 
import time
import os
import sophon.sail as sail
class EngineOV:
def __init__(self, model_path="", device_id=0) :
self.model_path = model_path
self.device_id = device_id
try:
self.model = sail.Engine(model_path, device_id, sail.IOMode.SYSIO)
except Exception as e:
print("load model error; please check model path and device status;")
print(">>>> model_path: ",model_path)
print(">>>> device_id: ",device_id)
print(">>>> sail.Engine error: ",e)
raise e
sail.set_print_flag(True)
self.graph_name = self.model.get_graph_names()[0]
self.input_name = self.model.get_input_names(self.graph_name)
self.output_name= self.model.get_output_names(self.graph_name)

def __str__(self):
return "EngineOV: model_path={}, device_id={}".format(self.model_path,self.device_id)

def __call__(self, args):
if isinstance(args, list):
values = args
elif isinstance(args, dict):
values = list(args.values())
else:
raise TypeError("args is not list or dict")
args = {}
for i in range(len(values)):
args[self.input_name[i]] = values[i]
# import pdb; pdb.set_trace()
output = self.model.process(self.graph_name, args)
res = []

for name in self.output_name:
res.append(output[name])
return res

Build Pipeline#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import inspect
import numpy as np
# tokenizer
from transformers import CLIPTokenizer
# utils
from tqdm import tqdm
from diffusers import LMSDiscreteScheduler
import cv2
# engine
from .npuengine import EngineOV

class StableDiffusionPipeline:
def __init__(
self,
scheduler,
model="",
tokenizer="openai/clip-vit-large-patch14",
device="NPU"
):
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
self.scheduler = scheduler
# models
# text features
self.text_encoder = EngineOV("./bmodel/text_encoder_fp32.bmodel")
# diffusion
self.unet = EngineOV("./bmodel/unet_dtype_fix.bmodel",device_id=0)
self.latent_shape = (4,64,64 )
# decoder
self.vae_decoder = EngineOV("./bmodel/vae_decoder_fp16.bmodel")
# encoder
self.vae_encoder = EngineOV("./bmodel/vae_encoder.bmodel")
# self.vae_encoder = engine_type(model, "vae_encoder", device)
self.init_image_shape = (512,512)

def _preprocess_mask(self, mask):
h, w = mask.shape
if h != self.init_image_shape[0] and w != self.init_image_shape[1]:
mask = cv2.resize(
mask,
(self.init_image_shape[1], self.init_image_shape[0]),
interpolation = cv2.INTER_NEAREST
)
mask = cv2.resize(
mask,
(self.init_image_shape[1] // 8, self.init_image_shape[0] // 8),
interpolation = cv2.INTER_NEAREST
)
mask = mask.astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = 1 - mask
return mask

def _preprocess_image(self, image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[1:]
if h != self.init_image_shape[0] and w != self.init_image_shape[1]:
image = cv2.resize(
image,
(self.init_image_shape[1], self.init_image_shape[0]),
interpolation=cv2.INTER_LANCZOS4
)
# normalize
image = image.astype(np.float32) / 255.0
image = 2.0 * image - 1.0
# to batch
image = image[None].transpose(0, 3, 1, 2)
return image

def _encode_image(self, init_image):
moments = self.vae_encoder({
"input.1": self._preprocess_image(init_image)
})
mean, logvar = np.split(moments, 2, axis=1)
std = np.exp(logvar * 0.5)
latent = (mean + std * np.random.randn(*mean.shape)) * 0.18215
return latent

def __call__(
self,
prompt,
init_image = None,
mask = None,
strength = 0.5,
num_inference_steps = 32,
guidance_scale = 7.5,
eta = 0.0
):
# extract condition
tokens = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True
).input_ids


# text_embedding use npu engine to inference
text_embeddings = self.text_encoder({"tokens": np.array([tokens])})
# do classifier free guidance
if guidance_scale > 1.0:
tokens_uncond = self.tokenizer(
"",
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True
).input_ids
#np.save("bins3/tokens_uncond.npy",np.array(tokens_uncond))
uncond_embeddings = self.text_encoder({"tokens": np.array([tokens_uncond])})
#np.save("bins3/uncond_embeddings.npy",uncond_embeddings)
text_embeddings = np.concatenate((uncond_embeddings, text_embeddings), axis=0)

# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

# initialize latent latent
if init_image is None:
latents = np.random.randn(*self.latent_shape)
# latents = np.load("bins2/latents_init.npy")
init_timestep = num_inference_steps
else:
init_latents = self._encode_image(init_image)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = np.array([[self.scheduler.timesteps[-init_timestep]]]).astype(np.long)
noise = np.random.randn(*self.latent_shape)
latents = self.scheduler.add_noise(init_latents, noise, timesteps)[0]

if init_image is not None and mask is not None:
mask = self._preprocess_mask(mask)
else:
mask = None

# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# #np.save("latents2.npy", latents)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.stack([latents, latents], 0) if guidance_scale > 1.0 else latents[None]
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# unet model we have modified, the `t` is batch sequence, so we need to expand the t
# t: (1) -> (batch_size)
batch_size = latent_model_input.shape[0]
newt = np.tile(t, (batch_size)).astype(np.long)

noise_pred = self.unet({
"input.1": latent_model_input,
"timesteps.1": newt,
"input0.1": text_embeddings
})

# perform guidance
if guidance_scale > 1.0:
noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0])

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# masking for inapinting
if mask is not None:
init_latents_proper = self.scheduler.add_noise(init_latents, noise, t)
latents = ((init_latents_proper * mask) + (latents * (1 - mask)))[0]

latents = 1 / 0.18215 * latents
image = self.vae_decoder({"input.1": np.expand_dims(latents, 0)})
#np.save("bins3/image.npy", image)
# convert tensor to opencv's image format
image = (image / 2 + 0.5).clip(0, 1)
image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
return image

Build Runtime Script#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -- coding: utf-8 --`
import argparse
import os
import random
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import numpy as np
import cv2
from PIL import Image, ImageEnhance
import numpy as np
# engine
from stable_diffusion import StableDiffusionPipeline
# scheduler
from diffusers import PNDMScheduler
def run(engine):
st.header("Sophon AI ZOOM")
with st.form(key="request"):
with st.sidebar:
prompt = st.text_area(label='Enter prompt(please use English)')

with st.expander("Initial image"):
init_image = st.file_uploader("init_image", type=['jpg','png','jpeg'])
stroke_width = st.slider("stroke_width", 1, 100, 50)
stroke_color = st.color_picker("stroke_color", "#00FF00")
canvas_result = st_canvas(
fill_color="rgb(0, 0, 0)",
stroke_width = stroke_width,
stroke_color = stroke_color,
background_color = "#000000",
background_image = Image.open(init_image) if init_image else None,
height = 512,
width = 512,
drawing_mode = "freedraw",
key = "canvas"
)

if init_image is not None:
init_image = cv2.cvtColor(np.array(Image.open(init_image)), cv2.COLOR_RGB2BGR)

if canvas_result.image_data is not None:
mask = cv2.cvtColor(canvas_result.image_data, cv2.COLOR_BGRA2GRAY)
mask[mask > 0] = 255
else:
mask = None

num_inference_steps = st.select_slider(
label='num_inference_steps',
options=range(1, 150),
value=32
)

guidance_scale = st.select_slider(
label='guidance_scale',
options=range(1, 21),
value=7
)

strength = st.slider(
label='strength',
min_value = 0.0,
max_value = 1.0,
value = 0.5
)

seed = st.number_input(
label='seed',
min_value = 0,
max_value = 2 ** 31,
value = random.randint(0, 2 ** 31)
)

generate = st.form_submit_button(label = 'Generate')

if prompt:
np.random.seed(seed)
image = engine(
prompt = prompt,
init_image = init_image,
mask = mask,
strength = strength,
num_inference_steps = num_inference_steps,
guidance_scale = guidance_scale
)
st.image(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)), width=512)

@st.cache(allow_output_mutation=True)
def load_pipeline(args):
scheduler = PNDMScheduler(
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
skip_prk_steps = True,
tensor_format="np"
)
pipeline = StableDiffusionPipeline(
model = args.model,
scheduler = scheduler,
tokenizer = args.tokenizer
)
return pipeline


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# pipeline configure
parser.add_argument("--model", type=str, default="bes-dev/stable-diffusion-v1-4-openvino", help="model name")
# scheduler params
parser.add_argument("--beta-start", type=float, default=0.00085, help="LMSDiscreteScheduler::beta_start")
parser.add_argument("--beta-end", type=float, default=0.012, help="LMSDiscreteScheduler::beta_end")
parser.add_argument("--beta-schedule", type=str, default="scaled_linear", help="LMSDiscreteScheduler::beta_schedule")
# tokenizer
parser.add_argument("--tokenizer", type=str, default="openai/clip-vit-large-patch14", help="tokenizer")

try:
args = parser.parse_args()
except SystemExit as e:
# This exception will be raised if --help or invalid command line arguments
# are used. Currently streamlit prevents the program from exiting normally
# so we have to do a hard exit.
os._exit(e.code)

engine = load_pipeline(args)
run(engine)

It starts streamlit server accessible through a web browser.