|
| 1 | +""" |
| 2 | +将 ERNIE-Image-Turbo/transformer 的权重键名修正为与 ERNIE-Image/transformer 一致。 |
| 3 | +
|
| 4 | +差异均位于每层 self_attention 子模块,共 6 类 × 36 层 = 216 个键需要重命名: |
| 5 | + k_layernorm -> norm_k |
| 6 | + q_layernorm -> norm_q |
| 7 | + k_proj -> to_k |
| 8 | + q_proj -> to_q |
| 9 | + v_proj -> to_v |
| 10 | + linear_proj -> to_out.0 |
| 11 | +""" |
| 12 | + |
| 13 | +import json |
| 14 | +import os |
| 15 | +import shutil |
| 16 | +from pathlib import Path |
| 17 | + |
| 18 | +import torch |
| 19 | +from safetensors.torch import load_file, save_file |
| 20 | + |
| 21 | +# ── 路径配置 ────────────────────────────────────────────────────────────────── |
| 22 | +TURBO_DIR = Path("/root/paddlejob/gpfsspace/model_weights/turbo/ERNIE-Image-Turbo/transformer") |
| 23 | +# 修正后的文件直接覆盖原目录(先备份),如需输出到新目录请修改此变量 |
| 24 | +OUTPUT_DIR = TURBO_DIR # 或改为 Path("/your/output/path") |
| 25 | +BACKUP_SUFFIX = ".bak" # 原文件备份后缀,设为 None 则不备份 |
| 26 | + |
| 27 | +# ── 键名映射(只处理 self_attention 子键,前缀 layers.N. 由脚本动态拼接)─── |
| 28 | +KEY_REMAP = { |
| 29 | + "self_attention.k_layernorm.weight": "self_attention.norm_k.weight", |
| 30 | + "self_attention.q_layernorm.weight": "self_attention.norm_q.weight", |
| 31 | + "self_attention.k_proj.weight": "self_attention.to_k.weight", |
| 32 | + "self_attention.q_proj.weight": "self_attention.to_q.weight", |
| 33 | + "self_attention.v_proj.weight": "self_attention.to_v.weight", |
| 34 | + "self_attention.linear_proj.weight": "self_attention.to_out.0.weight", |
| 35 | +} |
| 36 | + |
| 37 | +NUM_LAYERS = 36 # layers.0 ~ layers.35 |
| 38 | + |
| 39 | + |
| 40 | +def build_full_remap() -> dict[str, str]: |
| 41 | + """构建完整的旧键名 -> 新键名映射表(含层前缀)。""" |
| 42 | + remap = {} |
| 43 | + for layer_idx in range(NUM_LAYERS): |
| 44 | + prefix = f"layers.{layer_idx}." |
| 45 | + for old_suffix, new_suffix in KEY_REMAP.items(): |
| 46 | + remap[prefix + old_suffix] = prefix + new_suffix |
| 47 | + return remap |
| 48 | + |
| 49 | + |
| 50 | +def rename_keys_in_tensor_dict( |
| 51 | + tensors: dict[str, torch.Tensor], |
| 52 | + remap: dict[str, str], |
| 53 | +) -> tuple[dict[str, torch.Tensor], int]: |
| 54 | + """重命名张量字典中的键,返回新字典和实际重命名的数量。""" |
| 55 | + renamed = 0 |
| 56 | + new_tensors: dict[str, torch.Tensor] = {} |
| 57 | + for key, tensor in tensors.items(): |
| 58 | + new_key = remap.get(key, key) |
| 59 | + if new_key != key: |
| 60 | + renamed += 1 |
| 61 | + new_tensors[new_key] = tensor |
| 62 | + return new_tensors, renamed |
| 63 | + |
| 64 | + |
| 65 | +def backup_file(path: Path) -> None: |
| 66 | + if BACKUP_SUFFIX is None: |
| 67 | + return |
| 68 | + backup = path.with_suffix(path.suffix + BACKUP_SUFFIX) |
| 69 | + shutil.copy2(path, backup) |
| 70 | + print(f" [备份] {path.name} -> {backup.name}") |
| 71 | + |
| 72 | + |
| 73 | +def process_safetensors_files(remap: dict[str, str]) -> None: |
| 74 | + index_path = TURBO_DIR / "diffusion_pytorch_model.safetensors.index.json" |
| 75 | + with open(index_path, "r", encoding="utf-8") as f: |
| 76 | + index = json.load(f) |
| 77 | + |
| 78 | + # 找出所有需要处理的 shard 文件(去重) |
| 79 | + shard_files = sorted(set(index["weight_map"].values())) |
| 80 | + print(f"\n共发现 {len(shard_files)} 个 shard 文件,开始处理...\n") |
| 81 | + |
| 82 | + total_renamed = 0 |
| 83 | + for shard_name in shard_files: |
| 84 | + shard_path = TURBO_DIR / shard_name |
| 85 | + print(f"[处理] {shard_name}") |
| 86 | + |
| 87 | + tensors = load_file(shard_path) |
| 88 | + new_tensors, renamed = rename_keys_in_tensor_dict(tensors, remap) |
| 89 | + total_renamed += renamed |
| 90 | + print(f" 本文件重命名: {renamed} 个键") |
| 91 | + |
| 92 | + if renamed > 0: |
| 93 | + # 保留原始 metadata(如果有) |
| 94 | + metadata = {} |
| 95 | + |
| 96 | + out_path = OUTPUT_DIR / shard_name |
| 97 | + if out_path == shard_path and BACKUP_SUFFIX: |
| 98 | + backup_file(shard_path) |
| 99 | + |
| 100 | + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| 101 | + save_file(new_tensors, out_path, metadata=metadata) |
| 102 | + print(f" [保存] {out_path}") |
| 103 | + else: |
| 104 | + if OUTPUT_DIR != TURBO_DIR: |
| 105 | + shutil.copy2(shard_path, OUTPUT_DIR / shard_name) |
| 106 | + print(f" [复制(无变更)] {shard_name}") |
| 107 | + |
| 108 | + print(f"\n所有 shard 处理完毕,共重命名 {total_renamed} 个键。") |
| 109 | + |
| 110 | + # ── 更新 index.json 中的 weight_map ───────────────────────────────────── |
| 111 | + new_weight_map: dict[str, str] = {} |
| 112 | + for old_key, shard_name in index["weight_map"].items(): |
| 113 | + new_key = remap.get(old_key, old_key) |
| 114 | + new_weight_map[new_key] = shard_name |
| 115 | + |
| 116 | + index["weight_map"] = new_weight_map |
| 117 | + |
| 118 | + out_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json" |
| 119 | + if out_index_path == index_path and BACKUP_SUFFIX: |
| 120 | + backup_file(index_path) |
| 121 | + |
| 122 | + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| 123 | + with open(out_index_path, "w", encoding="utf-8") as f: |
| 124 | + json.dump(index, f, indent=2, ensure_ascii=False) |
| 125 | + print(f"[更新] index.json 已写入: {out_index_path}\n") |
| 126 | + |
| 127 | + |
| 128 | +def verify_against_base() -> None: |
| 129 | + """(可选)验证修正后的 Turbo 键名与 Base 完全一致。""" |
| 130 | + BASE_DIR = Path("/root/paddlejob/gpfsspace/model_weights/base/ERNIE-Image/transformer") |
| 131 | + base_index_path = BASE_DIR / "diffusion_pytorch_model.safetensors.index.json" |
| 132 | + turbo_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json" |
| 133 | + |
| 134 | + if not base_index_path.exists() or not turbo_index_path.exists(): |
| 135 | + print("[验证] 找不到 index.json,跳过验证。") |
| 136 | + return |
| 137 | + |
| 138 | + with open(base_index_path, "r") as f: |
| 139 | + base_keys = set(json.load(f)["weight_map"].keys()) |
| 140 | + with open(turbo_index_path, "r") as f: |
| 141 | + turbo_keys = set(json.load(f)["weight_map"].keys()) |
| 142 | + |
| 143 | + only_in_base = base_keys - turbo_keys |
| 144 | + only_in_turbo = turbo_keys - base_keys |
| 145 | + |
| 146 | + if not only_in_base and only_in_turbo: |
| 147 | + print(f"[验证] 警告:Turbo 中多余的键 ({len(only_in_turbo)}):") |
| 148 | + for k in sorted(only_in_turbo): |
| 149 | + print(f" + {k}") |
| 150 | + elif only_in_base: |
| 151 | + print(f"[验证] 警告:Base 中存在但 Turbo 中缺少的键 ({len(only_in_base)}):") |
| 152 | + for k in sorted(only_in_base): |
| 153 | + print(f" - {k}") |
| 154 | + else: |
| 155 | + print("[验证] 通过!修正后 Turbo 的键名与 Base 完全一致。") |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + remap = build_full_remap() |
| 160 | + print(f"键名映射表共 {len(remap)} 条({NUM_LAYERS} 层 × {len(KEY_REMAP)} 类)") |
| 161 | + |
| 162 | + process_safetensors_files(remap) |
| 163 | + verify_against_base() |
0 commit comments