Skip to content

Commit 764f7ed

Browse files
authored
[core] Flux2 klein kv followups (#13264)
* implement Flux2Transformer2DModelOutput. * add output class to docs. * add Flux2KleinKV to docs. * add pipeline tests for klein kv.
1 parent 8d0f3e1 commit 764f7ed

File tree

5 files changed

+207
-8
lines changed

5 files changed

+207
-8
lines changed

docs/source/en/api/models/flux2_transformer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
1717
## Flux2Transformer2DModel
1818

1919
[[autodoc]] Flux2Transformer2DModel
20+
21+
## Flux2Transformer2DModelOutput
22+
23+
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput

docs/source/en/api/pipelines/flux2.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
4141
## Flux2KleinPipeline
4242

4343
[[autodoc]] Flux2KleinPipeline
44+
- all
45+
- __call__
46+
47+
## Flux2KleinKVPipeline
48+
49+
[[autodoc]] Flux2KleinKVPipeline
4450
- all
4551
- __call__

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
from dataclasses import dataclass
1617
from typing import Any
1718

1819
import torch
@@ -21,7 +22,7 @@
2122

2223
from ...configuration_utils import ConfigMixin, register_to_config
2324
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import apply_lora_scale, logging
25+
from ...utils import BaseOutput, apply_lora_scale, logging
2526
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2627
from ..attention import AttentionMixin, AttentionModuleMixin
2728
from ..attention_dispatch import dispatch_attention_fn
@@ -32,14 +33,29 @@
3233
apply_rotary_emb,
3334
get_1d_rotary_pos_embed,
3435
)
35-
from ..modeling_outputs import Transformer2DModelOutput
3636
from ..modeling_utils import ModelMixin
3737
from ..normalization import AdaLayerNormContinuous
3838

3939

4040
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4141

4242

43+
@dataclass
44+
class Flux2Transformer2DModelOutput(BaseOutput):
45+
"""
46+
The output of [`Flux2Transformer2DModel`].
47+
48+
Args:
49+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
50+
The hidden states output conditioned on the `encoder_hidden_states` input.
51+
kv_cache (`Flux2KVCache`, *optional*):
52+
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
53+
"""
54+
55+
sample: "torch.Tensor" # noqa: F821
56+
kv_cache: "Flux2KVCache | None" = None
57+
58+
4359
class Flux2KVLayerCache:
4460
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
4561
@@ -1174,7 +1190,7 @@ def forward(
11741190
kv_cache_mode: str | None = None,
11751191
num_ref_tokens: int = 0,
11761192
ref_fixed_timestep: float = 0.0,
1177-
) -> torch.Tensor | Transformer2DModelOutput:
1193+
) -> torch.Tensor | Flux2Transformer2DModelOutput:
11781194
"""
11791195
The [`Flux2Transformer2DModel`] forward method.
11801196
@@ -1356,10 +1372,10 @@ def forward(
13561372

13571373
if kv_cache_mode == "extract":
13581374
if not return_dict:
1359-
return (output,), kv_cache
1360-
return Transformer2DModelOutput(sample=output), kv_cache
1375+
return (output, kv_cache)
1376+
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
13611377

13621378
if not return_dict:
13631379
return (output,)
13641380

1365-
return Transformer2DModelOutput(sample=output)
1381+
return Flux2Transformer2DModelOutput(sample=output)

src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ def __call__(
793793
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
794794
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
795795

796-
output, kv_cache = self.transformer(
796+
noise_pred, kv_cache = self.transformer(
797797
hidden_states=latent_model_input,
798798
timestep=timestep / 1000,
799799
guidance=None,
@@ -805,7 +805,6 @@ def __call__(
805805
kv_cache_mode="extract",
806806
num_ref_tokens=image_latents.shape[1],
807807
)
808-
noise_pred = output[0]
809808

810809
elif kv_cache is not None:
811810
# Steps 1+: use cached ref KV, no ref tokens in input
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import unittest
2+
3+
import numpy as np
4+
import torch
5+
from PIL import Image
6+
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
7+
8+
from diffusers import (
9+
AutoencoderKLFlux2,
10+
FlowMatchEulerDiscreteScheduler,
11+
Flux2KleinKVPipeline,
12+
Flux2Transformer2DModel,
13+
)
14+
15+
from ...testing_utils import torch_device
16+
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
17+
18+
19+
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
20+
pipeline_class = Flux2KleinKVPipeline
21+
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
22+
batch_params = frozenset(["prompt"])
23+
24+
test_xformers_attention = False
25+
test_layerwise_casting = True
26+
test_group_offloading = True
27+
28+
supports_dduf = False
29+
30+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
31+
torch.manual_seed(0)
32+
transformer = Flux2Transformer2DModel(
33+
patch_size=1,
34+
in_channels=4,
35+
num_layers=num_layers,
36+
num_single_layers=num_single_layers,
37+
attention_head_dim=16,
38+
num_attention_heads=2,
39+
joint_attention_dim=16,
40+
timestep_guidance_channels=256,
41+
axes_dims_rope=[4, 4, 4, 4],
42+
guidance_embeds=False,
43+
)
44+
45+
# Create minimal Qwen3 config
46+
config = Qwen3Config(
47+
intermediate_size=16,
48+
hidden_size=16,
49+
num_hidden_layers=2,
50+
num_attention_heads=2,
51+
num_key_value_heads=2,
52+
vocab_size=151936,
53+
max_position_embeddings=512,
54+
)
55+
torch.manual_seed(0)
56+
text_encoder = Qwen3ForCausalLM(config)
57+
58+
# Use a simple tokenizer for testing
59+
tokenizer = Qwen2TokenizerFast.from_pretrained(
60+
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
61+
)
62+
63+
torch.manual_seed(0)
64+
vae = AutoencoderKLFlux2(
65+
sample_size=32,
66+
in_channels=3,
67+
out_channels=3,
68+
down_block_types=("DownEncoderBlock2D",),
69+
up_block_types=("UpDecoderBlock2D",),
70+
block_out_channels=(4,),
71+
layers_per_block=1,
72+
latent_channels=1,
73+
norm_num_groups=1,
74+
use_quant_conv=False,
75+
use_post_quant_conv=False,
76+
)
77+
78+
scheduler = FlowMatchEulerDiscreteScheduler()
79+
80+
return {
81+
"scheduler": scheduler,
82+
"text_encoder": text_encoder,
83+
"tokenizer": tokenizer,
84+
"transformer": transformer,
85+
"vae": vae,
86+
}
87+
88+
def get_dummy_inputs(self, device, seed=0):
89+
if str(device).startswith("mps"):
90+
generator = torch.manual_seed(seed)
91+
else:
92+
generator = torch.Generator(device="cpu").manual_seed(seed)
93+
94+
inputs = {
95+
"prompt": "a dog is dancing",
96+
"image": Image.new("RGB", (64, 64)),
97+
"generator": generator,
98+
"num_inference_steps": 2,
99+
"height": 8,
100+
"width": 8,
101+
"max_sequence_length": 64,
102+
"output_type": "np",
103+
"text_encoder_out_layers": (1,),
104+
}
105+
return inputs
106+
107+
def test_fused_qkv_projections(self):
108+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
109+
components = self.get_dummy_components()
110+
pipe = self.pipeline_class(**components)
111+
pipe = pipe.to(device)
112+
pipe.set_progress_bar_config(disable=None)
113+
114+
inputs = self.get_dummy_inputs(device)
115+
image = pipe(**inputs).images
116+
original_image_slice = image[0, -3:, -3:, -1]
117+
118+
pipe.transformer.fuse_qkv_projections()
119+
self.assertTrue(
120+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
121+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
122+
)
123+
124+
inputs = self.get_dummy_inputs(device)
125+
image = pipe(**inputs).images
126+
image_slice_fused = image[0, -3:, -3:, -1]
127+
128+
pipe.transformer.unfuse_qkv_projections()
129+
inputs = self.get_dummy_inputs(device)
130+
image = pipe(**inputs).images
131+
image_slice_disabled = image[0, -3:, -3:, -1]
132+
133+
self.assertTrue(
134+
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
135+
("Fusion of QKV projections shouldn't affect the outputs."),
136+
)
137+
self.assertTrue(
138+
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
139+
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
140+
)
141+
self.assertTrue(
142+
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
143+
("Original outputs should match when fused QKV projections are disabled."),
144+
)
145+
146+
def test_image_output_shape(self):
147+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
148+
inputs = self.get_dummy_inputs(torch_device)
149+
150+
height_width_pairs = [(32, 32), (72, 57)]
151+
for height, width in height_width_pairs:
152+
expected_height = height - height % (pipe.vae_scale_factor * 2)
153+
expected_width = width - width % (pipe.vae_scale_factor * 2)
154+
155+
inputs.update({"height": height, "width": width})
156+
image = pipe(**inputs).images[0]
157+
output_height, output_width, _ = image.shape
158+
self.assertEqual(
159+
(output_height, output_width),
160+
(expected_height, expected_width),
161+
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
162+
)
163+
164+
def test_without_image(self):
165+
device = "cpu"
166+
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
167+
inputs = self.get_dummy_inputs(device)
168+
del inputs["image"]
169+
image = pipe(**inputs).images
170+
self.assertEqual(image.shape, (1, 8, 8, 3))
171+
172+
@unittest.skip("Needs to be revisited")
173+
def test_encode_prompt_works_in_isolation(self):
174+
pass

0 commit comments

Comments
 (0)