-
Notifications
You must be signed in to change notification settings - Fork 7k
Z-Image-Turbo ControlNet #12792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Z-Image-Turbo ControlNet #12792
Changes from 40 commits
2354fda
1e2009d
0c30839
52f996e
4b446b3
a1ff390
7ab347d
8cab0c9
8688fa6
9051272
0d8c3f1
f789325
5f8ab7b
bc72f9c
13b706a
728ba02
09849a7
f63a5a8
f9540cb
3e472ac
0e7c643
413c7cb
a00f104
a961402
6e1c218
8e7743a
c135170
a737b3c
7bc847a
ffde035
04388f4
62ee1c1
faf5a24
6126f02
c3def6b
f80ed52
721011e
efadd91
f4b7fcc
dd9775c
24f454c
1b49be4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,7 @@ | |
| "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", | ||
| "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], | ||
| "z-image-turbo": "cap_embedder.0.weight", | ||
| "z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight", | ||
| "sana": [ | ||
| "blocks.0.cross_attn.q_linear.weight", | ||
| "blocks.0.cross_attn.q_linear.bias", | ||
|
|
@@ -779,6 +780,9 @@ def infer_diffusers_model_type(checkpoint): | |
| else: | ||
| raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.") | ||
|
|
||
| elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint: | ||
| model_type = "z-image-turbo-controlnet" | ||
|
|
||
| else: | ||
| model_type = "v1" | ||
|
|
||
|
|
@@ -3885,3 +3889,52 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) | |
| handler_fn_inplace(key, converted_state_dict) | ||
|
|
||
| return converted_state_dict | ||
|
|
||
|
|
||
| def create_z_image_controlnet_config(checkpoint, **kwargs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we host these (either on the official model repo or create a community repo for them) and fetch from a model repo like we do for all the other configs? Why hardcode in src?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #12792 (comment) and #12792 (comment) |
||
| v1_config = { | ||
| "control_in_dim": 16, | ||
| "control_layers_places": [0, 5, 10, 15, 20, 25], | ||
| "dim": 3840, | ||
| "n_heads": 30, | ||
| "n_kv_heads": 30, | ||
| "n_refiner_layers": 2, | ||
| "norm_eps": 1e-05, | ||
| "qk_norm": True, | ||
| "all_f_patch_size": [1], | ||
| "all_patch_size": [2], | ||
| } | ||
| v2_config = { | ||
| "control_in_dim": 33, | ||
| "control_layers_places": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], | ||
| "control_refiner_layers_places": [0, 1], | ||
| "add_control_noise_refiner": True, | ||
| "dim": 3840, | ||
| "n_heads": 30, | ||
| "n_kv_heads": 30, | ||
| "n_refiner_layers": 2, | ||
| "norm_eps": 1e-05, | ||
| "qk_norm": True, | ||
| "all_f_patch_size": [1], | ||
| "all_patch_size": [2], | ||
| } | ||
| control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1] | ||
| if control_x_embedder_weight_shape == 64: | ||
| return v1_config | ||
| elif control_x_embedder_weight_shape == 132: | ||
| return v2_config | ||
| else: | ||
| raise ValueError("Unknown Z-Image Turbo ControlNet type.") | ||
|
|
||
|
|
||
| def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, **kwargs): | ||
| control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1] | ||
| if control_x_embedder_weight_shape == 64: | ||
| return checkpoint | ||
| elif control_x_embedder_weight_shape == 132: | ||
| converted_state_dict = { | ||
| key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.") | ||
| } | ||
| return converted_state_dict | ||
| else: | ||
| raise ValueError("Unknown Z-Image Turbo ControlNet type.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question. Is this to handle cases where you'd have transformer + controlnet in a single checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #12792 (comment)