|
| 1 | +""" |
| 2 | +模型載入相容性修補模組 |
| 3 | +
|
| 4 | +修補 transformers >= 5.0 與自訂模型程式碼(如 BiRefNet)的相容性問題: |
| 5 | +1. get_init_context 無條件使用 torch.device("meta"), |
| 6 | + 但自訂模型在 __init__ 中呼叫 .item(),meta tensor 不支援。 |
| 7 | +2. 自訂模型未呼叫 post_init(),導致 all_tied_weights_keys 未設定。 |
| 8 | +""" |
| 9 | + |
| 10 | +import logging |
| 11 | +import os |
| 12 | +import warnings |
| 13 | +from collections.abc import Iterator |
| 14 | +from contextlib import contextmanager |
| 15 | +from typing import Any |
| 16 | + |
| 17 | +import torch |
| 18 | +import transformers |
| 19 | +from transformers import AutoModelForImageSegmentation |
| 20 | +from transformers.modeling_utils import PreTrainedModel |
| 21 | + |
| 22 | + |
| 23 | +@contextmanager |
| 24 | +def _suppress_loading_noise() -> Iterator[None]: |
| 25 | + """暫時抑制模型載入期間的冗餘輸出(httpx、transformers、timm、tqdm)""" |
| 26 | + # 保存原始狀態 |
| 27 | + orig_verbosity = transformers.logging.get_verbosity() |
| 28 | + httpx_logger = logging.getLogger("httpx") |
| 29 | + orig_httpx_level = httpx_logger.level |
| 30 | + hf_logger = logging.getLogger("huggingface_hub") |
| 31 | + orig_hf_level = hf_logger.level |
| 32 | + orig_tqdm_disable = os.environ.get("TQDM_DISABLE") |
| 33 | + |
| 34 | + # 抑制: transformers 日誌、httpx HTTP 請求、huggingface_hub、tqdm 進度條 |
| 35 | + transformers.logging.set_verbosity_error() # type: ignore[no-untyped-call] |
| 36 | + httpx_logger.setLevel(logging.WARNING) |
| 37 | + hf_logger.setLevel(logging.WARNING) |
| 38 | + os.environ["TQDM_DISABLE"] = "1" |
| 39 | + |
| 40 | + # 抑制: timm FutureWarning |
| 41 | + with warnings.catch_warnings(): |
| 42 | + warnings.filterwarnings("ignore", category=FutureWarning, module="timm") |
| 43 | + try: |
| 44 | + yield |
| 45 | + finally: |
| 46 | + transformers.logging.set_verbosity(orig_verbosity) # type: ignore[no-untyped-call] |
| 47 | + httpx_logger.setLevel(orig_httpx_level) |
| 48 | + hf_logger.setLevel(orig_hf_level) |
| 49 | + if orig_tqdm_disable is None: |
| 50 | + os.environ.pop("TQDM_DISABLE", None) |
| 51 | + else: |
| 52 | + os.environ["TQDM_DISABLE"] = orig_tqdm_disable |
| 53 | + |
| 54 | + |
| 55 | +def load_pretrained_no_meta(model_name: str) -> Any: |
| 56 | + """ |
| 57 | + 載入預訓練 ImageSegmentation 模型,修補 meta device 相容性問題 |
| 58 | +
|
| 59 | + 自動抑制載入期間的冗餘輸出(HTTP 請求日誌、進度條、FutureWarning 等) |
| 60 | +
|
| 61 | + Args: |
| 62 | + model_name: HuggingFace 模型名稱 |
| 63 | +
|
| 64 | + Returns: |
| 65 | + 載入完成的模型 |
| 66 | + """ |
| 67 | + |
| 68 | + # 修補 1: 移除 meta device context |
| 69 | + orig_context = PreTrainedModel.__dict__["get_init_context"] |
| 70 | + |
| 71 | + @classmethod # type: ignore[misc] |
| 72 | + def _safe_context( |
| 73 | + cls: type, |
| 74 | + dtype: torch.dtype, |
| 75 | + is_quantized: bool, |
| 76 | + _is_ds_init_called: bool, |
| 77 | + ) -> list[Any]: |
| 78 | + bound_original = orig_context.__get__(None, cls) |
| 79 | + contexts: list[Any] = bound_original(dtype, is_quantized, _is_ds_init_called) |
| 80 | + return [ |
| 81 | + c |
| 82 | + for c in contexts |
| 83 | + if not (isinstance(c, torch.device) and c.type == "meta") |
| 84 | + ] |
| 85 | + |
| 86 | + # 修補 2: 確保 all_tied_weights_keys 存在 |
| 87 | + orig_finalize = PreTrainedModel.__dict__["_finalize_model_loading"] |
| 88 | + |
| 89 | + @classmethod # type: ignore[misc] |
| 90 | + def _safe_finalize(cls: type, model: Any, *args: Any, **kwargs: Any) -> Any: |
| 91 | + if not hasattr(model, "all_tied_weights_keys"): |
| 92 | + model.all_tied_weights_keys = {} |
| 93 | + return orig_finalize.__get__(None, cls)(model, *args, **kwargs) |
| 94 | + |
| 95 | + PreTrainedModel.get_init_context = _safe_context # type: ignore[assignment] |
| 96 | + PreTrainedModel._finalize_model_loading = _safe_finalize # type: ignore[assignment] |
| 97 | + try: |
| 98 | + with _suppress_loading_noise(): |
| 99 | + return AutoModelForImageSegmentation.from_pretrained( |
| 100 | + model_name, trust_remote_code=True |
| 101 | + ) |
| 102 | + finally: |
| 103 | + PreTrainedModel.get_init_context = orig_context # type: ignore[method-assign] |
| 104 | + PreTrainedModel._finalize_model_loading = orig_finalize # type: ignore[method-assign] |
0 commit comments