|
| 1 | +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai |
| 2 | +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# Contact: qubitium@modelcloud.ai, x.com/qubitium |
| 5 | + |
| 6 | +import torch |
| 7 | +from types import MethodType |
| 8 | + |
| 9 | +from ..base import BaseQModel |
| 10 | +from ...utils.device import get_device |
| 11 | +from ...utils.model import get_module_by_name_prefix, move_to, nested_move_to |
| 12 | +from . import LlamaQModel |
| 13 | + |
| 14 | + |
| 15 | +_GEMMA4_ALL_PER_LAYER_INPUTS = "__gptqmodel_gemma4_all_per_layer_inputs" |
| 16 | + |
| 17 | + |
| 18 | +def _gemma4_module_tree(): |
| 19 | + """Return the Gemma 4 decoder traversal with optional attention and per-layer input modules.""" |
| 20 | + |
| 21 | + return [ |
| 22 | + "model", |
| 23 | + "layers", |
| 24 | + "#", |
| 25 | + { |
| 26 | + "input_layernorm": ("input_layernorm:!",), |
| 27 | + "self_attn": ( |
| 28 | + "q_norm:!", |
| 29 | + "q_proj:0", |
| 30 | + "k_norm:!", |
| 31 | + "k_proj:0", |
| 32 | + "v_norm:!", |
| 33 | + "v_proj:0", |
| 34 | + "o_proj:1", |
| 35 | + ), |
| 36 | + "post_attention_layernorm": ("post_attention_layernorm:!",), |
| 37 | + "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",), |
| 38 | + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), |
| 39 | + "post_feedforward_layernorm": ("post_feedforward_layernorm:!",), |
| 40 | + "per_layer_input_gate": ("per_layer_input_gate:0",), |
| 41 | + "post_per_layer_input_norm": ("post_per_layer_input_norm:!",), |
| 42 | + "per_layer_projection": ("per_layer_projection:1",), |
| 43 | + }, |
| 44 | + ] |
| 45 | + |
| 46 | + |
| 47 | +def _capture_gemma4_positional_inputs(model_def, args, kwargs, batch_device): |
| 48 | + """Preserve Gemma 4 per-layer adapter inputs that flow through decoder layers positionally.""" |
| 49 | + |
| 50 | + layer_input = super(type(model_def), model_def).capture_first_layer_positional_inputs(args, kwargs, batch_device) |
| 51 | + per_layer_input = args[1] if len(args) > 1 else kwargs.get("per_layer_input") |
| 52 | + if per_layer_input is not None: |
| 53 | + layer_input.append(move_to(per_layer_input, device=batch_device)) |
| 54 | + return layer_input |
| 55 | + |
| 56 | + |
| 57 | +def _prepare_gemma4_replay_kwargs(model_def, layer, layer_input, additional_inputs, target_device): |
| 58 | + """Refresh Gemma 4 rotary kwargs per layer so replay follows sliding/full attention boundaries.""" |
| 59 | + |
| 60 | + rotary_path = getattr(model_def, "rotary_embedding", None) |
| 61 | + if not rotary_path or not layer_input: |
| 62 | + return additional_inputs |
| 63 | + |
| 64 | + rotary, _ = get_module_by_name_prefix(model_def.model, [rotary_path]) |
| 65 | + if rotary is None: |
| 66 | + return additional_inputs |
| 67 | + |
| 68 | + layer_type = getattr(getattr(layer, "self_attn", None), "layer_type", None) |
| 69 | + if layer_type is None: |
| 70 | + return additional_inputs |
| 71 | + |
| 72 | + hidden_states = layer_input[0] |
| 73 | + seq_len = hidden_states.shape[1] if hidden_states.dim() >= 2 else hidden_states.shape[0] |
| 74 | + batch_dim = hidden_states.shape[0] if hidden_states.dim() >= 2 else 1 |
| 75 | + |
| 76 | + position_ids = additional_inputs.get("position_ids") |
| 77 | + if position_ids is None or position_ids.shape[-1] != seq_len: |
| 78 | + position_ids = torch.arange(seq_len, device=target_device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1) |
| 79 | + additional_inputs["position_ids"] = position_ids |
| 80 | + |
| 81 | + try: |
| 82 | + rotary_device = get_device(rotary) |
| 83 | + except Exception: |
| 84 | + rotary_device = position_ids.device |
| 85 | + |
| 86 | + rotary_position_ids = move_to(position_ids, device=rotary_device) |
| 87 | + rotary_input = torch.empty(1, device=rotary_device, dtype=hidden_states.dtype) |
| 88 | + additional_inputs["position_embeddings"] = nested_move_to( |
| 89 | + rotary(rotary_input, rotary_position_ids, layer_type), |
| 90 | + device=target_device, |
| 91 | + ) |
| 92 | + |
| 93 | + if len(layer_input) == 1: |
| 94 | + all_per_layer_inputs = additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None) |
| 95 | + layer_index = getattr(getattr(layer, "self_attn", None), "layer_idx", None) |
| 96 | + if all_per_layer_inputs is not None and layer_index is not None: |
| 97 | + additional_inputs["per_layer_input"] = move_to( |
| 98 | + all_per_layer_inputs[:, :, layer_index, :], |
| 99 | + device=target_device, |
| 100 | + ) |
| 101 | + else: |
| 102 | + additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None) |
| 103 | + |
| 104 | + return additional_inputs |
| 105 | + |
| 106 | + |
| 107 | +def _resolve_gemma4_language_model(model_def): |
| 108 | + """Return the Gemma 4 text stack that owns per-layer input projection state.""" |
| 109 | + |
| 110 | + if hasattr(model_def.model, "model") and hasattr(model_def.model.model, "language_model"): |
| 111 | + return model_def.model.model.language_model |
| 112 | + return model_def.model.model |
| 113 | + |
| 114 | + |
| 115 | +def _patch_gemma4_per_layer_input_capture(model_def): |
| 116 | + """Capture projected per-layer inputs during calibration so later decoder replays can slice them by layer.""" |
| 117 | + |
| 118 | + language_model = _resolve_gemma4_language_model(model_def) |
| 119 | + if getattr(language_model, "_gptqmodel_project_per_layer_inputs_patched", False): |
| 120 | + return |
| 121 | + |
| 122 | + original = language_model.project_per_layer_inputs |
| 123 | + |
| 124 | + def patched(self, inputs_embeds, per_layer_inputs=None): |
| 125 | + result = original(inputs_embeds, per_layer_inputs) |
| 126 | + setattr(self, "_gptqmodel_cached_all_per_layer_inputs", result) |
| 127 | + return result |
| 128 | + |
| 129 | + language_model._gptqmodel_original_project_per_layer_inputs = original |
| 130 | + language_model.project_per_layer_inputs = MethodType(patched, language_model) |
| 131 | + language_model._gptqmodel_project_per_layer_inputs_patched = True |
| 132 | + |
| 133 | + |
| 134 | +def _restore_gemma4_per_layer_input_capture(model_def): |
| 135 | + """Restore Gemma 4 per-layer input helpers after calibration capture completes.""" |
| 136 | + |
| 137 | + language_model = _resolve_gemma4_language_model(model_def) |
| 138 | + original = getattr(language_model, "_gptqmodel_original_project_per_layer_inputs", None) |
| 139 | + if original is not None: |
| 140 | + language_model.project_per_layer_inputs = original |
| 141 | + delattr(language_model, "_gptqmodel_original_project_per_layer_inputs") |
| 142 | + if hasattr(language_model, "_gptqmodel_project_per_layer_inputs_patched"): |
| 143 | + delattr(language_model, "_gptqmodel_project_per_layer_inputs_patched") |
| 144 | + if hasattr(language_model, "_gptqmodel_cached_all_per_layer_inputs"): |
| 145 | + delattr(language_model, "_gptqmodel_cached_all_per_layer_inputs") |
| 146 | + |
| 147 | + |
| 148 | +class Gemma4TextQModel(LlamaQModel): |
| 149 | + """Quantization definition for text-only Gemma 4 checkpoints.""" |
| 150 | + |
| 151 | + # Gemma 4 mixes optional KV projections and per-layer residual adapters across variants. |
| 152 | + layer_modules_strict = False |
| 153 | + # Gemma 4 input preparation uses per-layer embeddings, so batch quantization stays conservative. |
| 154 | + support_batch_quantize = False |
| 155 | + pre_lm_head_norm_module = "model.norm" |
| 156 | + rotary_embedding = "model.rotary_emb" |
| 157 | + module_tree = _gemma4_module_tree() |
| 158 | + |
| 159 | + def capture_first_layer_positional_inputs(self, args, kwargs, batch_device): |
| 160 | + """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation.""" |
| 161 | + |
| 162 | + return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device) |
| 163 | + |
| 164 | + def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs): |
| 165 | + """Persist Gemma 4 per-layer adapter tensors for later decoder replays.""" |
| 166 | + |
| 167 | + layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs) |
| 168 | + language_model = _resolve_gemma4_language_model(self) |
| 169 | + all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None) |
| 170 | + if all_per_layer_inputs is not None: |
| 171 | + layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device) |
| 172 | + return layer_input_kwargs |
| 173 | + |
| 174 | + def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device): |
| 175 | + """Refresh Gemma 4 layer kwargs during cached replay.""" |
| 176 | + |
| 177 | + return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device) |
| 178 | + |
| 179 | + def pre_quantize_generate_hook_start(self): |
| 180 | + _patch_gemma4_per_layer_input_capture(self) |
| 181 | + |
| 182 | + def pre_quantize_generate_hook_end(self): |
| 183 | + _restore_gemma4_per_layer_input_capture(self) |
| 184 | + super().pre_quantize_generate_hook_end() |
| 185 | + |
| 186 | + |
| 187 | +class Gemma4ForConditionalGenerationGPTQ(BaseQModel): |
| 188 | + """Quantization definition for composite Gemma 4 checkpoints.""" |
| 189 | + |
| 190 | + # Gemma 4 composite checkpoints share the same decoder quirks as the text-only model. |
| 191 | + layer_modules_strict = False |
| 192 | + support_batch_quantize = False |
| 193 | + pre_lm_head_norm_module = "model.language_model.norm" |
| 194 | + rotary_embedding = "model.language_model.rotary_emb" |
| 195 | + |
| 196 | + module_tree = [ |
| 197 | + "model", |
| 198 | + "language_model", |
| 199 | + "layers", |
| 200 | + "#", |
| 201 | + { |
| 202 | + "input_layernorm": ("input_layernorm:!",), |
| 203 | + "self_attn": ( |
| 204 | + "q_norm:!", |
| 205 | + "q_proj:0", |
| 206 | + "k_norm:!", |
| 207 | + "k_proj:0", |
| 208 | + "v_norm:!", |
| 209 | + "v_proj:0", |
| 210 | + "o_proj:1", |
| 211 | + ), |
| 212 | + "post_attention_layernorm": ("post_attention_layernorm:!",), |
| 213 | + "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",), |
| 214 | + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), |
| 215 | + "post_feedforward_layernorm": ("post_feedforward_layernorm:!",), |
| 216 | + "per_layer_input_gate": ("per_layer_input_gate:0",), |
| 217 | + "post_per_layer_input_norm": ("post_per_layer_input_norm:!",), |
| 218 | + "per_layer_projection": ("per_layer_projection:1",), |
| 219 | + }, |
| 220 | + ] |
| 221 | + |
| 222 | + def capture_first_layer_positional_inputs(self, args, kwargs, batch_device): |
| 223 | + """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation.""" |
| 224 | + |
| 225 | + return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device) |
| 226 | + |
| 227 | + def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs): |
| 228 | + """Persist Gemma 4 per-layer adapter tensors for later decoder replays.""" |
| 229 | + |
| 230 | + layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs) |
| 231 | + language_model = _resolve_gemma4_language_model(self) |
| 232 | + all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None) |
| 233 | + if all_per_layer_inputs is not None: |
| 234 | + layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device) |
| 235 | + return layer_input_kwargs |
| 236 | + |
| 237 | + def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device): |
| 238 | + """Refresh Gemma 4 layer kwargs during cached replay.""" |
| 239 | + |
| 240 | + return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device) |
| 241 | + |
| 242 | + def pre_quantize_generate_hook_start(self): |
| 243 | + _patch_gemma4_per_layer_input_capture(self) |
| 244 | + |
| 245 | + def pre_quantize_generate_hook_end(self): |
| 246 | + _restore_gemma4_per_layer_input_capture(self) |
| 247 | + super().pre_quantize_generate_hook_end() |
0 commit comments