|
| 1 | +# ------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. |
| 4 | +# -------------------------------------------------------------------------- |
| 5 | + |
| 6 | +import argparse |
| 7 | +import math |
| 8 | + |
| 9 | +import onnx |
| 10 | +import torch |
| 11 | +import torch.nn.functional as F |
| 12 | + |
| 13 | +from onnxsim import simplify |
| 14 | +from torch import nn |
| 15 | +from transformers import Swin2SRForImageSuperResolution |
| 16 | + |
| 17 | +def window_partition(input_feature, window_size): |
| 18 | + """ |
| 19 | + Partitions the given input into windows. |
| 20 | + """ |
| 21 | + batch_size, height, width, num_channels = input_feature.shape |
| 22 | + input_feature = input_feature.reshape( |
| 23 | + batch_size * height // window_size, window_size, width // window_size, window_size * num_channels |
| 24 | + ) |
| 25 | + windows = input_feature.permute(0, 2, 1, 3).contiguous().view(-1, window_size, window_size, num_channels) |
| 26 | + return windows |
| 27 | + |
| 28 | + |
| 29 | +def window_reverse(windows, window_size, height, width): |
| 30 | + """ |
| 31 | + Merges windows to produce higher resolution features. |
| 32 | + """ |
| 33 | + num_channels = windows.shape[-1] |
| 34 | + windows = windows.reshape(-1, width // window_size, window_size, window_size * num_channels) |
| 35 | + windows = windows.permute(0, 2, 1, 3).contiguous().view(-1, height, width, num_channels) |
| 36 | + return windows |
| 37 | + |
| 38 | +class ModSwin2SRSelfAttention(nn.Module): |
| 39 | + """ |
| 40 | + Swin2SR self attention with modifications to reduce tensor ranks. |
| 41 | + """ |
| 42 | + def __init__(self, self_attention, device): |
| 43 | + super().__init__() |
| 44 | + self.model = self_attention |
| 45 | + |
| 46 | + def forward( |
| 47 | + self, |
| 48 | + hidden_states: torch.Tensor, |
| 49 | + attention_mask: torch.FloatTensor | None = None, |
| 50 | + head_mask: torch.FloatTensor | None = None, |
| 51 | + output_attentions: bool | None = False, |
| 52 | + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]: |
| 53 | + batch_size, dim, num_channels = hidden_states.shape |
| 54 | + query_layer = ( |
| 55 | + self.model.query(hidden_states) |
| 56 | + .view(batch_size, -1, self.model.num_attention_heads, self.model.attention_head_size) |
| 57 | + .transpose(1, 2) |
| 58 | + ) |
| 59 | + key_layer = ( |
| 60 | + self.model.key(hidden_states) |
| 61 | + .view(batch_size, -1, self.model.num_attention_heads, self.model.attention_head_size) |
| 62 | + .transpose(1, 2) |
| 63 | + ) |
| 64 | + value_layer = ( |
| 65 | + self.model.value(hidden_states) |
| 66 | + .view(batch_size, -1, self.model.num_attention_heads, self.model.attention_head_size) |
| 67 | + .transpose(1, 2) |
| 68 | + ) |
| 69 | + |
| 70 | + query_layer = query_layer.reshape(batch_size * self.model.num_attention_heads, -1, self.model.attention_head_size) |
| 71 | + key_layer = key_layer.reshape(batch_size * self.model.num_attention_heads, -1, self.model.attention_head_size) |
| 72 | + value_layer = value_layer.reshape(batch_size * self.model.num_attention_heads, -1, self.model.attention_head_size) |
| 73 | + |
| 74 | + # cosine attention |
| 75 | + attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( |
| 76 | + key_layer, dim=-1 |
| 77 | + ).transpose(-2, -1) |
| 78 | + logit_scale = torch.clamp(self.model.logit_scale, max=math.log(1.0 / 0.01)).exp() |
| 79 | + |
| 80 | + logit_scale = logit_scale.expand(batch_size, self.model.num_attention_heads, 1, 1).reshape(batch_size * self.model.num_attention_heads, 1, 1) |
| 81 | + logit_scale = logit_scale.reshape(batch_size * self.model.num_attention_heads, 1, 1) |
| 82 | + |
| 83 | + attention_scores = attention_scores * logit_scale |
| 84 | + relative_position_bias_table = self.model.continuous_position_bias_mlp(self.model.relative_coords_table).view( |
| 85 | + -1, self.model.num_attention_heads |
| 86 | + ) |
| 87 | + # [window_height*window_width,window_height*window_width,num_attention_heads] |
| 88 | + relative_position_bias = relative_position_bias_table[self.model.relative_position_index.view(-1)].view( |
| 89 | + self.model.window_size[0] * self.model.window_size[1], self.model.window_size[0] * self.model.window_size[1], -1 |
| 90 | + ) |
| 91 | + # [num_attention_heads,window_height*window_width,window_height*window_width] |
| 92 | + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww |
| 93 | + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) |
| 94 | + |
| 95 | + relative_position_bias = relative_position_bias.expand(batch_size, |
| 96 | + self.model.num_attention_heads, |
| 97 | + self.model.window_size[0] * self.model.window_size[1], |
| 98 | + self.model.window_size[0] * self.model.window_size[1]) |
| 99 | + relative_position_bias = relative_position_bias.reshape(batch_size * self.model.num_attention_heads, |
| 100 | + self.model.window_size[0] * self.model.window_size[1], |
| 101 | + self.model.window_size[0] * self.model.window_size[1]) |
| 102 | + attention_scores = attention_scores + relative_position_bias |
| 103 | + |
| 104 | + if attention_mask is not None: |
| 105 | + # Apply the attention mask is (precomputed for all layers in Swin2SRModel forward() function) |
| 106 | + mask_shape = attention_mask.shape[0] |
| 107 | + attention_mask = attention_mask.unsqueeze(1).expand(mask_shape, |
| 108 | + self.model.num_attention_heads, |
| 109 | + dim, |
| 110 | + dim) |
| 111 | + attention_mask = attention_mask.reshape(mask_shape * self.model.num_attention_heads, |
| 112 | + dim, |
| 113 | + dim) |
| 114 | + attention_scores = attention_scores + attention_mask |
| 115 | + attention_scores = attention_scores + attention_mask |
| 116 | + |
| 117 | + # Normalize the attention scores to probabilities. |
| 118 | + attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
| 119 | + |
| 120 | + # This is actually dropping out entire tokens to attend to, which might |
| 121 | + # seem a bit unusual, but is taken from the original Transformer paper. |
| 122 | + attention_probs = self.model.dropout(attention_probs) |
| 123 | + |
| 124 | + # Mask heads if we want to |
| 125 | + if head_mask is not None: |
| 126 | + attention_probs = attention_probs * head_mask |
| 127 | + |
| 128 | + context_layer = torch.matmul(attention_probs, value_layer) |
| 129 | + |
| 130 | + context_layer = context_layer.reshape(batch_size, self.model.num_attention_heads, -1, self.model.attention_head_size) |
| 131 | + |
| 132 | + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| 133 | + new_context_layer_shape = context_layer.size()[:-2] + (self.model.all_head_size,) |
| 134 | + context_layer = context_layer.view(new_context_layer_shape) |
| 135 | + |
| 136 | + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
| 137 | + |
| 138 | + return outputs |
| 139 | + |
| 140 | +class Swin2SRLayer(nn.Module): |
| 141 | + def __init__(self, layer, device): |
| 142 | + super().__init__() |
| 143 | + self.model = layer |
| 144 | + self.model.attention.self = ModSwin2SRSelfAttention(self.model.attention.self, device) |
| 145 | + |
| 146 | + def forward( |
| 147 | + self, |
| 148 | + hidden_states: torch.Tensor, |
| 149 | + input_dimensions: tuple[int, int], |
| 150 | + head_mask: torch.FloatTensor | None = None, |
| 151 | + output_attentions: bool | None= False, |
| 152 | + ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: |
| 153 | + height, width = input_dimensions |
| 154 | + batch_size, _, channels = hidden_states.size() |
| 155 | + shortcut = hidden_states |
| 156 | + |
| 157 | + # pad hidden_states to multiples of window size |
| 158 | + hidden_states = hidden_states.view(batch_size, height, width, channels) |
| 159 | + hidden_states, pad_values = self.model.maybe_pad(hidden_states, height, width) |
| 160 | + _, height_pad, width_pad, _ = hidden_states.shape |
| 161 | + # cyclic shift |
| 162 | + if self.model.shift_size > 0: |
| 163 | + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.model.shift_size, -self.model.shift_size), dims=(1, 2)) |
| 164 | + else: |
| 165 | + shifted_hidden_states = hidden_states |
| 166 | + |
| 167 | + # partition windows |
| 168 | + hidden_states_windows = window_partition(shifted_hidden_states, self.model.window_size) |
| 169 | + hidden_states_windows = hidden_states_windows.view(-1, self.model.window_size * self.model.window_size, channels) |
| 170 | + attn_mask = self.model.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) |
| 171 | + if attn_mask is not None: |
| 172 | + attn_mask = attn_mask.to(hidden_states_windows.device) |
| 173 | + |
| 174 | + attention_outputs = self.model.attention( |
| 175 | + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions |
| 176 | + ) |
| 177 | + |
| 178 | + attention_output = attention_outputs[0] |
| 179 | + |
| 180 | + attention_windows = attention_output.view(-1, self.model.window_size, self.model.window_size, channels) |
| 181 | + shifted_windows = window_reverse(attention_windows, self.model.window_size, height_pad, width_pad) |
| 182 | + |
| 183 | + # reverse cyclic shift |
| 184 | + if self.model.shift_size > 0: |
| 185 | + attention_windows = torch.roll(shifted_windows, shifts=(self.model.shift_size, self.model.shift_size), dims=(1, 2)) |
| 186 | + else: |
| 187 | + attention_windows = shifted_windows |
| 188 | + |
| 189 | + was_padded = pad_values[3] > 0 or pad_values[5] > 0 |
| 190 | + if was_padded: |
| 191 | + attention_windows = attention_windows[:, :height, :width, :].contiguous() |
| 192 | + |
| 193 | + attention_windows = attention_windows.view(batch_size, height * width, channels) |
| 194 | + hidden_states = self.model.layernorm_before(attention_windows) |
| 195 | + hidden_states = shortcut + self.model.drop_path(hidden_states) |
| 196 | + |
| 197 | + layer_output = self.model.intermediate(hidden_states) |
| 198 | + layer_output = self.model.output(layer_output) |
| 199 | + layer_output = hidden_states + self.model.drop_path(self.model.layernorm_after(layer_output)) |
| 200 | + |
| 201 | + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) |
| 202 | + return layer_outputs |
| 203 | + |
| 204 | +class Swin2SRModel(nn.Module): |
| 205 | + def __init__( |
| 206 | + self, |
| 207 | + swin2sr, |
| 208 | + device, |
| 209 | + ) -> None: |
| 210 | + super().__init__() |
| 211 | + self.model = swin2sr |
| 212 | + for i, stage in enumerate(self.model.swin2sr.encoder.stages): |
| 213 | + for j, layer in enumerate(stage.layers): |
| 214 | + self.model.swin2sr.encoder.stages[i].layers[j] = Swin2SRLayer(layer, device) |
| 215 | + |
| 216 | + def forward( |
| 217 | + self, |
| 218 | + pixel_values: torch.FloatTensor, |
| 219 | + output_attentions: bool | None = None, |
| 220 | + output_hidden_states: bool | None = None, |
| 221 | + return_dict: bool | None = None, |
| 222 | + **kwargs, |
| 223 | + ): |
| 224 | + return self.model(pixel_values, output_attentions, output_hidden_states, return_dict, kwargs) |
| 225 | + |
| 226 | +def main(): |
| 227 | + parser = argparse.ArgumentParser() |
| 228 | + parser.add_argument("--device", choices=["npu", "gpu"], default="npu") |
| 229 | + args = parser.parse_args() |
| 230 | + |
| 231 | + model = Swin2SRModel(Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"), args.device) |
| 232 | + |
| 233 | + inputs = {"pixel_values": torch.rand((1, 3, 256, 256))} |
| 234 | + |
| 235 | + with torch.no_grad(): |
| 236 | + torch.onnx.export(model, inputs, "swin2sr.onnx", opset_version=20, do_constant_folding=True, |
| 237 | + dynamo = False, input_names = list(inputs.keys())) |
| 238 | + |
| 239 | + onnx_model = onnx.load("swin2sr.onnx") |
| 240 | + simplified_onnx_model, check = simplify(onnx_model) |
| 241 | + |
| 242 | + if check: |
| 243 | + onnx.save(simplified_onnx_model, "swin2sr.onnx", save_as_external_data=True) |
| 244 | + |
| 245 | +if __name__ == "__main__": |
| 246 | + main() |
0 commit comments