-
Notifications
You must be signed in to change notification settings - Fork 163
Expand file tree
/
Copy pathconvert_rmbg.py
More file actions
87 lines (69 loc) · 2.23 KB
/
convert_rmbg.py
File metadata and controls
87 lines (69 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Convert BRIA RMBG-1.4 background removal to CoreML (1 model, INT8).
Architecture:
IS-Net based U2Net variant with ImageNet normalization baked in.
Input: 1024x1024 RGB image → Output: alpha mask [1, 1, 1024, 1024]
Requirements:
pip install transformers torch coremltools>=9.0
Usage:
python convert_rmbg.py
"""
import torch
import torch.nn as nn
import coremltools as ct
from coremltools.optimize.coreml import (
OpLinearQuantizerConfig,
OptimizationConfig,
linear_quantize_weights,
)
from transformers import AutoModelForImageSegmentation
class RMBGWrapper(nn.Module):
"""Adds ImageNet normalization and sigmoid to raw model output."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
# RMBG-1.4 uses normalize(x, [0.5,0.5,0.5], [1.0,1.0,1.0])
x = x - 0.5
return torch.sigmoid(self.model(x)[0][0])
def main():
print("Loading briaai/RMBG-1.4 ...")
model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4", trust_remote_code=True
)
model.eval()
wrapper = RMBGWrapper(model)
wrapper.eval()
dummy = torch.randn(1, 3, 1024, 1024).clamp(0, 1)
with torch.no_grad():
traced = torch.jit.trace(wrapper, dummy)
print("Converting to CoreML FP16 ...")
ml = ct.convert(
traced,
inputs=[
ct.ImageType(
name="image",
shape=(1, 3, 1024, 1024),
scale=1.0 / 255.0,
color_layout=ct.colorlayout.RGB,
)
],
outputs=[ct.TensorType(name="alpha_mask")],
minimum_deployment_target=ct.target.iOS17,
compute_precision=ct.precision.FLOAT16,
)
print("Quantizing to INT8 ...")
quant_config = OptimizationConfig(
global_config=OpLinearQuantizerConfig(mode="linear_symmetric", dtype="int8")
)
ml = linear_quantize_weights(ml, quant_config)
ml.author = "CoreML-Models"
ml.short_description = (
"RMBG-1.4 background removal. "
"1024x1024 RGB → alpha mask [1, 1, 1024, 1024]. INT8."
)
ml.license = "Apache-2.0"
ml.save("RMBG_1_4.mlpackage")
print("Saved RMBG_1_4.mlpackage")
if __name__ == "__main__":
main()