Skip to content

Commit b8762ff

Browse files
authored
fix brumby compat, thread safety (#2749)
* fix brumby compat * fix marlin jit test generation * fix thread pool warmup race * fix virtual warmup handoff
1 parent 89a297d commit b8762ff

8 files changed

Lines changed: 232 additions & 24 deletions

File tree

gptqmodel/models/writer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..utils.backend import BACKEND
4343
from ..utils.exllamav3 import build_exllamav3_tensor_storage
4444
from ..utils.hf import (
45+
_normalize_legacy_tied_weights_keys,
4546
prepare_remote_code_compat,
4647
sanitize_generation_config_file,
4748
sanitize_model_config,
@@ -609,6 +610,7 @@ def strip_attention_impl_fields(target: Any) -> Dict[str, Any]:
609610
removed_config_attention_attrs = strip_attention_impl_fields(self.model.config)
610611
if generation_config is not None:
611612
removed_generation_attention_attrs = strip_attention_impl_fields(generation_config)
613+
_normalize_legacy_tied_weights_keys(self.model)
612614

613615
# Save model config, including generation_config
614616
# Use empty state_dict hack to bypass saving weights

gptqmodel/utils/threadx.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,56 @@ def __exit__(self, exc_type, exc, tb):
269269
return self._group.__exit__(exc_type, exc, tb)
270270

271271

272+
class _WorkerWarmupState:
273+
"""
274+
Shared once-per-physical-device warmup coordination.
275+
276+
The first worker that reaches this state performs the warmup. Other workers
277+
wait on the completion event without holding the pool registry lock.
278+
"""
279+
280+
def __init__(self, warmup_fn: Callable[[torch.device], None]):
281+
self._warmup_fn = warmup_fn
282+
self._claim_lock = threading.Lock()
283+
self._started = False
284+
self._done = threading.Event()
285+
self._error: Optional[BaseException] = None
286+
287+
def run(self, *, device: torch.device, rwlock: _RWLock) -> None:
288+
if self._done.is_set():
289+
self._raise_if_failed()
290+
return
291+
292+
should_run = False
293+
with self._claim_lock:
294+
if self._done.is_set():
295+
pass
296+
elif not self._started:
297+
self._started = True
298+
should_run = True
299+
300+
if should_run:
301+
try:
302+
with ctx(rwlock.reader(), _device_ctx(device)):
303+
self._warmup_fn(device)
304+
except BaseException as exc:
305+
with self._claim_lock:
306+
self._error = exc
307+
raise
308+
finally:
309+
self._done.set()
310+
else:
311+
self._done.wait()
312+
313+
self._raise_if_failed()
314+
315+
def _raise_if_failed(self) -> None:
316+
with self._claim_lock:
317+
error = self._error
318+
if error is not None:
319+
raise error
320+
321+
272322
# --------------------------- Worker Thread ---------------------------
273323
# Each worker is bound to a specific device and runs a single thread. Tasks are
274324
# executed under the device’s read lock; GC acquires the writer lock to keep
@@ -292,15 +342,15 @@ def __init__(
292342
name: Optional[str] = None,
293343
inference_mode: bool = False,
294344
cpu_core: Optional[int] = None,
295-
warmup_fn: Optional[Callable[[torch.device], None]] = None,
345+
warmup_state: Optional[_WorkerWarmupState] = None,
296346
*,
297347
key_override: Optional[str] = None,
298348
):
299349
self.device = device
300350
self.rwlock = rwlock
301351
self._on_task_finished = on_task_finished
302352
self._on_worker_exit = on_worker_exit
303-
self._warmup_fn = warmup_fn
353+
self._warmup_state = warmup_state
304354

305355
if key_override is not None:
306356
self.key = key_override
@@ -375,14 +425,11 @@ def _apply_cpu_affinity(self) -> None:
375425
self._affinity_applied = True
376426

377427
def _run_warmup(self) -> None:
378-
warmup_fn = self._warmup_fn
379-
if warmup_fn is None:
428+
warmup_state = self._warmup_state
429+
if warmup_state is None:
380430
return
381-
try:
382-
with ctx(self.rwlock.reader(), _device_ctx(self.device)):
383-
warmup_fn(self.device)
384-
finally:
385-
self._warmup_fn = None
431+
warmup_state.run(device=self.device, rwlock=self.rwlock)
432+
self._warmup_state = None
386433

387434
def _run(self):
388435
"""
@@ -636,7 +683,7 @@ def __init__(
636683
{str(k).lower(): fn for k, fn in warmups.items()} if warmups else None
637684
)
638685
self._warmup_lock = threading.Lock()
639-
self._warmup_ran_keys: Set[str] = set()
686+
self._warmup_states: Dict[str, _WorkerWarmupState] = {}
640687

641688
workers_cfg = workers or {}
642689
base_workers: Dict[str, int] = {}
@@ -890,7 +937,11 @@ def _priority(dev_type: str) -> int:
890937

891938
return plan
892939

893-
def _resolve_worker_warmup(self, dev: torch.device, key: str) -> Optional[Callable[[torch.device], None]]:
940+
def _resolve_worker_warmup(
941+
self,
942+
dev: torch.device,
943+
key: str,
944+
) -> Optional[_WorkerWarmupState]:
894945
mapping = self._worker_warmups
895946
if not mapping:
896947
return None
@@ -904,13 +955,14 @@ def _resolve_worker_warmup(self, dev: torch.device, key: str) -> Optional[Callab
904955
if warmup is None:
905956
return None
906957

907-
# Map virtual workers back to their parent key so warmup runs once per physical device.
958+
# Virtual workers share the same physical-device warmup state as their parent.
908959
physical_key = self._virtual_to_parent.get(key, key)
909960
with self._warmup_lock:
910-
if physical_key in self._warmup_ran_keys:
911-
return None
912-
self._warmup_ran_keys.add(physical_key)
913-
return warmup
961+
state = self._warmup_states.get(physical_key)
962+
if state is None:
963+
state = _WorkerWarmupState(warmup)
964+
self._warmup_states[physical_key] = state
965+
return state
914966

915967
def _spawn_worker(
916968
self,
@@ -922,7 +974,7 @@ def _spawn_worker(
922974
"""
923975
Create and start a worker bound to the provided device.
924976
"""
925-
warmup_fn = self._resolve_worker_warmup(dev, key)
977+
warmup_state = self._resolve_worker_warmup(dev, key)
926978
w = _DeviceWorker(
927979
device=dev,
928980
rwlock=self._locks[key],
@@ -931,7 +983,7 @@ def _spawn_worker(
931983
name=name,
932984
inference_mode=self._inference_mode,
933985
cpu_core=cpu_core,
934-
warmup_fn=warmup_fn,
986+
warmup_state=warmup_state,
935987
key_override=key,
936988
)
937989
return w

tests/models/model_test.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,11 @@ class ModelTest(unittest.TestCase):
179179
INPUTS_MAX_LENGTH = 2048
180180
MODEL_MAX_LEN = 4096
181181
DATASET_SIZE = 512
182+
DATASET_SIZE_FAST = None
183+
DATASET_SIZE_SLOW = None
182184
DATASET_CONCAT_SIZE = None
185+
DATASET_CONCAT_SIZE_FAST = None
186+
DATASET_CONCAT_SIZE_SLOW = None
183187
DATASET_CONCAT_SEPARATOR = None
184188
DATASET_SORT = "desc"
185189
DELETE_QUANTIZED_MODEL = True
@@ -251,6 +255,8 @@ def setUpClass(cls):
251255
STOP_AFTER_LAYER: Optional[int] = None
252256
MOE_CONFIG: Optional[MoEConfig] = None
253257
OFFLOAD_TO_DISK: bool = True
258+
OFFLOAD_TO_DISK_FAST = None
259+
OFFLOAD_TO_DISK_SLOW = None
254260

255261
GENERIC_TEST_PROMPTS = [
256262
{"prompt": "Which city is the capital city of France?", "keywords": ["paris"]},
@@ -339,6 +345,14 @@ def _mode_specific_baseline_value(self, attr_name: str):
339345

340346
return self._resolve_metric_baseline_value(getattr(self, attr_name, None))
341347

348+
def _mode_specific_test_setting(self, attr_name: str):
349+
mode_suffix = "FAST" if self._is_fast_model_test_mode() else "SLOW"
350+
preferred = f"{attr_name}_{mode_suffix}"
351+
value = getattr(self, preferred, None)
352+
if value is not None:
353+
return value
354+
return getattr(self, attr_name, None)
355+
342356
def _legacy_metric_ceil_pct(self) -> float:
343357
if self._is_fast_model_test_mode():
344358
return 1.0
@@ -1499,7 +1513,7 @@ def _build_quantize_config(self):
14991513
dynamic=self.DYNAMIC,
15001514
hessian=HessianConfig(chunk_size=self.HESSIAN_CHUNK_SIZE),
15011515
moe=self.MOE_CONFIG,
1502-
offload_to_disk=self.OFFLOAD_TO_DISK,
1516+
offload_to_disk=self._mode_specific_test_setting("OFFLOAD_TO_DISK"),
15031517
)
15041518

15051519
def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", need_eval=True, batch_size: int = QUANT_BATCH_SIZE, call_perform_post_quant_validation: bool = True, **kwargs):
@@ -1551,9 +1565,21 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
15511565

15521566
self._apply_model_compat_quant_overrides(model)
15531567

1568+
dataset_size = self._mode_specific_test_setting("DATASET_SIZE")
1569+
dataset_concat_size = self._mode_specific_test_setting("DATASET_CONCAT_SIZE")
1570+
log.info(
1571+
"Calibration dataset config: size=%s, concat_size=%s",
1572+
dataset_size,
1573+
dataset_concat_size,
1574+
)
1575+
15541576
is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality
15551577
if quantize_config.requires_calibration_dataset():
1556-
calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE)
1578+
calibration_dataset = (
1579+
get_calib_dataset(model)
1580+
if is_image_to_text_model
1581+
else self.load_dataset(tokenizer, dataset_size)
1582+
)
15571583
else:
15581584
calibration_dataset = None
15591585

@@ -1577,7 +1603,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
15771603
log.info(f"Quantized model artifacts will be saved to: {planned_save_path}")
15781604
model.quantize(
15791605
calibration_dataset,
1580-
calibration_concat_size=self.DATASET_CONCAT_SIZE,
1606+
calibration_concat_size=dataset_concat_size,
15811607
calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR,
15821608
calibration_sort=self.DATASET_SORT,
15831609
backend=self.QUANT_BACKEND,

tests/models/test_brumby.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
class TestBrumby(ModelTest):
1414
GROUP_SIZE = 32
1515
DATASET_SIZE = 1024
16+
DATASET_SIZE_FAST = 128
17+
# Brumby decoder layers are structurally uniform, so fast mode can quantize
18+
# the first layers and avoid replaying 38 untouched layers during calibration.
19+
MODEL_COMPAT_FAST_LAYER_COUNT = 1
20+
MODEL_COMPAT_FAST_LAYER_POSITION = "first"
21+
OFFLOAD_TO_DISK_FAST = False
1622
NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base"
1723
TRUST_REMOTE_CODE = True
1824
LOAD_MODEL_EXTRA_ARGS = {"use_cache": False}
@@ -41,7 +47,13 @@ class TestBrumby(ModelTest):
4147
"acc": {"value": 0.71, "floor_pct": 0.05, "ceil_pct": 0.10},
4248
},
4349
}
44-
EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW)
50+
EVAL_TASKS_FAST = {
51+
"arc_challenge": {
52+
"evalution_batch_size": 8,
53+
"evalution_suite_kwargs": {"max_rows": 32},
54+
"acc": {"value": 0.89, "floor_pct": 0.10, "ceil_pct": 1.0},
55+
},
56+
}
4557

4658
@classmethod
4759
def setUpClass(cls):

tests/models/test_model_test_fast_mode.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@ def runTest(self):
1515
return None
1616

1717

18+
class _DatasetCompatCase(ModelTest):
19+
__test__ = False
20+
DATASET_SIZE = 512
21+
DATASET_SIZE_FAST = 128
22+
DATASET_CONCAT_SIZE = 2048
23+
DATASET_CONCAT_SIZE_FAST = 1024
24+
OFFLOAD_TO_DISK = True
25+
OFFLOAD_TO_DISK_FAST = False
26+
27+
def runTest(self):
28+
return None
29+
30+
1831
class _FakeQuantModel:
1932
def __init__(self, layer_count: int):
2033
self.model = SimpleNamespace(layers=nn.ModuleList([nn.Linear(1, 1) for _ in range(layer_count)]))
@@ -55,3 +68,21 @@ def test_model_test_fast_mode_first_layers_remain_configurable(monkeypatch):
5568
"-:^layers\\.4\\.",
5669
"-:^layers\\.5\\.",
5770
]
71+
72+
73+
def test_model_test_fast_mode_uses_fast_dataset_overrides(monkeypatch):
74+
monkeypatch.setenv("GPTQMODEL_MODEL_TEST_MODE", "fast")
75+
case = _DatasetCompatCase(methodName="runTest")
76+
77+
assert case._mode_specific_test_setting("DATASET_SIZE") == 128
78+
assert case._mode_specific_test_setting("DATASET_CONCAT_SIZE") == 1024
79+
assert case._mode_specific_test_setting("OFFLOAD_TO_DISK") is False
80+
81+
82+
def test_model_test_slow_mode_uses_default_dataset_settings(monkeypatch):
83+
monkeypatch.setenv("GPTQMODEL_MODEL_TEST_MODE", "slow")
84+
case = _DatasetCompatCase(methodName="runTest")
85+
86+
assert case._mode_specific_test_setting("DATASET_SIZE") == 512
87+
assert case._mode_specific_test_setting("DATASET_CONCAT_SIZE") == 2048
88+
assert case._mode_specific_test_setting("OFFLOAD_TO_DISK") is True

tests/test_hf_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
55

6+
import tempfile
7+
68
from torch import nn
79
from transformers import PretrainedConfig, PreTrainedModel
810
from transformers.modeling_utils import _get_tied_weight_keys
@@ -46,3 +48,14 @@ def test_legacy_list_tied_weights_are_normalized_to_input_embeddings():
4648
}
4749
assert model._tied_weights_keys == {"lm_head.weight": "embed_tokens.weight"}
4850
assert _get_tied_weight_keys(model) == ["lm_head.weight"]
51+
52+
53+
def test_legacy_list_tied_weights_allow_save_pretrained():
54+
model = _LegacyTiedWeightsModel(_DummyConfig())
55+
56+
with tempfile.TemporaryDirectory() as tmp_dir:
57+
model._tied_weights_keys = ["lm_head.weight"]
58+
model.get_expanded_tied_weights_keys(all_submodels=False)
59+
model._tied_weights_keys = ["lm_head.weight"]
60+
_hf_utils._normalize_legacy_tied_weights_keys(model)
61+
model.save_pretrained(tmp_dir, state_dict={}, is_main_process=True)

tests/test_marlin_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def test_sm75_turing_contract_is_present_in_marlin_sources():
268268

269269

270270
def test_stage2_dense_four_bit_tiles_stay_in_sync_between_selector_and_codegen():
271-
marlin_root = marlin_utils._marlin_root()
271+
marlin_root = marlin_utils._ensure_generated_marlin_kernels()
272272
gemm_cu = (marlin_root / "gptq_marlin.cu").read_text(encoding="utf-8")
273273
generator_py = (marlin_root / "generate_kernels.py").read_text(encoding="utf-8")
274274
kernel_u4 = (marlin_root / "kernel_fp16_ku4.cu").read_text(encoding="utf-8")

0 commit comments

Comments
 (0)