Skip to content

Commit f7f8d9a

Browse files
committed
[WIP] QNN GPU Procyon recipes
1 parent ee39438 commit f7f8d9a

9 files changed

Lines changed: 608 additions & 14 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"input_model": {
3+
"type": "HfModel",
4+
"model_path": "Salesforce/blip-image-captioning-base",
5+
"task": "image-text-to-text",
6+
"load_kwargs": {
7+
"output_attentions": true
8+
},
9+
"io_config": {
10+
"input_names": [
11+
"pixel_values",
12+
"input_ids",
13+
"attention_mask"
14+
],
15+
"input_shapes": [
16+
[ 1, 3, 384, 384 ],
17+
[ 1, 31 ],
18+
[ 1, 31 ]
19+
],
20+
"input_types": [
21+
"float32",
22+
"int64",
23+
"int64"
24+
],
25+
"output_names": [ "logits", "encoder_hidden_states" ]
26+
}
27+
},
28+
"systems": {
29+
"target_system": {
30+
"type": "LocalSystem",
31+
"accelerators": [ { "device": "gpu", "execution_providers": [ "QNNExecutionProvider" ] } ]
32+
}
33+
},
34+
"passes": {
35+
"cs": {
36+
"type": "CaptureSplitInfo",
37+
"block_to_split": [ "vision_model", "text_decoder" ],
38+
"num_splits": 2
39+
},
40+
"conversion": { "device": "cpu", "type": "OnnxConversion", "target_opset": 20 },
41+
"add_metadata": { "type": "AddOliveMetadata", "graph_name": "main_graph" },
42+
"surgery": {
43+
"type": "GraphSurgeries",
44+
"surgeries": [ { "surgeon": "MatMulAddToGemm" } ]
45+
},
46+
"sp": { "type": "SplitModel" }
47+
},
48+
"host": "target_system",
49+
"target": "target_system",
50+
"cache_dir": "cache",
51+
"output_dir": "output",
52+
"evaluate_input_model": false,
53+
"clean_cache": true
54+
}
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
import argparse
7+
import math
8+
9+
import onnx
10+
import torch
11+
import torch.nn.functional as F
12+
13+
from onnxsim import simplify
14+
from torch import nn
15+
from transformers import Swin2SRForImageSuperResolution
16+
17+
def window_partition(input_feature, window_size):
18+
"""
19+
Partitions the given input into windows.
20+
"""
21+
batch_size, height, width, num_channels = input_feature.shape
22+
input_feature = input_feature.reshape(
23+
batch_size * height // window_size, window_size, width // window_size, window_size * num_channels
24+
)
25+
windows = input_feature.permute(0, 2, 1, 3).contiguous().view(-1, window_size, window_size, num_channels)
26+
return windows
27+
28+
29+
def window_reverse(windows, window_size, height, width):
30+
"""
31+
Merges windows to produce higher resolution features.
32+
"""
33+
num_channels = windows.shape[-1]
34+
windows = windows.reshape(-1, width // window_size, window_size, window_size * num_channels)
35+
windows = windows.permute(0, 2, 1, 3).contiguous().view(-1, height, width, num_channels)
36+
return windows
37+
38+
class ModSwin2SRSelfAttention(nn.Module):
39+
"""
40+
Swin2SR self attention with modifications to reduce tensor ranks.
41+
"""
42+
def __init__(self, self_attention, device):
43+
super().__init__()
44+
self.model = self_attention
45+
46+
def forward(
47+
self,
48+
hidden_states: torch.Tensor,
49+
attention_mask: torch.FloatTensor | None = None,
50+
head_mask: torch.FloatTensor | None = None,
51+
output_attentions: bool | None = False,
52+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
53+
batch_size, dim, num_channels = hidden_states.shape
54+
query_layer = (
55+
self.model.query(hidden_states)
56+
.view(batch_size, -1, self.model.num_attention_heads, self.model.attention_head_size)
57+
.transpose(1, 2)
58+
)
59+
key_layer = (
60+
self.model.key(hidden_states)
61+
.view(batch_size, -1, self.model.num_attention_heads, self.model.attention_head_size)
62+
.transpose(1, 2)
63+
)
64+
value_layer = (
65+
self.model.value(hidden_states)
66+
.view(batch_size, -1, self.model.num_attention_heads, self.model.attention_head_size)
67+
.transpose(1, 2)
68+
)
69+
70+
query_layer = query_layer.reshape(batch_size * self.model.num_attention_heads, -1, self.model.attention_head_size)
71+
key_layer = key_layer.reshape(batch_size * self.model.num_attention_heads, -1, self.model.attention_head_size)
72+
value_layer = value_layer.reshape(batch_size * self.model.num_attention_heads, -1, self.model.attention_head_size)
73+
74+
# cosine attention
75+
attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
76+
key_layer, dim=-1
77+
).transpose(-2, -1)
78+
logit_scale = torch.clamp(self.model.logit_scale, max=math.log(1.0 / 0.01)).exp()
79+
80+
logit_scale = logit_scale.expand(batch_size, self.model.num_attention_heads, 1, 1).reshape(batch_size * self.model.num_attention_heads, 1, 1)
81+
logit_scale = logit_scale.reshape(batch_size * self.model.num_attention_heads, 1, 1)
82+
83+
attention_scores = attention_scores * logit_scale
84+
relative_position_bias_table = self.model.continuous_position_bias_mlp(self.model.relative_coords_table).view(
85+
-1, self.model.num_attention_heads
86+
)
87+
# [window_height*window_width,window_height*window_width,num_attention_heads]
88+
relative_position_bias = relative_position_bias_table[self.model.relative_position_index.view(-1)].view(
89+
self.model.window_size[0] * self.model.window_size[1], self.model.window_size[0] * self.model.window_size[1], -1
90+
)
91+
# [num_attention_heads,window_height*window_width,window_height*window_width]
92+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
93+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
94+
95+
relative_position_bias = relative_position_bias.expand(batch_size,
96+
self.model.num_attention_heads,
97+
self.model.window_size[0] * self.model.window_size[1],
98+
self.model.window_size[0] * self.model.window_size[1])
99+
relative_position_bias = relative_position_bias.reshape(batch_size * self.model.num_attention_heads,
100+
self.model.window_size[0] * self.model.window_size[1],
101+
self.model.window_size[0] * self.model.window_size[1])
102+
attention_scores = attention_scores + relative_position_bias
103+
104+
if attention_mask is not None:
105+
# Apply the attention mask is (precomputed for all layers in Swin2SRModel forward() function)
106+
mask_shape = attention_mask.shape[0]
107+
attention_mask = attention_mask.unsqueeze(1).expand(mask_shape,
108+
self.model.num_attention_heads,
109+
dim,
110+
dim)
111+
attention_mask = attention_mask.reshape(mask_shape * self.model.num_attention_heads,
112+
dim,
113+
dim)
114+
attention_scores = attention_scores + attention_mask
115+
attention_scores = attention_scores + attention_mask
116+
117+
# Normalize the attention scores to probabilities.
118+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
119+
120+
# This is actually dropping out entire tokens to attend to, which might
121+
# seem a bit unusual, but is taken from the original Transformer paper.
122+
attention_probs = self.model.dropout(attention_probs)
123+
124+
# Mask heads if we want to
125+
if head_mask is not None:
126+
attention_probs = attention_probs * head_mask
127+
128+
context_layer = torch.matmul(attention_probs, value_layer)
129+
130+
context_layer = context_layer.reshape(batch_size, self.model.num_attention_heads, -1, self.model.attention_head_size)
131+
132+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
133+
new_context_layer_shape = context_layer.size()[:-2] + (self.model.all_head_size,)
134+
context_layer = context_layer.view(new_context_layer_shape)
135+
136+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
137+
138+
return outputs
139+
140+
class Swin2SRLayer(nn.Module):
141+
def __init__(self, layer, device):
142+
super().__init__()
143+
self.model = layer
144+
self.model.attention.self = ModSwin2SRSelfAttention(self.model.attention.self, device)
145+
146+
def forward(
147+
self,
148+
hidden_states: torch.Tensor,
149+
input_dimensions: tuple[int, int],
150+
head_mask: torch.FloatTensor | None = None,
151+
output_attentions: bool | None= False,
152+
) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
153+
height, width = input_dimensions
154+
batch_size, _, channels = hidden_states.size()
155+
shortcut = hidden_states
156+
157+
# pad hidden_states to multiples of window size
158+
hidden_states = hidden_states.view(batch_size, height, width, channels)
159+
hidden_states, pad_values = self.model.maybe_pad(hidden_states, height, width)
160+
_, height_pad, width_pad, _ = hidden_states.shape
161+
# cyclic shift
162+
if self.model.shift_size > 0:
163+
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.model.shift_size, -self.model.shift_size), dims=(1, 2))
164+
else:
165+
shifted_hidden_states = hidden_states
166+
167+
# partition windows
168+
hidden_states_windows = window_partition(shifted_hidden_states, self.model.window_size)
169+
hidden_states_windows = hidden_states_windows.view(-1, self.model.window_size * self.model.window_size, channels)
170+
attn_mask = self.model.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
171+
if attn_mask is not None:
172+
attn_mask = attn_mask.to(hidden_states_windows.device)
173+
174+
attention_outputs = self.model.attention(
175+
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
176+
)
177+
178+
attention_output = attention_outputs[0]
179+
180+
attention_windows = attention_output.view(-1, self.model.window_size, self.model.window_size, channels)
181+
shifted_windows = window_reverse(attention_windows, self.model.window_size, height_pad, width_pad)
182+
183+
# reverse cyclic shift
184+
if self.model.shift_size > 0:
185+
attention_windows = torch.roll(shifted_windows, shifts=(self.model.shift_size, self.model.shift_size), dims=(1, 2))
186+
else:
187+
attention_windows = shifted_windows
188+
189+
was_padded = pad_values[3] > 0 or pad_values[5] > 0
190+
if was_padded:
191+
attention_windows = attention_windows[:, :height, :width, :].contiguous()
192+
193+
attention_windows = attention_windows.view(batch_size, height * width, channels)
194+
hidden_states = self.model.layernorm_before(attention_windows)
195+
hidden_states = shortcut + self.model.drop_path(hidden_states)
196+
197+
layer_output = self.model.intermediate(hidden_states)
198+
layer_output = self.model.output(layer_output)
199+
layer_output = hidden_states + self.model.drop_path(self.model.layernorm_after(layer_output))
200+
201+
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
202+
return layer_outputs
203+
204+
class Swin2SRModel(nn.Module):
205+
def __init__(
206+
self,
207+
swin2sr,
208+
device,
209+
) -> None:
210+
super().__init__()
211+
self.model = swin2sr
212+
for i, stage in enumerate(self.model.swin2sr.encoder.stages):
213+
for j, layer in enumerate(stage.layers):
214+
self.model.swin2sr.encoder.stages[i].layers[j] = Swin2SRLayer(layer, device)
215+
216+
def forward(
217+
self,
218+
pixel_values: torch.FloatTensor,
219+
output_attentions: bool | None = None,
220+
output_hidden_states: bool | None = None,
221+
return_dict: bool | None = None,
222+
**kwargs,
223+
):
224+
return self.model(pixel_values, output_attentions, output_hidden_states, return_dict, kwargs)
225+
226+
def main():
227+
parser = argparse.ArgumentParser()
228+
parser.add_argument("--device", choices=["npu", "gpu"], default="npu")
229+
args = parser.parse_args()
230+
231+
model = Swin2SRModel(Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"), args.device)
232+
233+
inputs = {"pixel_values": torch.rand((1, 3, 136, 136))}
234+
235+
with torch.no_grad():
236+
torch.onnx.export(model, inputs, "swin2sr.onnx", opset_version=20, do_constant_folding=True,
237+
dynamo = False, input_names = list(inputs.keys()))
238+
239+
onnx_model = onnx.load("swin2sr.onnx")
240+
simplified_onnx_model, check = simplify(onnx_model)
241+
242+
if check:
243+
onnx.save(simplified_onnx_model, "swin2sr.onnx", save_as_external_data=False)
244+
245+
if __name__ == "__main__":
246+
main()

0 commit comments

Comments
 (0)