Skip to content

Commit bb87910

Browse files
authored
Use loader device selection for EoRA adapter generation (#2800)
1 parent ff9dfd5 commit bb87910

2 files changed

Lines changed: 77 additions & 3 deletions

File tree

gptqmodel/models/auto.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
)
6969
from ..utils.hub import list_repo_files # noqa: E402
7070
from ..utils.model import find_modules # noqa: E402
71-
from ..utils.torch import CPU, torch_empty_cache # noqa: E402
71+
from ..utils.torch import torch_empty_cache # noqa: E402
7272
from .base import BaseQModel # noqa: E402
7373
from .definitions.afmoe import AfMoeQModel # noqa: E402
7474
from .definitions.apertus import ApertusQModel # noqa: E402
@@ -686,6 +686,7 @@ def generate(
686686
# pass-through vars for load()
687687
trust_remote_code: bool = False,
688688
dtype: Optional[Union[str, torch.dtype]] = None,
689+
device: Optional[Union[str, torch.device]] = None,
689690
):
690691
if not adapter or not isinstance(adapter, Lora):
691692
raise ValueError(f"Adapter: expected `adapter` type to be `Lora`: actual = `{adapter}`.")
@@ -696,7 +697,7 @@ def generate(
696697
quantized_model = GPTQModel.load(
697698
model_id_or_path=quantized_model_id_or_path,
698699
backend=BACKEND.GPTQ_TORCH,
699-
device=CPU,
700+
device=device,
700701
trust_remote_code=trust_remote_code,
701702
dtype=dtype,
702703
)
@@ -715,7 +716,7 @@ def generate(
715716
backend=BACKEND.GPTQ_TORCH,
716717
trust_remote_code=trust_remote_code,
717718
dtype=dtype,
718-
device=CPU,
719+
device=device,
719720
)
720721

721722
log.info("Model: Adapter generation started")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# GPU=-1
2+
from types import SimpleNamespace
3+
4+
from gptqmodel.adapter.adapter import Lora
5+
from gptqmodel.models import auto
6+
7+
8+
class _FakeNativeModel:
9+
def __init__(self):
10+
self.generate_calls = []
11+
12+
def _eora_generate(self, **kwargs):
13+
self.generate_calls.append(kwargs)
14+
15+
16+
def _run_adapter_generate(tmp_path, monkeypatch, *, device):
17+
load_calls = []
18+
find_modules_calls = []
19+
20+
quantized_model = SimpleNamespace(quantize_config="qcfg", model="quantized-model")
21+
native_model = _FakeNativeModel()
22+
23+
def fake_load(cls, model_id_or_path, *args, **kwargs):
24+
load_calls.append((model_id_or_path, kwargs.copy()))
25+
if model_id_or_path == "quantized":
26+
return quantized_model
27+
if model_id_or_path == "native":
28+
return native_model
29+
raise AssertionError(f"unexpected load target: {model_id_or_path}")
30+
31+
monkeypatch.setattr(auto.GPTQModel, "load", classmethod(fake_load))
32+
monkeypatch.setattr(
33+
auto,
34+
"find_modules",
35+
lambda module, layers: find_modules_calls.append((module, layers)) or {"module": object()},
36+
)
37+
monkeypatch.setattr(auto, "torch_empty_cache", lambda: None)
38+
39+
adapter = Lora(path=str(tmp_path / "adapter"), rank=8)
40+
kwargs = {
41+
"adapter": adapter,
42+
"model_id_or_path": "native",
43+
"quantized_model_id_or_path": "quantized",
44+
"calibration_dataset": ["sample"],
45+
}
46+
if device is not None:
47+
kwargs["device"] = device
48+
49+
auto.GPTQModel.adapter.generate(**kwargs)
50+
51+
return load_calls, find_modules_calls, native_model.generate_calls
52+
53+
54+
def test_adapter_generate_defaults_to_loader_device_selection(tmp_path, monkeypatch):
55+
load_calls, find_modules_calls, generate_calls = _run_adapter_generate(
56+
tmp_path,
57+
monkeypatch,
58+
device=None,
59+
)
60+
61+
assert [kwargs["device"] for _, kwargs in load_calls] == [None, None]
62+
assert find_modules_calls == [("quantized-model", [auto.TorchLinear])]
63+
assert generate_calls[0]["quantized_modules"].keys() == {"module"}
64+
65+
66+
def test_adapter_generate_forwards_explicit_device(tmp_path, monkeypatch):
67+
load_calls, _, _ = _run_adapter_generate(
68+
tmp_path,
69+
monkeypatch,
70+
device="cuda:2",
71+
)
72+
73+
assert [kwargs["device"] for _, kwargs in load_calls] == ["cuda:2", "cuda:2"]

0 commit comments

Comments
 (0)