Skip to content

Commit 24f454c

Browse files
committed
2.1
1 parent dd9775c commit 24f454c

5 files changed

Lines changed: 71 additions & 53 deletions

File tree

src/diffusers/loaders/single_file_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
create_controlnet_diffusers_config_from_ldm,
5555
create_unet_diffusers_config_from_ldm,
5656
create_vae_diffusers_config_from_ldm,
57-
create_z_image_controlnet_config,
5857
fetch_diffusers_config,
5958
fetch_original_config,
6059
load_single_file_checkpoint,
@@ -176,7 +175,6 @@
176175
},
177176
"ZImageControlNetModel": {
178177
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
179-
"config_create_fn": create_z_image_controlnet_config,
180178
},
181179
}
182180

@@ -379,10 +377,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
379377
diffusers_model_config = config_mapping_fn(
380378
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
381379
)
382-
elif "config_create_fn" in mapping_functions:
383-
config_create_fn = mapping_functions["config_create_fn"]
384-
config_create_kwargs = _get_mapping_function_kwargs(config_create_fn, **kwargs)
385-
diffusers_model_config = config_create_fn(checkpoint=checkpoint, **config_create_kwargs)
386380
else:
387381
if config is not None:
388382
if isinstance(config, str):

src/diffusers/loaders/single_file_utils.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123123
"z-image-turbo": "cap_embedder.0.weight",
124124
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
125+
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
125126
"sana": [
126127
"blocks.0.cross_attn.q_linear.weight",
127128
"blocks.0.cross_attn.q_linear.bias",
@@ -221,6 +222,8 @@
221222
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
222223
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
223224
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
225+
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
226+
"z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
224227
}
225228

226229
# Use to configure model sample size when original config is provided
@@ -780,6 +783,9 @@ def infer_diffusers_model_type(checkpoint):
780783
else:
781784
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
782785

786+
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
787+
model_type = "z-image-turbo-controlnet-2.x"
788+
783789
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
784790
model_type = "z-image-turbo-controlnet"
785791

@@ -3891,47 +3897,12 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38913897
return converted_state_dict
38923898

38933899

3894-
def create_z_image_controlnet_config(checkpoint, **kwargs):
3895-
v1_config = {
3896-
"control_in_dim": 16,
3897-
"control_layers_places": [0, 5, 10, 15, 20, 25],
3898-
"dim": 3840,
3899-
"n_heads": 30,
3900-
"n_kv_heads": 30,
3901-
"n_refiner_layers": 2,
3902-
"norm_eps": 1e-05,
3903-
"qk_norm": True,
3904-
"all_f_patch_size": [1],
3905-
"all_patch_size": [2],
3906-
}
3907-
v2_config = {
3908-
"control_in_dim": 33,
3909-
"control_layers_places": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
3910-
"control_refiner_layers_places": [0, 1],
3911-
"add_control_noise_refiner": True,
3912-
"dim": 3840,
3913-
"n_heads": 30,
3914-
"n_kv_heads": 30,
3915-
"n_refiner_layers": 2,
3916-
"norm_eps": 1e-05,
3917-
"qk_norm": True,
3918-
"all_f_patch_size": [1],
3919-
"all_patch_size": [2],
3920-
}
3921-
control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1]
3922-
if control_x_embedder_weight_shape == 64:
3923-
return v1_config
3924-
elif control_x_embedder_weight_shape == 132:
3925-
return v2_config
3926-
else:
3927-
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
3928-
3929-
3930-
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, **kwargs):
3931-
control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1]
3932-
if control_x_embedder_weight_shape == 64:
3900+
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
3901+
if config["add_control_noise_refiner"] is None:
3902+
return checkpoint
3903+
elif config["add_control_noise_refiner"] == "control_noise_refiner":
39333904
return checkpoint
3934-
elif control_x_embedder_weight_shape == 132:
3905+
elif config["add_control_noise_refiner"] == "control_layers":
39353906
converted_state_dict = {
39363907
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
39373908
}

src/diffusers/models/controlnets/controlnet_z_image.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional
16+
from typing import List, Literal, Optional
1717

1818
import torch
1919
import torch.nn as nn
@@ -398,7 +398,7 @@ def __init__(
398398
control_layers_places: List[int] = None,
399399
control_refiner_layers_places: List[int] = None,
400400
control_in_dim=None,
401-
add_control_noise_refiner=False,
401+
add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None,
402402
all_patch_size=(2,),
403403
all_f_patch_size=(1,),
404404
dim=3840,
@@ -431,8 +431,24 @@ def __init__(
431431
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
432432

433433
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
434-
if self.add_control_noise_refiner:
434+
if self.add_control_noise_refiner == "control_layers":
435435
self.control_noise_refiner = None
436+
elif self.add_control_noise_refiner == "control_noise_refiner":
437+
self.control_noise_refiner = nn.ModuleList(
438+
[
439+
ZImageControlTransformerBlock(
440+
1000 + layer_id,
441+
dim,
442+
n_heads,
443+
n_kv_heads,
444+
norm_eps,
445+
qk_norm,
446+
modulation=True,
447+
block_id=layer_id,
448+
)
449+
for layer_id in range(n_refiner_layers)
450+
]
451+
)
436452
else:
437453
self.control_noise_refiner = nn.ModuleList(
438454
[
@@ -449,6 +465,7 @@ def __init__(
449465
]
450466
)
451467

468+
self.t_scale: Optional[float] = None
452469
self.t_embedder: Optional[TimestepEmbedder] = None
453470
self.all_x_embedder: Optional[nn.ModuleDict] = None
454471
self.cap_embedder: Optional[nn.Sequential] = None
@@ -624,7 +641,8 @@ def forward(
624641
f_patch_size=1,
625642
):
626643
if (
627-
self.t_embedder is None
644+
self.t_scale is None
645+
or self.t_embedder is None
628646
or self.all_x_embedder is None
629647
or self.cap_embedder is None
630648
or self.rope_embedder is None
@@ -687,8 +705,14 @@ def forward(
687705
for i, seq_len in enumerate(x_item_seqlens):
688706
x_attn_mask[i, :seq_len] = 1
689707

690-
if self.add_control_noise_refiner:
691-
for layer in self.control_layers:
708+
if self.add_control_noise_refiner is not None:
709+
if self.add_control_noise_refiner == "control_layers":
710+
layers = self.control_layers
711+
elif self.add_control_noise_refiner == "control_noise_refiner":
712+
layers = self.control_noise_refiner
713+
else:
714+
raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.")
715+
for layer in layers:
692716
if torch.is_grad_enabled() and self.gradient_checkpointing:
693717
control_context = self._gradient_checkpointing_func(
694718
layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input

src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,25 @@
4949
... torch_dtype=torch.bfloat16,
5050
... )
5151
52+
>>> # 2.1
53+
>>> # controlnet = ZImageControlNetModel.from_single_file(
54+
... # hf_hub_download(
55+
... # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
56+
... # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
57+
... # ),
58+
... # torch_dtype=torch.bfloat16,
59+
... # )
60+
61+
>>> # 2.0 - `config` is required
62+
>>> # controlnet = ZImageControlNetModel.from_single_file(
63+
... # hf_hub_download(
64+
... # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
65+
... # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
66+
... # ),
67+
... # torch_dtype=torch.bfloat16,
68+
... # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
69+
... # )
70+
5271
>>> pipe = ZImageControlNetPipeline.from_pretrained(
5372
... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
5473
... )

src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,21 @@
4545
>>> controlnet = ZImageControlNetModel.from_single_file(
4646
... hf_hub_download(
4747
... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
48-
... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
48+
... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
4949
... ),
5050
... torch_dtype=torch.bfloat16,
5151
... )
5252
53+
>>> # 2.0 - `config` is required
54+
>>> # controlnet = ZImageControlNetModel.from_single_file(
55+
... # hf_hub_download(
56+
... # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
57+
... # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
58+
... # ),
59+
... # torch_dtype=torch.bfloat16,
60+
... # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
61+
... # )
62+
5363
>>> pipe = ZImageControlNetInpaintPipeline.from_pretrained(
5464
... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
5565
... )

0 commit comments

Comments
 (0)