Skip to content

Commit 377baa7

Browse files
Shiva ChilukamariShiva Chilukamari
authored andcommitted
Add Mask input for decoder model. Renamed Encoder input to input
1 parent 81f4b3b commit 377baa7

3 files changed

Lines changed: 96 additions & 46 deletions

File tree

sam2.1-hiera-small/QNN/generate_model.py

Lines changed: 95 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -290,18 +290,18 @@ def __init__(
290290
super().__init__()
291291
self.model = sam2
292292
self._bb_feat_sizes = [
293-
(256, 256),
294-
(128, 128),
295-
(64, 64),
293+
(32, 256, 256),
294+
(64, 128, 128),
295+
(256, 64, 64),
296296
]
297297
for i, block in enumerate(self.model.image_encoder.trunk.blocks):
298298
self.model.image_encoder.trunk.blocks[i] = ModMultiScaleBlock(block)
299299

300-
def forward(self, Image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
301-
"""Run SAM2 Image encoder and returns image_embeddings, high_res_features1, high_res_features2.
300+
def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
301+
"""Run SAM2 input encoder and returns image_embeddings, high_res_features1, high_res_features2.
302302
303303
Args:
304-
Image:
304+
input:
305305
Raw floating point pixel values for encoder consumption.
306306
3-channel Color Space: RGB, range [0, 1]
307307
@@ -311,15 +311,14 @@ def forward(self, Image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torc
311311
high_res_features2: Shape (1, 64, 128, 128)
312312
313313
"""
314-
# x = self.normalize(Image)
315-
x = Image
314+
x = input
316315
backbone_out = self.model.forward_image(x)
317316
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
318317
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
319318
if self.model.directly_add_no_mem_embed:
320319
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
321320
feats = [
322-
feat.permute(1, 2, 0).view(1, -1, *feat_size)
321+
feat.permute(1, 2, 0).view(1, *feat_size)
323322
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
324323
][::-1]
325324
image_embeddings = feats[2]
@@ -342,13 +341,70 @@ def __init__(self, sam2) -> None:
342341
self.mask_decoder = self.model.sam_mask_decoder
343342
self.prompt_encoder = self.model.sam_prompt_encoder
344343

344+
def _embed_masks(
345+
self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
346+
) -> torch.Tensor:
347+
mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(
348+
input_mask
349+
)
350+
mask_embedding = mask_embedding + (
351+
1 - has_mask_input
352+
) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
353+
return mask_embedding
354+
355+
356+
def _embed_points(
357+
self,
358+
points: torch.Tensor,
359+
labels: torch.Tensor
360+
) -> torch.Tensor:
361+
"""Embeds point prompts."""
362+
points = points + 0.5 # Shift to center of pixel
363+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
364+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
365+
points = torch.cat([points, padding_point], dim=1)
366+
labels = torch.cat([labels, padding_label], dim=1)
367+
368+
point_embedding = self.prompt_encoder.pe_layer.forward_with_coords(
369+
points, self.prompt_encoder.input_image_size
370+
)
371+
372+
point_embedding = torch.where(
373+
(labels == -1).unsqueeze(-1),
374+
torch.zeros_like(point_embedding) + self.prompt_encoder.not_a_point_embed.weight,
375+
point_embedding,
376+
)
377+
point_embedding = torch.where(
378+
(labels == 0).unsqueeze(-1),
379+
point_embedding + self.prompt_encoder.point_embeddings[0].weight,
380+
point_embedding,
381+
)
382+
point_embedding = torch.where(
383+
(labels == 1).unsqueeze(-1),
384+
point_embedding + self.prompt_encoder.point_embeddings[1].weight,
385+
point_embedding,
386+
)
387+
point_embedding = torch.where(
388+
(labels == 2).unsqueeze(-1),
389+
point_embedding + self.prompt_encoder.point_embeddings[2].weight,
390+
point_embedding,
391+
)
392+
point_embedding = torch.where(
393+
(labels == 3).unsqueeze(-1),
394+
point_embedding + self.prompt_encoder.point_embeddings[3].weight,
395+
point_embedding,
396+
)
397+
return point_embedding
398+
345399
def forward(
346400
self,
347401
image_embeddings: torch.Tensor, # [1,256,64,64]
348402
high_res_features1: torch.Tensor, # [1, 32, 256, 256]
349403
high_res_features2: torch.Tensor, # [1, 64, 128, 128]
350-
unnorm_coords: torch.Tensor, # [num_labels,num_points,2]
351-
labels: torch.Tensor, # [num_labels,num_points]
404+
point_coords: torch.Tensor, # [num_labels,num_points,2]
405+
point_labels: torch.Tensor, # [num_labels,num_points]
406+
mask_input: torch.Tensor, # [1, 1, 256, 256]
407+
has_mask_input: torch.Tensor # [1]
352408
) -> tuple[torch.Tensor, torch.Tensor]:
353409
"""Run SAM2 lightweight decoder and return generated mask for given points.
354410
@@ -359,22 +415,24 @@ def forward(
359415
First set of high-resolution features.
360416
high_res_features2: torch.Tensor of shape [1, high_res_2_dim, high_res_2_size, high_res_2_size]
361417
Second set of high-resolution features.
362-
unnorm_coords: torch.Tensor of shape [1, k, 2]
418+
point_coords: torch.Tensor of shape [1, k, 2]
363419
Point coordinates from input image for segmentation, mapped to the resized image
364-
labels: torch.Tensor of shape [1, k]
420+
point_labels: torch.Tensor of shape [1, k]
365421
Point Labels to select/de-select given point for segmentation
366422
e.g. Corresponding value is 1 if this point is to be included, otherwise 0
367-
423+
mask_input: torch.Tensor of shape [1, 1, 256, 256]
424+
Mask corresponding to focus area
425+
has_mask_input: torch.Tensor of shape [1]
426+
1 if Mask provided, else 0
368427
Returns:
369428
masks: torch.Tensor of shape [1, 1, 256, 256]
370429
scores: torch.Tensor of shape [1, 1]
371430
372431
"""
373-
sparse_embedding, dense_embedding = self.prompt_encoder(
374-
points=(unnorm_coords, labels),
375-
boxes=None,
376-
masks=None,
377-
)
432+
433+
sparse_embedding = self._embed_points(point_coords, point_labels)
434+
dense_embedding = self._embed_masks(mask_input, has_mask_input)
435+
378436
low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
379437
image_embeddings=image_embeddings,
380438
image_pe=self.prompt_encoder.get_dense_pe(),
@@ -383,7 +441,8 @@ def forward(
383441
repeat_image=False,
384442
high_res_features=[high_res_features1, high_res_features2],
385443
)
386-
return low_res_masks, iou_predictions
444+
masks = F.interpolate(low_res_masks, size=(1024, 1024), mode="bilinear", align_corners=False)
445+
return masks, iou_predictions, low_res_masks
387446

388447

389448
model_weights_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
@@ -402,26 +461,31 @@ def main():
402461
sam2_model = build_sam2(model_cfg, checkpoint, device="cpu")
403462
encoder = SAM2Encoder(sam2_model)
404463
decoder = SAM2Decoder(sam2_model)
405-
en_inputs = {"Image": torch.rand((1, 3, 1024, 1024))}
406-
de_inputs = {
407-
"image_embeddings": torch.rand([1, 256, 64, 64]),
408-
"high_res_features1": torch.rand([1, 32, 256, 256]),
409-
"high_res_features2": torch.rand((1, 64, 128, 128)),
410-
"unnorm_coords": torch.randn((1, 5, 2)),
411-
"labels": torch.ones((1, 5)),
412-
}
464+
465+
en_inputs = {"input": torch.rand((1, 3, 1024, 1024))}
413466

414467
with torch.no_grad():
415-
torch.onnx.export(encoder, en_inputs, "sam21_vision_encoder.onnx", opset_version=20, do_constant_folding=True, dynamo = False)
416-
with torch.no_grad():
417-
torch.onnx.export(decoder, de_inputs, "sam21_mask_decoder.onnx", opset_version=20, do_constant_folding=True, dynamo = False)
468+
torch.onnx.export(encoder, en_inputs, "sam21_vision_encoder.onnx", opset_version=20, do_constant_folding=True,
469+
dynamo = False, input_names = list(en_inputs.keys()))
418470

419471
encoder_onnx_model = onnx.load("sam21_vision_encoder.onnx")
420472
simplified_encoder_onnx_model, check = simplify(encoder_onnx_model)
421473

422474
if check:
423475
onnx.save(simplified_encoder_onnx_model, "sam21_vision_encoder.onnx")
424476

477+
de_inputs = {
478+
"image_embeddings": torch.rand([1, 256, 64, 64]),
479+
"high_res_features1": torch.rand([1, 32, 256, 256]),
480+
"high_res_features2": torch.rand((1, 64, 128, 128)),
481+
"point_coords": torch.randn((1, 5, 2)),
482+
"point_labels": torch.ones((1, 5)),
483+
"mask_input": torch.zeros((1, 1, 256, 256)),
484+
"has_mask_input": torch.zeros([1])
485+
}
486+
with torch.no_grad():
487+
torch.onnx.export(decoder, de_inputs, "sam21_mask_decoder.onnx", opset_version=20, do_constant_folding=True,
488+
dynamo = False, input_names = list(de_inputs.keys()))
425489
decoder_onnx_model = onnx.load("sam21_mask_decoder.onnx")
426490
simplified_decoder_onnx_model, check = simplify(decoder_onnx_model)
427491

sam2.1-hiera-small/QNN/sam21_vision_encoder_qnn_ctx.json

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,6 @@
6969
}
7070
},
7171
"passes": {
72-
"surgeries": {
73-
"type": "GraphSurgeries",
74-
"surgeries": [
75-
{
76-
"surgeon": "RenameInputs",
77-
"old_names": [
78-
"input.1"
79-
],
80-
"new_names": [
81-
"pixel_values"
82-
]
83-
}
84-
]
85-
},
8672
"quantization": {
8773
"type": "OnnxStaticQuantization",
8874
"data_config": "quantize_data_config",

sam2.1-hiera-small/QNN/user_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ def ve_generate_quant_data(num_samples):
117117
image = sample["image"]
118118
inputs = processor(image, return_tensors="pt")
119119
pixel_values = inputs["pixel_values"].detach().cpu().numpy()
120-
np.savez(f"{ModelConfig.data_dir}/input_{i}_images.npz", pixel_values=pixel_values)
120+
np.savez(f"{ModelConfig.data_dir}/input_{i}_images.npz", input=pixel_values)

0 commit comments

Comments
 (0)