Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion gptqmodel/looper/forward_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,13 @@ def run_single(
self.looper._set_processor_mask(processor, keep_mask)
additional_inputs: Dict[str, Optional[torch.Tensor]] = {}
if self.looper.support_batch_quantize and attn_tensor is not None:
additional_inputs["attention_mask"] = attn_tensor
# Some architectures expect a 2D keep-mask ([B, S]) instead of
# a broadcasted 4D mask during layer replay.
if self.looper.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT and keep_mask is not None:
required_dtype = getattr(self.looper.gptq_model, "ATTENTION_MASKS_DTYPE", torch.bool)
additional_inputs["attention_mask"] = keep_mask.to(dtype=required_dtype)
else:
additional_inputs["attention_mask"] = attn_tensor
else:
additional_inputs["attention_mask"] = None

Expand Down
13 changes: 5 additions & 8 deletions gptqmodel/looper/stage_inputs_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,11 @@ def store_input_hook(module, args, kwargs):
# TODO: why data_device sometimes set to cuda (self.gptq_model.quantize_config.device) and sometimes to CPU (cur_layer_device)?
try:
for batch_index, example in enumerate(calibration_data, start=1):
if self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT:
data_device = self.gptq_model.quantize_config.device
else:
data_device = (
self.gptq_model.quantize_config.device
if "pixel_values" in example.keys()
else cur_layer_device
)
data_device = (
self.gptq_model.quantize_config.device
if "pixel_values" in example.keys()
else cur_layer_device
)
example = self.gptq_model.move_input_capture_example(example, data_device)
try:
with ctx(
Expand Down
89 changes: 82 additions & 7 deletions gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from concurrent.futures import as_completed
from typing import TYPE_CHECKING, Dict, List, Optional

import defuser
from defuser.modeling.replace_modules import materialize_model
from ..nn_modules.hooked_linear import replace_module_with_hooked_legacy
from ..nn_modules.converter import MODULE_CONVERTER_MAP
Expand All @@ -36,7 +37,7 @@
from ..utils.device import get_device, get_device_new
from ..utils.looper_helpers import normalize_device_like
from ..utils.logger import live_renderables_suppressed, log_time_block, setup_logger
from ..utils.model import find_modules, get_module
from ..utils.model import _module_has_meta_tensors, find_modules, get_module
from ..utils.offload import offload_to_disk
from ..utils.torch import CPU, torch_empty_cache, torch_sync
from .stage_subset import SubsetPlan, build_layer_subset_plans, run_subset_stage
Expand Down Expand Up @@ -130,6 +131,38 @@ def _should_empty_cache_after_sync_finalize(
return any(isinstance(process, ParoQuantProcessor) for process, *_ in finalize_tasks)


def _materialize_remaining_meta_submodules(
looper: "ModuleLooper",
*,
module: torch.nn.Module,
device: torch.device,
) -> torch.nn.Module:
"""Best-effort materialization for any submodule that still carries meta tensors."""
if not _module_has_meta_tensors(module):
return module

# Process deeper modules first so parent wrappers observe materialized children.
names = [name for name, _ in module.named_modules() if name]
names.sort(key=lambda n: n.count("."), reverse=True)

for name in names:
try:
submodule = get_module(module, name)
except Exception:
continue
if not isinstance(submodule, torch.nn.Module):
continue
if not _module_has_meta_tensors(submodule):
continue
looper.gptq_model.shell_module_materialize(
target_submodule=submodule,
device=device,
role="forward",
)

return module


def _processor_needs_pristine_group_clone(processor) -> bool:
"""Whether grouped capture needs a dedicated pristine layer clone for this processor."""
needs_clone = getattr(processor, "needs_pristine_layer_clone", None)
Expand Down Expand Up @@ -475,25 +508,67 @@ def run_layer_stage(
if needs_group_pristine:
pristine_group_module = copy.deepcopy(module) if needs_pristine_group_clone else None

replace_module_with_hooked_legacy(
module,
quant_lm_head=looper.gptq_model.quantize_config.lm_head,
skip_module_paths=hook_skip_modules,
)

layers[layer_index] = module

if layers_prefix:
layer_descriptor = f"{layers_prefix}.{layer_index}"
else:
layer_descriptor = str(layer_index)

# Materialize the original shell modules first; hooking wrappers before
# this can prevent defuser materialization from resolving lazy weights.
materialize_model(module)
if _module_has_meta_tensors(module):
materialize_device = normalize_device_like(looper.gptq_model.quantize_config.device) or CPU
module = looper.gptq_model.shell_module_materialize(
target_submodule=module,
device=materialize_device,
role="forward",
)
module = _materialize_remaining_meta_submodules(
looper,
module=module,
device=materialize_device,
)
layers[layer_index] = module

# LazyTurtle materializes checkpoint tensors by runtime module path.
# Defuser conversion can mutate those paths/projections before all shell
# tensors are resolved, leaving newly introduced parameters on meta.
# Keep conversion for eager-loaded models, but skip during lazy shell
# quantization where checkpoint-backed materialization is authoritative.
if (
looper.gptq_model.turtle_model is None
and not getattr(module, "_gptqmodel_defuser_converted", False)
):
defuser.convert_model(module, cleanup_original=False)
setattr(module, "_gptqmodel_defuser_converted", True)

replace_module_with_hooked_legacy(
module,
quant_lm_head=looper.gptq_model.quantize_config.lm_head,
skip_module_paths=hook_skip_modules,
)

cur_layer_device = get_device(module)
if getattr(cur_layer_device, "type", None) == "meta":
# Lazy shell layers can stay meta until a later subset stage materializes them.
cur_layer_device = normalize_device_like(looper.gptq_model.quantize_config.device) or CPU
if getattr(cur_layer_device, "type", None) == "cpu":
quant_devices = getattr(looper, "_quant_devices", None) or []
# Keep replay on an accelerator when quant device policy resolved one;
# otherwise large attention replays can silently fall back to CPU.
accel_device = next(
(
normalize_device_like(device)
for device in quant_devices
if (normalize_device_like(device) is not None and normalize_device_like(device).type != "cpu")
),
None,
)
if accel_device is not None:
module.to(accel_device)
cur_layer_device = accel_device
full = find_modules(module, name=looper.gptq_model.lm_head if is_lm_head_module else "")

for p_index, processor in enumerate(looper.processors):
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class MiniMaxM2GPTQ(BaseQModel):
layer_modules_strict = False

dynamic_expert_index = "num_local_experts"
ATTENTION_MASKS_REQUIRED_FOR_INPUT = True

# MoE lifecycle hooks for w1/w3/w2 pattern
moe_lifecycle_hooks = W1W3W2MoELifecycleHooks()
Expand Down
16 changes: 13 additions & 3 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ..utils.machete import _validate_machete_device_support
from ..utils.marlin import _marlin_capability_supported, _validate_marlin_device_support
from ..utils.model import (
_module_has_meta_tensors,
auto_dtype,
convert_gptq_v1_to_v2_format,
find_config_seq_len,
Expand Down Expand Up @@ -642,12 +643,18 @@ def skip(*args, **kwargs):
)
if getattr(model, "config", None) is config:
model.config = copy.deepcopy(config)
defuser.convert_model(model, cleanup_original=False)
if _module_has_meta_tensors(model):
log.info("Loader: defuser.convert_model deferred until layer materialization (meta shell detected)")
else:
defuser.convert_model(model, cleanup_original=False)
model._model_init_kwargs = fallback_init_kwargs
_maybe_print_module_tree(model=model)
turtle_model = None
else:
defuser.convert_model(model, cleanup_original=False)
if _module_has_meta_tensors(model):
log.info("Loader: defuser.convert_model deferred until layer materialization (meta shell detected)")
else:
defuser.convert_model(model, cleanup_original=False)
shell_model_init_kwargs = dict(model_init_kwargs_without_internal)
shell_model_init_kwargs.update(hf_gguf_load_kwargs)
model._model_init_kwargs = shell_model_init_kwargs
Expand Down Expand Up @@ -682,7 +689,10 @@ def skip(*args, **kwargs):
)
if getattr(model, "config", None) is config:
model.config = copy.deepcopy(config)
defuser.convert_model(model, cleanup_original=False)
if _module_has_meta_tensors(model):
log.info("Loader: defuser.convert_model deferred until layer materialization (meta shell detected)")
else:
defuser.convert_model(model, cleanup_original=False)
direct_model_init_kwargs = dict(model_init_kwargs_without_internal)
direct_model_init_kwargs.update(hf_gguf_load_kwargs)
model._model_init_kwargs = direct_model_init_kwargs
Expand Down
12 changes: 12 additions & 0 deletions gptqmodel/nn_modules/hooked_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)

# `meta` is a shape-only device and cannot receive real tensor data.
# During shell/offload flows, some replay paths may pass meta inputs
# while weights execute on a materialized device (e.g. CPU/CUDA).
# In that case, keep the computed output on its real device.
if original_device.type == "meta":
return output
if original_device.type == "meta" or output.device.type == "meta":
return output
if original_device.type == "meta" or output.device.type == "meta":
return output
if output.device != original_device:
output = output.to(device=original_device)
return output
Expand Down Expand Up @@ -248,6 +258,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
raise STOP_FORWARD_EXCEPTION.with_traceback(None)
if original_device.type == "meta" or output.device.type == "meta":
return output
if output.device != original_device:
output = output.to(device=original_device)
return output
Expand Down
15 changes: 11 additions & 4 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,17 @@ def _module_has_meta_tensors(module: nn.Module) -> bool:
def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dtype = None):
if isinstance(obj, nn.Module) and _module_has_meta_tensors(obj):
if not accelerate.utils.has_offloaded_params(obj):
raise NotImplementedError(
"Cannot move a module that still contains meta tensors without offload hooks. "
"Materialize it first before calling move_to()."
)
# Some quant/offload paths can leave non-participating leaves on meta.
# Best-effort move concrete tensors and keep meta placeholders as-is.
for _, param in obj.named_parameters(recurse=True):
if getattr(param, "is_meta", False) or param.device.type == "meta":
continue
param.data = param.data.to(device=device, dtype=dtype, non_blocking=False)
for _, buf in obj.named_buffers(recurse=True):
if getattr(buf, "is_meta", False) or buf.device.type == "meta":
continue
buf.data = buf.data.to(device=device, dtype=dtype, non_blocking=False)
return obj

# Accelerate disk-offloaded modules keep meta placeholders until they are
# explicitly restored, so materialize those leaves before the device move.
Expand Down
Loading