diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..6a1e0d4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,103 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + workflow_dispatch: + +# Cancel in-progress runs of the same workflow on the same branch / PR. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + lint: + name: Ruff lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + - name: Install ruff + run: | + python -m pip install --upgrade pip + pip install "ruff>=0.5.0" + - name: Run ruff + run: ruff check . + + tests: + name: Tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: Install minimal runtime deps + # We deliberately install CPU-only ``torch`` to keep CI fast and avoid + # pulling CUDA / cuDNN wheels. ``bitsandbytes`` is also skipped (it is + # GPU-only). Tests stub out the heavy I/O and never touch a real GPU. + run: | + python -m pip install --upgrade pip + pip install --index-url https://download.pytorch.org/whl/cpu "torch>=2.0.0" + pip install \ + "transformers>=4.36.0" \ + "datasets>=2.14.0" \ + "accelerate>=0.24.0" \ + "peft>=0.6.0" \ + "scipy>=1.10.0" \ + "scikit-learn>=1.3.0" \ + "tqdm>=4.65.0" \ + "rich>=13.0.0" \ + "huggingface_hub>=0.20.0" \ + "psutil" \ + "gguf" \ + "py-cpuinfo" \ + "pytest>=7.4.0" + - name: Install QuantLLM (no deps; we already installed them above) + # ``--no-deps`` skips re-resolving the heavy dependency set (notably + # ``bitsandbytes``, which is GPU-only and not needed by the test + # suite). The import-only install is what makes ``import quantllm`` + # work in the test workers. + run: pip install --no-deps -e . + - name: Run pytest + env: + QUANTLLM_BANNER: "0" + run: pytest tests/ -ra + + build: + name: Build sdist + wheel + runs-on: ubuntu-latest + needs: [lint, tests] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + - name: Install build tooling + run: | + python -m pip install --upgrade pip + pip install build twine + - name: Build distribution + run: python -m build + - name: Validate artifacts + run: twine check dist/* + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: dist-${{ github.sha }} + path: dist/ + retention-days: 14 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..aa4c47e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +# Pre-commit configuration for QuantLLM contributors. +# +# Setup: +# pip install pre-commit +# pre-commit install +# +# The hooks below run a fast subset of CI locally before each commit. The full +# test suite still runs in GitHub Actions; pre-commit only blocks obviously +# broken commits (lint failures, leftover merge markers, accidentally +# committed large files, etc.). + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-merge-conflict + - id: check-added-large-files + args: ["--maxkb=1024"] + - id: debug-statements + - id: mixed-line-ending + args: ["--fix=lf"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.7 + hooks: + - id: ruff + args: ["--fix"] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..83b7d43 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,95 @@ +# Changelog + +All notable changes to QuantLLM are recorded here. The format follows +[Keep a Changelog](https://keepachangelog.com/en/1.1.0/) and the project +adheres to [Semantic Versioning](https://semver.org/). + +## [Unreleased] — production hardening on top of v2.1.0rc1 + +### Fixed + +- **`is_quantized` no longer lies about the loaded model state.** The + attribute is now a derived property reading + `model.config.quantization_config` (and BitsAndBytes layer types) at + call time. This fixes three concrete bugs in v2.1.0rc1: + * `from_config_only=True` previously left `_is_quantized=True` even + though `AutoModelForCausalLM.from_config(...)` returns a random- + weights model with no quantization. The flag is now `False` and a + warning is emitted to make the random-weights nature explicit. + * A missing `bitsandbytes` install used to silently fall through to + full precision while keeping `_is_quantized=True`. We now log a + descriptive warning and report `False`. + * Pre-quantized HF repos that already shipped a `quantization_config` + (GPTQ, AWQ, etc.) are now correctly reported as quantized regardless + of the user's `quantize=False` flag. +- **`DEFAULT_ARCHITECTURE_FALLBACKS` is now actually consulted.** The + fallback table introduced by PR #27 was dead code whenever HF returned + a non-empty `model_type` (i.e. always). `resolve_model_type` now + checks the table directly and recognises common version-suffix + patterns (`qwen3` → `qwen2`, `llama4` → `llama`, `phi4` → `phi3`, + `gemma3` → `gemma2`, etc.). +- **`register_architecture` class lookup now uses the natural API.** + Calling `register_architecture("newmodel", base_model_type="llama", + model_class=NewModel)` previously stored the class under `"newmodel"` + but looked it up under `"llama"`, so the fallback path silently + ignored it. The lookup now tries the original `config.model_type` + first and falls back to the resolved base family. +- Removed an accidentally duplicated `if is_bnb and is_8bit ...` block + in the existing-quant detection branch of + `TurboModel.from_pretrained`. + +### Added + +- **`TurboModel.is_quantized` public property** plus + **`TurboModel.report()`** returning a structured dict (`model_id`, + `params_billion`, `requested_bits`, `effective_loading_bits`, + `is_quantized`, `quant_method`, `device`, `dtype`, `finetuned`, + `lora_applied`). Use `report()` to assert programmatically what the + loader actually produced. +- **Pre-quantized repo detection.** Repository names matching + `*-bnb-4bit`, `*-bnb-8bit`, `*-AWQ`, `*-GPTQ`, `*-INT4`, `*-INT8`, + `*-FP8`, `*-EETQ`, `*-HQQ`, `*-AQLM` log a friendly hint that the + embedded `quantization_config` will be honoured rather than + re-quantized. +- **GGUF-only repo hint.** When a name contains `-gguf` / `.gguf`, + `from_pretrained` warns and points the user at `from_gguf`. +- **Expanded `DEFAULT_ARCHITECTURE_FALLBACKS` table** covering Llama 2/3/4, + Mistral / Mixtral, Qwen 2 / 2-MoE / 3, Phi / Phi-3 / Phi-4, Gemma / + Gemma 2 / Gemma 3, Falcon, Cohere / Command-R, DeepSeek (V2/V3), + OLMo / OLMo 2, SmolLM / SmolLM 2 / SmolLM 3, Yi, StarCoder / + StarCoder 2, InternLM / InternLM 2, Baichuan, ChatGLM and StableLM. +- **Real CI workflow** at `.github/workflows/ci.yml` running ruff, + pytest on Python 3.10 / 3.11 / 3.12, and `python -m build` + + `twine check` on every PR. +- **`pyproject.toml`** providing PEP 517 / 518 build metadata, a + conservative ruff lint profile and pytest defaults. +- **`.pre-commit-config.yaml`** for local enforcement (whitespace, + end-of-file fixer, large-file guard, ruff with autofix). +- **`docs/guide/consumer-hardware.md`** documenting expected behaviour + on every tier of consumer hardware (CPU-only, ≤ 8 GB VRAM, + 12 – 24 GB, Apple Silicon, multi-GPU) and how to inspect the loaded + state. +- **Regression tests** for every fix above: + * `tests/test_quantization_state.py` — runtime quantization state + tracking, `from_config_only` honesty, `report()` schema. + * `tests/test_resolve_model_type.py` — fallback table consultation, + family-suffix matching, registry-class lookup ergonomics. + +### Changed + +- `TurboModel.__repr__` now reads from the new `is_quantized` property + and degrades gracefully when `num_parameters()` is unavailable + (mocked / lazily-loaded models). +- `TurboModel.from_gguf` now sets `_is_quantized_override = True` + rather than mutating an attribute the type system thought was a + property -- this is functionally identical but more honest about the + contract. +- The "bitsandbytes not installed" warning now explains how to install + it and explicitly states that loading falls back to full precision. + +## [2.0.0] — 2025-12-21 + +Initial public release of the `turbo()` API and the GGUF / ONNX / MLX +export pipeline. See the GitHub +[releases page](https://github.com/codewithdark-git/QuantLLM/releases/tag/v2.0.0) +for the full notes. diff --git a/docs/guide/consumer-hardware.md b/docs/guide/consumer-hardware.md new file mode 100644 index 0000000..55999a1 --- /dev/null +++ b/docs/guide/consumer-hardware.md @@ -0,0 +1,184 @@ +# 🖥️ Consumer Hardware Guide + +QuantLLM is designed to run modern LLMs on the kind of hardware most +developers actually have — gaming GPUs, laptops, Apple Silicon Macs and +even pure-CPU machines. This page is a flat index of what works, where +to expect compromises, and which knobs to turn for each tier. + +`turbo()` auto-detects your hardware via :class:`SmartConfig` and picks +sensible defaults, but you can always override any value explicitly. + +--- + +## Quick decision table + +| Hardware | Recommended bits | Recommended path | Notes | +|------------------------------|------------------|--------------------------------------------|-----------------------------------------------------------------| +| CPU only / no GPU | 4 (GGUF) | `TurboModel.from_gguf(...)` (llama.cpp) | BitsAndBytes is GPU-only; GGUF runs on CPU. | +| Apple Silicon (M-series) | 4 (MLX or GGUF) | Export to MLX or GGUF | Native Metal kernels; see *Apple Silicon* below. | +| ≤ 4 GB VRAM (GTX 1650, etc.) | 4 (BnB) + offload| `turbo(..., bits=4)` with CPU offload | Auto-enabled by SmartConfig; expect slower generation. | +| 6 – 8 GB VRAM (3060/3070) | 4 (BnB) | `turbo(..., bits=4)` | The default sweet spot for 7B-class models. | +| 12 GB VRAM (3060 12 / 3080) | 4 or 8 | `turbo(...)` (auto) | 13B at 4-bit fits comfortably; 7B can run at 8-bit. | +| 16 – 24 GB VRAM (4080/4090) | 4 or 8 | `turbo(...)` (auto) | 70B-class needs 4-bit + cpu_offload; 13B at 8-bit fits. | +| Multi-GPU / server | 4 / 8 / 16 | `turbo(..., device_map="auto")` | `accelerate` shards weights across visible GPUs. | + +--- + +## CPU-only inference + +If `torch.cuda.is_available()` is `False`, BitsAndBytes is unavailable +(it depends on CUDA). Use one of: + +```python +# Option A: load full precision on CPU (slow but works) +model = turbo("microsoft/phi-2", quantize=False, device="cpu") + +# Option B (recommended): load a pre-quantized GGUF via llama.cpp +from quantllm import TurboModel +model = TurboModel.from_gguf( + "TheBloke/phi-2-GGUF", + filename="phi-2.Q4_K_M.gguf", +) +``` + +`from_gguf` uses `llama-cpp-python` under the hood, which produces real +4-bit / 5-bit / 8-bit inference on CPU at usable speeds. + +--- + +## ≤ 8 GB VRAM (gaming laptops, GTX 1660 / RTX 2060 / 3060 6 GB) + +```python +from quantllm import turbo + +# 4-bit NF4 with double quant, automatic CPU offload, fp16 compute +model = turbo("meta-llama/Llama-3.2-3B-Instruct", bits=4) + +# Fine-tune with LoRA at 4-bit (QLoRA-style) +model.finetune("my_data.json", epochs=3) +``` + +Inspect the actual loaded state via `model.report()` if generation feels +slow — the `device` and `quant_method` keys tell you whether the model +ended up on GPU or partly offloaded. + +### What SmartConfig does for small VRAM + +* Picks `bits=4` and `BitsAndBytesConfig(load_in_4bit=True, + bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)`. +* Sets `cpu_offload=True` automatically when the model exceeds ~90% of + free VRAM (with a 1.5× headroom factor for inference, 3× for training). +* Falls back to `torch.float16` compute when bf16 is not supported. +* Keeps `compile_model=False` (Torch compile only kicks in at ≥ 16 GB + total VRAM, where its compile-time cost amortises). + +--- + +## 12 – 24 GB VRAM (consumer flagship: 3060 12 / 3080 / 4080 / 4090) + +```python +# Auto: SmartConfig will pick 4-bit for 13B, 8-bit for 7B +model = turbo("mistralai/Mistral-7B-Instruct-v0.3") + +# Or be explicit +model = turbo("meta-llama/Llama-3.1-8B-Instruct", bits=8) +``` + +At 24 GB you can fit: +* 70B-class models at 4-bit *with* CPU offload for the largest layers. +* 13B-class models at 8-bit comfortably. +* 7B-class models at 16-bit if you want zero-loss inference for + benchmarking. + +--- + +## Apple Silicon (M1 / M2 / M3 / M4) + +There is no CUDA on Apple Silicon and BitsAndBytes does not run there. +The two supported paths are: + +```python +# Export an MLX-format model that uses Apple's native Metal kernels +model = turbo("microsoft/phi-2", quantize=False) # load on CPU first +model.export(format="mlx", path="./phi-2-mlx") + +# Or download a GGUF and run it through llama.cpp's Metal backend +from quantllm import TurboModel +model = TurboModel.from_gguf( + "TheBloke/phi-2-GGUF", + filename="phi-2.Q4_K_M.gguf", +) +``` + +`pip install quantllm[mlx]` pulls `mlx` and `mlx-lm`. `from_gguf` uses +`llama-cpp-python`, which auto-detects Metal at install time when built +with `CMAKE_ARGS="-DLLAMA_METAL=on"`. + +--- + +## Multi-GPU & server-class hardware + +```python +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + +model = turbo( + "meta-llama/Llama-3.1-70B-Instruct", + bits=4, + # Pass through to ``transformers``: shard weights across all GPUs + device_map="auto", +) +``` + +For training jobs, prefer `accelerate launch` with a YAML config; the +SmartConfig auto-tuner only inspects GPU 0 and won't pick up multi-GPU +batch-size headroom on its own. + +--- + +## When QuantLLM cannot quantize + +A few situations produce a full-precision model even when you asked for +4-bit. They are now reflected honestly in `model.is_quantized`: + +| Situation | `is_quantized` | What QuantLLM does | +|------------------------------------------|----------------|--------------------------------------------------------------------| +| `bitsandbytes` not installed (no CUDA) | `False` | Logs a warning and loads in fp16 / bf16. | +| Model already shipped with `quantization_config` (GPTQ / AWQ / etc.) | `True` | Honours the embedded config; skips dynamic BnB. | +| `from_config_only=True` | `False` | Returns random-init weights; warns that the model is not usable. | +| `bits=16` (explicit full precision) | `False` | Loads the original checkpoint without any quantization layer. | + +Always use `model.report()` if you need to *programmatically* assert +which quantization path actually ran — it is the canonical source of +truth. + +--- + +## Troubleshooting + +### `ImportError: bitsandbytes is not installed` + +```bash +pip install bitsandbytes +# or, on Windows: +pip install bitsandbytes --extra-index-url https://jllllll.github.io/bitsandbytes-windows-webui +``` + +If you cannot install it (no CUDA, ROCm, MPS), use the GGUF path above. + +### CUDA OOM on a 6 – 8 GB GPU + +1. Confirm you got 4-bit via `model.report()`. If not, force `bits=4`. +2. Close other GPU processes (`nvidia-smi`); VRAM is sticky. +3. Lower `max_length` (e.g. `turbo(..., max_length=2048)`). +4. Set `device_map="auto"` to let `accelerate` offload entire blocks. + +### Slow generation on consumer GPUs + +1. Make sure `flash-attn` is installed (`pip install quantllm[flash]`). +2. Check that the model lives on GPU: `model.report()["device"]` should + start with `cuda`. If you see `cpu`, free up VRAM or drop to a + smaller model. +3. For longer-running deployments, export to GGUF and serve through + `llama.cpp` — it produces ~2–3× the tokens/sec of HF on the same + hardware below 7B. diff --git a/docs/guide/loading-models.md b/docs/guide/loading-models.md index 12de986..9e8f90e 100644 --- a/docs/guide/loading-models.md +++ b/docs/guide/loading-models.md @@ -76,14 +76,36 @@ model = turbo( ### New Architecture Fallbacks (for very recent model releases) -If `transformers` does not recognize a just-released architecture yet, register a fallback family: +QuantLLM ships a built-in fallback table covering common model-type +suffixes — `qwen3` → `qwen2`, `llama4` → `llama`, `phi4` → `phi3`, +`gemma3` → `gemma2`, and many others — so brand-new releases load with +the same one-line API as established models: + +```python +from quantllm import turbo + +# Works without manual registration: qwen3 falls back to qwen2 automatically +model = turbo("Qwen/Qwen3-8B", trust_remote_code=True) +``` + +When the built-in mapping does not cover your model, register an +explicit fallback before loading: ```python from quantllm import turbo, register_architecture -# Map new architecture/model_type to a compatible base family +# Map a brand-new architecture/model_type to a compatible base family register_architecture("newmodel", base_model_type="llama") +# Optionally provide an explicit ``model_class`` (most useful for +# fine-tuned variants that ship their own modelling code): +from transformers import LlamaForCausalLM +register_architecture( + "newmodel", + base_model_type="llama", + model_class=LlamaForCausalLM, +) + model = turbo( "new-model-org/NewModel-7B", model_type_override="llama", # optional explicit override @@ -95,14 +117,36 @@ model = turbo( > ⚠️ **Security note:** `trust_remote_code=True` executes model-provided code. > Only enable it for trusted publishers, especially when loading unregistered or very new architectures. -You can also load from config only (no checkpoint weights) while waiting for upstream support: +#### Pre-quantized HuggingFace repos + +QuantLLM detects pre-quantized repository names (Unsloth `*-bnb-4bit` / +`*-bnb-8bit`, AWQ, GPTQ, AQLM, HQQ, FP8, EETQ, etc.) and lets the model's +own `quantization_config` win — so you don't accidentally re-quantize a +model that ships at-rest in 4-bit: + +```python +# Loaded as 4-bit BitsAndBytes from the repo's embedded config -- no +# additional dynamic quantization is applied on top. +model = turbo("unsloth/Llama-3.2-3B-Instruct-bnb-4bit") + +# Verify what actually got loaded: +print(model.report()) +# {'quant_method': 'bitsandbytes', 'is_quantized': True, ...} +``` + +#### `from_config_only` is for skeleton inspection only ```python +# Loads a randomly-initialised model from the config -- useful for +# inspecting layer shapes or wiring up tests, NOT for inference. model = turbo( "new-model-org/NewModel-7B", from_config_only=True, trust_remote_code=True, ) + +# ``model.is_quantized`` will correctly report False here even when you +# also passed ``bits=4`` -- there are no real weights to quantize. ``` #### Fast contribution template for new architectures @@ -111,22 +155,34 @@ model = turbo( - `register_architecture("new-arch", base_model_type="llama")` 2. Validate loading with: - `turbo("org/model", base_model_fallback=True, trust_remote_code=True)` -3. Add/extend a focused test in `tests/test_architecture_fallback.py`. +3. Add/extend a focused test in `tests/test_architecture_fallback.py` + or `tests/test_resolve_model_type.py`. -#### Real-world style "released yesterday" example +#### Inspecting the loaded state ```python -from quantllm import turbo, register_architecture - -# Example: transformers doesn't recognize Qwen3 yet -register_architecture("qwen3", base_model_type="qwen2") - -model = turbo( - "Qwen/Qwen3-8B", - trust_remote_code=True, -) +model = turbo("Qwen/Qwen3-8B", bits=4) + +report = model.report() +# { +# 'model_id': 'Qwen/Qwen3-8B', +# 'params_billion': 8.0, +# 'requested_bits': 4, +# 'effective_loading_bits': 4, +# 'is_quantized': True, +# 'quant_method': 'bitsandbytes', +# 'device': 'cuda:0', +# 'dtype': 'torch.bfloat16', +# 'finetuned': False, +# 'lora_applied': False, +# } ``` +`model.is_quantized` is derived from the actual loaded model state +(`config.quantization_config` and BitsAndBytes layer types). It is not +a cached snapshot of your load-time intent, so `from_config_only=True` +or a missing `bitsandbytes` install will correctly report `False`. + ### Memory Options ```python diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..58b6c96 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,67 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +# NOTE: Project metadata, dependencies and extras are still declared in +# ``setup.py`` for now to keep a single source of truth and avoid duplicating +# the install-requires list. ``setup.py`` defines: +# +# * ``install_requires`` (mandatory deps) +# * ``extras_require`` (``gguf``, ``onnx``, ``mlx``, ``triton``, ``flash``, +# ``hub``, ``full``, ``dev``) +# * ``entry_points`` (the ``quantllm`` CLI) +# +# Adding a ``[project]`` table here would silently override those values in +# editable installs, which has tripped up several contributors. Once the +# ``setup.py`` migration to PEP 621 is complete, this file will own the full +# project metadata. + +[tool.ruff] +line-length = 120 +target-version = "py310" +extend-exclude = [ + "build", + "dist", + ".venv", + "examples/output", + "docs/_build", +] + +[tool.ruff.lint] +# Minimal "blocker" ruleset enforced on every push. The codebase predates +# ruff, so we deliberately start small -- this catches genuine bugs (syntax +# errors, undefined names, broken comparisons) without producing a giant +# reformatting diff. Stricter rules can be opted-in incrementally. +select = [ + "E9", # pycodestyle runtime errors (syntax, indentation) + "F63", # invalid `is` comparisons / `not in` issues + "F7", # syntax errors in expressions + "F82", # undefined names actually used + "F811", # redefinition of unused name + "F821", # undefined-name reference + "F823", # local variable referenced before assignment + # NOTE: F841 (unused local) is intentionally NOT enabled yet. The existing + # codebase has several intentional unused-binding patterns (progress + # tasks, documenting intent, etc.) and fixing them is out of scope. + "B006", # mutable default arguments (genuine bug class) + "B017", # ``assertRaises(Exception)`` -- masks real failures +] +ignore = [ + "E501", # line length: not enforced strictly + "E402", # module-level import not at top of file + "E741", # ambiguous variable names +] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["F401", "F811", "F841"] +"quantllm/__init__.py" = ["F401", "E402"] +"quantllm/core/__init__.py" = ["F401"] + +[tool.pytest.ini_options] +minversion = "7.0" +testpaths = ["tests"] +addopts = "-ra --strict-markers" +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::FutureWarning", +] diff --git a/quantllm/core/turbo_model.py b/quantllm/core/turbo_model.py index ffabb37..7a3c9ae 100644 --- a/quantllm/core/turbo_model.py +++ b/quantllm/core/turbo_model.py @@ -35,15 +35,86 @@ "quantization": "Q4_K_M", "push_quantization": None, } -DEFAULT_ARCHITECTURE_FALLBACKS = { +# Default mapping of HuggingFace ``config.model_type`` values (or model-name +# tokens) to a known-loadable base family. Used as a best-effort fallback for +# brand-new architectures that ``transformers`` does not yet recognize. The +# mapping is consulted only when the user has not registered an explicit +# fallback via :func:`register_architecture`. +# +# Order matters: more specific patterns must come before more generic ones +# (e.g. ``qwen2_moe`` before ``qwen``). +DEFAULT_ARCHITECTURE_FALLBACKS: Dict[str, str] = { + # Llama family and direct derivatives "llama": "llama", + "llama2": "llama", + "llama3": "llama", + "llama4": "llama", + "code_llama": "llama", + "codellama": "llama", + "tinyllama": "llama", + "smollm": "llama", + "smollm2": "llama", + "smollm3": "llama", + "yi": "llama", + "deepseek": "llama", + "deepseek_v2": "llama", + "deepseek_v3": "llama", + "command_r": "llama", + "cohere": "llama", + "olmo": "llama", + "olmo2": "llama", + "stablelm": "llama", + "starcoder": "llama", + "starcoder2": "llama", + "internlm": "llama", + "internlm2": "llama", + "baichuan": "llama", + "chatglm": "llama", + # Mistral / Mixtral "mistral": "mistral", "mixtral": "mistral", + # Qwen family (note: qwen2_moe must come before qwen) + "qwen2_moe": "qwen2", + "qwen2": "qwen2", + "qwen3": "qwen2", "qwen": "qwen2", + # Phi family + "phi3": "phi3", + "phi4": "phi3", "phi": "phi", + "phi2": "phi", + # Gemma family + "gemma3": "gemma2", + "gemma2": "gemma2", "gemma": "gemma", + # Falcon + "falcon": "falcon", } +# Substring markers in HF repo names that indicate the model is already +# pre-quantized at rest. When detected, QuantLLM should let ``transformers`` +# load the existing quantized weights instead of dynamically applying its own +# BitsAndBytes quantization on top. +PREQUANTIZED_NAME_MARKERS: tuple = ( + "-bnb-4bit", + "-bnb-8bit", + "-4bit", + "-8bit", + "-awq", + "-gptq", + "-int4", + "-int8", + "-fp8", + "-eetq", + "-hqq", + "-aqlm", +) + +# Markers in HF repo names indicating GGUF-only repositories. Loading these +# via :meth:`TurboModel.from_pretrained` (instead of ``from_gguf``) is almost +# always a user mistake; we surface a helpful hint. +GGUF_NAME_MARKERS: tuple = ("-gguf", ".gguf") + class TurboModel: """ @@ -91,7 +162,11 @@ def __init__( self.model = model self.tokenizer = tokenizer self.config = config - self._is_quantized = False + # ``_is_quantized_override`` is consulted by :pyattr:`is_quantized` + # *only* when the caller explicitly asserts a quantization state + # (e.g. :meth:`from_gguf` knows GGUF is always quantized). When + # ``None`` the property derives the answer from the loaded model. + self._is_quantized_override: Optional[bool] = None self._is_finetuned = False self._lora_applied = False self.export_push_config = self._build_export_push_config(export_push_config) @@ -132,32 +207,59 @@ def resolve_model_type( model_type_override: Optional[str] = None, ) -> Optional[str]: """ - Resolve model type using override, registry, and default family patterns. - - If config_model_type is provided but unregistered, the original config value - is returned unchanged. + Resolve a HuggingFace ``model_type`` to a known-loadable base family. + + Resolution order (first non-``None`` match wins): + + 1. Explicit ``model_type_override`` from the caller. + 2. Exact match in :attr:`_architecture_registry` (user-registered alias). + 3. Exact match in :data:`DEFAULT_ARCHITECTURE_FALLBACKS`. + 4. Family-style match against the config's ``model_type`` (e.g. + ``qwen3`` -> ``qwen``). + 5. Family-style match against the repository name (e.g. + ``Qwen/Qwen3-8B`` -> ``qwen``). + 6. The original ``config_model_type`` unchanged, or ``None`` when no + config was loadable. + + The function never raises; callers are expected to handle ``None``. """ if model_type_override: return model_type_override.lower().strip() - + model_type = (config_model_type or "").lower().strip() - if model_type: - return cls._architecture_registry.get(model_type, model_type) - name = model_name.lower() + + # 2. Exact registry hit. + if model_type and model_type in cls._architecture_registry: + return cls._architecture_registry[model_type] + + # 3. Exact default-fallback hit. + if model_type and model_type in DEFAULT_ARCHITECTURE_FALLBACKS: + return DEFAULT_ARCHITECTURE_FALLBACKS[model_type] + + # 4. Family-style match against model_type itself (qwen3 -> qwen). + if model_type: + for pattern, fallback in cls._architecture_registry.items(): + if cls._matches_family(model_type, pattern): + return fallback + for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): + if cls._matches_family(model_type, pattern): + return fallback + + # 5. Token-boundary match against the repo name. for pattern, fallback in cls._architecture_registry.items(): if cls._matches_model_name_pattern(name, pattern): return fallback - for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): if cls._matches_model_name_pattern(name, pattern): return fallback - - return None + + # 6. Nothing matched. + return model_type or None @classmethod def _matches_model_name_pattern(cls, model_name: str, pattern: str) -> bool: - """Return True when pattern appears as a token in model_name.""" + """Return True when ``pattern`` appears as a token in ``model_name``.""" return cls._compiled_model_name_pattern(pattern).search(model_name) is not None @staticmethod @@ -168,15 +270,60 @@ def _compiled_model_name_pattern(pattern: str): # Match architecture tokens as standalone chunks split by separators. return re.compile(rf"(^|[^a-z0-9]){escaped}([^a-z0-9]|$)") + @classmethod + def _matches_family(cls, model_type: str, family: str) -> bool: + """ + Decide whether ``model_type`` belongs to ``family``. + + Recognises common version-suffix patterns used by HuggingFace, e.g. + ``qwen2``, ``qwen2_5``, ``qwen-2``, ``qwen3`` all match family ``qwen``. + Plain prefix matches (``llamafication``) are intentionally rejected; + only digit / underscore / dash separators count as family suffixes. + """ + if not model_type or not family: + return False + if model_type == family: + return True + return bool(cls._compiled_family_pattern(family).match(model_type)) + + @staticmethod + @lru_cache(maxsize=None) + def _compiled_family_pattern(family: str): + """Cache regex used by :meth:`_matches_family`.""" + return re.compile(rf"^{re.escape(family)}[\d_\-]") + @staticmethod def _should_apply_quantization( quantize: bool, bits: int, from_config_only: bool, ) -> bool: - """Check whether quantization arguments should be added for loading.""" + """Decide whether ``BitsAndBytes`` kwargs should be added at load time. + + Returns False whenever the model is being constructed from config only + (no weights to quantize) or whenever the user explicitly disabled + quantization or asked for full precision. + """ return quantize and bits < 16 and not from_config_only + @staticmethod + def _looks_prequantized(model_name: str) -> Optional[str]: + """ + Return a marker (e.g. ``"-bnb-4bit"``) when the repo name suggests it + is already pre-quantized at rest. ``None`` otherwise. + """ + lowered = model_name.lower() + for marker in PREQUANTIZED_NAME_MARKERS: + if marker in lowered: + return marker + return None + + @staticmethod + def _looks_like_gguf_repo(model_name: str) -> bool: + """Heuristic: repo name looks like a GGUF-only weights repository.""" + lowered = model_name.lower() + return any(marker in lowered for marker in GGUF_NAME_MARKERS) + @classmethod def _load_model_with_fallback( cls, @@ -245,16 +392,24 @@ def _load_model_with_fallback( except Exception as fallback_config_error: fallback_error = fallback_config_error - if resolved_model_type: + # Look up an explicit user-registered model class. Try the + # original ``config.model_type`` first (most natural API: + # ``register_architecture("newmodel", model_class=NewModel)``) + # and fall back to the resolved base family for users that prefer + # to register a class under the family name. + registered_cls: Optional[Type[PreTrainedModel]] = None + if config_model_type: + registered_cls = cls._model_class_registry.get(config_model_type) + if registered_cls is None and resolved_model_type: registered_cls = cls._model_class_registry.get(resolved_model_type) - if registered_cls is not None: - class_kwargs = dict(model_kwargs) - if resolved_config is not None: - class_kwargs["config"] = resolved_config - try: - return registered_cls.from_pretrained(model_name, **class_kwargs) - except Exception as fallback_registered_error: - fallback_error = fallback_registered_error + if registered_cls is not None: + class_kwargs = dict(model_kwargs) + if resolved_config is not None: + class_kwargs["config"] = resolved_config + try: + return registered_cls.from_pretrained(model_name, **class_kwargs) + except Exception as fallback_registered_error: + fallback_error = fallback_registered_error error_details = f" Last fallback error: {fallback_error}" if fallback_error else "" architecture_label = config_model_type or "" @@ -331,10 +486,36 @@ def from_pretrained( # Disable default progress bars disable_hf_progress_bar() disable_ds_progress_bar() - + if verbose: print_header(f"Loading {model_name}") - + + # Friendly hint when a user accidentally points ``from_pretrained`` at + # a GGUF repository. ``transformers`` *can* load some GGUF repos via + # ``from_pretrained`` with a ``gguf_file`` arg, but the dedicated + # :meth:`from_gguf` path handles tokenizer fall-back and version + # validation more safely. + if cls._looks_like_gguf_repo(model_name): + logger.warning( + "Repository name '%s' looks like a GGUF-only repo. " + "Use TurboModel.from_gguf(...) for GGUF weights; " + "from_pretrained() is intended for standard transformers " + "checkpoints (safetensors / pytorch_model.bin).", + model_name, + ) + + # Friendly hint when the repo name advertises pre-quantization. We + # still attempt to load it: ``transformers`` honours the embedded + # ``quantization_config`` automatically. + prequant_marker = cls._looks_prequantized(model_name) + if prequant_marker and verbose: + logger.info( + "Detected pre-quantized repo (marker '%s'); honouring the " + "model's own quantization config and skipping dynamic " + "BitsAndBytes quantization.", + prequant_marker, + ) + # Auto-configure everything if verbose: logger.info("🚀 Detecting hardware and configuration...") @@ -405,8 +586,6 @@ def from_pretrained( is_bnb = "BitsAndBytesConfig" in existing_quant.__class__.__name__ is_8bit = getattr(existing_quant, "load_in_8bit", False) - if is_bnb and is_8bit and smart_config.bits == 4: - allow_requantize = True if is_bnb and is_8bit and smart_config.bits == 4: allow_requantize = True if verbose: @@ -475,8 +654,48 @@ def from_pretrained( logger.info("") instance = cls(model, tokenizer, smart_config, export_push_config=config) - instance._is_quantized = quantize and smart_config.bits < 16 - + instance.verbose = verbose + + # Reflect the *actual* runtime state of the loaded model rather than + # the user's load-time intent. ``from_config_only=True`` returns a + # randomly-initialised model with no quantization regardless of the + # ``quantize`` flag, and a missing ``bitsandbytes`` install also + # silently falls back to full precision -- both of which used to leave + # ``_is_quantized=True`` set incorrectly. + actual_quantized = instance._has_runtime_quantization() + if from_config_only: + # ``AutoModelForCausalLM.from_config`` returns a model with random + # weights and never honours ``quantization_config`` -- so the + # actual quantization state is whatever the loader produced + # (almost always ``False``). + instance._is_quantized_override = bool(actual_quantized) + if verbose: + print_warning( + "from_config_only=True returned a model with random " + "weights and no quantization. Call model.load_weights(...) " + "or reload with from_config_only=False before using it " + "for inference." + ) + else: + # Let the property derive truth from the model state. Override is + # only set when the caller explicitly asked for quantization but + # the runtime layer silently skipped it (e.g. bitsandbytes + # missing) -- in that case we set False so downstream code does + # not try to call BnB-only training paths on a full-precision + # model. + wanted_quantization = cls._should_apply_quantization( + quantize, smart_config.bits, from_config_only=False + ) + if wanted_quantization and not actual_quantized: + instance._is_quantized_override = False + if verbose: + print_warning( + "Requested quantization was not applied at load time " + "(typically because ``bitsandbytes`` is not installed " + "or the model was already pre-quantized). Continuing " + "in full precision." + ) + return instance @classmethod @@ -606,7 +825,10 @@ def from_gguf( print_info(f"Parameters: {params:.2f}B") instance = cls(model, tokenizer, smart_config, verbose=verbose) - instance._is_quantized = True + # GGUF models are inherently quantized; set the override so the + # property does not need to introspect the (often opaque) loaded + # weights. + instance._is_quantized_override = True return instance @staticmethod @@ -698,7 +920,11 @@ def _get_quantization_kwargs(config: SmartConfig) -> Dict[str, Any]: return {"quantization_config": quantization_config} except ImportError: - logger.warning("⚠ bitsandbytes not installed, loading without quantization") + logger.warning( + "\u26a0 bitsandbytes is not installed; falling back to full " + "precision. Install with ``pip install bitsandbytes`` to " + "enable 4-bit / 8-bit quantization on CUDA." + ) return {} @staticmethod @@ -1580,21 +1806,24 @@ def _export_gguf( return output_path def _is_bnb_quantized(self) -> bool: - """Check if model is BitsAndBytes quantized.""" - # Check config for quantization_config + """Return True iff the loaded model is BitsAndBytes-quantized. + + Checks both the model's ``quantization_config`` metadata and the + actual layer types (``Linear4bit`` / ``Linear8bitLt``) so it works + whether the model came from a pre-quantized HF repo or from a + dynamic BitsAndBytes load. + """ if hasattr(self.model, 'config'): quant_config = getattr(self.model.config, 'quantization_config', None) if quant_config: - # Check if it's BitsAndBytes quant_method = getattr(quant_config, 'quant_method', None) - if quant_method in ['bitsandbytes', 'bnb']: + if quant_method in ('bitsandbytes', 'bnb'): return True if getattr(quant_config, 'load_in_4bit', False): return True if getattr(quant_config, 'load_in_8bit', False): return True - - # Check for BNB linear layers in the model + try: import bitsandbytes as bnb for module in self.model.modules(): @@ -1602,9 +1831,107 @@ def _is_bnb_quantized(self) -> bool: return True except ImportError: pass - + return False - + + def _has_runtime_quantization(self) -> bool: + """Return True iff the loaded model carries *any* quantization. + + Detects BitsAndBytes (4-bit / 8-bit), GPTQ, AWQ, AQLM, HQQ, FP8 + and EETQ via the standard ``quantization_config.quant_method`` slot + on a ``transformers`` ``PretrainedConfig``. This is the canonical + source-of-truth used by :pyattr:`is_quantized`. + """ + if self._is_bnb_quantized(): + return True + + if hasattr(self.model, 'config'): + quant_config = getattr(self.model.config, 'quantization_config', None) + if quant_config: + quant_method = getattr(quant_config, 'quant_method', None) + if quant_method: + return True + if isinstance(quant_config, dict) and quant_config.get('quant_method'): + return True + return False + + @property + def is_quantized(self) -> bool: + """Whether the underlying model is currently quantized. + + Derived from the loaded model state (``config.quantization_config`` + and the actual layer types). When :meth:`from_gguf` or another + loader explicitly knows the quantization status it can set + :pyattr:`_is_quantized_override` to short-circuit the introspection. + """ + if self._is_quantized_override is not None: + return self._is_quantized_override + return self._has_runtime_quantization() + + # Backwards-compatible alias kept for existing internal callers and any + # downstream code that read the previous attribute name. New code should + # prefer the :pyattr:`is_quantized` public property. + @property + def _is_quantized(self) -> bool: # type: ignore[override] + return self.is_quantized + + @_is_quantized.setter + def _is_quantized(self, value: Optional[bool]) -> None: + self._is_quantized_override = None if value is None else bool(value) + + def report(self) -> Dict[str, Any]: + """Return a structured snapshot of the actual loaded-model state. + + Keys: + * ``model_id``: HF repo id or local path (when known). + * ``params_billion``: parameter count in billions. + * ``requested_bits``: bits the user (or :class:`SmartConfig`) + asked for. + * ``effective_loading_bits``: bits actually used for BnB loading + (4 / 8 / 16). Differs from ``requested_bits`` when GGUF export + targets sub-4-bit quantization but loading falls back to 4-bit. + * ``is_quantized``: real runtime quantization state. + * ``quant_method``: e.g. ``"bitsandbytes"`` / ``"gptq"`` / ``None``. + * ``device``: torch device the model lives on. + * ``dtype``: torch dtype of the model parameters. + * ``finetuned`` / ``lora_applied``: training-state flags. + """ + params_billion: Optional[float] + try: + params_billion = self.model.num_parameters() / 1e9 + except Exception: + params_billion = None + + quant_method = None + if hasattr(self.model, 'config'): + quant_config = getattr(self.model.config, 'quantization_config', None) + if quant_config is not None: + quant_method = ( + getattr(quant_config, 'quant_method', None) + or (quant_config.get('quant_method') if isinstance(quant_config, dict) else None) + ) + if not quant_method and self._is_bnb_quantized(): + quant_method = 'bitsandbytes' + + device = getattr(self.model, 'device', None) + try: + dtype = next(self.model.parameters()).dtype + except (StopIteration, AttributeError): + dtype = getattr(self.config, 'dtype', None) + + return { + "model_id": getattr(getattr(self.model, 'config', None), '_name_or_path', None), + "params_billion": params_billion, + "requested_bits": getattr(self.config, 'bits', None), + "effective_loading_bits": getattr(self.config, 'effective_loading_bits', None), + "is_quantized": self.is_quantized, + "quant_method": quant_method, + "device": str(device) if device is not None else None, + "dtype": str(dtype) if dtype is not None else None, + "finetuned": self._is_finetuned, + "lora_applied": self._lora_applied, + } + def _dequantize_model(self) -> nn.Module: """ Dequantize a BitsAndBytes model to full precision for GGUF export. @@ -2032,15 +2359,20 @@ def _export_mlx( return output_path def __repr__(self) -> str: - params = self.model.num_parameters() / 1e9 + try: + params = self.model.num_parameters() / 1e9 + params_str = f"{params:.2f}B" + except Exception: + params_str = "?" + model_id = getattr(getattr(self.model, "config", None), "_name_or_path", "?") return ( - f"TurboModel(\n" - f" model={self.model.config._name_or_path},\n" - f" params={params:.2f}B,\n" + "TurboModel(\n" + f" model={model_id},\n" + f" params={params_str},\n" f" bits={self.config.bits},\n" - f" quantized={self._is_quantized},\n" + f" quantized={self.is_quantized},\n" f" finetuned={self._is_finetuned}\n" - f")" + ")" ) diff --git a/tests/test_quantization_state.py b/tests/test_quantization_state.py new file mode 100644 index 0000000..60da837 --- /dev/null +++ b/tests/test_quantization_state.py @@ -0,0 +1,194 @@ +""" +Regression tests for runtime quantization state tracking. + +These tests cover bugs reproduced in the v2.1.0rc1 review: + +* ``from_config_only=True`` used to set ``_is_quantized = True`` even though + the loader returns a model with random weights and no quantization. +* A missing ``bitsandbytes`` install used to fall through silently while the + ``_is_quantized`` flag remained ``True``. +* ``is_quantized`` should now reflect the loaded model's ``quantization_config`` + rather than the user's load-time intent. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import transformers + +from quantllm.core.turbo_model import TurboModel +import quantllm.core.turbo_model as turbo_model_module + + +def _smart_config(bits: int = 16): + return SimpleNamespace( + bits=bits, + effective_loading_bits=4 if bits <= 4 else (8 if bits < 16 else 16), + dtype="float16", + cpu_offload=False, + device="cpu", + gradient_checkpointing=False, + use_flash_attention=False, + compile_model=False, + ) + + +def _tokenizer(): + return SimpleNamespace(pad_token=None, eos_token="", eos_token_id=2) + + +def _patch_common(monkeypatch, *, model_type: str = "llama", quant_config=None, smart_bits: int = 16): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *a, **kw: _smart_config(bits=smart_bits), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *a, **kw: _tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *a, **kw: SimpleNamespace( + model_type=model_type, + quantization_config=quant_config, + ), + ) + + +def test_from_config_only_does_not_lie_about_quantization(monkeypatch): + """``from_config_only=True`` returns a random-weights model and must not + advertise itself as quantized just because the user asked for 4 bits.""" + _patch_common(monkeypatch, smart_bits=4) + + class _FakeAutoModel: + @classmethod + def from_pretrained(cls, *a, **kw): # should not be called + raise AssertionError("from_pretrained should not run when from_config_only=True") + + @classmethod + def from_config(cls, *a, **kw): + # ``from_config`` cannot quantize -- model has no + # ``quantization_config`` attribute on its config. + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + monkeypatch.setattr(turbo_model_module, "AutoModelForCausalLM", _FakeAutoModel) + + loaded = TurboModel.from_pretrained( + "org/llama-like-7b", + quantize=True, + bits=4, + verbose=False, + from_config_only=True, + ) + + assert loaded.is_quantized is False + assert loaded._is_quantized is False # back-compat alias must agree + + +def test_runtime_quantization_property_reads_model_config(monkeypatch): + """``is_quantized`` should return True when the loaded model's + ``config.quantization_config.quant_method`` is set, regardless of the + user's load-time flags.""" + _patch_common( + monkeypatch, + quant_config=SimpleNamespace(quant_method="gptq"), + ) + + fake_model_config = SimpleNamespace( + model_type="llama", + quantization_config=SimpleNamespace(quant_method="gptq"), + _name_or_path="org/llama-gptq", + ) + fake_model = SimpleNamespace( + config=fake_model_config, + modules=lambda: iter([]), + num_parameters=lambda: 7_000_000_000, + device="cpu", + parameters=lambda: iter([]), + ) + + class _FakeAutoModel: + @classmethod + def from_pretrained(cls, *a, **kw): + return fake_model + + monkeypatch.setattr(turbo_model_module, "AutoModelForCausalLM", _FakeAutoModel) + + loaded = TurboModel.from_pretrained( + "org/llama-gptq", + quantize=False, + verbose=False, + ) + + # User explicitly disabled dynamic quantization, but the underlying model + # IS already GPTQ-quantized: the property must reflect that. + assert loaded.is_quantized is True + assert loaded.report()["quant_method"] == "gptq" + + +def test_is_quantized_override_accessor(): + """The ``_is_quantized`` setter should record an explicit override that + short-circuits the runtime introspection.""" + instance = TurboModel.__new__(TurboModel) + instance.model = SimpleNamespace(config=SimpleNamespace(quantization_config=None)) + instance.config = _smart_config() + instance._is_quantized_override = None + instance._is_finetuned = False + instance._lora_applied = False + + assert instance.is_quantized is False + instance._is_quantized = True + assert instance.is_quantized is True + instance._is_quantized = None # clears override -> derives again + assert instance.is_quantized is False + + +def test_report_returns_structured_state(monkeypatch): + """``report()`` should expose a stable, machine-readable summary.""" + _patch_common(monkeypatch) + + fake_model = SimpleNamespace( + config=SimpleNamespace( + model_type="llama", + quantization_config=None, + _name_or_path="org/llama-7b", + ), + modules=lambda: iter([]), + num_parameters=lambda: 7_000_000_000, + device="cpu", + parameters=lambda: iter([]), + ) + + class _FakeAutoModel: + @classmethod + def from_pretrained(cls, *a, **kw): + return fake_model + + monkeypatch.setattr(turbo_model_module, "AutoModelForCausalLM", _FakeAutoModel) + + loaded = TurboModel.from_pretrained("org/llama-7b", quantize=False, verbose=False) + report = loaded.report() + + expected_keys = { + "model_id", + "params_billion", + "requested_bits", + "effective_loading_bits", + "is_quantized", + "quant_method", + "device", + "dtype", + "finetuned", + "lora_applied", + } + assert set(report) == expected_keys + assert report["model_id"] == "org/llama-7b" + assert report["params_billion"] == 7.0 + assert report["is_quantized"] is False + assert report["finetuned"] is False + assert report["lora_applied"] is False diff --git a/tests/test_resolve_model_type.py b/tests/test_resolve_model_type.py new file mode 100644 index 0000000..f7032bd --- /dev/null +++ b/tests/test_resolve_model_type.py @@ -0,0 +1,104 @@ +""" +Tests for :meth:`TurboModel.resolve_model_type`. + +PR #27 introduced ``DEFAULT_ARCHITECTURE_FALLBACKS`` but the resolution +function consulted them only when the HF config returned an empty +``model_type`` -- which never happens in practice. These tests pin the +post-fix behaviour: the default fallback table is consulted for unknown +``model_type`` values, family-style suffixes are recognised, and explicit +registrations still win. +""" + +import pytest + +from quantllm.core.turbo_model import TurboModel + + +@pytest.fixture(autouse=True) +def _clean_registry(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + + +def test_unknown_model_type_falls_back_to_family(): + """``qwen3`` is not registered in transformers <= 4.x but is a Qwen2 + derivative; resolution should return ``qwen2``.""" + assert TurboModel.resolve_model_type( + "Qwen/Qwen3-8B", + config_model_type="qwen3", + ) == "qwen2" + + +def test_family_match_against_model_type_directly(): + """Resolving by ``model_type`` alone should still pick up the family + even when the repo name does not contain the family marker.""" + assert TurboModel.resolve_model_type( + "Qwen/private-fork", + config_model_type="qwen3", + ) == "qwen2" + + +def test_specific_family_wins_over_generic(): + """``phi3`` has its own entry and must NOT be flattened to ``phi``.""" + assert TurboModel.resolve_model_type( + "microsoft/phi-4", + config_model_type="phi4", + ) == "phi3" + + +def test_user_registered_alias_takes_precedence(): + TurboModel.register_architecture("zoolm", base_model_type="mistral") + assert TurboModel.resolve_model_type( + "org/zoolm-13b", + config_model_type="zoolm", + ) == "mistral" + + +def test_override_takes_precedence_over_everything(): + TurboModel.register_architecture("zoolm", base_model_type="mistral") + assert TurboModel.resolve_model_type( + "Qwen/Qwen3-8B", + config_model_type="qwen3", + model_type_override="llama", + ) == "llama" + + +def test_truly_unknown_model_type_is_returned_unchanged(): + """When nothing matches we surface the original ``model_type`` so the + caller can decide how to react (registering, overriding, etc.).""" + assert TurboModel.resolve_model_type( + "org/exotic-1b", + config_model_type="something-totally-new", + ) == "something-totally-new" + + +def test_no_config_returns_none_when_name_has_no_marker(): + assert TurboModel.resolve_model_type("org/exotic-1b") is None + + +def test_name_pattern_used_when_config_missing(): + """When ``config_model_type`` is empty the name is consulted for tokens.""" + assert TurboModel.resolve_model_type("meta-llama/Llama-3.2-3B") == "llama" + + +def test_register_architecture_class_lookup_uses_original_name(monkeypatch): + """Bug from review: ``register_architecture("newmodel", base_model_type="llama", + model_class=Cls)`` used to register the class under ``"newmodel"`` but + look it up under ``"llama"`` and find nothing. Verify the class is now + discoverable by the original architecture name.""" + sentinel = object() + + class _Stub: + @classmethod + def from_pretrained(cls, *a, **kw): + return sentinel + + TurboModel.register_architecture( + "newmodel", + base_model_type="llama", + model_class=_Stub, + ) + + # Direct registry assertions. + assert TurboModel._architecture_registry["newmodel"] == "llama" + assert TurboModel._model_class_registry["newmodel"] is _Stub