33# Licensed under the MIT License.
44# --------------------------------------------------------------------------
55
6+ import argparse
67import os
78
89import onnx
@@ -99,7 +100,7 @@ def forward(self, x: torch.Tensor):
99100
100101
101102class SplitHeadSAMEncoderAttention (nn .Module ):
102- """SAM Attention block with the following modifications necessary to run on QNN.
103+ """SAM Attention block with the following modifications necessary to run on QNN NPU .
103104
104105 * Heads are split into separate ops, rather than all heads running in a single op.
105106 * QKV is unpacked from 1 tensor into 3 tensors.
@@ -176,6 +177,53 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
176177
177178 return self .proj (x )
178179
180+ class GpuSAMEncoderAttention (nn .Module ):
181+ """SAM Attention block with the following modifications necessary to run on QNN GPU."""
182+
183+ def __init__ (self , attention_block ) -> None :
184+ super ().__init__ ()
185+ self .qkv = attention_block .qkv
186+ self .proj = attention_block .proj
187+ self .num_heads = attention_block .num_heads
188+ self .q_pool = attention_block .q_pool
189+
190+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
191+ B , H , W , _ = x .shape
192+
193+ qkv = self .qkv (x )
194+ if self .num_heads == 1 :
195+ # qkv with shape (B, H * W, 3, C)
196+ qkv = qkv .reshape (B , H * W , 3 , - 1 )
197+ # q, k, v with shape (B, H * W, C)
198+ q , k , v = torch .unbind (qkv , 2 )
199+ else :
200+ # qkv with shape (B, H * W, 3, nHead, C)
201+ qkv = qkv .reshape (B , H * W , 3 , self .num_heads , - 1 )
202+
203+ # q, k, v with shape (B, H * W, nheads, C)
204+ q , k , v = torch .unbind (qkv , 2 )
205+
206+ k = k .reshape (B , H * W , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 ).reshape (B * self .num_heads , H * W , - 1 )
207+ v = v .reshape (B , H * W , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 ).reshape (B * self .num_heads , H * W , - 1 )
208+
209+ # Q pooling (for downsample at stage changes)
210+ if self .q_pool :
211+ q = do_pool (q .reshape (B , H , W , - 1 ), self .q_pool )
212+ H , W = q .shape [1 :3 ] # downsampled shape
213+
214+ q = q .reshape (B , H * W , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 ).reshape (B * self .num_heads , H * W , - 1 )
215+
216+ x = F .scaled_dot_product_attention (q , k , v )
217+
218+ # Transpose back
219+ x = x .reshape (B , self .num_heads , H * W , - 1 )
220+ x = x .transpose (1 , 2 )
221+ x = x .reshape (B , H , W , - 1 )
222+
223+ x = self .proj (x )
224+
225+ return x
226+
179227
180228class Conv2DInplaceLinearSAMTransformerMLPBlock (nn .Module ):
181229 """SAM MLPBlock that uses 1x1 Conv2D in place of linear layers."""
@@ -240,11 +288,14 @@ def window_unpartition(x, window_size, pad_hw, hw):
240288
241289
242290class ModMultiScaleBlock (nn .Module ):
243- def __init__ (self , block ):
291+ def __init__ (self , block , device ):
244292 super ().__init__ ()
245293 self .model = block
246294 self .model .mlp = Conv2DInplaceLinearSAMTransformerMLPBlock (self .model .mlp )
247- self .model .attn = SplitHeadSAMEncoderAttention (self .model .attn )
295+ if device == "npu" :
296+ self .model .attn = SplitHeadSAMEncoderAttention (self .model .attn )
297+ else :
298+ self .model .attn = GpuSAMEncoderAttention (self .model .attn )
248299
249300 def forward (self , x : torch .Tensor ) -> torch .Tensor :
250301 shortcut = x # B, h, w, C
@@ -286,6 +337,7 @@ class SAM2Encoder(nn.Module):
286337 def __init__ (
287338 self ,
288339 sam2 ,
340+ device ,
289341 ) -> None :
290342 super ().__init__ ()
291343 self .model = sam2
@@ -295,7 +347,7 @@ def __init__(
295347 (256 , 64 , 64 ),
296348 ]
297349 for i , block in enumerate (self .model .image_encoder .trunk .blocks ):
298- self .model .image_encoder .trunk .blocks [i ] = ModMultiScaleBlock (block )
350+ self .model .image_encoder .trunk .blocks [i ] = ModMultiScaleBlock (block , device )
299351
300352 def forward (self , input : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
301353 """Run SAM2 input encoder and returns image_embeddings, high_res_features1, high_res_features2.
@@ -352,7 +404,7 @@ def _embed_masks(
352404 ) * self .prompt_encoder .no_mask_embed .weight .reshape (1 , - 1 , 1 , 1 )
353405 return mask_embedding
354406
355-
407+
356408 def _embed_points (
357409 self ,
358410 points : torch .Tensor ,
@@ -364,7 +416,7 @@ def _embed_points(
364416 padding_label = - torch .ones ((labels .shape [0 ], 1 ), device = labels .device )
365417 points = torch .cat ([points , padding_point ], dim = 1 )
366418 labels = torch .cat ([labels , padding_label ], dim = 1 )
367-
419+
368420 point_embedding = self .prompt_encoder .pe_layer .forward_with_coords (
369421 points , self .prompt_encoder .input_image_size
370422 )
@@ -395,7 +447,7 @@ def _embed_points(
395447 point_embedding ,
396448 )
397449 return point_embedding
398-
450+
399451 def forward (
400452 self ,
401453 image_embeddings : torch .Tensor , # [1,256,64,64]
@@ -432,7 +484,7 @@ def forward(
432484
433485 sparse_embedding = self ._embed_points (point_coords , point_labels )
434486 dense_embedding = self ._embed_masks (mask_input , has_mask_input )
435-
487+
436488 low_res_masks , iou_predictions , _ , _ = self .mask_decoder .predict_masks (
437489 image_embeddings = image_embeddings ,
438490 image_pe = self .prompt_encoder .get_dense_pe (),
@@ -442,7 +494,7 @@ def forward(
442494 high_res_features = [high_res_features1 , high_res_features2 ],
443495 )
444496 masks = F .interpolate (low_res_masks , size = (1024 , 1024 ), mode = "bilinear" , align_corners = False )
445- return masks , iou_predictions , low_res_masks
497+ return masks , iou_predictions , low_res_masks
446498
447499
448500model_weights_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
@@ -452,16 +504,20 @@ def forward(
452504
453505
454506def main ():
507+ parser = argparse .ArgumentParser ()
508+ parser .add_argument ("--device" , choices = ["npu" , "gpu" ], default = "npu" )
509+ args = parser .parse_args ()
510+
455511 download_file (model_weights_url , checkpoint )
456512 download_file (model_config_url , model_cfg )
457513
458514 GlobalHydra .instance ().clear ()
459515 initialize (config_path = "./" , job_name = "sam2_inference" , version_base = None )
460516
461517 sam2_model = build_sam2 (model_cfg , checkpoint , device = "cpu" )
462- encoder = SAM2Encoder (sam2_model )
518+ encoder = SAM2Encoder (sam2_model , args . device )
463519 decoder = SAM2Decoder (sam2_model )
464-
520+
465521 en_inputs = {"input" : torch .rand ((1 , 3 , 1024 , 1024 ))}
466522
467523 with torch .no_grad ():
0 commit comments