Skip to content

Commit 95b1c18

Browse files
Shiva ChilukamariShiva Chilukamari
authored andcommitted
Lintrunner done
1 parent 9d41e6f commit 95b1c18

4 files changed

Lines changed: 186 additions & 124 deletions

File tree

sam-vit-base/Olive/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
class config:
1+
class ModelConfig:
22
model_name = "facebook/sam-vit-base"
33
data_dir = "quantization_dataset"
4+
image_dataset = "nielsr/coco-panoptic-val2017"
45
ve_input_name = "pixel_values"
56
ve_sample_size = 1024
67
ve_channel_size = 3
7-
mask_point_input_names = ['input_points', 'image_embeddings']
8-
mask_point_input_shapes = [(1,1,2), (256, 64, 64)]
9-
10-
mask_box_input_names = ['input_boxes', 'image_embeddings']
11-
mask_box_input_shapes = [(1,4), (256, 64, 64)]
8+
mask_point_input_names = ("input_points", "image_embeddings")
9+
mask_point_input_shapes = ((1, 1, 2), (256, 64, 64))
10+
mask_box_input_names = ("input_boxes", "image_embeddings")
11+
mask_box_input_shapes = ((1, 4), (256, 64, 64))

sam-vit-base/Olive/model_patches.py

Lines changed: 79 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import torch.nn.functional as F
2-
from torch import nn
3-
import torch
41
from typing import Optional
52

3+
import torch
4+
from torch import nn
5+
6+
67
class Conv2DInplaceLinear(nn.Module):
7-
"""
8-
An implementation of Linear / Conv1D that uses a 1x1 Conv2D op instead.
8+
"""An implementation of Linear / Conv1D that uses a 1x1 Conv2D op instead.
99
1010
The Conv2D implementation for Qualcomm DSPs is faster than the Linear/Conv1D implementation.
1111
"""
@@ -21,7 +21,7 @@ def from_linear(mod: torch.nn.Linear | torch.nn.Conv1d):
2121
elif isinstance(mod, torch.nn.Conv1d):
2222
weight, bias = mod.weight.T, mod.bias
2323
else:
24-
raise NotImplementedError()
24+
raise NotImplementedError
2525

2626
out_features, in_features = weight.shape
2727
linear = Conv2DInplaceLinear(
@@ -76,9 +76,10 @@ def forward(self, x: torch.Tensor):
7676
pass
7777
return x
7878

79+
7980
class SplitHeadSamVisionSdpaAttention(nn.Module):
8081
def __init__(self, attention_block):
81-
super().__init__()
82+
super().__init__()
8283
self.out_feature, self.in_feature = (
8384
attention_block.qkv.weight.shape[0] // 3,
8485
attention_block.qkv.weight.shape[1],
@@ -93,8 +94,8 @@ def __init__(self, attention_block):
9394
self.use_rel_pos = attention_block.use_rel_pos
9495
self.model = attention_block
9596

96-
for chunk, projList in enumerate([self.q, self.k, self.v]):
97-
projList.conv2d.weight.data.copy_(
97+
for chunk, proj_list in enumerate([self.q, self.k, self.v]):
98+
proj_list.conv2d.weight.data.copy_(
9899
attention_block.qkv.weight[
99100
(chunk) * self.out_feature : (chunk + 1) * self.out_feature,
100101
:,
@@ -103,11 +104,9 @@ def __init__(self, attention_block):
103104
]
104105
)
105106

106-
assert projList.conv2d.bias is not None
107-
projList.conv2d.bias.data.copy_(
108-
attention_block.qkv.bias[
109-
(chunk) * self.out_feature : (chunk + 1) * self.out_feature,
110-
]
107+
assert proj_list.conv2d.bias is not None
108+
proj_list.conv2d.bias.data.copy_(
109+
attention_block.qkv.bias[(chunk) * self.out_feature : (chunk + 1) * self.out_feature,]
111110
)
112111

113112
def get_decomposed_rel_pos(
@@ -118,10 +117,10 @@ def get_decomposed_rel_pos(
118117
q_size: tuple[int, int],
119118
k_size: tuple[int, int],
120119
) -> torch.Tensor:
121-
"""
122-
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
120+
"""Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
121+
123122
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
124-
123+
125124
Args:
126125
query (`torch.Tensor`):
127126
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
@@ -133,55 +132,53 @@ def get_decomposed_rel_pos(
133132
spatial sequence size of query q with (query_height, query_width).
134133
k_size (tuple):
135134
spatial sequence size of key k with (key_height, key_width).
136-
135+
137136
Returns:
138137
decomposed_rel_pos (`torch.Tensor`):
139138
decomposed relative position embeddings.
139+
140140
"""
141141
query_height, query_width = q_size
142142
key_height, key_width = k_size
143143
relative_position_height = self.model.get_rel_pos(query_height, key_height, rel_pos_h)
144144
relative_position_width = self.model.get_rel_pos(query_width, key_width, rel_pos_w)
145-
145+
146146
batch_size, _, dim = query.shape
147147
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
148-
148+
149149
# Original
150150
# rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
151151
# rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
152-
152+
153153
# Using MatMul
154154
rel_h = reshaped_query @ relative_position_height.transpose(1, 2)
155155
rel_w = (reshaped_query.transpose(1, 2) @ relative_position_width.transpose(1, 2)).transpose(1, 2)
156-
157-
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
158-
159-
return decomposed_rel_pos
160-
156+
157+
return rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
158+
161159
def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
162160
x = hidden_states
163161
batch_size, height, width, _ = x.shape
164-
B, H, W = batch_size, height, width
165162

166163
key = (
167164
self.k(x)
168-
.reshape(B, H * W, self.num_heads, -1)
165+
.reshape(batch_size, height * width, self.num_heads, -1)
169166
.permute(0, 2, 1, 3)
170-
.reshape(B * self.num_heads, H * W, -1)
167+
.reshape(batch_size * self.num_heads, height * width, -1)
171168
)
172169
value = (
173170
self.v(x)
174-
.reshape(B, H * W, self.num_heads, -1)
171+
.reshape(batch_size, height * width, self.num_heads, -1)
175172
.permute(0, 2, 1, 3)
176-
.reshape(B * self.num_heads, H * W, -1)
173+
.reshape(batch_size * self.num_heads, height * width, -1)
177174
)
178175
query = (
179176
self.q(x)
180-
.reshape(B, H * W, self.num_heads, -1)
177+
.reshape(batch_size, height * width, self.num_heads, -1)
181178
.permute(0, 2, 1, 3)
182-
.reshape(B * self.num_heads, H * W, -1)
179+
.reshape(batch_size * self.num_heads, height * width, -1)
183180
)
184-
181+
185182
# # qkv with shape (3, batch_size, nHead, height * width, channel)
186183
# qkv = (
187184
# self.model.qkv(hidden_states)
@@ -195,7 +192,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[
195192

196193
if self.use_rel_pos:
197194
decomposed_rel_pos = self.get_decomposed_rel_pos(
198-
query, self.model.rel_pos_h, self.model.rel_pos_w, (height, width), (height, width)
195+
query,
196+
self.model.rel_pos_h,
197+
self.model.rel_pos_w,
198+
(height, width),
199+
(height, width),
199200
)
200201
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
201202
attn_weights = attn_weights + decomposed_rel_pos
@@ -210,10 +211,9 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[
210211
attn_output = self.model.proj(attn_output)
211212
return attn_output, attn_weights
212213

214+
213215
class Conv2DInplaceLinearSAMMLPBlock(nn.Module):
214-
"""
215-
SAM MLPBlock that uses 1x1 Conv2D in place of linear layers.
216-
"""
216+
"""SAM MLPBlock that uses 1x1 Conv2D in place of linear layers."""
217217

218218
def __init__(self, mlp_block) -> None:
219219
super().__init__()
@@ -231,46 +231,57 @@ def __init__(self, model):
231231
self.model = model
232232
self.attn = SplitHeadSamVisionSdpaAttention(self.model.attn)
233233
self.mlp = Conv2DInplaceLinearSAMMLPBlock(self.model.mlp)
234-
234+
235235
def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
236236
batch_size, height, width, channel = hidden_states.shape
237-
237+
238238
pad_h = (window_size - height % window_size) % window_size
239239
pad_w = (window_size - width % window_size) % window_size
240240

241-
# hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h), mode = 'constant', value = 0.0)
242-
243241
c1 = torch.zeros((batch_size, pad_h, width, channel))
244242
c2 = torch.zeros((batch_size, height + pad_h, pad_w, channel))
245243

246-
hidden_states = torch.concatenate((hidden_states, c1), axis = 1)
247-
hidden_states = torch.concatenate((hidden_states, c2), axis = 2)
248-
244+
hidden_states = torch.concatenate((hidden_states, c1), axis=1)
245+
hidden_states = torch.concatenate((hidden_states, c2), axis=2)
246+
249247
pad_height, pad_width = height + pad_h, width + pad_w
250-
248+
251249
hidden_states = hidden_states.reshape(
252-
pad_height // window_size, window_size, pad_width // window_size, window_size, channel
250+
pad_height // window_size,
251+
window_size,
252+
pad_width // window_size,
253+
window_size,
254+
channel,
253255
)
254256
windows = hidden_states.permute(0, 2, 1, 3, 4).contiguous().reshape(-1, window_size, window_size, channel)
255257
return windows, (pad_height, pad_width)
256-
258+
257259
def window_unpartition(
258-
self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
260+
self,
261+
windows: torch.Tensor,
262+
window_size: int,
263+
padding_shape: tuple[int, int],
264+
original_shape: tuple[int, int],
259265
) -> torch.Tensor:
260266
pad_height, pad_width = padding_shape
261267
height, width = original_shape
262268
batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
263269
hidden_states = windows.reshape(
264-
pad_height // window_size, pad_width // window_size, window_size, window_size, -1
265-
)
266-
hidden_states = (
267-
hidden_states.permute(0, 2, 1, 3, 4).contiguous().reshape(batch_size, pad_height, pad_width, -1)
270+
pad_height // window_size,
271+
pad_width // window_size,
272+
window_size,
273+
window_size,
274+
-1,
268275
)
269-
270-
hidden_states = hidden_states[:, :height, :width, :].contiguous()
271-
return hidden_states
272-
273-
def forward(self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False,) -> tuple[torch.FloatTensor]:
276+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).contiguous().reshape(batch_size, pad_height, pad_width, -1)
277+
278+
return hidden_states[:, :height, :width, :].contiguous()
279+
280+
def forward(
281+
self,
282+
hidden_states: torch.Tensor,
283+
output_attentions: Optional[bool] = False,
284+
) -> tuple[torch.FloatTensor]:
274285
residual = hidden_states
275286
hidden_states = self.model.layer_norm1(hidden_states)
276287
# Window partition
@@ -283,13 +294,16 @@ def forward(self, hidden_states: torch.Tensor, output_attentions: Optional[bool]
283294
)
284295
# Reverse window partition
285296
if self.model.window_size > 0:
286-
hidden_states = self.window_unpartition(hidden_states, self.model.window_size, padding_shape, (height, width))
297+
hidden_states = self.window_unpartition(
298+
hidden_states, self.model.window_size, padding_shape, (height, width)
299+
)
287300

288301
hidden_states = residual + hidden_states
289302
layernorm_output = self.model.layer_norm2(hidden_states)
290303
hidden_states = hidden_states + self.mlp(layernorm_output)
291304
return hidden_states, attn_weights
292305

306+
293307
class ModSamModel(nn.Module):
294308
def __init__(self, model):
295309
super().__init__()
@@ -298,7 +312,7 @@ def __init__(self, model):
298312
self.model.vision_encoder.layers[i] = ModSamVisionLayer(self.model.vision_encoder.layers[i])
299313

300314
def forward(self, pixel_values, input_points):
301-
return self.model(pixel_values = pixel_values, input_points = input_points)
315+
return self.model(pixel_values=pixel_values, input_points=input_points)
302316

303317

304318
class ModSamVisionEncoder(nn.Module):
@@ -307,20 +321,22 @@ def __init__(self, model):
307321
self.vision_encoder = ModSamModel(model).model.vision_encoder
308322

309323
def forward(self, pixel_values):
310-
return self.vision_encoder(pixel_values = pixel_values)
324+
return self.vision_encoder(pixel_values=pixel_values)
325+
311326

312327
class ModSamMaskPointDecoder(nn.Module):
313328
def __init__(self, model):
314329
super().__init__()
315330
self.model = model
316331

317332
def forward(self, input_points, image_embeddings):
318-
return self.model(input_points = input_points, input_boxes = input_boxes, image_embeddings = image_embeddings)
333+
return self.model(input_points=input_points, image_embeddings=image_embeddings)
334+
319335

320336
class ModSamMaskBoxDecoder(nn.Module):
321337
def __init__(self, model):
322338
super().__init__()
323339
self.model = model
324340

325341
def forward(self, input_boxes, image_embeddings):
326-
return self.model(input_boxes = input_boxes, image_embeddings = image_embeddings)
342+
return self.model(input_boxes=input_boxes, image_embeddings=image_embeddings)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
transformers
21
datasets
32
olive-ai
43
onnx
5-
torchvision
64
torch
5+
torchvision
6+
transformers

0 commit comments

Comments
 (0)