Skip to content

Commit 7b976cd

Browse files
authored
Merge branch 'main' into fix/modular-pipeline-lora-missing-transformer
2 parents b05b25e + c8c8401 commit 7b976cd

28 files changed

Lines changed: 2110 additions & 105 deletions

docs/source/en/optimization/speed-memory-optims.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ The table below provides a comparison of optimization strategy combinations and
3333

3434
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
3535

36+
While we use bitsandbytes in this example, other quantization backends such as [TorchAO](../quantization/torchao.md) also support these features.
37+
3638
```bash
3739
pip install -U bitsandbytes
3840
```

examples/research_projects/pytorch_xla/inference/flux/README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,42 @@ python flux_inference.py
5151

5252
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
5353

54-
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
54+
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel).
55+
56+
> **Note:** `flux_inference.py` uses `xmp.spawn` (one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below.
57+
58+
### SPMD version (for v5e-8 and similar)
59+
60+
On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use `flux_inference_spmd.py`. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a `(data, model)` mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism.
61+
62+
```bash
63+
python flux_inference_spmd.py --schnell
64+
```
65+
66+
Key differences from `flux_inference.py`:
67+
- **Single-process SPMD** instead of multi-process `xmp.spawn` — the XLA compiler handles all collective communication transparently.
68+
- **Transformer weights are sharded** across the `"model"` mesh axis using `xs.mark_sharding`.
69+
- **VAE lives on CPU**, moved to XLA only for decode (then moved back), since the transformer stays on device throughout.
70+
- **Text encoding** runs on CPU before loading the transformer.
71+
72+
On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation):
73+
74+
```
75+
2026-04-15 02:24:30 [info ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8
76+
2026-04-15 02:24:30 [info ] encoding prompt on CPU...
77+
2026-04-15 02:26:20 [info ] loading VAE on CPU...
78+
2026-04-15 02:26:20 [info ] loading flux transformer from black-forest-labs/FLUX.1-schnell
79+
2026-04-15 02:27:22 [info ] starting compilation run...
80+
2026-04-15 02:52:55 [info ] compilation took 1533.4575625509997 sec.
81+
2026-04-15 02:52:56 [info ] starting inference run...
82+
2026-04-15 02:56:11 [info ] inference time: 195.74092420299985
83+
2026-04-15 02:56:13 [info ] inference time: 1.7625778899996476
84+
2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
85+
```
86+
87+
The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
88+
89+
### v6e-4 results (original `flux_inference.py`)
5590

5691
```bash
5792
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""FLUX inference on TPU using PyTorch/XLA SPMD.
2+
3+
Uses SPMD to shard the transformer across multiple TPU chips, enabling
4+
inference on devices where the model doesn't fit on a single chip (e.g., v5e).
5+
The VAE is loaded on CPU at startup, moved to XLA for decode, then moved back.
6+
"""
7+
8+
from argparse import ArgumentParser
9+
from pathlib import Path
10+
from time import perf_counter
11+
12+
import numpy as np
13+
import structlog
14+
import torch
15+
import torch_xla.core.xla_model as xm
16+
import torch_xla.debug.metrics as met
17+
import torch_xla.debug.profiler as xp
18+
import torch_xla.distributed.spmd as xs
19+
import torch_xla.runtime as xr
20+
from torch_xla.experimental.custom_kernel import FlashAttention
21+
22+
from diffusers import AutoencoderKL, FluxPipeline
23+
24+
25+
cache_path = Path("/tmp/data/compiler_cache_eXp")
26+
cache_path.mkdir(parents=True, exist_ok=True)
27+
xr.initialize_cache(str(cache_path), readonly=False)
28+
xr.use_spmd()
29+
30+
logger = structlog.get_logger()
31+
metrics_filepath = "/tmp/metrics_report.txt"
32+
VAE_SCALE_FACTOR = 8
33+
34+
35+
def _vae_decode(latents, vae, height, width, device):
36+
"""Move VAE to XLA, decode latents, move VAE back to CPU."""
37+
vae.to(device)
38+
latents = FluxPipeline._unpack_latents(latents, height, width, VAE_SCALE_FACTOR)
39+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
40+
with torch.no_grad():
41+
image = vae.decode(latents, return_dict=False)[0]
42+
vae.to("cpu")
43+
return image
44+
45+
46+
def main(args):
47+
# --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips ---
48+
num_devices = xr.global_runtime_device_count()
49+
if num_devices >= 4:
50+
mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model"))
51+
else:
52+
NotImplementedError
53+
xs.set_global_mesh(mesh)
54+
logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}")
55+
56+
# --- Profiler ---
57+
profile_path = Path("/tmp/data/profiler_out_eXp")
58+
profile_path.mkdir(parents=True, exist_ok=True)
59+
profiler_port = 9012
60+
profile_duration = args.profile_duration
61+
if args.profile:
62+
logger.info(f"starting profiler on port {profiler_port}")
63+
_ = xp.start_server(profiler_port)
64+
65+
device = xm.xla_device()
66+
67+
# --- Checkpoint ---
68+
if args.schnell:
69+
ckpt_id = "black-forest-labs/FLUX.1-schnell"
70+
else:
71+
ckpt_id = "black-forest-labs/FLUX.1-dev"
72+
73+
# --- Text encoding (CPU) ---
74+
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
75+
logger.info("encoding prompt on CPU...")
76+
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
77+
with torch.no_grad():
78+
prompt_embeds, pooled_prompt_embeds, _ = text_pipe.encode_prompt(
79+
prompt=prompt, prompt_2=None, max_sequence_length=512
80+
)
81+
image_processor = text_pipe.image_processor
82+
del text_pipe
83+
84+
# --- Load VAE on CPU (moved to XLA only for decode) ---
85+
logger.info("loading VAE on CPU...")
86+
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16)
87+
88+
# --- Load transformer and shard ---
89+
logger.info(f"loading flux transformer from {ckpt_id}")
90+
flux_pipe = FluxPipeline.from_pretrained(
91+
ckpt_id,
92+
text_encoder=None,
93+
tokenizer=None,
94+
text_encoder_2=None,
95+
tokenizer_2=None,
96+
vae=None,
97+
torch_dtype=torch.bfloat16,
98+
).to(device)
99+
100+
for name, param in flux_pipe.transformer.named_parameters():
101+
if param.dim() >= 2:
102+
spec = [None] * param.dim()
103+
largest_dim = max(range(param.dim()), key=lambda d: param.shape[d])
104+
spec[largest_dim] = "model"
105+
xs.mark_sharding(param, mesh, tuple(spec))
106+
107+
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
108+
FlashAttention.DEFAULT_BLOCK_SIZES = {
109+
"block_q": 1536,
110+
"block_k_major": 1536,
111+
"block_k": 1536,
112+
"block_b": 1536,
113+
"block_q_major_dkv": 1536,
114+
"block_k_major_dkv": 1536,
115+
"block_q_dkv": 1536,
116+
"block_k_dkv": 1536,
117+
"block_q_dq": 1536,
118+
"block_k_dq": 1536,
119+
"block_k_major_dq": 1536,
120+
}
121+
122+
width = args.width
123+
height = args.height
124+
guidance = args.guidance
125+
n_steps = 4 if args.schnell else 28
126+
127+
prompt_embeds = prompt_embeds.to(device)
128+
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
129+
xs.mark_sharding(prompt_embeds, mesh, ("data", None, None))
130+
xs.mark_sharding(pooled_prompt_embeds, mesh, ("data", None))
131+
132+
# --- Compilation run ---
133+
logger.info("starting compilation run...")
134+
ts = perf_counter()
135+
latents = flux_pipe(
136+
prompt_embeds=prompt_embeds,
137+
pooled_prompt_embeds=pooled_prompt_embeds,
138+
num_inference_steps=28,
139+
guidance_scale=guidance,
140+
height=height,
141+
width=width,
142+
output_type="latent",
143+
).images
144+
image = _vae_decode(latents, vae, height, width, device)
145+
image = image_processor.postprocess(image)[0]
146+
logger.info(f"compilation took {perf_counter() - ts} sec.")
147+
image.save("/tmp/compile_out.png")
148+
149+
# --- Inference loop ---
150+
seed = 4096 if args.seed is None else args.seed
151+
xm.set_rng_state(seed=seed, device=device)
152+
times = []
153+
logger.info("starting inference run...")
154+
for _ in range(args.itters):
155+
ts = perf_counter()
156+
157+
if args.profile:
158+
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
159+
latents = flux_pipe(
160+
prompt_embeds=prompt_embeds,
161+
pooled_prompt_embeds=pooled_prompt_embeds,
162+
num_inference_steps=n_steps,
163+
guidance_scale=guidance,
164+
height=height,
165+
width=width,
166+
output_type="latent",
167+
).images
168+
image = _vae_decode(latents, vae, height, width, device)
169+
image = image_processor.postprocess(image)[0]
170+
inference_time = perf_counter() - ts
171+
logger.info(f"inference time: {inference_time}")
172+
times.append(inference_time)
173+
174+
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
175+
image.save("/tmp/inference_out.png")
176+
metrics_report = met.metrics_report()
177+
with open(metrics_filepath, "w+") as fout:
178+
fout.write(metrics_report)
179+
logger.info(f"saved metric information as {metrics_filepath}")
180+
181+
182+
if __name__ == "__main__":
183+
parser = ArgumentParser()
184+
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
185+
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
186+
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
187+
parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev")
188+
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
189+
parser.add_argument("--profile", action="store_true", help="enable profiling")
190+
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
191+
parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.")
192+
args = parser.parse_args()
193+
main(args)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@
533533
"EasyAnimateInpaintPipeline",
534534
"EasyAnimatePipeline",
535535
"ErnieImagePipeline",
536+
"Flux2KleinInpaintPipeline",
536537
"Flux2KleinKVPipeline",
537538
"Flux2KleinPipeline",
538539
"Flux2Pipeline",
@@ -1317,6 +1318,7 @@
13171318
EasyAnimateInpaintPipeline,
13181319
EasyAnimatePipeline,
13191320
ErnieImagePipeline,
1321+
Flux2KleinInpaintPipeline,
13201322
Flux2KleinKVPipeline,
13211323
Flux2KleinPipeline,
13221324
Flux2Pipeline,

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,6 +2331,20 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331
temp_state_dict[new_key] = v
23322332
original_state_dict = temp_state_dict
23332333

2334+
# Some Flux2 checkpoints skip the ai-toolkit `single_blocks` / `double_blocks`
2335+
# layout and already store expanded diffusers block names. Accept those
2336+
# directly, and normalize the legacy `sformer_blocks` alias used by some exports.
2337+
possible_expanded_block_prefixes = {
2338+
"single_transformer_blocks.": "single_transformer_blocks.",
2339+
"transformer_blocks.": "transformer_blocks.",
2340+
"sformer_blocks.": "transformer_blocks.",
2341+
}
2342+
for key in list(original_state_dict.keys()):
2343+
for source_prefix, target_prefix in possible_expanded_block_prefixes.items():
2344+
if key.startswith(source_prefix):
2345+
converted_state_dict[target_prefix + key[len(source_prefix) :]] = original_state_dict.pop(key)
2346+
break
2347+
23342348
num_double_layers = 0
23352349
num_single_layers = 0
23362350
for key in original_state_dict.keys():
@@ -2421,6 +2435,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24212435
"txt_in": "context_embedder",
24222436
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
24232437
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
2438+
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
2439+
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
24242440
"final_layer.linear": "proj_out",
24252441
"final_layer.adaLN_modulation.1": "norm_out.linear",
24262442
"single_stream_modulation.lin": "single_stream_modulation.linear",

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,17 +1521,16 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas
15211521
if attn_mask is not None and torch.all(attn_mask != 0):
15221522
attn_mask = None
15231523

1524-
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
1524+
# Reshape Attention Mask: [batch_size, seq_len_k] or [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
15251525
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
1526-
if (
1527-
attn_mask is not None
1528-
and attn_mask.ndim == 2
1529-
and attn_mask.shape[0] == query.shape[0]
1530-
and attn_mask.shape[1] == key.shape[1]
1531-
):
1532-
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
1526+
if attn_mask is not None:
1527+
if attn_mask.ndim == 2 and attn_mask.shape[0] == query.shape[0] and attn_mask.shape[1] == key.shape[1]:
1528+
batch_size, seq_len_q, seq_len_kv = attn_mask.shape[0], query.shape[1], key.shape[1]
1529+
attn_mask = attn_mask.unsqueeze(1).expand(batch_size, seq_len_q, seq_len_kv).unsqueeze(1).contiguous()
1530+
elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1):
1531+
attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1).contiguous()
1532+
15331533
attn_mask = ~attn_mask.to(torch.bool)
1534-
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
15351534

15361535
return attn_mask
15371536

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,15 @@ def get_qwen_prompt_embeds_edit(
117117
).to(device)
118118

119119
outputs = text_encoder(
120-
input_ids=model_inputs.input_ids,
121-
attention_mask=model_inputs.attention_mask,
122-
pixel_values=model_inputs.pixel_values,
123-
image_grid_thw=model_inputs.image_grid_thw,
120+
input_ids=model_inputs["input_ids"],
121+
attention_mask=model_inputs["attention_mask"],
122+
pixel_values=model_inputs.get("pixel_values"),
123+
image_grid_thw=model_inputs.get("image_grid_thw"),
124124
output_hidden_states=True,
125125
)
126126

127127
hidden_states = outputs.hidden_states[-1]
128-
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
128+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
129129
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
130130
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
131131
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -173,15 +173,15 @@ def get_qwen_prompt_embeds_edit_plus(
173173
return_tensors="pt",
174174
).to(device)
175175
outputs = text_encoder(
176-
input_ids=model_inputs.input_ids,
177-
attention_mask=model_inputs.attention_mask,
178-
pixel_values=model_inputs.pixel_values,
179-
image_grid_thw=model_inputs.image_grid_thw,
176+
input_ids=model_inputs["input_ids"],
177+
attention_mask=model_inputs["attention_mask"],
178+
pixel_values=model_inputs.get("pixel_values"),
179+
image_grid_thw=model_inputs.get("image_grid_thw"),
180180
output_hidden_states=True,
181181
)
182182

183183
hidden_states = outputs.hidden_states[-1]
184-
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
184+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
185185
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
186186
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
187187
max_seq_len = max([e.size(0) for e in split_hidden_states])

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,12 @@
160160
]
161161
_import_structure["bria"] = ["BriaPipeline"]
162162
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
163-
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
163+
_import_structure["flux2"] = [
164+
"Flux2Pipeline",
165+
"Flux2KleinPipeline",
166+
"Flux2KleinInpaintPipeline",
167+
"Flux2KleinKVPipeline",
168+
]
164169
_import_structure["flux"] = [
165170
"FluxControlPipeline",
166171
"FluxControlInpaintPipeline",
@@ -697,7 +702,7 @@
697702
FluxPriorReduxPipeline,
698703
ReduxImageEncoder,
699704
)
700-
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
705+
from .flux2 import Flux2KleinInpaintPipeline, Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
701706
from .glm_image import GlmImagePipeline
702707
from .helios import HeliosPipeline, HeliosPyramidPipeline
703708
from .hidream_image import HiDreamImagePipeline

src/diffusers/pipelines/flux2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
else:
2525
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
2626
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
27+
_import_structure["pipeline_flux2_klein_inpaint"] = ["Flux2KleinInpaintPipeline"]
2728
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
2829
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2930
try:
@@ -34,6 +35,7 @@
3435
else:
3536
from .pipeline_flux2 import Flux2Pipeline
3637
from .pipeline_flux2_klein import Flux2KleinPipeline
38+
from .pipeline_flux2_klein_inpaint import Flux2KleinInpaintPipeline
3739
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
3840
else:
3941
import sys

0 commit comments

Comments
 (0)