Skip to content

Commit 25d2fd8

Browse files
authored
Merge branch 'main' into pos_emb_on_npu
2 parents a0f7b63 + 3138e37 commit 25d2fd8

8 files changed

Lines changed: 883 additions & 21 deletions

File tree

examples/community/pipeline_z_image_differential_img2img.py

Lines changed: 844 additions & 0 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def run(self):
274274

275275
setup(
276276
name="diffusers",
277-
version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
277+
version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
278278
description="State-of-the-art diffusion in PyTorch and JAX.",
279279
long_description=open("README.md", "r", encoding="utf-8").read(),
280280
long_description_content_type="text/markdown",

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@
162162
"default_subfolder": "transformer",
163163
},
164164
"QwenImageTransformer2DModel": {
165-
"checkpoint_mapping_fn": lambda x: x,
165+
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
166166
"default_subfolder": "transformer",
167167
},
168168
"Flux2Transformer2DModel": {

src/diffusers/loaders/single_file_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@
120120
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
121121
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123-
"z-image-turbo": "cap_embedder.0.weight",
123+
"z-image-turbo": [
124+
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
125+
"layers.0.adaLN_modulation.0.weight",
126+
],
124127
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
125128
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
126129
"sana": [
@@ -727,10 +730,7 @@ def infer_diffusers_model_type(checkpoint):
727730
):
728731
model_type = "instruct-pix2pix"
729732

730-
elif (
731-
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
732-
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
733-
):
733+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
734734
model_type = "z-image-turbo"
735735

736736
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
@@ -3852,6 +3852,7 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
38523852
".attention.k_norm.weight": ".attention.norm_k.weight",
38533853
".attention.q_norm.weight": ".attention.norm_q.weight",
38543854
".attention.out.weight": ".attention.to_out.0.weight",
3855+
"model.diffusion_model.": "",
38553856
}
38563857

38573858
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
@@ -3886,6 +3887,9 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38863887

38873888
update_state_dict(converted_state_dict, key, new_key)
38883889

3890+
if "norm_final.weight" in converted_state_dict.keys():
3891+
_ = converted_state_dict.pop("norm_final.weight")
3892+
38893893
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
38903894
# special_keys_remap
38913895
for key in list(converted_state_dict.keys()):

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def apply_rotary_emb(
134134
dropout_p=0.0,
135135
is_causal=False,
136136
backend=self._attention_backend,
137-
parallel_config=self._parallel_config,
137+
# Reference: https://github.com/huggingface/diffusers/pull/12909
138+
parallel_config=None,
138139
)
139140
hidden_states_img = hidden_states_img.flatten(2, 3)
140141
hidden_states_img = hidden_states_img.type_as(query)
@@ -147,7 +148,8 @@ def apply_rotary_emb(
147148
dropout_p=0.0,
148149
is_causal=False,
149150
backend=self._attention_backend,
150-
parallel_config=self._parallel_config,
151+
# Reference: https://github.com/huggingface/diffusers/pull/12909
152+
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
151153
)
152154
hidden_states = hidden_states.flatten(2, 3)
153155
hidden_states = hidden_states.type_as(query)
@@ -552,9 +554,11 @@ class WanTransformer3DModel(
552554
"blocks.0": {
553555
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
554556
},
555-
"blocks.*": {
556-
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
557-
},
557+
# Reference: https://github.com/huggingface/diffusers/pull/12909
558+
# We need to disable the splitting of encoder_hidden_states because the image_encoder
559+
# (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
560+
# of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
561+
# —to be indivisible by the number of devices in the CP.
558562
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
559563
"": {
560564
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,8 @@ def apply_rotary_emb(
609609
dropout_p=0.0,
610610
is_causal=False,
611611
backend=self._attention_backend,
612-
parallel_config=self._parallel_config,
612+
# Reference: https://github.com/huggingface/diffusers/pull/12909
613+
parallel_config=None,
613614
)
614615
hidden_states_img = hidden_states_img.flatten(2, 3)
615616
hidden_states_img = hidden_states_img.type_as(query)
@@ -622,7 +623,8 @@ def apply_rotary_emb(
622623
dropout_p=0.0,
623624
is_causal=False,
624625
backend=self._attention_backend,
625-
parallel_config=self._parallel_config,
626+
# Reference: https://github.com/huggingface/diffusers/pull/12909
627+
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
626628
)
627629
hidden_states = hidden_states.flatten(2, 3)
628630
hidden_states = hidden_states.type_as(query)

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def retrieve_latents(
7676
7777
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
7878
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
79-
... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
79+
... model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
8080
... )
8181
>>> pipe = pipe.to("cuda")
8282

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from ..base import DiffusersQuantizer
3737

3838

39+
logger = logging.get_logger(__name__)
40+
41+
3942
if TYPE_CHECKING:
4043
from ...models.modeling_utils import ModelMixin
4144

@@ -83,11 +86,19 @@ def _update_torch_safe_globals():
8386
]
8487
try:
8588
from torchao.dtypes import NF4Tensor
86-
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
87-
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
8889
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
8990

90-
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
91+
safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor])
92+
93+
# note: is_torchao_version(">=", "0.16.0") does not work correctly
94+
# with torchao nightly, so using a ">" check which does work correctly
95+
if is_torchao_version(">", "0.15.0"):
96+
pass
97+
else:
98+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
99+
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
100+
101+
safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl])
91102

92103
except (ImportError, ModuleNotFoundError) as e:
93104
logger.warning(
@@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]:
123134
return None
124135

125136

126-
logger = logging.get_logger(__name__)
127-
128-
129137
def _quantization_type(weight):
130138
from torchao.dtypes import AffineQuantizedTensor
131139
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

0 commit comments

Comments
 (0)