Skip to content

Commit aeb641f

Browse files
authored
Fix exllamav3_torch import under meta-device context (#2825)
Module import of `gptqmodel.nn_modules.exllamav3_torch` evaluates two constants via `_half_scalar_from_bits(...)`, which previously called `torch.tensor([bits], dtype=torch.uint16).view(torch.float16).item()`. Under a `with torch.device('meta'):` preload context (transformers' `replace_with_awq_linear` imports `gptqmodel.quantization` inside such a guard) this raises: RuntimeError: Tensor.item() cannot be called on meta tensors This blocks any GPTQ load on transformers >= 5.6 that goes through the AWQ preload pathway. Replace the tensor round-trip with a `struct`-based IEEE-754 binary16 conversion (`<e` format). The conversion is bit-equivalent to the tensor-based path and does not allocate a tensor, so it is unaffected by the active device. Verified on 0x0000, 0x3C00 (+1.0), 0xBC00 (-1.0), 0x1EEE (`_EXL3_MUL1_INV`), 0xC931 (`_EXL3_MUL1_BIAS`), 0x7BFF (max finite), 0xFBFF (min finite), and the NaN bit patterns 0x7FFF / 0xFFFF. Adds `tests/test_exllamav3_meta_import.py` with three regression cases: 1. `_half_scalar_from_bits` is bit-equivalent to the historical tensor-based path across the canonical bit set. 2. The module imports cleanly under `with torch.device('meta'):`. 3. The two module-level constants `_EXL3_MUL1_INV` / `_EXL3_MUL1_BIAS` are pinned to their canonical float64 values.
1 parent 011128d commit aeb641f

2 files changed

Lines changed: 114 additions & 1 deletion

File tree

gptqmodel/nn_modules/exllamav3_torch.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import math
10+
import struct
1011
from functools import lru_cache
1112
from typing import Any, Dict, Optional
1213

@@ -26,7 +27,20 @@
2627

2728

2829
def _half_scalar_from_bits(bits: int) -> float:
29-
return float(torch.tensor([bits], dtype=torch.uint16).view(torch.float16).item())
30+
# Convert a uint16 bit pattern to its IEEE-754 binary16 (float16) value
31+
# without allocating a torch tensor. The previous implementation used
32+
# ``torch.tensor([bits], dtype=torch.uint16).view(torch.float16).item()``
33+
# at module import, which fails under a meta-device preload context with
34+
# ``RuntimeError: Tensor.item() cannot be called on meta tensors``. The
35+
# transformers AWQ pathway (``replace_with_awq_linear``) imports
36+
# ``gptqmodel.quantization`` inside ``with torch.device('meta'):``, which
37+
# triggered the failure for any GPTQ load on transformers >= 5.6.
38+
#
39+
# ``struct`` format ``<e`` is IEEE-754 binary16 little-endian and matches
40+
# ``torch.float16``'s byte layout on all supported platforms (x86_64 /
41+
# aarch64). The conversion is bit-equivalent to the tensor-based path.
42+
packed = struct.pack("<H", int(bits) & 0xFFFF)
43+
return float(struct.unpack("<e", packed)[0])
3044

3145

3246
_EXL3_MUL1_INV = _half_scalar_from_bits(0x1EEE)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Regression test for module-level import of exllamav3_torch under a
5+
meta-device preload context.
6+
7+
Background
8+
----------
9+
``gptqmodel.nn_modules.exllamav3_torch`` evaluates two module-level constants
10+
(``_EXL3_MUL1_INV``, ``_EXL3_MUL1_BIAS``) by calling ``_half_scalar_from_bits``
11+
during import. The previous implementation built a torch tensor and called
12+
``.item()`` on it, which fails on a meta tensor with::
13+
14+
RuntimeError: Tensor.item() cannot be called on meta tensors
15+
16+
Transformers' AWQ pathway (``replace_with_awq_linear``) imports
17+
``gptqmodel.quantization`` inside ``with torch.device('meta'):``, so any
18+
GPTQ load path on transformers >= 5.6 would surface that error transitively.
19+
This test guards against regression of that import path.
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import struct
25+
import sys
26+
27+
import torch
28+
29+
30+
_MOD = "gptqmodel.nn_modules.exllamav3_torch"
31+
32+
33+
def _drop_module():
34+
"""Drop the module from ``sys.modules`` so the next import re-runs the
35+
module body. The full ``gptqmodel`` package import (already completed by
36+
the test session) caches the module, masking import-time failures."""
37+
if _MOD in sys.modules:
38+
del sys.modules[_MOD]
39+
40+
41+
def test_half_scalar_from_bits_matches_torch_view_path():
42+
"""The ``struct``-based conversion must return the same float as the
43+
historical ``torch.tensor(...).view(torch.float16).item()`` path for
44+
every uint16 bit pattern of interest, including the two canonical
45+
EXL3 constants and a handful of edge cases (zero, +/-1.0, max
46+
finite, NaN)."""
47+
_drop_module()
48+
from gptqmodel.nn_modules.exllamav3_torch import _half_scalar_from_bits
49+
50+
cases = [
51+
0x0000, # +0.0
52+
0x3C00, # +1.0
53+
0xBC00, # -1.0
54+
0x1EEE, # _EXL3_MUL1_INV
55+
0xC931, # _EXL3_MUL1_BIAS
56+
0x7BFF, # max finite
57+
0xFBFF, # min finite
58+
]
59+
for bits in cases:
60+
got = _half_scalar_from_bits(bits)
61+
want_struct = float(struct.unpack("<e", struct.pack("<H", bits & 0xFFFF))[0])
62+
want_torch = float(
63+
torch.tensor([bits], dtype=torch.uint16).view(torch.float16).item()
64+
)
65+
assert got == want_struct == want_torch, (
66+
f"bits=0x{bits:04X} got={got!r} want_struct={want_struct!r} "
67+
f"want_torch={want_torch!r}"
68+
)
69+
70+
# NaN bit patterns: torch and struct must both produce NaN; equality
71+
# via ``!= self`` because NaN != NaN.
72+
for bits in (0x7FFF, 0xFFFF):
73+
got = _half_scalar_from_bits(bits)
74+
assert got != got, f"bits=0x{bits:04X} expected NaN, got {got!r}"
75+
76+
77+
def test_module_imports_under_meta_device():
78+
"""Loading ``exllamav3_torch`` inside ``with torch.device('meta'):`` must
79+
not raise. This reproduces the transformers AWQ pathway that calls
80+
``replace_with_awq_linear`` while a meta-device guard is active."""
81+
_drop_module()
82+
with torch.device("meta"):
83+
import gptqmodel.nn_modules.exllamav3_torch as exl # noqa: F401
84+
85+
# Module-level constants must still be plain Python floats (not tensors).
86+
assert isinstance(exl._EXL3_MUL1_INV, float)
87+
assert isinstance(exl._EXL3_MUL1_BIAS, float)
88+
89+
90+
def test_module_constants_are_canonical_values():
91+
"""Lock in the exact float64 values of the two module-level constants so
92+
any change to the bit-pattern math (or to the underlying conversion)
93+
is caught here rather than downstream in a kernel correctness test."""
94+
_drop_module()
95+
import gptqmodel.nn_modules.exllamav3_torch as exl
96+
97+
# Float64-precise values reproduced from the IEEE-754 binary16 patterns.
98+
assert exl._EXL3_MUL1_INV == 0.00676727294921875
99+
assert exl._EXL3_MUL1_BIAS == -10.3828125

0 commit comments

Comments
 (0)