@@ -36,6 +36,12 @@ def str2bool(v):
3636 default = None ,
3737 help = "Path to the ONNX language encoder model (for SAM3)" ,
3838)
39+ argparser .add_argument (
40+ "--text_prompt" ,
41+ type = str ,
42+ default = None ,
43+ help = "Text prompt for SAM3 (e.g. 'truck'). Overrides any text entry in the prompt JSON." ,
44+ )
3945argparser .add_argument (
4046 "--image" ,
4147 type = str ,
@@ -90,11 +96,14 @@ def str2bool(v):
9096
9197text_prompt = None
9298if args .sam_variant == "sam3" :
93- # Extract text prompt from JSON if available, otherwise default to "visual"
94- for p in prompt :
95- if p ["type" ] == "text" :
96- text_prompt = p ["data" ]
97- break
99+ # --text_prompt takes priority; fall back to any text entry in the JSON.
100+ if args .text_prompt :
101+ text_prompt = args .text_prompt
102+ else :
103+ for p in prompt :
104+ if p ["type" ] == "text" :
105+ text_prompt = p ["data" ]
106+ break
98107 if text_prompt is None :
99108 text_prompt = "visual"
100109
@@ -109,14 +118,14 @@ def str2bool(v):
109118# Merge masks
110119mask = np .zeros ((masks .shape [2 ], masks .shape [3 ], 3 ), dtype = np .uint8 )
111120if args .sam_variant == "sam3" :
112- # SAM3 returns (N, 1, H, W) – render all N detected instances.
121+ # SAM3 returns bool (N, 1, H, W) – render all N detected instances.
113122 for i in range (masks .shape [0 ]):
114- m = masks [i , 0 ] # (H, W)
115- mask [m > 0.5 ] = [255 , 0 , 0 ]
123+ m = masks [i , 0 ] # (H, W) bool
124+ mask [m ] = [255 , 0 , 0 ]
116125else :
117- # SAM1/SAM2 returns (1, 3, H, W) – merge all quality levels .
126+ # SAM1/SAM2 return raw logits (1, 3, H, W) – threshold at 0 (= sigmoid 0.5) .
118127 for m in masks [0 , :, :, :]:
119- mask [m > 0.5 ] = [255 , 0 , 0 ]
128+ mask [m > 0.0 ] = [255 , 0 , 0 ]
120129
121130# Binding image and mask
122131visualized = cv2 .addWeighted (image , 0.5 , mask , 0.5 , 0 )
0 commit comments