1- import torch .nn .functional as F
2- from torch import nn
3- import torch
41from typing import Optional
52
3+ import torch
4+ from torch import nn
5+
6+
67class Conv2DInplaceLinear (nn .Module ):
7- """
8- An implementation of Linear / Conv1D that uses a 1x1 Conv2D op instead.
8+ """An implementation of Linear / Conv1D that uses a 1x1 Conv2D op instead.
99
1010 The Conv2D implementation for Qualcomm DSPs is faster than the Linear/Conv1D implementation.
1111 """
@@ -21,7 +21,7 @@ def from_linear(mod: torch.nn.Linear | torch.nn.Conv1d):
2121 elif isinstance (mod , torch .nn .Conv1d ):
2222 weight , bias = mod .weight .T , mod .bias
2323 else :
24- raise NotImplementedError ()
24+ raise NotImplementedError
2525
2626 out_features , in_features = weight .shape
2727 linear = Conv2DInplaceLinear (
@@ -76,9 +76,10 @@ def forward(self, x: torch.Tensor):
7676 pass
7777 return x
7878
79+
7980class SplitHeadSamVisionSdpaAttention (nn .Module ):
8081 def __init__ (self , attention_block ):
81- super ().__init__ ()
82+ super ().__init__ ()
8283 self .out_feature , self .in_feature = (
8384 attention_block .qkv .weight .shape [0 ] // 3 ,
8485 attention_block .qkv .weight .shape [1 ],
@@ -93,8 +94,8 @@ def __init__(self, attention_block):
9394 self .use_rel_pos = attention_block .use_rel_pos
9495 self .model = attention_block
9596
96- for chunk , projList in enumerate ([self .q , self .k , self .v ]):
97- projList .conv2d .weight .data .copy_ (
97+ for chunk , proj_list in enumerate ([self .q , self .k , self .v ]):
98+ proj_list .conv2d .weight .data .copy_ (
9899 attention_block .qkv .weight [
99100 (chunk ) * self .out_feature : (chunk + 1 ) * self .out_feature ,
100101 :,
@@ -103,11 +104,9 @@ def __init__(self, attention_block):
103104 ]
104105 )
105106
106- assert projList .conv2d .bias is not None
107- projList .conv2d .bias .data .copy_ (
108- attention_block .qkv .bias [
109- (chunk ) * self .out_feature : (chunk + 1 ) * self .out_feature ,
110- ]
107+ assert proj_list .conv2d .bias is not None
108+ proj_list .conv2d .bias .data .copy_ (
109+ attention_block .qkv .bias [(chunk ) * self .out_feature : (chunk + 1 ) * self .out_feature ,]
111110 )
112111
113112 def get_decomposed_rel_pos (
@@ -118,10 +117,10 @@ def get_decomposed_rel_pos(
118117 q_size : tuple [int , int ],
119118 k_size : tuple [int , int ],
120119 ) -> torch .Tensor :
121- """
122- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
120+ """Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
121+
123122 https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
124-
123+
125124 Args:
126125 query (`torch.Tensor`):
127126 query q in the attention layer with shape (batch_size, query_height * query_width, channel).
@@ -133,55 +132,53 @@ def get_decomposed_rel_pos(
133132 spatial sequence size of query q with (query_height, query_width).
134133 k_size (tuple):
135134 spatial sequence size of key k with (key_height, key_width).
136-
135+
137136 Returns:
138137 decomposed_rel_pos (`torch.Tensor`):
139138 decomposed relative position embeddings.
139+
140140 """
141141 query_height , query_width = q_size
142142 key_height , key_width = k_size
143143 relative_position_height = self .model .get_rel_pos (query_height , key_height , rel_pos_h )
144144 relative_position_width = self .model .get_rel_pos (query_width , key_width , rel_pos_w )
145-
145+
146146 batch_size , _ , dim = query .shape
147147 reshaped_query = query .reshape (batch_size , query_height , query_width , dim )
148-
148+
149149 # Original
150150 # rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
151151 # rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
152-
152+
153153 # Using MatMul
154154 rel_h = reshaped_query @ relative_position_height .transpose (1 , 2 )
155155 rel_w = (reshaped_query .transpose (1 , 2 ) @ relative_position_width .transpose (1 , 2 )).transpose (1 , 2 )
156-
157- decomposed_rel_pos = rel_h [:, :, :, :, None ] + rel_w [:, :, :, None , :]
158-
159- return decomposed_rel_pos
160-
156+
157+ return rel_h [:, :, :, :, None ] + rel_w [:, :, :, None , :]
158+
161159 def forward (self , hidden_states : torch .Tensor , output_attentions = None ) -> tuple [torch .Tensor , torch .Tensor ]:
162160 x = hidden_states
163161 batch_size , height , width , _ = x .shape
164- B , H , W = batch_size , height , width
165162
166163 key = (
167164 self .k (x )
168- .reshape (B , H * W , self .num_heads , - 1 )
165+ .reshape (batch_size , height * width , self .num_heads , - 1 )
169166 .permute (0 , 2 , 1 , 3 )
170- .reshape (B * self .num_heads , H * W , - 1 )
167+ .reshape (batch_size * self .num_heads , height * width , - 1 )
171168 )
172169 value = (
173170 self .v (x )
174- .reshape (B , H * W , self .num_heads , - 1 )
171+ .reshape (batch_size , height * width , self .num_heads , - 1 )
175172 .permute (0 , 2 , 1 , 3 )
176- .reshape (B * self .num_heads , H * W , - 1 )
173+ .reshape (batch_size * self .num_heads , height * width , - 1 )
177174 )
178175 query = (
179176 self .q (x )
180- .reshape (B , H * W , self .num_heads , - 1 )
177+ .reshape (batch_size , height * width , self .num_heads , - 1 )
181178 .permute (0 , 2 , 1 , 3 )
182- .reshape (B * self .num_heads , H * W , - 1 )
179+ .reshape (batch_size * self .num_heads , height * width , - 1 )
183180 )
184-
181+
185182 # # qkv with shape (3, batch_size, nHead, height * width, channel)
186183 # qkv = (
187184 # self.model.qkv(hidden_states)
@@ -195,7 +192,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[
195192
196193 if self .use_rel_pos :
197194 decomposed_rel_pos = self .get_decomposed_rel_pos (
198- query , self .model .rel_pos_h , self .model .rel_pos_w , (height , width ), (height , width )
195+ query ,
196+ self .model .rel_pos_h ,
197+ self .model .rel_pos_w ,
198+ (height , width ),
199+ (height , width ),
199200 )
200201 decomposed_rel_pos = decomposed_rel_pos .reshape_as (attn_weights )
201202 attn_weights = attn_weights + decomposed_rel_pos
@@ -210,10 +211,9 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[
210211 attn_output = self .model .proj (attn_output )
211212 return attn_output , attn_weights
212213
214+
213215class Conv2DInplaceLinearSAMMLPBlock (nn .Module ):
214- """
215- SAM MLPBlock that uses 1x1 Conv2D in place of linear layers.
216- """
216+ """SAM MLPBlock that uses 1x1 Conv2D in place of linear layers."""
217217
218218 def __init__ (self , mlp_block ) -> None :
219219 super ().__init__ ()
@@ -231,46 +231,57 @@ def __init__(self, model):
231231 self .model = model
232232 self .attn = SplitHeadSamVisionSdpaAttention (self .model .attn )
233233 self .mlp = Conv2DInplaceLinearSAMMLPBlock (self .model .mlp )
234-
234+
235235 def window_partition (self , hidden_states : torch .Tensor , window_size : int ) -> tuple [torch .Tensor , tuple [int , int ]]:
236236 batch_size , height , width , channel = hidden_states .shape
237-
237+
238238 pad_h = (window_size - height % window_size ) % window_size
239239 pad_w = (window_size - width % window_size ) % window_size
240240
241- # hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h), mode = 'constant', value = 0.0)
242-
243241 c1 = torch .zeros ((batch_size , pad_h , width , channel ))
244242 c2 = torch .zeros ((batch_size , height + pad_h , pad_w , channel ))
245243
246- hidden_states = torch .concatenate ((hidden_states , c1 ), axis = 1 )
247- hidden_states = torch .concatenate ((hidden_states , c2 ), axis = 2 )
248-
244+ hidden_states = torch .concatenate ((hidden_states , c1 ), axis = 1 )
245+ hidden_states = torch .concatenate ((hidden_states , c2 ), axis = 2 )
246+
249247 pad_height , pad_width = height + pad_h , width + pad_w
250-
248+
251249 hidden_states = hidden_states .reshape (
252- pad_height // window_size , window_size , pad_width // window_size , window_size , channel
250+ pad_height // window_size ,
251+ window_size ,
252+ pad_width // window_size ,
253+ window_size ,
254+ channel ,
253255 )
254256 windows = hidden_states .permute (0 , 2 , 1 , 3 , 4 ).contiguous ().reshape (- 1 , window_size , window_size , channel )
255257 return windows , (pad_height , pad_width )
256-
258+
257259 def window_unpartition (
258- self , windows : torch .Tensor , window_size : int , padding_shape : tuple [int , int ], original_shape : tuple [int , int ]
260+ self ,
261+ windows : torch .Tensor ,
262+ window_size : int ,
263+ padding_shape : tuple [int , int ],
264+ original_shape : tuple [int , int ],
259265 ) -> torch .Tensor :
260266 pad_height , pad_width = padding_shape
261267 height , width = original_shape
262268 batch_size = windows .shape [0 ] // (pad_height * pad_width // window_size // window_size )
263269 hidden_states = windows .reshape (
264- pad_height // window_size , pad_width // window_size , window_size , window_size , - 1
265- )
266- hidden_states = (
267- hidden_states .permute (0 , 2 , 1 , 3 , 4 ).contiguous ().reshape (batch_size , pad_height , pad_width , - 1 )
270+ pad_height // window_size ,
271+ pad_width // window_size ,
272+ window_size ,
273+ window_size ,
274+ - 1 ,
268275 )
269-
270- hidden_states = hidden_states [:, :height , :width , :].contiguous ()
271- return hidden_states
272-
273- def forward (self , hidden_states : torch .Tensor , output_attentions : Optional [bool ] = False ,) -> tuple [torch .FloatTensor ]:
276+ hidden_states = hidden_states .permute (0 , 2 , 1 , 3 , 4 ).contiguous ().reshape (batch_size , pad_height , pad_width , - 1 )
277+
278+ return hidden_states [:, :height , :width , :].contiguous ()
279+
280+ def forward (
281+ self ,
282+ hidden_states : torch .Tensor ,
283+ output_attentions : Optional [bool ] = False ,
284+ ) -> tuple [torch .FloatTensor ]:
274285 residual = hidden_states
275286 hidden_states = self .model .layer_norm1 (hidden_states )
276287 # Window partition
@@ -283,13 +294,16 @@ def forward(self, hidden_states: torch.Tensor, output_attentions: Optional[bool]
283294 )
284295 # Reverse window partition
285296 if self .model .window_size > 0 :
286- hidden_states = self .window_unpartition (hidden_states , self .model .window_size , padding_shape , (height , width ))
297+ hidden_states = self .window_unpartition (
298+ hidden_states , self .model .window_size , padding_shape , (height , width )
299+ )
287300
288301 hidden_states = residual + hidden_states
289302 layernorm_output = self .model .layer_norm2 (hidden_states )
290303 hidden_states = hidden_states + self .mlp (layernorm_output )
291304 return hidden_states , attn_weights
292305
306+
293307class ModSamModel (nn .Module ):
294308 def __init__ (self , model ):
295309 super ().__init__ ()
@@ -298,7 +312,7 @@ def __init__(self, model):
298312 self .model .vision_encoder .layers [i ] = ModSamVisionLayer (self .model .vision_encoder .layers [i ])
299313
300314 def forward (self , pixel_values , input_points ):
301- return self .model (pixel_values = pixel_values , input_points = input_points )
315+ return self .model (pixel_values = pixel_values , input_points = input_points )
302316
303317
304318class ModSamVisionEncoder (nn .Module ):
@@ -307,20 +321,22 @@ def __init__(self, model):
307321 self .vision_encoder = ModSamModel (model ).model .vision_encoder
308322
309323 def forward (self , pixel_values ):
310- return self .vision_encoder (pixel_values = pixel_values )
324+ return self .vision_encoder (pixel_values = pixel_values )
325+
311326
312327class ModSamMaskPointDecoder (nn .Module ):
313328 def __init__ (self , model ):
314329 super ().__init__ ()
315330 self .model = model
316331
317332 def forward (self , input_points , image_embeddings ):
318- return self .model (input_points = input_points , input_boxes = input_boxes , image_embeddings = image_embeddings )
333+ return self .model (input_points = input_points , image_embeddings = image_embeddings )
334+
319335
320336class ModSamMaskBoxDecoder (nn .Module ):
321337 def __init__ (self , model ):
322338 super ().__init__ ()
323339 self .model = model
324340
325341 def forward (self , input_boxes , image_embeddings ):
326- return self .model (input_boxes = input_boxes , image_embeddings = image_embeddings )
342+ return self .model (input_boxes = input_boxes , image_embeddings = image_embeddings )
0 commit comments