From f05ec3c5b149242565fc1aeaf983348dcb71ddb5 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Wed, 13 May 2026 13:44:55 +0800 Subject: [PATCH] Fix MiniMax-M2 lazy materialization and replay mask/device handling --- gptqmodel/looper/forward_executor.py | 8 +- gptqmodel/looper/stage_inputs_capture.py | 13 +- gptqmodel/looper/stage_layer.py | 89 ++++++- gptqmodel/models/definitions/minimax_m2.py | 1 + gptqmodel/models/loader.py | 16 +- gptqmodel/nn_modules/hooked_linear.py | 12 + gptqmodel/utils/model.py | 15 +- gptqmodel/utils/structure.py | 260 ++++++++++++++++----- 8 files changed, 334 insertions(+), 80 deletions(-) diff --git a/gptqmodel/looper/forward_executor.py b/gptqmodel/looper/forward_executor.py index b98a7548c..987e5be7a 100644 --- a/gptqmodel/looper/forward_executor.py +++ b/gptqmodel/looper/forward_executor.py @@ -268,7 +268,13 @@ def run_single( self.looper._set_processor_mask(processor, keep_mask) additional_inputs: Dict[str, Optional[torch.Tensor]] = {} if self.looper.support_batch_quantize and attn_tensor is not None: - additional_inputs["attention_mask"] = attn_tensor + # Some architectures expect a 2D keep-mask ([B, S]) instead of + # a broadcasted 4D mask during layer replay. + if self.looper.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT and keep_mask is not None: + required_dtype = getattr(self.looper.gptq_model, "ATTENTION_MASKS_DTYPE", torch.bool) + additional_inputs["attention_mask"] = keep_mask.to(dtype=required_dtype) + else: + additional_inputs["attention_mask"] = attn_tensor else: additional_inputs["attention_mask"] = None diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py index f910a90bd..3a90a13c6 100644 --- a/gptqmodel/looper/stage_inputs_capture.py +++ b/gptqmodel/looper/stage_inputs_capture.py @@ -195,14 +195,11 @@ def store_input_hook(module, args, kwargs): # TODO: why data_device sometimes set to cuda (self.gptq_model.quantize_config.device) and sometimes to CPU (cur_layer_device)? try: for batch_index, example in enumerate(calibration_data, start=1): - if self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT: - data_device = self.gptq_model.quantize_config.device - else: - data_device = ( - self.gptq_model.quantize_config.device - if "pixel_values" in example.keys() - else cur_layer_device - ) + data_device = ( + self.gptq_model.quantize_config.device + if "pixel_values" in example.keys() + else cur_layer_device + ) example = self.gptq_model.move_input_capture_example(example, data_device) try: with ctx( diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index 3a1c81f34..3ae13e7a9 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -21,6 +21,7 @@ from concurrent.futures import as_completed from typing import TYPE_CHECKING, Dict, List, Optional +import defuser from defuser.modeling.replace_modules import materialize_model from ..nn_modules.hooked_linear import replace_module_with_hooked_legacy from ..nn_modules.converter import MODULE_CONVERTER_MAP @@ -36,7 +37,7 @@ from ..utils.device import get_device, get_device_new from ..utils.looper_helpers import normalize_device_like from ..utils.logger import live_renderables_suppressed, log_time_block, setup_logger -from ..utils.model import find_modules, get_module +from ..utils.model import _module_has_meta_tensors, find_modules, get_module from ..utils.offload import offload_to_disk from ..utils.torch import CPU, torch_empty_cache, torch_sync from .stage_subset import SubsetPlan, build_layer_subset_plans, run_subset_stage @@ -130,6 +131,38 @@ def _should_empty_cache_after_sync_finalize( return any(isinstance(process, ParoQuantProcessor) for process, *_ in finalize_tasks) +def _materialize_remaining_meta_submodules( + looper: "ModuleLooper", + *, + module: torch.nn.Module, + device: torch.device, +) -> torch.nn.Module: + """Best-effort materialization for any submodule that still carries meta tensors.""" + if not _module_has_meta_tensors(module): + return module + + # Process deeper modules first so parent wrappers observe materialized children. + names = [name for name, _ in module.named_modules() if name] + names.sort(key=lambda n: n.count("."), reverse=True) + + for name in names: + try: + submodule = get_module(module, name) + except Exception: + continue + if not isinstance(submodule, torch.nn.Module): + continue + if not _module_has_meta_tensors(submodule): + continue + looper.gptq_model.shell_module_materialize( + target_submodule=submodule, + device=device, + role="forward", + ) + + return module + + def _processor_needs_pristine_group_clone(processor) -> bool: """Whether grouped capture needs a dedicated pristine layer clone for this processor.""" needs_clone = getattr(processor, "needs_pristine_layer_clone", None) @@ -475,12 +508,6 @@ def run_layer_stage( if needs_group_pristine: pristine_group_module = copy.deepcopy(module) if needs_pristine_group_clone else None - replace_module_with_hooked_legacy( - module, - quant_lm_head=looper.gptq_model.quantize_config.lm_head, - skip_module_paths=hook_skip_modules, - ) - layers[layer_index] = module if layers_prefix: @@ -488,12 +515,60 @@ def run_layer_stage( else: layer_descriptor = str(layer_index) + # Materialize the original shell modules first; hooking wrappers before + # this can prevent defuser materialization from resolving lazy weights. materialize_model(module) + if _module_has_meta_tensors(module): + materialize_device = normalize_device_like(looper.gptq_model.quantize_config.device) or CPU + module = looper.gptq_model.shell_module_materialize( + target_submodule=module, + device=materialize_device, + role="forward", + ) + module = _materialize_remaining_meta_submodules( + looper, + module=module, + device=materialize_device, + ) + layers[layer_index] = module + + # LazyTurtle materializes checkpoint tensors by runtime module path. + # Defuser conversion can mutate those paths/projections before all shell + # tensors are resolved, leaving newly introduced parameters on meta. + # Keep conversion for eager-loaded models, but skip during lazy shell + # quantization where checkpoint-backed materialization is authoritative. + if ( + looper.gptq_model.turtle_model is None + and not getattr(module, "_gptqmodel_defuser_converted", False) + ): + defuser.convert_model(module, cleanup_original=False) + setattr(module, "_gptqmodel_defuser_converted", True) + + replace_module_with_hooked_legacy( + module, + quant_lm_head=looper.gptq_model.quantize_config.lm_head, + skip_module_paths=hook_skip_modules, + ) cur_layer_device = get_device(module) if getattr(cur_layer_device, "type", None) == "meta": # Lazy shell layers can stay meta until a later subset stage materializes them. cur_layer_device = normalize_device_like(looper.gptq_model.quantize_config.device) or CPU + if getattr(cur_layer_device, "type", None) == "cpu": + quant_devices = getattr(looper, "_quant_devices", None) or [] + # Keep replay on an accelerator when quant device policy resolved one; + # otherwise large attention replays can silently fall back to CPU. + accel_device = next( + ( + normalize_device_like(device) + for device in quant_devices + if (normalize_device_like(device) is not None and normalize_device_like(device).type != "cpu") + ), + None, + ) + if accel_device is not None: + module.to(accel_device) + cur_layer_device = accel_device full = find_modules(module, name=looper.gptq_model.lm_head if is_lm_head_module else "") for p_index, processor in enumerate(looper.processors): diff --git a/gptqmodel/models/definitions/minimax_m2.py b/gptqmodel/models/definitions/minimax_m2.py index 2592cc51a..45f11fcdf 100644 --- a/gptqmodel/models/definitions/minimax_m2.py +++ b/gptqmodel/models/definitions/minimax_m2.py @@ -20,6 +20,7 @@ class MiniMaxM2GPTQ(BaseQModel): layer_modules_strict = False dynamic_expert_index = "num_local_experts" + ATTENTION_MASKS_REQUIRED_FOR_INPUT = True # MoE lifecycle hooks for w1/w3/w2 pattern moe_lifecycle_hooks = W1W3W2MoELifecycleHooks() diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index a4a839bcb..226d74306 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -63,6 +63,7 @@ from ..utils.machete import _validate_machete_device_support from ..utils.marlin import _marlin_capability_supported, _validate_marlin_device_support from ..utils.model import ( + _module_has_meta_tensors, auto_dtype, convert_gptq_v1_to_v2_format, find_config_seq_len, @@ -642,12 +643,18 @@ def skip(*args, **kwargs): ) if getattr(model, "config", None) is config: model.config = copy.deepcopy(config) - defuser.convert_model(model, cleanup_original=False) + if _module_has_meta_tensors(model): + log.info("Loader: defuser.convert_model deferred until layer materialization (meta shell detected)") + else: + defuser.convert_model(model, cleanup_original=False) model._model_init_kwargs = fallback_init_kwargs _maybe_print_module_tree(model=model) turtle_model = None else: - defuser.convert_model(model, cleanup_original=False) + if _module_has_meta_tensors(model): + log.info("Loader: defuser.convert_model deferred until layer materialization (meta shell detected)") + else: + defuser.convert_model(model, cleanup_original=False) shell_model_init_kwargs = dict(model_init_kwargs_without_internal) shell_model_init_kwargs.update(hf_gguf_load_kwargs) model._model_init_kwargs = shell_model_init_kwargs @@ -682,7 +689,10 @@ def skip(*args, **kwargs): ) if getattr(model, "config", None) is config: model.config = copy.deepcopy(config) - defuser.convert_model(model, cleanup_original=False) + if _module_has_meta_tensors(model): + log.info("Loader: defuser.convert_model deferred until layer materialization (meta shell detected)") + else: + defuser.convert_model(model, cleanup_original=False) direct_model_init_kwargs = dict(model_init_kwargs_without_internal) direct_model_init_kwargs.update(hf_gguf_load_kwargs) model._model_init_kwargs = direct_model_init_kwargs diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index c956a5573..0a330ef4f 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -52,6 +52,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) + # `meta` is a shape-only device and cannot receive real tensor data. + # During shell/offload flows, some replay paths may pass meta inputs + # while weights execute on a materialized device (e.g. CPU/CUDA). + # In that case, keep the computed output on its real device. + if original_device.type == "meta": + return output + if original_device.type == "meta" or output.device.type == "meta": + return output + if original_device.type == "meta" or output.device.type == "meta": + return output if output.device != original_device: output = output.to(device=original_device) return output @@ -248,6 +258,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.forward_hook(self, (input,), output) if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) + if original_device.type == "meta" or output.device.type == "meta": + return output if output.device != original_device: output = output.to(device=original_device) return output diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 44e892767..5efb0dcba 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -249,10 +249,17 @@ def _module_has_meta_tensors(module: nn.Module) -> bool: def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dtype = None): if isinstance(obj, nn.Module) and _module_has_meta_tensors(obj): if not accelerate.utils.has_offloaded_params(obj): - raise NotImplementedError( - "Cannot move a module that still contains meta tensors without offload hooks. " - "Materialize it first before calling move_to()." - ) + # Some quant/offload paths can leave non-participating leaves on meta. + # Best-effort move concrete tensors and keep meta placeholders as-is. + for _, param in obj.named_parameters(recurse=True): + if getattr(param, "is_meta", False) or param.device.type == "meta": + continue + param.data = param.data.to(device=device, dtype=dtype, non_blocking=False) + for _, buf in obj.named_buffers(recurse=True): + if getattr(buf, "is_meta", False) or buf.device.type == "meta": + continue + buf.data = buf.data.to(device=device, dtype=dtype, non_blocking=False) + return obj # Accelerate disk-offloaded modules keep meta placeholders until they are # explicitly restored, so materialize those leaves before the device move. diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index 7a8c805e2..a9a765ee0 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -1931,7 +1931,7 @@ def _copy_checkpoint_tensors_into_submodule( t_params = dict(target_submodule.named_parameters(recurse=recurse)) t_bufs = dict(target_submodule.named_buffers(recurse=recurse)) modules_by_name = dict(target_model.named_modules()) - missing_nonpersistent_buffers: list[tuple[str, str]] = [] + missing_template_buffers: list[tuple[str, str, bool]] = [] grouped_names: Dict[str, list[tuple[str, str, str, Optional[int], Optional[int], Optional[int]]]] = {} for rel_name in t_params: @@ -1967,17 +1967,12 @@ def _copy_checkpoint_tensors_into_submodule( if shard is None: t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name) non_persistent = leaf in getattr(t_parent, "_non_persistent_buffers_set", set()) - if non_persistent: - if ( - getattr(target_buffer, "is_meta", False) - or target_buffer.device.type == "meta" - or target_buffer.device != device - ): - missing_nonpersistent_buffers.append((rel_name, leaf)) - continue - if getattr(target_buffer, "is_meta", False) or target_buffer.device.type == "meta": - if leaf in getattr(t_parent, "_buffers", {}): - del t_parent._buffers[leaf] + if ( + getattr(target_buffer, "is_meta", False) + or target_buffer.device.type == "meta" + or target_buffer.device != device + ): + missing_template_buffers.append((rel_name, leaf, not non_persistent)) continue grouped_names.setdefault(shard, []).append(("buffer", rel_name, full_name, expert_index, split_index, split_dim)) @@ -2100,14 +2095,34 @@ def _copy_checkpoint_tensors_into_submodule( source = source.to(dtype=target_buffer.dtype) target_buffer.copy_(source, non_blocking=(non_blocking and source.is_pinned())) - self._restore_missing_nonpersistent_buffers( + self._restore_missing_template_buffers( target_model=target_model, target_submodule=target_submodule, t_bufs=t_bufs, - missing_nonpersistent_buffers=missing_nonpersistent_buffers, + missing_template_buffers=missing_template_buffers, device=device, ) + # Some runtimes expose fused projection names (e.g. qkv/gate_up) that + # are not present verbatim in checkpoint shards. Resolve any remaining + # meta tensors via direct/meta synthesis pass before returning. + param_cache: Dict[tuple[str, Optional[int], Optional[int], Optional[int], torch.dtype, bool], nn.Parameter] = {} + buffer_cache: Dict[tuple[str, Optional[int], Optional[int], Optional[int], torch.dtype], torch.Tensor] = {} + local_names = [name for name, _ in target_submodule.named_modules() if name] + local_names.sort(key=lambda item: item.count("."), reverse=True) + local_names.append("") + for local_name in local_names: + shell_sub = target_submodule if not local_name else dict(target_submodule.named_modules()).get(local_name) + if shell_sub is None: + continue + shell_path = module_path if not local_name else f"{module_path}.{local_name}" + self._materialize_direct_meta_tensors( + shell_sub=shell_sub, + module_path=shell_path, + param_cache=param_cache, + buffer_cache=buffer_cache, + ) + def _build_nonpersistent_buffer_template( self, *, @@ -2177,19 +2192,19 @@ def _build_nonpersistent_buffer_template( ) return None - def _restore_missing_nonpersistent_buffers( + def _restore_missing_template_buffers( self, *, target_model: nn.Module, target_submodule: nn.Module, t_bufs: Dict[str, torch.Tensor], - missing_nonpersistent_buffers: list[tuple[str, str]], + missing_template_buffers: list[tuple[str, str, bool]], device: torch.device, ) -> None: - """Restore constructor-owned buffers that are intentionally absent from checkpoints.""" + """Restore constructor-owned buffers that are absent from checkpoints.""" owner_templates: Dict[str, Optional[nn.Module]] = {} - for rel_name, leaf in missing_nonpersistent_buffers: + for rel_name, leaf, persistent in missing_template_buffers: parent_rel_path, _, _ = rel_name.rpartition(".") owner_module = target_submodule if not parent_rel_path else dict(target_submodule.named_modules()).get(parent_rel_path) if owner_module is None: @@ -2218,7 +2233,7 @@ def _restore_missing_nonpersistent_buffers( target_dtype = source_buffer.dtype if current_buffer is None else current_buffer.dtype materialized = source_buffer.to(device=device, dtype=target_dtype) - owner_module.register_buffer(leaf, materialized, persistent=False) + owner_module.register_buffer(leaf, materialized, persistent=persistent) t_bufs[rel_name] = materialized def _materialize_direct_meta_tensors( @@ -2231,54 +2246,184 @@ def _materialize_direct_meta_tensors( ) -> int: synced = 0 + def _resolve_first_existing_weight(*names: str) -> Optional[torch.Tensor]: + for candidate in names: + if not candidate: + continue + shard = self._weight_map.get(candidate) + if shard is None: + continue + source_path = os.path.join(self.model_local_path, shard) + with safe_open(source_path, framework="pt", device="cpu") as handler: + return handler.get_tensor(candidate) + return None + + def _synthesize_runtime_param(module_path: str, rel_name: str, target_shape: tuple[int, ...]) -> Optional[torch.Tensor]: + full_runtime_name = self._join_tensor_name(module_path, rel_name) + aliased_runtime_names = self._all_runtime_to_checkpoint_candidates(full_runtime_name) + module_leaf = module_path.rsplit(".", 1)[-1] if module_path else "" + + # Build qkv from legacy q/k/v checkpoint projections when combined qkv is absent. + if rel_name == "weight" and module_leaf == "qkv_proj": + for runtime_name in aliased_runtime_names: + prefix, _, _ = runtime_name.rpartition(".qkv_proj.weight") + if not prefix: + continue + q = _resolve_first_existing_weight( + *self._all_runtime_to_checkpoint_candidates(f"{prefix}.q_proj.weight") + ) + k = _resolve_first_existing_weight( + *self._all_runtime_to_checkpoint_candidates(f"{prefix}.k_proj.weight") + ) + v = _resolve_first_existing_weight( + *self._all_runtime_to_checkpoint_candidates(f"{prefix}.v_proj.weight") + ) + if q is None or k is None or v is None: + continue + if k.shape[0] != q.shape[0]: + if q.shape[0] % k.shape[0] != 0: + continue + k = k.repeat_interleave(q.shape[0] // k.shape[0], dim=0) + if v.shape[0] != q.shape[0]: + if q.shape[0] % v.shape[0] != 0: + continue + v = v.repeat_interleave(q.shape[0] // v.shape[0], dim=0) + fused = torch.cat([q, k, v], dim=0).contiguous() + if tuple(fused.shape) == target_shape: + return fused + return None + + # Align HF `o_proj` naming with runtime `out_proj` when needed. + if rel_name == "weight" and module_leaf == "out_proj": + for runtime_name in aliased_runtime_names: + candidate = runtime_name.replace(".out_proj.weight", ".o_proj.weight") + tensor = _resolve_first_existing_weight( + candidate, + *self._all_runtime_to_checkpoint_candidates(candidate), + ) + if tensor is not None and tuple(tensor.shape) == target_shape: + return tensor.contiguous() + return None + + # Some checkpoints expose per-head norms while runtime expects one combined norm. + if rel_name == "weight" and module_leaf == "norm": + for runtime_name in aliased_runtime_names: + qn = runtime_name.replace(".norm.weight", ".q_norm.weight") + kn = runtime_name.replace(".norm.weight", ".k_norm.weight") + q_tensor = _resolve_first_existing_weight(qn, *self._all_runtime_to_checkpoint_candidates(qn)) + k_tensor = _resolve_first_existing_weight(kn, *self._all_runtime_to_checkpoint_candidates(kn)) + source = q_tensor if q_tensor is not None else k_tensor + if source is not None and tuple(source.shape) == target_shape: + return source.contiguous() + return None + + # Legacy MiniMax checkpoints do not carry output_gate; materialize deterministic zeros. + if rel_name == "weight" and module_leaf == "output_gate": + return torch.zeros(target_shape, dtype=torch.float32) + + # Build fused MoE projections from per-expert w1/w2/w3 tensors. + if rel_name in {"gate_up_proj", "down_proj"} and module_leaf == "experts" and len(target_shape) == 3: + expert_count = target_shape[0] + prefix, _, _ = full_runtime_name.rpartition(f".{rel_name}") + if not prefix: + return None + checkpoint_prefix = prefix.replace(".mlp.experts", ".block_sparse_moe.experts") + fused_rows = [] + for expert_idx in range(expert_count): + base = f"{checkpoint_prefix}.{expert_idx}" + if rel_name == "gate_up_proj": + w1 = _resolve_first_existing_weight( + base + ".w1.weight", + *self._all_runtime_to_checkpoint_candidates(base + ".w1.weight"), + ) + w3 = _resolve_first_existing_weight( + base + ".w3.weight", + *self._all_runtime_to_checkpoint_candidates(base + ".w3.weight"), + ) + if w1 is None or w3 is None: + return None + fused_rows.append(torch.cat([w1, w3], dim=0).contiguous()) + else: + w2 = _resolve_first_existing_weight( + base + ".w2.weight", + *self._all_runtime_to_checkpoint_candidates(base + ".w2.weight"), + ) + if w2 is None: + return None + fused_rows.append(w2.contiguous()) + fused = torch.stack(fused_rows, dim=0).contiguous() + if tuple(fused.shape) == target_shape: + return fused + return None + + return None + with torch.inference_mode(): for name, shell_param in dict(shell_sub.named_parameters(recurse=False)).items(): if not _is_meta_tensor(shell_param): continue full_name, expert_index, split_index, split_dim = self._resolve_checkpoint_tensor_source(module_path, name) + source_param = None if full_name is None: - continue - shard = self._weight_map.get(full_name) - if shard is None: - raise RuntimeError(self._materialization_issue_message( - phase="direct-meta sync", - kind="param", - module_path=module_path, - rel_name=name, - reason="checkpoint tensor mapping resolved to a missing shard", - full_name=full_name, - target_shape=tuple(shell_param.shape), - expert_index=expert_index, - split_index=split_index, - split_dim=split_dim, - )) + source_param = _synthesize_runtime_param(module_path, name, tuple(shell_param.shape)) + if source_param is None: + if module_path.endswith(("qkv_proj", "out_proj", "output_gate", "norm", "experts")): + log.debug( + "LazyTurtle synth miss: module_path=%s rel=%s shape=%s", + module_path, + name, + tuple(shell_param.shape), + ) + continue + if module_path.endswith(("qkv_proj", "out_proj", "output_gate", "norm", "experts")): + log.debug( + "LazyTurtle synth hit: module_path=%s rel=%s out_shape=%s", + module_path, + name, + tuple(source_param.shape), + ) + else: + shard = self._weight_map.get(full_name) + if shard is None: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="param", + module_path=module_path, + rel_name=name, + reason="checkpoint tensor mapping resolved to a missing shard", + full_name=full_name, + target_shape=tuple(shell_param.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) - source_path = os.path.join(self.model_local_path, shard) - with safe_open(source_path, framework="pt", device="cpu") as handler: - checkpoint_param = handler.get_tensor(full_name) - source_param = self._transform_checkpoint_tensor( - checkpoint_param, - expert_index=expert_index, - split_index=split_index, - split_dim=split_dim, - expected_shape=tuple(shell_param.shape), - prefer_transposed=getattr(shell_sub, "is_transposed", None), - ) - if source_param is None: - raise RuntimeError(self._materialization_issue_message( - phase="direct-meta sync", - kind="param", - module_path=module_path, - rel_name=name, - reason="checkpoint tensor could not be reshaped into the target layout", - full_name=full_name, - source_shape=tuple(checkpoint_param.shape), - target_shape=tuple(shell_param.shape), + source_path = os.path.join(self.model_local_path, shard) + with safe_open(source_path, framework="pt", device="cpu") as handler: + checkpoint_param = handler.get_tensor(full_name) + source_param = self._transform_checkpoint_tensor( + checkpoint_param, expert_index=expert_index, split_index=split_index, split_dim=split_dim, - )) + expected_shape=tuple(shell_param.shape), + prefer_transposed=getattr(shell_sub, "is_transposed", None), + ) + if source_param is None: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="param", + module_path=module_path, + rel_name=name, + reason="checkpoint tensor could not be reshaped into the target layout", + full_name=full_name, + source_shape=tuple(checkpoint_param.shape), + target_shape=tuple(shell_param.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) if shell_param.shape != source_param.shape: raise RuntimeError(self._materialization_issue_message( @@ -2295,7 +2440,8 @@ def _materialize_direct_meta_tensors( split_dim=split_dim, )) - cache_key = (full_name, expert_index, split_index, split_dim, shell_param.dtype, shell_param.requires_grad) + cache_name = full_name or f"__synth__:{module_path}.{name}" + cache_key = (cache_name, expert_index, split_index, split_dim, shell_param.dtype, shell_param.requires_grad) new_param = param_cache.get(cache_key) if new_param is None: if source_param.dtype != shell_param.dtype: