|
16 | 16 | import copy |
17 | 17 | import glob |
18 | 18 | import inspect |
| 19 | +import json |
19 | 20 | import os |
| 21 | +import re |
20 | 22 | import shutil |
21 | 23 | import sys |
22 | 24 | import warnings |
|
27 | 29 | import transformers |
28 | 30 | from accelerate import infer_auto_device_map, init_empty_weights |
29 | 31 | from accelerate.utils import get_max_memory |
| 32 | +from safetensors.torch import load_file |
30 | 33 | from transformers import ( |
31 | 34 | AutoConfig, |
32 | 35 | AutoModelForCausalLM, |
@@ -314,6 +317,106 @@ def get_processor( |
314 | 317 | return None |
315 | 318 |
|
316 | 319 |
|
| 320 | +def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]: |
| 321 | + """Load MTP weights from separate safetensors if needed (e.g., GLM-4.7). |
| 322 | +
|
| 323 | + Some models store additional layers in separate safetensors files with non-standard |
| 324 | + names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these |
| 325 | + files even though they're referenced in model.safetensors.index.json. |
| 326 | +
|
| 327 | + This function detects such cases and explicitly loads the missing weights. |
| 328 | +
|
| 329 | + Args: |
| 330 | + model: The loaded model that may be missing weights |
| 331 | + model_path: Path to the model directory |
| 332 | +
|
| 333 | + Returns: |
| 334 | + List of layer prefixes that were loaded from non-standard safetensors files. |
| 335 | + These layers should typically be excluded from quantization. |
| 336 | + Empty list if no additional weights were loaded. |
| 337 | + """ |
| 338 | + model_path = Path(model_path) |
| 339 | + index_file = model_path / "model.safetensors.index.json" |
| 340 | + mtp_layer_prefixes: list[str] = [] |
| 341 | + |
| 342 | + if not index_file.exists(): |
| 343 | + return mtp_layer_prefixes |
| 344 | + |
| 345 | + # Load the index to find all referenced safetensors files |
| 346 | + with open(index_file) as f: |
| 347 | + index = json.load(f) |
| 348 | + |
| 349 | + # Find all unique safetensors files referenced |
| 350 | + all_files = set(index["weight_map"].values()) |
| 351 | + |
| 352 | + # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) |
| 353 | + standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") |
| 354 | + non_standard_files = [f for f in all_files if not standard_pattern.match(f)] |
| 355 | + |
| 356 | + if not non_standard_files: |
| 357 | + return mtp_layer_prefixes |
| 358 | + |
| 359 | + # Check which non-standard files exist and have missing weights |
| 360 | + model_state = model.state_dict() |
| 361 | + total_loaded = 0 |
| 362 | + |
| 363 | + for filename in non_standard_files: |
| 364 | + filepath = model_path / filename |
| 365 | + if not filepath.exists(): |
| 366 | + continue |
| 367 | + |
| 368 | + # Find keys that should be in this file |
| 369 | + expected_keys = [k for k, v in index["weight_map"].items() if v == filename] |
| 370 | + |
| 371 | + # Check which are missing from the model |
| 372 | + missing_keys = [k for k in expected_keys if k not in model_state] |
| 373 | + |
| 374 | + if not missing_keys: |
| 375 | + # Even if weights are loaded, record the layer prefixes for exclusion |
| 376 | + # Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight") |
| 377 | + for key in expected_keys: |
| 378 | + # Extract layer prefix like "model.layers.92" or "layers.92" |
| 379 | + parts = key.split(".") |
| 380 | + for i, part in enumerate(parts): |
| 381 | + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): |
| 382 | + prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" |
| 383 | + if prefix not in mtp_layer_prefixes: |
| 384 | + mtp_layer_prefixes.append(prefix) |
| 385 | + break |
| 386 | + continue |
| 387 | + |
| 388 | + print(f"Loading {len(missing_keys)} missing weights from {filename}...") |
| 389 | + |
| 390 | + # Extract unique layer prefixes for exclusion from quantization |
| 391 | + for key in missing_keys: |
| 392 | + parts = key.split(".") |
| 393 | + for i, part in enumerate(parts): |
| 394 | + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): |
| 395 | + prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" |
| 396 | + if prefix not in mtp_layer_prefixes: |
| 397 | + mtp_layer_prefixes.append(prefix) |
| 398 | + break |
| 399 | + |
| 400 | + # Load the weights to CPU first, load_state_dict will handle device placement |
| 401 | + weights = load_file(str(filepath), device="cpu") |
| 402 | + weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} |
| 403 | + |
| 404 | + # Load into model |
| 405 | + missing, unexpected = model.load_state_dict(weights_to_load, strict=False) |
| 406 | + total_loaded += len(weights_to_load) |
| 407 | + |
| 408 | + if missing: |
| 409 | + print(f" Warning: {len(missing)} keys still missing after loading {filename}") |
| 410 | + |
| 411 | + if total_loaded > 0: |
| 412 | + print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files") |
| 413 | + |
| 414 | + if mtp_layer_prefixes: |
| 415 | + print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}") |
| 416 | + |
| 417 | + return mtp_layer_prefixes |
| 418 | + |
| 419 | + |
317 | 420 | def get_dtype(dtype): |
318 | 421 | if dtype == "bf16": |
319 | 422 | dtype = torch.bfloat16 |
@@ -473,6 +576,12 @@ def get_model( |
473 | 576 | if device == "cuda" and not is_model_on_gpu(model): |
474 | 577 | print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") |
475 | 578 |
|
| 579 | + # Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors) |
| 580 | + # Store the MTP layer prefixes on the model for later exclusion from quantization |
| 581 | + mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path) |
| 582 | + if mtp_layer_prefixes: |
| 583 | + model._mtp_layer_prefixes = mtp_layer_prefixes |
| 584 | + |
476 | 585 | return model |
477 | 586 |
|
478 | 587 |
|
|
0 commit comments