Skip to content

Commit ff8f78f

Browse files
authored
Add files via upload
1 parent 19dcc90 commit ff8f78f

10 files changed

Lines changed: 804 additions & 13 deletions

File tree

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SAMLabeler Pro: 使用 [Segment Anyting Model](https://github.com/facebookresearch/segment-anything) 辅助的图像标注工具,支持远程多人同时标注
1+
# SAMLabeler Pro: 使用 [SAM](https://github.com/facebookresearch/segment-anything)[MobileSAM](https://github.com/ChaoningZhang/MobileSAM) 辅助的图像标注工具,支持远程多人同时标注
22

33
![image](https://user-images.githubusercontent.com/69880398/235317010-2ec560cf-1de9-436d-81a4-79654e533de1.png)
44

@@ -10,9 +10,12 @@
1010
- 使用本工具时为避免导入冲突,请务必不要在运行环境中安装SAM源码,本项目中的segment_anything文件夹便是作了一定改动的SAM源码。
1111
- 如果有QT报错,大概率是opencv的原因,请把requirements.txt文件中的opencv_python换为opencv_python_headless
1212

13+
## 0 最近更新
14+
[2023/6/28] 目前已经支持轻量化模型[MobileSAM](https://github.com/ChaoningZhang/MobileSAM) ,以与SAM相近的精度达到近百倍的速度,且显存占用更少。[点此下载模型](https://github.com/LSH9832/SAMLabelerPro/releases/download/v0.2.0/mobile_sam.pt) 。若无特殊设置,将会优先加载此模型。
15+
1316
## 1 即将更新
1417

15-
- CV真的太卷了,还好我毕业了,真的卷不动了。有空的话会加入FastSAM模型和MobileSAM模型(MobileSAM优先考虑)
18+
- 暂无,发现bug请在Issues中留言
1619

1720
## 2 相对于原版的新特性
1821

demo/box2segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def draw_mask(self, mask, image=None, color=(0, 255, 0), ratio=0.5):
6868

6969

7070
def main(args):
71-
assert args.size.lower() in ["b", "l", "h"]
71+
assert args.size.lower() in ["mobile", "b", "l", "h"]
7272
boxSegmenter = SegBox(force_size=args.size, half=args.half)
7373

7474
if not (args.cfg.startswith("/") or args.cfg[1] == ":"):

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ torchvision
1010
pycocotools
1111
matplotlib
1212
requests
13+
timm
1314
flask # only server need this.
1415
# onnxruntime
1516
# onnx

segment_any/segment_any.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@ def __init__(self, checkpoint, half=True, force_model_type=None):
1111

1212
success = False
1313

14-
all_model_type = ["h", "l", "b"]
14+
all_model_type = ["mobile", "h", "l", "b"]
1515
if force_model_type is not None and force_model_type in all_model_type:
1616
all_model_type = [force_model_type]
1717

1818
for self.model_type in [f"vit_{t}" for t in all_model_type]:
1919
try:
2020
half = half and torch.cuda.is_available() and not self.model_type.endswith("h")
21+
if (self.model_type.endswith("h") or self.model_type.endswith("mobile")) and half:
22+
print(f"{self.model_type} can not run with half precision, using full precision.")
23+
half = False
2124
print(f"try load weights '{checkpoint}' with model size '{self.model_type}'")
2225
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
2326
sam.to(device=self.device, dtype=torch.float16 if half else torch.float32)
@@ -35,14 +38,15 @@ def __init__(self, checkpoint, half=True, force_model_type=None):
3538
self.success = success
3639

3740
def set_image(self, image):
41+
# print("set image")
3842
self.predictor.set_image(image)
43+
# print("done")
3944

4045
def reset_image(self):
4146
self.predictor.reset_image()
4247
self.image = None
4348
torch.cuda.empty_cache()
4449

45-
4650
def predict_box(self, box, xyxy=True, expand=0):
4751

4852
def modify_box(bbox: (list, np.ndarray), xy2=True):

segment_anything/build_sam.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from functools import partial
1010

11-
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
11+
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT
1212

1313

1414
def build_sam_vit_h(checkpoint=None):
@@ -44,11 +44,60 @@ def build_sam_vit_b(checkpoint=None):
4444
)
4545

4646

47+
def build_sam_vit_mobile(checkpoint=None):
48+
prompt_embed_dim = 256
49+
image_size = 1024
50+
vit_patch_size = 16
51+
image_embedding_size = image_size // vit_patch_size
52+
mobile_sam = Sam(
53+
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
54+
embed_dims=[64, 128, 160, 320],
55+
depths=[2, 2, 6, 2],
56+
num_heads=[2, 4, 5, 10],
57+
window_sizes=[7, 7, 14, 7],
58+
mlp_ratio=4.,
59+
drop_rate=0.,
60+
drop_path_rate=0.0,
61+
use_checkpoint=False,
62+
mbconv_expand_ratio=4.0,
63+
local_conv_size=3,
64+
layer_lr_decay=0.8
65+
),
66+
prompt_encoder=PromptEncoder(
67+
embed_dim=prompt_embed_dim,
68+
image_embedding_size=(image_embedding_size, image_embedding_size),
69+
input_image_size=(image_size, image_size),
70+
mask_in_chans=16,
71+
),
72+
mask_decoder=MaskDecoder(
73+
num_multimask_outputs=3,
74+
transformer=TwoWayTransformer(
75+
depth=2,
76+
embedding_dim=prompt_embed_dim,
77+
mlp_dim=2048,
78+
num_heads=8,
79+
),
80+
transformer_dim=prompt_embed_dim,
81+
iou_head_depth=3,
82+
iou_head_hidden_dim=256,
83+
),
84+
pixel_mean=[123.675, 116.28, 103.53],
85+
pixel_std=[58.395, 57.12, 57.375],
86+
)
87+
mobile_sam.eval()
88+
if checkpoint is not None:
89+
with open(checkpoint, "rb") as f:
90+
state_dict = torch.load(f)
91+
mobile_sam.load_state_dict(state_dict)
92+
return mobile_sam
93+
94+
4795
sam_model_registry = {
4896
"default": build_sam_vit_h,
4997
"vit_h": build_sam_vit_h,
5098
"vit_l": build_sam_vit_l,
5199
"vit_b": build_sam_vit_b,
100+
"vit_mobile": build_sam_vit_mobile
52101
}
53102

54103

segment_anything/modeling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .mask_decoder import MaskDecoder
1010
from .prompt_encoder import PromptEncoder
1111
from .transformer import TwoWayTransformer
12+
from .tiny_vit_sam import TinyViT

segment_anything/modeling/sam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Dict, List, Tuple
1212

1313
from .image_encoder import ImageEncoderViT
14+
from .tiny_vit_sam import TinyViT
1415
from .mask_decoder import MaskDecoder
1516
from .prompt_encoder import PromptEncoder
1617

@@ -22,7 +23,7 @@ class Sam(nn.Module):
2223

2324
def __init__(
2425
self,
25-
image_encoder: ImageEncoderViT,
26+
image_encoder: [ImageEncoderViT, TinyViT],
2627
prompt_encoder: PromptEncoder,
2728
mask_decoder: MaskDecoder,
2829
pixel_mean: List[float] = [123.675, 116.28, 103.53],
@@ -32,7 +33,7 @@ def __init__(
3233
SAM predicts object masks from an image and input prompts.
3334
3435
Arguments:
35-
image_encoder (ImageEncoderViT): The backbone used to encode the
36+
image_encoder (ImageEncoderViT, TinyVit): The backbone used to encode the
3637
image into image embeddings that allow for efficient mask prediction.
3738
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
3839
mask_decoder (MaskDecoder): Predicts masks from the image embeddings

0 commit comments

Comments
 (0)