Skip to content

Commit 5995a34

Browse files
Merge branch 'main' into feat/sd3-modular-pipeline
2 parents 2510dba + 160852d commit 5995a34

37 files changed

Lines changed: 2636 additions & 99 deletions

.ai/models.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,14 @@ Consult the implementations in `src/diffusers/models/transformers/` if you need
7373

7474
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
7575

76-
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
76+
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
77+
78+
9. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
79+
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
80+
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
81+
```python
82+
is_mps = hidden_states.device.type == "mps"
83+
is_npu = hidden_states.device.type == "npu"
84+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
85+
```
86+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.

examples/research_projects/pytorch_xla/inference/flux/README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,42 @@ python flux_inference.py
5151

5252
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
5353

54-
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
54+
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel).
55+
56+
> **Note:** `flux_inference.py` uses `xmp.spawn` (one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below.
57+
58+
### SPMD version (for v5e-8 and similar)
59+
60+
On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use `flux_inference_spmd.py`. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a `(data, model)` mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism.
61+
62+
```bash
63+
python flux_inference_spmd.py --schnell
64+
```
65+
66+
Key differences from `flux_inference.py`:
67+
- **Single-process SPMD** instead of multi-process `xmp.spawn` — the XLA compiler handles all collective communication transparently.
68+
- **Transformer weights are sharded** across the `"model"` mesh axis using `xs.mark_sharding`.
69+
- **VAE lives on CPU**, moved to XLA only for decode (then moved back), since the transformer stays on device throughout.
70+
- **Text encoding** runs on CPU before loading the transformer.
71+
72+
On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation):
73+
74+
```
75+
2026-04-15 02:24:30 [info ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8
76+
2026-04-15 02:24:30 [info ] encoding prompt on CPU...
77+
2026-04-15 02:26:20 [info ] loading VAE on CPU...
78+
2026-04-15 02:26:20 [info ] loading flux transformer from black-forest-labs/FLUX.1-schnell
79+
2026-04-15 02:27:22 [info ] starting compilation run...
80+
2026-04-15 02:52:55 [info ] compilation took 1533.4575625509997 sec.
81+
2026-04-15 02:52:56 [info ] starting inference run...
82+
2026-04-15 02:56:11 [info ] inference time: 195.74092420299985
83+
2026-04-15 02:56:13 [info ] inference time: 1.7625778899996476
84+
2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
85+
```
86+
87+
The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
88+
89+
### v6e-4 results (original `flux_inference.py`)
5590

5691
```bash
5792
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""FLUX inference on TPU using PyTorch/XLA SPMD.
2+
3+
Uses SPMD to shard the transformer across multiple TPU chips, enabling
4+
inference on devices where the model doesn't fit on a single chip (e.g., v5e).
5+
The VAE is loaded on CPU at startup, moved to XLA for decode, then moved back.
6+
"""
7+
8+
from argparse import ArgumentParser
9+
from pathlib import Path
10+
from time import perf_counter
11+
12+
import numpy as np
13+
import structlog
14+
import torch
15+
import torch_xla.core.xla_model as xm
16+
import torch_xla.debug.metrics as met
17+
import torch_xla.debug.profiler as xp
18+
import torch_xla.distributed.spmd as xs
19+
import torch_xla.runtime as xr
20+
from torch_xla.experimental.custom_kernel import FlashAttention
21+
22+
from diffusers import AutoencoderKL, FluxPipeline
23+
24+
25+
cache_path = Path("/tmp/data/compiler_cache_eXp")
26+
cache_path.mkdir(parents=True, exist_ok=True)
27+
xr.initialize_cache(str(cache_path), readonly=False)
28+
xr.use_spmd()
29+
30+
logger = structlog.get_logger()
31+
metrics_filepath = "/tmp/metrics_report.txt"
32+
VAE_SCALE_FACTOR = 8
33+
34+
35+
def _vae_decode(latents, vae, height, width, device):
36+
"""Move VAE to XLA, decode latents, move VAE back to CPU."""
37+
vae.to(device)
38+
latents = FluxPipeline._unpack_latents(latents, height, width, VAE_SCALE_FACTOR)
39+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
40+
with torch.no_grad():
41+
image = vae.decode(latents, return_dict=False)[0]
42+
vae.to("cpu")
43+
return image
44+
45+
46+
def main(args):
47+
# --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips ---
48+
num_devices = xr.global_runtime_device_count()
49+
if num_devices >= 4:
50+
mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model"))
51+
else:
52+
NotImplementedError
53+
xs.set_global_mesh(mesh)
54+
logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}")
55+
56+
# --- Profiler ---
57+
profile_path = Path("/tmp/data/profiler_out_eXp")
58+
profile_path.mkdir(parents=True, exist_ok=True)
59+
profiler_port = 9012
60+
profile_duration = args.profile_duration
61+
if args.profile:
62+
logger.info(f"starting profiler on port {profiler_port}")
63+
_ = xp.start_server(profiler_port)
64+
65+
device = xm.xla_device()
66+
67+
# --- Checkpoint ---
68+
if args.schnell:
69+
ckpt_id = "black-forest-labs/FLUX.1-schnell"
70+
else:
71+
ckpt_id = "black-forest-labs/FLUX.1-dev"
72+
73+
# --- Text encoding (CPU) ---
74+
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
75+
logger.info("encoding prompt on CPU...")
76+
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
77+
with torch.no_grad():
78+
prompt_embeds, pooled_prompt_embeds, _ = text_pipe.encode_prompt(
79+
prompt=prompt, prompt_2=None, max_sequence_length=512
80+
)
81+
image_processor = text_pipe.image_processor
82+
del text_pipe
83+
84+
# --- Load VAE on CPU (moved to XLA only for decode) ---
85+
logger.info("loading VAE on CPU...")
86+
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16)
87+
88+
# --- Load transformer and shard ---
89+
logger.info(f"loading flux transformer from {ckpt_id}")
90+
flux_pipe = FluxPipeline.from_pretrained(
91+
ckpt_id,
92+
text_encoder=None,
93+
tokenizer=None,
94+
text_encoder_2=None,
95+
tokenizer_2=None,
96+
vae=None,
97+
torch_dtype=torch.bfloat16,
98+
).to(device)
99+
100+
for name, param in flux_pipe.transformer.named_parameters():
101+
if param.dim() >= 2:
102+
spec = [None] * param.dim()
103+
largest_dim = max(range(param.dim()), key=lambda d: param.shape[d])
104+
spec[largest_dim] = "model"
105+
xs.mark_sharding(param, mesh, tuple(spec))
106+
107+
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
108+
FlashAttention.DEFAULT_BLOCK_SIZES = {
109+
"block_q": 1536,
110+
"block_k_major": 1536,
111+
"block_k": 1536,
112+
"block_b": 1536,
113+
"block_q_major_dkv": 1536,
114+
"block_k_major_dkv": 1536,
115+
"block_q_dkv": 1536,
116+
"block_k_dkv": 1536,
117+
"block_q_dq": 1536,
118+
"block_k_dq": 1536,
119+
"block_k_major_dq": 1536,
120+
}
121+
122+
width = args.width
123+
height = args.height
124+
guidance = args.guidance
125+
n_steps = 4 if args.schnell else 28
126+
127+
prompt_embeds = prompt_embeds.to(device)
128+
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
129+
xs.mark_sharding(prompt_embeds, mesh, ("data", None, None))
130+
xs.mark_sharding(pooled_prompt_embeds, mesh, ("data", None))
131+
132+
# --- Compilation run ---
133+
logger.info("starting compilation run...")
134+
ts = perf_counter()
135+
latents = flux_pipe(
136+
prompt_embeds=prompt_embeds,
137+
pooled_prompt_embeds=pooled_prompt_embeds,
138+
num_inference_steps=28,
139+
guidance_scale=guidance,
140+
height=height,
141+
width=width,
142+
output_type="latent",
143+
).images
144+
image = _vae_decode(latents, vae, height, width, device)
145+
image = image_processor.postprocess(image)[0]
146+
logger.info(f"compilation took {perf_counter() - ts} sec.")
147+
image.save("/tmp/compile_out.png")
148+
149+
# --- Inference loop ---
150+
seed = 4096 if args.seed is None else args.seed
151+
xm.set_rng_state(seed=seed, device=device)
152+
times = []
153+
logger.info("starting inference run...")
154+
for _ in range(args.itters):
155+
ts = perf_counter()
156+
157+
if args.profile:
158+
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
159+
latents = flux_pipe(
160+
prompt_embeds=prompt_embeds,
161+
pooled_prompt_embeds=pooled_prompt_embeds,
162+
num_inference_steps=n_steps,
163+
guidance_scale=guidance,
164+
height=height,
165+
width=width,
166+
output_type="latent",
167+
).images
168+
image = _vae_decode(latents, vae, height, width, device)
169+
image = image_processor.postprocess(image)[0]
170+
inference_time = perf_counter() - ts
171+
logger.info(f"inference time: {inference_time}")
172+
times.append(inference_time)
173+
174+
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
175+
image.save("/tmp/inference_out.png")
176+
metrics_report = met.metrics_report()
177+
with open(metrics_filepath, "w+") as fout:
178+
fout.write(metrics_report)
179+
logger.info(f"saved metric information as {metrics_filepath}")
180+
181+
182+
if __name__ == "__main__":
183+
parser = ArgumentParser()
184+
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
185+
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
186+
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
187+
parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev")
188+
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
189+
parser.add_argument("--profile", action="store_true", help="enable profiling")
190+
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
191+
parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.")
192+
args = parser.parse_args()
193+
main(args)

scripts/convert_longcat_audio_dit_to_diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def convert_longcat_audio_dit(
131131
cross_attn_norm=config.get("dit_cross_attn_norm", False),
132132
eps=config.get("dit_eps", 1e-6),
133133
use_latent_condition=config.get("dit_use_latent_condition", True),
134+
ff_mult=config.get("dit_ff_mult", 4),
134135
)
135136
transformer.load_state_dict(transformer_state_dict, strict=True)
136137
transformer = transformer.to(dtype=torch_dtype)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@
124124
"pytest-xdist",
125125
"python>=3.10.0",
126126
"ruff==0.9.10",
127-
"safetensors>=0.3.1",
127+
"safetensors>=0.8.0-rc.0",
128128
"sentencepiece>=0.1.91,!=0.1.92",
129129
"GitPython<3.1.19",
130130
"scipy",

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,8 @@
458458
"HeliosPyramidDistilledAutoBlocks",
459459
"HeliosPyramidDistilledModularPipeline",
460460
"HeliosPyramidModularPipeline",
461+
"HunyuanVideo15AutoBlocks",
462+
"HunyuanVideo15ModularPipeline",
461463
"LTXAutoBlocks",
462464
"LTXModularPipeline",
463465
"QwenImageAutoBlocks",
@@ -1246,6 +1248,8 @@
12461248
HeliosPyramidDistilledAutoBlocks,
12471249
HeliosPyramidDistilledModularPipeline,
12481250
HeliosPyramidModularPipeline,
1251+
HunyuanVideo15AutoBlocks,
1252+
HunyuanVideo15ModularPipeline,
12491253
LTXAutoBlocks,
12501254
LTXModularPipeline,
12511255
QwenImageAutoBlocks,

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"pytest-xdist": "pytest-xdist",
3232
"python": "python>=3.10.0",
3333
"ruff": "ruff==0.9.10",
34-
"safetensors": "safetensors>=0.3.1",
34+
"safetensors": "safetensors>=0.8.0-rc.0",
3535
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
3636
"GitPython": "GitPython<3.1.19",
3737
"scipy": "scipy",

src/diffusers/models/transformers/transformer_longcat_audio_dit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def __init__(
475475
cross_attn_norm: bool = False,
476476
eps: float = 1e-6,
477477
use_latent_condition: bool = True,
478+
ff_mult: float = 4.0,
478479
):
479480
super().__init__()
480481
dim = dit_dim
@@ -498,7 +499,7 @@ def __init__(
498499
cross_attn_norm=cross_attn_norm,
499500
adaln_type=adaln_type,
500501
adaln_use_text_cond=adaln_use_text_cond,
501-
ff_mult=4.0,
502+
ff_mult=ff_mult,
502503
)
503504
for _ in range(dit_depth)
504505
]

src/diffusers/modular_pipelines/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@
8989
"QwenImageLayeredModularPipeline",
9090
"QwenImageLayeredAutoBlocks",
9191
]
92+
_import_structure["hunyuan_video1_5"] = [
93+
"HunyuanVideo15AutoBlocks",
94+
"HunyuanVideo15ModularPipeline",
95+
]
9296
_import_structure["ltx"] = [
9397
"LTXAutoBlocks",
9498
"LTXModularPipeline",
@@ -124,6 +128,10 @@
124128
HeliosPyramidDistilledModularPipeline,
125129
HeliosPyramidModularPipeline,
126130
)
131+
from .hunyuan_video1_5 import (
132+
HunyuanVideo15AutoBlocks,
133+
HunyuanVideo15ModularPipeline,
134+
)
127135
from .ltx import LTXAutoBlocks, LTXModularPipeline
128136
from .modular_pipeline import (
129137
AutoPipelineBlocks,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["modular_blocks_hunyuan_video1_5"] = [
25+
"HunyuanVideo15AutoBlocks",
26+
]
27+
_import_structure["modular_pipeline"] = ["HunyuanVideo15ModularPipeline"]
28+
29+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30+
try:
31+
if not (is_transformers_available() and is_torch_available()):
32+
raise OptionalDependencyNotAvailable()
33+
except OptionalDependencyNotAvailable:
34+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
35+
else:
36+
from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15AutoBlocks
37+
from .modular_pipeline import HunyuanVideo15ModularPipeline
38+
else:
39+
import sys
40+
41+
sys.modules[__name__] = _LazyModule(
42+
__name__,
43+
globals()["__file__"],
44+
_import_structure,
45+
module_spec=__spec__,
46+
)
47+
48+
for name, value in _dummy_objects.items():
49+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)