Skip to content

Commit f8b1395

Browse files
committed
update
1 parent c482b0d commit f8b1395

6 files changed

Lines changed: 251 additions & 143 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@
412412
title: WanTransformer3DModel
413413
- local: api/models/z_image_transformer2d
414414
title: ZImageTransformer2DModel
415+
- local: api/models/ernie_image_transformer2d
416+
title: ErnieImageTransformer2DModel
415417
title: Transformers
416418
- sections:
417419
- local: api/models/stable_cascade_unet
@@ -634,6 +636,8 @@
634636
title: VisualCloze
635637
- local: api/pipelines/z_image
636638
title: Z-Image
639+
- local: api/pipelines/ernie_image
640+
title: ERNIE-Image
637641
title: Image
638642
- sections:
639643
- local: api/pipelines/llada2

docs/source/en/api/pipelines/ernie_image.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ from diffusers.utils import load_image
4646

4747
pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16)
4848
pipe.to("cuda")
49-
# 如果显存不足,可以开启offload
49+
# If you are running low on GPU VRAM, you can enable offloading
5050
pipe.enable_model_cpu_offload()
5151

5252
prompt = "一只黑白相间的中华田园犬"
@@ -55,8 +55,8 @@ images = pipe(
5555
height=1024,
5656
width=1024,
5757
num_inference_steps=50,
58-
guidance_scale=5.0,
59-
generator=generator,
58+
guidance_scale=4.0,
59+
generator=torch.Generator("cuda").manual_seed(42),
6060
use_pe=True,
6161
).images
6262
images[0].save("ernie-image-output.png")
@@ -69,7 +69,7 @@ from diffusers.utils import load_image
6969

7070
pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16)
7171
pipe.to("cuda")
72-
# 如果显存不足,可以开启offload
72+
# If you are running low on GPU VRAM, you can enable offloading
7373
pipe.enable_model_cpu_offload()
7474

7575
prompt = "一只黑白相间的中华田园犬"
@@ -78,8 +78,8 @@ images = pipe(
7878
height=1024,
7979
width=1024,
8080
num_inference_steps=8,
81-
guidance_scale=5.0,
82-
generator=generator,
81+
guidance_scale=1.0,
82+
generator=torch.Generator("cuda").manual_seed(42),
8383
use_pe=True,
8484
).images
8585
images[0].save("ernie-image-turbo-output.png")

fix_turbo_weight_keys.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
将 ERNIE-Image-Turbo/transformer 的权重键名修正为与 ERNIE-Image/transformer 一致。
3+
4+
差异均位于每层 self_attention 子模块,共 6 类 × 36 层 = 216 个键需要重命名:
5+
k_layernorm -> norm_k
6+
q_layernorm -> norm_q
7+
k_proj -> to_k
8+
q_proj -> to_q
9+
v_proj -> to_v
10+
linear_proj -> to_out.0
11+
"""
12+
13+
import json
14+
import os
15+
import shutil
16+
from pathlib import Path
17+
18+
import torch
19+
from safetensors.torch import load_file, save_file
20+
21+
# ── 路径配置 ──────────────────────────────────────────────────────────────────
22+
TURBO_DIR = Path("/root/paddlejob/gpfsspace/model_weights/turbo/ERNIE-Image-Turbo/transformer")
23+
# 修正后的文件直接覆盖原目录(先备份),如需输出到新目录请修改此变量
24+
OUTPUT_DIR = TURBO_DIR # 或改为 Path("/your/output/path")
25+
BACKUP_SUFFIX = ".bak" # 原文件备份后缀,设为 None 则不备份
26+
27+
# ── 键名映射(只处理 self_attention 子键,前缀 layers.N. 由脚本动态拼接)───
28+
KEY_REMAP = {
29+
"self_attention.k_layernorm.weight": "self_attention.norm_k.weight",
30+
"self_attention.q_layernorm.weight": "self_attention.norm_q.weight",
31+
"self_attention.k_proj.weight": "self_attention.to_k.weight",
32+
"self_attention.q_proj.weight": "self_attention.to_q.weight",
33+
"self_attention.v_proj.weight": "self_attention.to_v.weight",
34+
"self_attention.linear_proj.weight": "self_attention.to_out.0.weight",
35+
}
36+
37+
NUM_LAYERS = 36 # layers.0 ~ layers.35
38+
39+
40+
def build_full_remap() -> dict[str, str]:
41+
"""构建完整的旧键名 -> 新键名映射表(含层前缀)。"""
42+
remap = {}
43+
for layer_idx in range(NUM_LAYERS):
44+
prefix = f"layers.{layer_idx}."
45+
for old_suffix, new_suffix in KEY_REMAP.items():
46+
remap[prefix + old_suffix] = prefix + new_suffix
47+
return remap
48+
49+
50+
def rename_keys_in_tensor_dict(
51+
tensors: dict[str, torch.Tensor],
52+
remap: dict[str, str],
53+
) -> tuple[dict[str, torch.Tensor], int]:
54+
"""重命名张量字典中的键,返回新字典和实际重命名的数量。"""
55+
renamed = 0
56+
new_tensors: dict[str, torch.Tensor] = {}
57+
for key, tensor in tensors.items():
58+
new_key = remap.get(key, key)
59+
if new_key != key:
60+
renamed += 1
61+
new_tensors[new_key] = tensor
62+
return new_tensors, renamed
63+
64+
65+
def backup_file(path: Path) -> None:
66+
if BACKUP_SUFFIX is None:
67+
return
68+
backup = path.with_suffix(path.suffix + BACKUP_SUFFIX)
69+
shutil.copy2(path, backup)
70+
print(f" [备份] {path.name} -> {backup.name}")
71+
72+
73+
def process_safetensors_files(remap: dict[str, str]) -> None:
74+
index_path = TURBO_DIR / "diffusion_pytorch_model.safetensors.index.json"
75+
with open(index_path, "r", encoding="utf-8") as f:
76+
index = json.load(f)
77+
78+
# 找出所有需要处理的 shard 文件(去重)
79+
shard_files = sorted(set(index["weight_map"].values()))
80+
print(f"\n共发现 {len(shard_files)} 个 shard 文件,开始处理...\n")
81+
82+
total_renamed = 0
83+
for shard_name in shard_files:
84+
shard_path = TURBO_DIR / shard_name
85+
print(f"[处理] {shard_name}")
86+
87+
tensors = load_file(shard_path)
88+
new_tensors, renamed = rename_keys_in_tensor_dict(tensors, remap)
89+
total_renamed += renamed
90+
print(f" 本文件重命名: {renamed} 个键")
91+
92+
if renamed > 0:
93+
# 保留原始 metadata(如果有)
94+
metadata = {}
95+
96+
out_path = OUTPUT_DIR / shard_name
97+
if out_path == shard_path and BACKUP_SUFFIX:
98+
backup_file(shard_path)
99+
100+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
101+
save_file(new_tensors, out_path, metadata=metadata)
102+
print(f" [保存] {out_path}")
103+
else:
104+
if OUTPUT_DIR != TURBO_DIR:
105+
shutil.copy2(shard_path, OUTPUT_DIR / shard_name)
106+
print(f" [复制(无变更)] {shard_name}")
107+
108+
print(f"\n所有 shard 处理完毕,共重命名 {total_renamed} 个键。")
109+
110+
# ── 更新 index.json 中的 weight_map ─────────────────────────────────────
111+
new_weight_map: dict[str, str] = {}
112+
for old_key, shard_name in index["weight_map"].items():
113+
new_key = remap.get(old_key, old_key)
114+
new_weight_map[new_key] = shard_name
115+
116+
index["weight_map"] = new_weight_map
117+
118+
out_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json"
119+
if out_index_path == index_path and BACKUP_SUFFIX:
120+
backup_file(index_path)
121+
122+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
123+
with open(out_index_path, "w", encoding="utf-8") as f:
124+
json.dump(index, f, indent=2, ensure_ascii=False)
125+
print(f"[更新] index.json 已写入: {out_index_path}\n")
126+
127+
128+
def verify_against_base() -> None:
129+
"""(可选)验证修正后的 Turbo 键名与 Base 完全一致。"""
130+
BASE_DIR = Path("/root/paddlejob/gpfsspace/model_weights/base/ERNIE-Image/transformer")
131+
base_index_path = BASE_DIR / "diffusion_pytorch_model.safetensors.index.json"
132+
turbo_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json"
133+
134+
if not base_index_path.exists() or not turbo_index_path.exists():
135+
print("[验证] 找不到 index.json,跳过验证。")
136+
return
137+
138+
with open(base_index_path, "r") as f:
139+
base_keys = set(json.load(f)["weight_map"].keys())
140+
with open(turbo_index_path, "r") as f:
141+
turbo_keys = set(json.load(f)["weight_map"].keys())
142+
143+
only_in_base = base_keys - turbo_keys
144+
only_in_turbo = turbo_keys - base_keys
145+
146+
if not only_in_base and only_in_turbo:
147+
print(f"[验证] 警告:Turbo 中多余的键 ({len(only_in_turbo)}):")
148+
for k in sorted(only_in_turbo):
149+
print(f" + {k}")
150+
elif only_in_base:
151+
print(f"[验证] 警告:Base 中存在但 Turbo 中缺少的键 ({len(only_in_base)}):")
152+
for k in sorted(only_in_base):
153+
print(f" - {k}")
154+
else:
155+
print("[验证] 通过!修正后 Turbo 的键名与 Base 完全一致。")
156+
157+
158+
if __name__ == "__main__":
159+
remap = build_full_remap()
160+
print(f"键名映射表共 {len(remap)} 条({NUM_LAYERS} 层 × {len(KEY_REMAP)} 类)")
161+
162+
process_safetensors_files(remap)
163+
verify_against_base()

src/diffusers/models/transformers/transformer_ernie_image.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@
2828
from ..embeddings import Timesteps
2929
from ..embeddings import TimestepEmbedding
3030
from ..modeling_utils import ModelMixin
31-
from ...utils import BaseOutput
31+
from ...utils import BaseOutput, logging
3232
from ..normalization import RMSNorm
3333
from ..attention_processor import Attention
3434
from ..attention_dispatch import dispatch_attention_fn
3535
from ..attention import AttentionMixin, AttentionModuleMixin
3636

37+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3738

3839
@dataclass
3940
class ErnieImageTransformer2DModelOutput(BaseOutput):
@@ -248,7 +249,13 @@ def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps:
248249
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps)
249250
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
250251

251-
def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None):
252+
def forward(
253+
self,
254+
x,
255+
rotary_pos_emb, temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
256+
attention_mask: torch.Tensor | None = None
257+
):
258+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
252259
residual = x
253260
x = self.adaLN_sa_ln(x)
254261
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
@@ -360,21 +367,17 @@ def forward(
360367
c = self.time_embedding(sample)
361368
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)]
362369
for layer in self.layers:
370+
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
363371
if torch.is_grad_enabled() and self.gradient_checkpointing:
364372
x = self._gradient_checkpointing_func(
365373
layer,
366374
x,
367375
rotary_pos_emb,
368-
shift_msa,
369-
scale_msa,
370-
gate_msa,
371-
shift_mlp,
372-
scale_mlp,
373-
gate_mlp,
376+
temb,
374377
attention_mask,
375378
)
376379
else:
377-
x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask)
380+
x = layer(x, rotary_pos_emb, temb, attention_mask)
378381
x = self.final_norm(x, c).type_as(x)
379382
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
380383
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def __init__(
6767
pe=pe,
6868
pe_tokenizer=pe_tokenizer,
6969
)
70-
self.vae_scale_factor = 16 # VAE downsample factor
70+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
71+
print(f"vae_scale_factor: {self.vae_scale_factor}")
7172

7273
@property
7374
def guidance_scale(self):
@@ -278,7 +279,7 @@ def __call__(
278279
# Latent dimensions
279280
latent_h = height // self.vae_scale_factor
280281
latent_w = width // self.vae_scale_factor
281-
latent_channels = 128 # After patchify
282+
latent_channels = self.transformer.config.in_channels # After patchify
282283

283284
# Initialize latents
284285
if latents is None:

0 commit comments

Comments
 (0)