@@ -290,18 +290,18 @@ def __init__(
290290 super ().__init__ ()
291291 self .model = sam2
292292 self ._bb_feat_sizes = [
293- (256 , 256 ),
294- (128 , 128 ),
295- (64 , 64 ),
293+ (32 , 256 , 256 ),
294+ (64 , 128 , 128 ),
295+ (256 , 64 , 64 ),
296296 ]
297297 for i , block in enumerate (self .model .image_encoder .trunk .blocks ):
298298 self .model .image_encoder .trunk .blocks [i ] = ModMultiScaleBlock (block )
299299
300- def forward (self , Image : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
301- """Run SAM2 Image encoder and returns image_embeddings, high_res_features1, high_res_features2.
300+ def forward (self , input : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
301+ """Run SAM2 input encoder and returns image_embeddings, high_res_features1, high_res_features2.
302302
303303 Args:
304- Image :
304+ input :
305305 Raw floating point pixel values for encoder consumption.
306306 3-channel Color Space: RGB, range [0, 1]
307307
@@ -311,15 +311,14 @@ def forward(self, Image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torc
311311 high_res_features2: Shape (1, 64, 128, 128)
312312
313313 """
314- # x = self.normalize(Image)
315- x = Image
314+ x = input
316315 backbone_out = self .model .forward_image (x )
317316 _ , vision_feats , _ , _ = self .model ._prepare_backbone_features (backbone_out )
318317 # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
319318 if self .model .directly_add_no_mem_embed :
320319 vision_feats [- 1 ] = vision_feats [- 1 ] + self .model .no_mem_embed
321320 feats = [
322- feat .permute (1 , 2 , 0 ).view (1 , - 1 , * feat_size )
321+ feat .permute (1 , 2 , 0 ).view (1 , * feat_size )
323322 for feat , feat_size in zip (vision_feats [::- 1 ], self ._bb_feat_sizes [::- 1 ])
324323 ][::- 1 ]
325324 image_embeddings = feats [2 ]
@@ -342,13 +341,70 @@ def __init__(self, sam2) -> None:
342341 self .mask_decoder = self .model .sam_mask_decoder
343342 self .prompt_encoder = self .model .sam_prompt_encoder
344343
344+ def _embed_masks (
345+ self , input_mask : torch .Tensor , has_mask_input : torch .Tensor
346+ ) -> torch .Tensor :
347+ mask_embedding = has_mask_input * self .prompt_encoder .mask_downscaling (
348+ input_mask
349+ )
350+ mask_embedding = mask_embedding + (
351+ 1 - has_mask_input
352+ ) * self .prompt_encoder .no_mask_embed .weight .reshape (1 , - 1 , 1 , 1 )
353+ return mask_embedding
354+
355+
356+ def _embed_points (
357+ self ,
358+ points : torch .Tensor ,
359+ labels : torch .Tensor
360+ ) -> torch .Tensor :
361+ """Embeds point prompts."""
362+ points = points + 0.5 # Shift to center of pixel
363+ padding_point = torch .zeros ((points .shape [0 ], 1 , 2 ), device = points .device )
364+ padding_label = - torch .ones ((labels .shape [0 ], 1 ), device = labels .device )
365+ points = torch .cat ([points , padding_point ], dim = 1 )
366+ labels = torch .cat ([labels , padding_label ], dim = 1 )
367+
368+ point_embedding = self .prompt_encoder .pe_layer .forward_with_coords (
369+ points , self .prompt_encoder .input_image_size
370+ )
371+
372+ point_embedding = torch .where (
373+ (labels == - 1 ).unsqueeze (- 1 ),
374+ torch .zeros_like (point_embedding ) + self .prompt_encoder .not_a_point_embed .weight ,
375+ point_embedding ,
376+ )
377+ point_embedding = torch .where (
378+ (labels == 0 ).unsqueeze (- 1 ),
379+ point_embedding + self .prompt_encoder .point_embeddings [0 ].weight ,
380+ point_embedding ,
381+ )
382+ point_embedding = torch .where (
383+ (labels == 1 ).unsqueeze (- 1 ),
384+ point_embedding + self .prompt_encoder .point_embeddings [1 ].weight ,
385+ point_embedding ,
386+ )
387+ point_embedding = torch .where (
388+ (labels == 2 ).unsqueeze (- 1 ),
389+ point_embedding + self .prompt_encoder .point_embeddings [2 ].weight ,
390+ point_embedding ,
391+ )
392+ point_embedding = torch .where (
393+ (labels == 3 ).unsqueeze (- 1 ),
394+ point_embedding + self .prompt_encoder .point_embeddings [3 ].weight ,
395+ point_embedding ,
396+ )
397+ return point_embedding
398+
345399 def forward (
346400 self ,
347401 image_embeddings : torch .Tensor , # [1,256,64,64]
348402 high_res_features1 : torch .Tensor , # [1, 32, 256, 256]
349403 high_res_features2 : torch .Tensor , # [1, 64, 128, 128]
350- unnorm_coords : torch .Tensor , # [num_labels,num_points,2]
351- labels : torch .Tensor , # [num_labels,num_points]
404+ point_coords : torch .Tensor , # [num_labels,num_points,2]
405+ point_labels : torch .Tensor , # [num_labels,num_points]
406+ mask_input : torch .Tensor , # [1, 1, 256, 256]
407+ has_mask_input : torch .Tensor # [1]
352408 ) -> tuple [torch .Tensor , torch .Tensor ]:
353409 """Run SAM2 lightweight decoder and return generated mask for given points.
354410
@@ -359,22 +415,24 @@ def forward(
359415 First set of high-resolution features.
360416 high_res_features2: torch.Tensor of shape [1, high_res_2_dim, high_res_2_size, high_res_2_size]
361417 Second set of high-resolution features.
362- unnorm_coords : torch.Tensor of shape [1, k, 2]
418+ point_coords : torch.Tensor of shape [1, k, 2]
363419 Point coordinates from input image for segmentation, mapped to the resized image
364- labels : torch.Tensor of shape [1, k]
420+ point_labels : torch.Tensor of shape [1, k]
365421 Point Labels to select/de-select given point for segmentation
366422 e.g. Corresponding value is 1 if this point is to be included, otherwise 0
367-
423+ mask_input: torch.Tensor of shape [1, 1, 256, 256]
424+ Mask corresponding to focus area
425+ has_mask_input: torch.Tensor of shape [1]
426+ 1 if Mask provided, else 0
368427 Returns:
369428 masks: torch.Tensor of shape [1, 1, 256, 256]
370429 scores: torch.Tensor of shape [1, 1]
371430
372431 """
373- sparse_embedding , dense_embedding = self .prompt_encoder (
374- points = (unnorm_coords , labels ),
375- boxes = None ,
376- masks = None ,
377- )
432+
433+ sparse_embedding = self ._embed_points (point_coords , point_labels )
434+ dense_embedding = self ._embed_masks (mask_input , has_mask_input )
435+
378436 low_res_masks , iou_predictions , _ , _ = self .mask_decoder .predict_masks (
379437 image_embeddings = image_embeddings ,
380438 image_pe = self .prompt_encoder .get_dense_pe (),
@@ -383,7 +441,8 @@ def forward(
383441 repeat_image = False ,
384442 high_res_features = [high_res_features1 , high_res_features2 ],
385443 )
386- return low_res_masks , iou_predictions
444+ masks = F .interpolate (low_res_masks , size = (1024 , 1024 ), mode = "bilinear" , align_corners = False )
445+ return masks , iou_predictions , low_res_masks
387446
388447
389448model_weights_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
@@ -402,26 +461,31 @@ def main():
402461 sam2_model = build_sam2 (model_cfg , checkpoint , device = "cpu" )
403462 encoder = SAM2Encoder (sam2_model )
404463 decoder = SAM2Decoder (sam2_model )
405- en_inputs = {"Image" : torch .rand ((1 , 3 , 1024 , 1024 ))}
406- de_inputs = {
407- "image_embeddings" : torch .rand ([1 , 256 , 64 , 64 ]),
408- "high_res_features1" : torch .rand ([1 , 32 , 256 , 256 ]),
409- "high_res_features2" : torch .rand ((1 , 64 , 128 , 128 )),
410- "unnorm_coords" : torch .randn ((1 , 5 , 2 )),
411- "labels" : torch .ones ((1 , 5 )),
412- }
464+
465+ en_inputs = {"input" : torch .rand ((1 , 3 , 1024 , 1024 ))}
413466
414467 with torch .no_grad ():
415- torch .onnx .export (encoder , en_inputs , "sam21_vision_encoder.onnx" , opset_version = 20 , do_constant_folding = True , dynamo = False )
416- with torch .no_grad ():
417- torch .onnx .export (decoder , de_inputs , "sam21_mask_decoder.onnx" , opset_version = 20 , do_constant_folding = True , dynamo = False )
468+ torch .onnx .export (encoder , en_inputs , "sam21_vision_encoder.onnx" , opset_version = 20 , do_constant_folding = True ,
469+ dynamo = False , input_names = list (en_inputs .keys ()))
418470
419471 encoder_onnx_model = onnx .load ("sam21_vision_encoder.onnx" )
420472 simplified_encoder_onnx_model , check = simplify (encoder_onnx_model )
421473
422474 if check :
423475 onnx .save (simplified_encoder_onnx_model , "sam21_vision_encoder.onnx" )
424476
477+ de_inputs = {
478+ "image_embeddings" : torch .rand ([1 , 256 , 64 , 64 ]),
479+ "high_res_features1" : torch .rand ([1 , 32 , 256 , 256 ]),
480+ "high_res_features2" : torch .rand ((1 , 64 , 128 , 128 )),
481+ "point_coords" : torch .randn ((1 , 5 , 2 )),
482+ "point_labels" : torch .ones ((1 , 5 )),
483+ "mask_input" : torch .zeros ((1 , 1 , 256 , 256 )),
484+ "has_mask_input" : torch .zeros ([1 ])
485+ }
486+ with torch .no_grad ():
487+ torch .onnx .export (decoder , de_inputs , "sam21_mask_decoder.onnx" , opset_version = 20 , do_constant_folding = True ,
488+ dynamo = False , input_names = list (de_inputs .keys ()))
425489 decoder_onnx_model = onnx .load ("sam21_mask_decoder.onnx" )
426490 simplified_decoder_onnx_model , check = simplify (decoder_onnx_model )
427491
0 commit comments