Skip to content

Commit 25dd8ea

Browse files
authored
add gemma4 (#2663)
* add gemma4 * Update README.md * Bump version from 6.0.0 to 6.0.2
1 parent a856afe commit 25dd8ea

13 files changed

Lines changed: 672 additions & 14 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
</p>
2121

2222
## Latest News
23-
* 04/02/2026 [6.0.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.1): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support.
23+
* 04/03/2026 [6.0.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.2): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `Gemma4`, `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support.
2424
* 03/19/2026 [5.8.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.8.0): ✨HF Transformers 5.3.0 support with auto-defusing of `fused` models via pypi pkg: [Defuser](https://github.com/ModelCloud/Defuser). Qwen 3.5 family support added. New fast HF `cpu` kernels for GPTQ/AWQ added. Experimental INT8 `cpu` kernel added for GPTQ.
2525
* 03/09/2026 [main]: ✨Qwen 3.5 MoE model support added. New HF Kernel support added for AWQ.
2626
HF Kernel for both gptq/awq are now used by default for cpu devices for best performance. New INT8 kernel ported from Intel for gptq.

gptqmodel/looper/forward_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ def run_single(
239239
additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=exec_device)
240240

241241
additional_inputs["use_cache"] = False
242+
additional_inputs = self.looper.gptq_model.prepare_layer_replay_kwargs(
243+
layer=module,
244+
layer_input=layer_input,
245+
additional_inputs=additional_inputs,
246+
target_device=exec_device,
247+
)
242248

243249
if not preserve_module_devices:
244250
rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)
@@ -489,6 +495,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
489495
layer_input_kwargs[batch_idx],
490496
attention_masks[batch_idx],
491497
position_ids[batch_idx] if position_ids else None,
498+
gptq_model=self.looper.gptq_model,
492499
support_batch_quantize=self.looper.support_batch_quantize,
493500
is_lm_head_module=is_lm_head_module,
494501
need_output=need_outputs,

gptqmodel/looper/stage_inputs_capture.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,11 @@ def store_input_hook(module, args, kwargs):
141141
else:
142142
batch_device = data_device
143143

144-
layer_input: List[torch.Tensor] = []
145-
if kwargs.get("hidden_states") is not None:
146-
layer_input.append(move_to(kwargs["hidden_states"], device=batch_device))
147-
else:
148-
layer_input.append(move_to(args[0], device=batch_device))
144+
layer_input = self.gptq_model.capture_first_layer_positional_inputs(
145+
args=args,
146+
kwargs=kwargs,
147+
batch_device=batch_device,
148+
)
149149

150150
layer_inputs.append(layer_input)
151151

@@ -161,6 +161,12 @@ def store_input_hook(module, args, kwargs):
161161
for (k, v) in kwargs.items():
162162
if k not in ["hidden_states", "attention_mask", "position_ids"]:
163163
one_kwargs[k] = nested_move_to(v, device=batch_device)
164+
one_kwargs = self.gptq_model.capture_first_layer_input_kwargs(
165+
args=args,
166+
kwargs=kwargs,
167+
batch_device=batch_device,
168+
layer_input_kwargs=one_kwargs,
169+
)
164170
layer_input_kwargs.append(one_kwargs)
165171

166172
# In normal repeating layer/sbuset early stop happens on the last module forward

gptqmodel/models/auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
from .definitions.falcon_h1 import FalconH1QModel # noqa: E402
9393
from .definitions.gemma2 import Gemma2QModel # noqa: E402
9494
from .definitions.gemma3 import Gemma3ForConditionalGenerationGPTQ, Gemma3QModel # noqa: E402
95+
from .definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel # noqa: E402
9596
from .definitions.glm import GlmQModel # noqa: E402
9697
from .definitions.glm4_moe import GLM4MoEGPTQ # noqa: E402
9798
from .definitions.glm4_moe_lite import Glm4MoeLiteQModel # noqa: E402
@@ -210,6 +211,8 @@
210211
"gemma2": Gemma2QModel,
211212
"gemma3_text": Gemma3QModel,
212213
"gemma3": Gemma3ForConditionalGenerationGPTQ,
214+
"gemma4_text": Gemma4TextQModel,
215+
"gemma4": Gemma4ForConditionalGenerationGPTQ,
213216
"phi": PhiQModel,
214217
"phi3": Phi3QModel,
215218
"phi4mm": Phi4MMGPTQ,

gptqmodel/models/base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,42 @@ def pre_quantize_generate_hook_end(self):
13651365
# offload_to_disk(model=self.model, module=self.get_base_modules(model=self.model), disk_path=self.quantize_config.offload_to_disk_path)
13661366
pass
13671367

1368+
def capture_first_layer_positional_inputs(
1369+
self,
1370+
args: tuple[Any, ...],
1371+
kwargs: Dict[str, Any],
1372+
batch_device: torch.device,
1373+
) -> List[torch.Tensor]:
1374+
"""Normalize first-layer positional inputs so cached forwards can replay decoder layers directly."""
1375+
1376+
if kwargs.get("hidden_states") is not None:
1377+
return [move_to(kwargs["hidden_states"], device=batch_device)]
1378+
if args:
1379+
return [move_to(args[0], device=batch_device)]
1380+
return []
1381+
1382+
def capture_first_layer_input_kwargs(
1383+
self,
1384+
args: tuple[Any, ...],
1385+
kwargs: Dict[str, Any],
1386+
batch_device: torch.device,
1387+
layer_input_kwargs: Dict[str, Any],
1388+
) -> Dict[str, Any]:
1389+
"""Allow model definitions to persist extra first-layer replay metadata during calibration capture."""
1390+
1391+
return layer_input_kwargs
1392+
1393+
def prepare_layer_replay_kwargs(
1394+
self,
1395+
layer: nn.Module,
1396+
layer_input: List[torch.Tensor],
1397+
additional_inputs: Dict[str, Any],
1398+
target_device: torch.device,
1399+
) -> Dict[str, Any]:
1400+
"""Allow model definitions to refresh layer-specific kwargs before cached layer replay."""
1401+
1402+
return additional_inputs
1403+
13681404
def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) -> List[List[torch.tensor]]:
13691405
if self.pre_lm_head_norm_module:
13701406
norm, _ = get_module_by_name_prefix(self.model, [self.pre_lm_head_norm_module])

gptqmodel/models/definitions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .ernie4_5_moe import Ernie4_5_MoeQModel
2727
from .gemma2 import Gemma2QModel
2828
from .gemma3 import Gemma3QModel
29+
from .gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel
2930
from .glm import GlmQModel
3031
from .gpt2 import GPT2QModel
3132
from .gpt_bigcode import GptBigCodeQModel
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
6+
import torch
7+
from types import MethodType
8+
9+
from ..base import BaseQModel
10+
from ...utils.device import get_device
11+
from ...utils.model import get_module_by_name_prefix, move_to, nested_move_to
12+
from . import LlamaQModel
13+
14+
15+
_GEMMA4_ALL_PER_LAYER_INPUTS = "__gptqmodel_gemma4_all_per_layer_inputs"
16+
17+
18+
def _gemma4_module_tree():
19+
"""Return the Gemma 4 decoder traversal with optional attention and per-layer input modules."""
20+
21+
return [
22+
"model",
23+
"layers",
24+
"#",
25+
{
26+
"input_layernorm": ("input_layernorm:!",),
27+
"self_attn": (
28+
"q_norm:!",
29+
"q_proj:0",
30+
"k_norm:!",
31+
"k_proj:0",
32+
"v_norm:!",
33+
"v_proj:0",
34+
"o_proj:1",
35+
),
36+
"post_attention_layernorm": ("post_attention_layernorm:!",),
37+
"pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",),
38+
"mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
39+
"post_feedforward_layernorm": ("post_feedforward_layernorm:!",),
40+
"per_layer_input_gate": ("per_layer_input_gate:0",),
41+
"post_per_layer_input_norm": ("post_per_layer_input_norm:!",),
42+
"per_layer_projection": ("per_layer_projection:1",),
43+
},
44+
]
45+
46+
47+
def _capture_gemma4_positional_inputs(model_def, args, kwargs, batch_device):
48+
"""Preserve Gemma 4 per-layer adapter inputs that flow through decoder layers positionally."""
49+
50+
layer_input = super(type(model_def), model_def).capture_first_layer_positional_inputs(args, kwargs, batch_device)
51+
per_layer_input = args[1] if len(args) > 1 else kwargs.get("per_layer_input")
52+
if per_layer_input is not None:
53+
layer_input.append(move_to(per_layer_input, device=batch_device))
54+
return layer_input
55+
56+
57+
def _prepare_gemma4_replay_kwargs(model_def, layer, layer_input, additional_inputs, target_device):
58+
"""Refresh Gemma 4 rotary kwargs per layer so replay follows sliding/full attention boundaries."""
59+
60+
rotary_path = getattr(model_def, "rotary_embedding", None)
61+
if not rotary_path or not layer_input:
62+
return additional_inputs
63+
64+
rotary, _ = get_module_by_name_prefix(model_def.model, [rotary_path])
65+
if rotary is None:
66+
return additional_inputs
67+
68+
layer_type = getattr(getattr(layer, "self_attn", None), "layer_type", None)
69+
if layer_type is None:
70+
return additional_inputs
71+
72+
hidden_states = layer_input[0]
73+
seq_len = hidden_states.shape[1] if hidden_states.dim() >= 2 else hidden_states.shape[0]
74+
batch_dim = hidden_states.shape[0] if hidden_states.dim() >= 2 else 1
75+
76+
position_ids = additional_inputs.get("position_ids")
77+
if position_ids is None or position_ids.shape[-1] != seq_len:
78+
position_ids = torch.arange(seq_len, device=target_device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1)
79+
additional_inputs["position_ids"] = position_ids
80+
81+
try:
82+
rotary_device = get_device(rotary)
83+
except Exception:
84+
rotary_device = position_ids.device
85+
86+
rotary_position_ids = move_to(position_ids, device=rotary_device)
87+
rotary_input = torch.empty(1, device=rotary_device, dtype=hidden_states.dtype)
88+
additional_inputs["position_embeddings"] = nested_move_to(
89+
rotary(rotary_input, rotary_position_ids, layer_type),
90+
device=target_device,
91+
)
92+
93+
if len(layer_input) == 1:
94+
all_per_layer_inputs = additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None)
95+
layer_index = getattr(getattr(layer, "self_attn", None), "layer_idx", None)
96+
if all_per_layer_inputs is not None and layer_index is not None:
97+
additional_inputs["per_layer_input"] = move_to(
98+
all_per_layer_inputs[:, :, layer_index, :],
99+
device=target_device,
100+
)
101+
else:
102+
additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None)
103+
104+
return additional_inputs
105+
106+
107+
def _resolve_gemma4_language_model(model_def):
108+
"""Return the Gemma 4 text stack that owns per-layer input projection state."""
109+
110+
if hasattr(model_def.model, "model") and hasattr(model_def.model.model, "language_model"):
111+
return model_def.model.model.language_model
112+
return model_def.model.model
113+
114+
115+
def _patch_gemma4_per_layer_input_capture(model_def):
116+
"""Capture projected per-layer inputs during calibration so later decoder replays can slice them by layer."""
117+
118+
language_model = _resolve_gemma4_language_model(model_def)
119+
if getattr(language_model, "_gptqmodel_project_per_layer_inputs_patched", False):
120+
return
121+
122+
original = language_model.project_per_layer_inputs
123+
124+
def patched(self, inputs_embeds, per_layer_inputs=None):
125+
result = original(inputs_embeds, per_layer_inputs)
126+
setattr(self, "_gptqmodel_cached_all_per_layer_inputs", result)
127+
return result
128+
129+
language_model._gptqmodel_original_project_per_layer_inputs = original
130+
language_model.project_per_layer_inputs = MethodType(patched, language_model)
131+
language_model._gptqmodel_project_per_layer_inputs_patched = True
132+
133+
134+
def _restore_gemma4_per_layer_input_capture(model_def):
135+
"""Restore Gemma 4 per-layer input helpers after calibration capture completes."""
136+
137+
language_model = _resolve_gemma4_language_model(model_def)
138+
original = getattr(language_model, "_gptqmodel_original_project_per_layer_inputs", None)
139+
if original is not None:
140+
language_model.project_per_layer_inputs = original
141+
delattr(language_model, "_gptqmodel_original_project_per_layer_inputs")
142+
if hasattr(language_model, "_gptqmodel_project_per_layer_inputs_patched"):
143+
delattr(language_model, "_gptqmodel_project_per_layer_inputs_patched")
144+
if hasattr(language_model, "_gptqmodel_cached_all_per_layer_inputs"):
145+
delattr(language_model, "_gptqmodel_cached_all_per_layer_inputs")
146+
147+
148+
class Gemma4TextQModel(LlamaQModel):
149+
"""Quantization definition for text-only Gemma 4 checkpoints."""
150+
151+
# Gemma 4 mixes optional KV projections and per-layer residual adapters across variants.
152+
layer_modules_strict = False
153+
# Gemma 4 input preparation uses per-layer embeddings, so batch quantization stays conservative.
154+
support_batch_quantize = False
155+
pre_lm_head_norm_module = "model.norm"
156+
rotary_embedding = "model.rotary_emb"
157+
module_tree = _gemma4_module_tree()
158+
159+
def capture_first_layer_positional_inputs(self, args, kwargs, batch_device):
160+
"""Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation."""
161+
162+
return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device)
163+
164+
def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs):
165+
"""Persist Gemma 4 per-layer adapter tensors for later decoder replays."""
166+
167+
layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs)
168+
language_model = _resolve_gemma4_language_model(self)
169+
all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None)
170+
if all_per_layer_inputs is not None:
171+
layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device)
172+
return layer_input_kwargs
173+
174+
def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device):
175+
"""Refresh Gemma 4 layer kwargs during cached replay."""
176+
177+
return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device)
178+
179+
def pre_quantize_generate_hook_start(self):
180+
_patch_gemma4_per_layer_input_capture(self)
181+
182+
def pre_quantize_generate_hook_end(self):
183+
_restore_gemma4_per_layer_input_capture(self)
184+
super().pre_quantize_generate_hook_end()
185+
186+
187+
class Gemma4ForConditionalGenerationGPTQ(BaseQModel):
188+
"""Quantization definition for composite Gemma 4 checkpoints."""
189+
190+
# Gemma 4 composite checkpoints share the same decoder quirks as the text-only model.
191+
layer_modules_strict = False
192+
support_batch_quantize = False
193+
pre_lm_head_norm_module = "model.language_model.norm"
194+
rotary_embedding = "model.language_model.rotary_emb"
195+
196+
module_tree = [
197+
"model",
198+
"language_model",
199+
"layers",
200+
"#",
201+
{
202+
"input_layernorm": ("input_layernorm:!",),
203+
"self_attn": (
204+
"q_norm:!",
205+
"q_proj:0",
206+
"k_norm:!",
207+
"k_proj:0",
208+
"v_norm:!",
209+
"v_proj:0",
210+
"o_proj:1",
211+
),
212+
"post_attention_layernorm": ("post_attention_layernorm:!",),
213+
"pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",),
214+
"mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
215+
"post_feedforward_layernorm": ("post_feedforward_layernorm:!",),
216+
"per_layer_input_gate": ("per_layer_input_gate:0",),
217+
"post_per_layer_input_norm": ("post_per_layer_input_norm:!",),
218+
"per_layer_projection": ("per_layer_projection:1",),
219+
},
220+
]
221+
222+
def capture_first_layer_positional_inputs(self, args, kwargs, batch_device):
223+
"""Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation."""
224+
225+
return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device)
226+
227+
def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs):
228+
"""Persist Gemma 4 per-layer adapter tensors for later decoder replays."""
229+
230+
layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs)
231+
language_model = _resolve_gemma4_language_model(self)
232+
all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None)
233+
if all_per_layer_inputs is not None:
234+
layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device)
235+
return layer_input_kwargs
236+
237+
def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device):
238+
"""Refresh Gemma 4 layer kwargs during cached replay."""
239+
240+
return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device)
241+
242+
def pre_quantize_generate_hook_start(self):
243+
_patch_gemma4_per_layer_input_capture(self)
244+
245+
def pre_quantize_generate_hook_end(self):
246+
_restore_gemma4_per_layer_input_capture(self)
247+
super().pre_quantize_generate_hook_end()

gptqmodel/quantization/awq/quantize/scale.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,15 @@
2727
from gptqmodel.quantization.awq.utils.utils import get_best_device
2828

2929

30+
try:
31+
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
32+
except Exception: # pragma: no cover - older transformers builds do not expose Gemma 4 yet
33+
Gemma4RMSNorm = None
34+
35+
3036
allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm]
37+
if Gemma4RMSNorm is not None:
38+
allowed_norms.append(Gemma4RMSNorm)
3139
allowed_act_fns = [
3240
nn.GELU,
3341
BloomGelu,

0 commit comments

Comments
 (0)