Skip to content

Commit dd50e6e

Browse files
DennySORAclaude
andcommitted
feat(ultra): add full Ultra backend config UI, model compat fixes, and noise suppression
- Add complete Ultra backend configuration flow in modern UI with settings persistence (SettingsHistory) - Fix transformers 5.x compatibility: extract shared model_compat.py module that patches meta device and all_tied_weights_keys issues for BiRefNet models - Fix CUDA OOM: replace torch.compile(mode="reduce-overhead") with TF32 matmul precision - Fix RMBG-2.0 model output indexing (use [-1].sigmoid().squeeze() for full resolution) - Fix backend registration (import UltraBackend in backends/__init__.py) - Suppress verbose model loading noise (httpx logs, tqdm progress bars, timm FutureWarning, transformers Loading weights output) - Condense logger.info messages across ultra.py and portrait_matting.py - Add kornia and timm dependencies Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ed484b5 commit dd50e6e

10 files changed

Lines changed: 569 additions & 112 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ readme = "README.md"
66
requires-python = ">=3.13"
77
dependencies = [
88
"inquirerpy>=0.3.4",
9+
"kornia>=0.8.2",
910
"numpy>=1.24.0",
1011
"opencv-contrib-python>=4.10.0",
1112
"pillow>=12.0.0",
1213
"pydantic>=2.12.5",
1314
"pydantic-settings>=2.12.0",
1415
"rich>=13.0.0",
1516
"scikit-learn>=1.3.0",
17+
"timm>=1.0.24",
1618
"torch>=2.0.0",
1719
"torchvision>=0.15.0",
1820
"transformers>=4.45.0",

src/app.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
import logging
88

99
from src.backends.registry import BackendRegistry
10-
from src.common import ColorFilter, ColorFilterConfig
10+
from src.common import (
11+
AlphaConfig,
12+
AlphaMode,
13+
ColorFilter,
14+
ColorFilterConfig,
15+
ResolutionConfig,
16+
ResolutionMode,
17+
)
1118
from src.core.interfaces import BackendProtocol
1219
from src.core.processor import ImageProcessor
1320
from src.data_model import ProcessConfig, ProcessResult
@@ -106,15 +113,8 @@ def _create_backend(self, config: ProcessConfig) -> BackendProtocol:
106113
"""
107114
backend_kwargs = {}
108115

109-
# 如果需要色彩過濾,建立配置
110-
if config.backend_name == "ultra" and "color_filter" in config.extra_config:
111-
color_value = str(config.extra_config["color_filter"])
112-
color_filter = ColorFilterConfig(
113-
enabled=True,
114-
color=ColorFilter(color_value),
115-
edge_refine_strength=config.strength,
116-
)
117-
backend_kwargs["color_filter"] = color_filter
116+
if config.backend_name == "ultra":
117+
backend_kwargs = self._build_ultra_kwargs(config)
118118

119119
# 使用註冊表建立後端(工廠模式)
120120
return self.backend_registry.create(
@@ -124,6 +124,58 @@ def _create_backend(self, config: ProcessConfig) -> BackendProtocol:
124124
**backend_kwargs,
125125
)
126126

127+
def _build_ultra_kwargs(self, config: ProcessConfig) -> dict[str, object]:
128+
"""
129+
從 extra_config 建構 Ultra 後端的完整參數
130+
131+
Args:
132+
config: 處理配置
133+
134+
Returns:
135+
Ultra 後端建構參數
136+
"""
137+
extra = config.extra_config
138+
kwargs: dict[str, object] = {}
139+
140+
# 色彩過濾
141+
color_value = str(extra.get("color_filter", "none"))
142+
if color_value != "none":
143+
kwargs["color_filter"] = ColorFilterConfig(
144+
enabled=True,
145+
color=ColorFilter(color_value),
146+
edge_refine_strength=config.strength,
147+
)
148+
149+
# Trimap 精修
150+
if "use_trimap_refine" in extra:
151+
kwargs["use_trimap_refine"] = bool(extra["use_trimap_refine"])
152+
153+
# 人像 Matting 精修
154+
if "use_portrait_matting" in extra:
155+
kwargs["use_portrait_matting"] = bool(extra["use_portrait_matting"])
156+
if "portrait_matting_strength" in extra:
157+
kwargs["portrait_matting_strength"] = float(
158+
extra["portrait_matting_strength"] # type: ignore[arg-type]
159+
)
160+
if "portrait_matting_model" in extra:
161+
kwargs["portrait_matting_model"] = str(extra["portrait_matting_model"])
162+
163+
# Alpha 設定
164+
alpha_mode = str(extra.get("alpha_mode", "straight"))
165+
edge_decontam = bool(extra.get("edge_decontamination", True))
166+
kwargs["alpha_config"] = AlphaConfig(
167+
mode=AlphaMode(alpha_mode),
168+
edge_decontamination=edge_decontam,
169+
)
170+
171+
# 解析度設定
172+
resolution = str(extra.get("resolution_mode", "1024"))
173+
kwargs["resolution_config"] = ResolutionConfig(
174+
mode=ResolutionMode(resolution),
175+
)
176+
177+
return kwargs
178+
127179
def _display_result(self, result: ProcessResult) -> None:
128180
"""
129181
顯示處理結果

src/backends/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from .gemini_watermark import GeminiWatermarkBackend
88
from .image_splitter import ImageSplitterBackend
99
from .registry import BackendRegistry
10+
from .ultra import UltraBackend
1011

1112

12-
# Note: UltraBackend is imported from src.features.background_removal.ultra
13-
# to avoid circular imports. Import it directly from there if needed.
14-
1513
__all__ = [
1614
"BackendRegistry",
1715
"GeminiWatermarkBackend",
1816
"ImageSplitterBackend",
17+
"UltraBackend",
1918
]

src/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
unpremultiply_alpha,
1717
)
1818
from .color_filter import ColorFilter, ColorFilterConfig
19+
from .model_compat import load_pretrained_no_meta
1920
from .preset_config import (
2021
BackgroundRemovalPreset,
2122
PresetLevel,
@@ -42,4 +43,5 @@
4243
"get_preset",
4344
"list_presets",
4445
"print_preset_comparison",
46+
"load_pretrained_no_meta",
4547
]

src/common/model_compat.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
模型載入相容性修補模組
3+
4+
修補 transformers >= 5.0 與自訂模型程式碼(如 BiRefNet)的相容性問題:
5+
1. get_init_context 無條件使用 torch.device("meta"),
6+
但自訂模型在 __init__ 中呼叫 .item(),meta tensor 不支援。
7+
2. 自訂模型未呼叫 post_init(),導致 all_tied_weights_keys 未設定。
8+
"""
9+
10+
import logging
11+
import os
12+
import warnings
13+
from collections.abc import Iterator
14+
from contextlib import contextmanager
15+
from typing import Any
16+
17+
import torch
18+
import transformers
19+
from transformers import AutoModelForImageSegmentation
20+
from transformers.modeling_utils import PreTrainedModel
21+
22+
23+
@contextmanager
24+
def _suppress_loading_noise() -> Iterator[None]:
25+
"""暫時抑制模型載入期間的冗餘輸出(httpx、transformers、timm、tqdm)"""
26+
# 保存原始狀態
27+
orig_verbosity = transformers.logging.get_verbosity()
28+
httpx_logger = logging.getLogger("httpx")
29+
orig_httpx_level = httpx_logger.level
30+
hf_logger = logging.getLogger("huggingface_hub")
31+
orig_hf_level = hf_logger.level
32+
orig_tqdm_disable = os.environ.get("TQDM_DISABLE")
33+
34+
# 抑制: transformers 日誌、httpx HTTP 請求、huggingface_hub、tqdm 進度條
35+
transformers.logging.set_verbosity_error() # type: ignore[no-untyped-call]
36+
httpx_logger.setLevel(logging.WARNING)
37+
hf_logger.setLevel(logging.WARNING)
38+
os.environ["TQDM_DISABLE"] = "1"
39+
40+
# 抑制: timm FutureWarning
41+
with warnings.catch_warnings():
42+
warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
43+
try:
44+
yield
45+
finally:
46+
transformers.logging.set_verbosity(orig_verbosity) # type: ignore[no-untyped-call]
47+
httpx_logger.setLevel(orig_httpx_level)
48+
hf_logger.setLevel(orig_hf_level)
49+
if orig_tqdm_disable is None:
50+
os.environ.pop("TQDM_DISABLE", None)
51+
else:
52+
os.environ["TQDM_DISABLE"] = orig_tqdm_disable
53+
54+
55+
def load_pretrained_no_meta(model_name: str) -> Any:
56+
"""
57+
載入預訓練 ImageSegmentation 模型,修補 meta device 相容性問題
58+
59+
自動抑制載入期間的冗餘輸出(HTTP 請求日誌、進度條、FutureWarning 等)
60+
61+
Args:
62+
model_name: HuggingFace 模型名稱
63+
64+
Returns:
65+
載入完成的模型
66+
"""
67+
68+
# 修補 1: 移除 meta device context
69+
orig_context = PreTrainedModel.__dict__["get_init_context"]
70+
71+
@classmethod # type: ignore[misc]
72+
def _safe_context(
73+
cls: type,
74+
dtype: torch.dtype,
75+
is_quantized: bool,
76+
_is_ds_init_called: bool,
77+
) -> list[Any]:
78+
bound_original = orig_context.__get__(None, cls)
79+
contexts: list[Any] = bound_original(dtype, is_quantized, _is_ds_init_called)
80+
return [
81+
c
82+
for c in contexts
83+
if not (isinstance(c, torch.device) and c.type == "meta")
84+
]
85+
86+
# 修補 2: 確保 all_tied_weights_keys 存在
87+
orig_finalize = PreTrainedModel.__dict__["_finalize_model_loading"]
88+
89+
@classmethod # type: ignore[misc]
90+
def _safe_finalize(cls: type, model: Any, *args: Any, **kwargs: Any) -> Any:
91+
if not hasattr(model, "all_tied_weights_keys"):
92+
model.all_tied_weights_keys = {}
93+
return orig_finalize.__get__(None, cls)(model, *args, **kwargs)
94+
95+
PreTrainedModel.get_init_context = _safe_context # type: ignore[assignment]
96+
PreTrainedModel._finalize_model_loading = _safe_finalize # type: ignore[assignment]
97+
try:
98+
with _suppress_loading_noise():
99+
return AutoModelForImageSegmentation.from_pretrained(
100+
model_name, trust_remote_code=True
101+
)
102+
finally:
103+
PreTrainedModel.get_init_context = orig_context # type: ignore[method-assign]
104+
PreTrainedModel._finalize_model_loading = orig_finalize # type: ignore[method-assign]

src/features/background_removal/portrait_matting.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@ def __init__(
9999
self._model_loaded = False
100100
self._transform: transforms.Compose | None = None
101101

102-
logger.info("Portrait matting refiner initialized")
103-
logger.info(" Model: %s", self.model_name)
104-
logger.info(" Device: %s", self.device)
105-
logger.info(" High-res mode: %s", self.enable_hr_mode)
102+
logger.info(
103+
"Portrait matting: model=%s, device=%s", self.model_name, self.device
104+
)
106105

107106
def load_model(self) -> None:
108107
"""載入人像 matting 模型"""
@@ -131,30 +130,21 @@ def _load_birefnet(self) -> None:
131130
- Hugging Face: ZhengPeng7/BiRefNet-matting
132131
- 授權:MIT License
133132
"""
134-
from transformers import AutoModelForImageSegmentation
133+
from src.common.model_compat import load_pretrained_no_meta
135134

136135
repo_id = BIREFNET_MODELS[self.model_name]
137136
input_size = BIREFNET_INPUT_SIZES[self.model_name]
138137

139138
try:
140139
logger.info("Loading BiRefNet model: %s ...", repo_id)
141140

142-
self._model = AutoModelForImageSegmentation.from_pretrained(
143-
repo_id, trust_remote_code=True
144-
)
141+
self._model = load_pretrained_no_meta(repo_id)
145142
self._model.to(self.device)
146143
self._model.eval()
147144

148-
# torch.compile() 加速(MPS 不支援)
149-
if hasattr(torch, "compile") and self.device.type != "mps":
150-
try:
151-
self._model = torch.compile(self._model, mode="reduce-overhead")
152-
logger.info("torch.compile() enabled for BiRefNet")
153-
except Exception:
154-
logger.debug(
155-
"torch.compile() unavailable for BiRefNet",
156-
exc_info=True,
157-
)
145+
# 啟用 TF32(Ampere+ GPU 自動加速 float32 矩陣運算)
146+
if self.device.type == "cuda":
147+
torch.set_float32_matmul_precision("high")
158148

159149
# 建立預處理轉換器
160150
self._transform = transforms.Compose(
@@ -166,7 +156,7 @@ def _load_birefnet(self) -> None:
166156
)
167157

168158
self._model_loaded = True
169-
logger.info("BiRefNet loaded successfully (%s)", self.model_name)
159+
logger.info("BiRefNet loaded on %s", self.device)
170160

171161
except Exception as e:
172162
logger.warning(

0 commit comments

Comments
 (0)