|
| 1 | +# GPU=-1 |
| 2 | +from types import SimpleNamespace |
| 3 | + |
| 4 | +from gptqmodel.adapter.adapter import Lora |
| 5 | +from gptqmodel.models import auto |
| 6 | + |
| 7 | + |
| 8 | +class _FakeNativeModel: |
| 9 | + def __init__(self): |
| 10 | + self.generate_calls = [] |
| 11 | + |
| 12 | + def _eora_generate(self, **kwargs): |
| 13 | + self.generate_calls.append(kwargs) |
| 14 | + |
| 15 | + |
| 16 | +def _run_adapter_generate(tmp_path, monkeypatch, *, device): |
| 17 | + load_calls = [] |
| 18 | + find_modules_calls = [] |
| 19 | + |
| 20 | + quantized_model = SimpleNamespace(quantize_config="qcfg", model="quantized-model") |
| 21 | + native_model = _FakeNativeModel() |
| 22 | + |
| 23 | + def fake_load(cls, model_id_or_path, *args, **kwargs): |
| 24 | + load_calls.append((model_id_or_path, kwargs.copy())) |
| 25 | + if model_id_or_path == "quantized": |
| 26 | + return quantized_model |
| 27 | + if model_id_or_path == "native": |
| 28 | + return native_model |
| 29 | + raise AssertionError(f"unexpected load target: {model_id_or_path}") |
| 30 | + |
| 31 | + monkeypatch.setattr(auto.GPTQModel, "load", classmethod(fake_load)) |
| 32 | + monkeypatch.setattr( |
| 33 | + auto, |
| 34 | + "find_modules", |
| 35 | + lambda module, layers: find_modules_calls.append((module, layers)) or {"module": object()}, |
| 36 | + ) |
| 37 | + monkeypatch.setattr(auto, "torch_empty_cache", lambda: None) |
| 38 | + |
| 39 | + adapter = Lora(path=str(tmp_path / "adapter"), rank=8) |
| 40 | + kwargs = { |
| 41 | + "adapter": adapter, |
| 42 | + "model_id_or_path": "native", |
| 43 | + "quantized_model_id_or_path": "quantized", |
| 44 | + "calibration_dataset": ["sample"], |
| 45 | + } |
| 46 | + if device is not None: |
| 47 | + kwargs["device"] = device |
| 48 | + |
| 49 | + auto.GPTQModel.adapter.generate(**kwargs) |
| 50 | + |
| 51 | + return load_calls, find_modules_calls, native_model.generate_calls |
| 52 | + |
| 53 | + |
| 54 | +def test_adapter_generate_defaults_to_loader_device_selection(tmp_path, monkeypatch): |
| 55 | + load_calls, find_modules_calls, generate_calls = _run_adapter_generate( |
| 56 | + tmp_path, |
| 57 | + monkeypatch, |
| 58 | + device=None, |
| 59 | + ) |
| 60 | + |
| 61 | + assert [kwargs["device"] for _, kwargs in load_calls] == [None, None] |
| 62 | + assert find_modules_calls == [("quantized-model", [auto.TorchLinear])] |
| 63 | + assert generate_calls[0]["quantized_modules"].keys() == {"module"} |
| 64 | + |
| 65 | + |
| 66 | +def test_adapter_generate_forwards_explicit_device(tmp_path, monkeypatch): |
| 67 | + load_calls, _, _ = _run_adapter_generate( |
| 68 | + tmp_path, |
| 69 | + monkeypatch, |
| 70 | + device="cuda:2", |
| 71 | + ) |
| 72 | + |
| 73 | + assert [kwargs["device"] for _, kwargs in load_calls] == ["cuda:2", "cuda:2"] |
0 commit comments