[new-model] Port FLUX.1-dev T2I to FastVideo#1228
[new-model] Port FLUX.1-dev T2I to FastVideo#1228Ishxn20 wants to merge 2 commits intohao-ai-lab:mainfrom
Conversation
|
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for:
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for the FLUX.1-dev text-to-image model. It includes the implementation of the FLUX transformer architecture with joint and single-stream attention, specialized pipeline stages for latent packing and flow matching, and configuration updates to handle embedded guidance. Feedback focuses on simplifying redundant model layers, correcting timestep scaling in the forward context, and improving error handling by replacing assertions with explicit exceptions in pipeline stages.
| self.to_out = nn.ModuleList( | ||
| [ | ||
| ReplicatedLinear(self.inner_dim, dim, bias=True), | ||
| nn.Dropout(0.0), | ||
| ] | ||
| ) |
There was a problem hiding this comment.
The nn.Dropout(0.0) is a no-op and adds unnecessary overhead to the model's forward pass. Additionally, using nn.ModuleList for a single projection layer is redundant. It is recommended to simplify self.to_out to a direct ReplicatedLinear layer and update the forward call accordingly.
self.to_out = ReplicatedLinear(self.inner_dim, dim, bias=True)| img_out, _ = self.to_out[0](img_out) | ||
| img_out = self.to_out[1](img_out) |
| get_forward_context() | ||
| forward_context = nullcontext() | ||
| except AssertionError: | ||
| ts0 = int(timestep[0].item()) if timestep.numel() > 0 else 0 |
There was a problem hiding this comment.
The fallback calculation for ts0 uses the scaled timestep (in range [0, 1]) directly as an integer, which will result in either 0 or 1. Since the forward pass later scales this by 1000 (line 516), ts0 should be scaled by 1000 here to correctly represent the raw timestep in the forward context (e.g., for TeaCache or logging).
| ts0 = int(timestep[0].item()) if timestep.numel() > 0 else 0 | |
| ts0 = int(timestep[0].item() * 1000) if timestep.numel() > 0 else 0 |
| ), | ||
| ): | ||
| if use_true_cfg: | ||
| assert neg_enc is not None and neg_pooled is not None |
There was a problem hiding this comment.
Using assert for state validation in pipeline stages is discouraged because assertions can be disabled in production environments (using the -O flag). It is safer to raise a RuntimeError with a descriptive message to handle cases where the required conditioning embeddings are missing.
if neg_enc is None or neg_pooled is None:\n raise RuntimeError('True CFG requires negative prompt embeddings (neg_enc and neg_pooled) to be populated in batch.extra.')There was a problem hiding this comment.
Pull request overview
This PR adds FLUX.1-dev text-to-image support to FastVideo, including a Diffusers-aligned FLUX transformer implementation, a composed FLUX pipeline with packed-latent + FlowMatch scheduling, and associated registration, tests, and an example script.
Changes:
- Introduces FLUX model + pipeline configs and wiring to support
black-forest-labs/FLUX.1-devfor T2I. - Adds FLUX pipeline stages (conditioning, packed-latent prep, FlowMatch
mu, denoise loop, VAE decode) and a newFluxPipeline. - Adds parity + SSIM + local smoke/loader tests, and extends SSIM utilities to support single-frame image outputs.
Reviewed changes
Copilot reviewed 17 out of 18 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
tests/local_tests/pipelines/test_flux_dev_pipeline_smoke.py |
Local end-to-end FLUX T2I smoke test from a local checkpoint. |
tests/local_tests/flux/test_flux_dev_component_loaders.py |
Local loader smoke tests for FLUX tokenizers/encoders/VAE/scheduler. |
tests/local_tests/flux/__init__.py |
Marks local FLUX tests as a package. |
fastvideo/tests/utils.py |
Extends SSIM utilities to read images as 1-frame clips. |
fastvideo/tests/transformers/test_flux.py |
Adds FastVideo vs Diffusers parity test for FluxTransformer2DModel. |
fastvideo/tests/ssim/test_flux_t2i_similarity.py |
Adds FLUX T2I SSIM gate using .png outputs. |
fastvideo/tests/ssim/inference_similarity_utils.py |
Generalizes SSIM harness from “video” to “media” (video or image). |
fastvideo/registry.py |
Registers FLUX pipeline + sampling param and adds model detectors for discovery. |
fastvideo/pipelines/stages/flux_stages.py |
Implements FLUX pipeline stages (pack/unpack, mu shifting, denoise, decode). |
fastvideo/pipelines/pipeline_batch_info.py |
Adds embedded-guidance and true-CFG knobs to ForwardBatch. |
fastvideo/pipelines/basic/flux/flux_pipeline.py |
Adds composed FluxPipeline definition and stage wiring. |
fastvideo/pipelines/basic/flux/__init__.py |
Package marker for the FLUX pipeline. |
fastvideo/models/dits/flux.py |
Adds FluxTransformer2DModel implementation compatible with Diffusers weights. |
fastvideo/configs/sample/flux.py |
Adds FLUX sampling defaults aligned to FLUX.1-dev. |
fastvideo/configs/sample/base.py |
Adds embedded-guidance and true-CFG fields to global sampling params. |
fastvideo/configs/pipelines/flux.py |
Adds FluxPipelineConfig (encoders, precisions, tokenization defaults). |
fastvideo/configs/models/dits/flux.py |
Adds FluxDiTConfig and arch config for the FLUX transformer. |
examples/inference/basic/basic_flux_dev.py |
Adds a minimal CLI example to run FLUX.1-dev T2I and save PNGs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| args = FastVideoArgs( | ||
| model_path=_FLUX_DEV_ROOT, | ||
| pipeline_config=_FluxDevLoaderPipelineConfig(), | ||
| hsdp_shard_dim=1, | ||
| pin_cpu_memory=False, | ||
| ) |
There was a problem hiding this comment.
FastVideoArgs defaults enable text_encoder_cpu_offload (and can trigger FSDP2 fully_shard wrapping for CLIP/T5 since their configs define _fsdp_shard_conditions). In that case te_clip.__class__.__name__ / te_t5.__class__.__name__ won’t end with CLIPTextModel/T5EncoderModel, making these assertions flaky. Either disable text encoder CPU offload explicitly in this test (and/or FSDP inference), or unwrap the underlying module before asserting its type/name.
| assert te_clip.__class__.__name__.endswith("CLIPTextModel") | ||
| assert te_t5.__class__.__name__.endswith("T5EncoderModel") | ||
|
|
There was a problem hiding this comment.
This class-name assertion is not robust when encoders are wrapped by FSDP2 during CPU offload (default FastVideoArgs behavior). Consider unwrapping the wrapped module (or asserting on a stable attribute/config) instead of relying on __class__.__name__ suffix matching.
| assert te_clip.__class__.__name__.endswith("CLIPTextModel") | |
| assert te_t5.__class__.__name__.endswith("T5EncoderModel") | |
| def _unwrap_wrapped_module(module: object) -> object: | |
| current = module | |
| seen: set[int] = set() | |
| while id(current) not in seen: | |
| seen.add(id(current)) | |
| wrapped = getattr(current, "module", None) | |
| if wrapped is not None and wrapped is not current: | |
| current = wrapped | |
| continue | |
| wrapped = getattr(current, "_fsdp_wrapped_module", None) | |
| if wrapped is not None and wrapped is not current: | |
| current = wrapped | |
| continue | |
| break | |
| return current | |
| te_clip_unwrapped = _unwrap_wrapped_module(te_clip) | |
| te_t5_unwrapped = _unwrap_wrapped_module(te_t5) | |
| assert te_clip_unwrapped.__class__.__name__.endswith("CLIPTextModel") | |
| assert te_t5_unwrapped.__class__.__name__.endswith("T5EncoderModel") |
| """Require height/width divisible by 16 (VAE scale × 2 for FLUX packing).""" | ||
|
|
||
| def forward( | ||
| self, | ||
| batch: ForwardBatch, | ||
| fastvideo_args: FastVideoArgs, | ||
| ) -> ForwardBatch: | ||
| if (batch.height is not None and batch.width is not None | ||
| and (batch.height % 16 != 0 or batch.width % 16 != 0)): | ||
| raise ValueError( | ||
| "FLUX expects height and width divisible by 16 " | ||
| f"(VAE latent grid × 2× packing); got {batch.height}×{batch.width}." |
There was a problem hiding this comment.
FluxInputValidationStage hard-codes height/width divisibility by 16, but later stages derive the VAE spatial compression ratio from fastvideo_args.pipeline_config.vae_config.arch_config.spatial_compression_ratio. If that ratio differs from 8, the correct constraint is height % (2*spatial_ratio) == 0 and width % (2*spatial_ratio) == 0 (2x comes from FLUX packing), otherwise invalid inputs can pass this stage and fail later.
| """Require height/width divisible by 16 (VAE scale × 2 for FLUX packing).""" | |
| def forward( | |
| self, | |
| batch: ForwardBatch, | |
| fastvideo_args: FastVideoArgs, | |
| ) -> ForwardBatch: | |
| if (batch.height is not None and batch.width is not None | |
| and (batch.height % 16 != 0 or batch.width % 16 != 0)): | |
| raise ValueError( | |
| "FLUX expects height and width divisible by 16 " | |
| f"(VAE latent grid × 2× packing); got {batch.height}×{batch.width}." | |
| """Require height/width divisible by 2 × VAE spatial compression ratio.""" | |
| def forward( | |
| self, | |
| batch: ForwardBatch, | |
| fastvideo_args: FastVideoArgs, | |
| ) -> ForwardBatch: | |
| arch_config = getattr(getattr(getattr(fastvideo_args, "pipeline_config", None), "vae_config", None), "arch_config", None) | |
| spatial_compression_ratio = getattr(arch_config, "spatial_compression_ratio", 8) | |
| required_divisibility = 2 * spatial_compression_ratio | |
| if (batch.height is not None and batch.width is not None | |
| and (batch.height % required_divisibility != 0 or batch.width % required_divisibility != 0)): | |
| raise ValueError( | |
| "FLUX expects height and width divisible by " | |
| f"{required_divisibility} (2 × VAE spatial compression ratio " | |
| f"{spatial_compression_ratio}); got {batch.height}×{batch.width}." |
| @@ -88,25 +106,29 @@ def _assert_similarity( | |||
| raise FileNotFoundError(error_msg) | |||
There was a problem hiding this comment.
The raised error message still says “Reference video folder…” even though this helper now supports image (non-video) media. Update the message (and the download instructions text if needed) to refer to “reference media” to avoid confusing failures when running T2I SSIM.
| @@ -58,8 +77,8 @@ def compute_video_ssim_torchvision(video1_path, video2_path, use_ms_ssim=True): | |||
| if not os.path.exists(video2_path): | |||
| raise FileNotFoundError(f"Video2 not found: {video2_path}") | |||
There was a problem hiding this comment.
Now that compute_video_ssim_torchvision accepts image paths too, the FileNotFoundError messages (“Video1 not found”, “Video2 not found”) are misleading. Consider switching these to “Media1/Media2 not found” (or include “video/image”) to match the updated behavior.
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import contextlib |
There was a problem hiding this comment.
PR description says the example is examples/inference/basic/flux_dev_t2i.py, but the added example file here is named basic_flux_dev.py. Please align the filename/path in the PR description (or rename/move the script) so contributors can find it easily.
|
|
||
|
|
||
| def test_flux_dev_pipeline_short_run_finite_output( | ||
| monkeypatch: pytest.MonkeyPatch) -> None: |
There was a problem hiding this comment.
Function signature indentation is inconsistent with the surrounding codebase’s typical formatting (and may fail auto-formatting checks). Consider reformatting this definition to standard 4-space continuation indentation (e.g., Black-compatible).
| monkeypatch: pytest.MonkeyPatch) -> None: | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: |
Pre-commit checks failedHi @Ishxn20, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
|
This PR has merge conflicts with the base branch. Please rebase: git fetch origin main
git rebase origin/main
# Resolve any conflicts, then:
git push --force-with-lease |
Summary
Adds FLUX.1-dev text-to-image support to FastVideo: Diffusers-aligned packed latents, FlowMatch
mu, CLIP pooled + T5 sequence conditioning, embedded guidance and optional true CFG viatrue_cfg_scale, plus registry wiring forblack-forest-labs/FLUX.1-dev. Includes parity, loader, pipeline smoke, and SSIM test hooks, a minimal example script, and contributor-oriented test layout (pipeline smoke and checkpoint loader tests undertests/local_tests/).What changed
FluxTransformer2DModelandFluxDiTConfigwith Diffusers-compatible forward (txt_ids/img_ids, timestep scaling, guidance whenguidance_embeds).FluxPipelineand FLUX stages (pack/unpack, latent image ids, schedulermu, denoise loop, VAE denormalize, 5D image output).FluxPipelineConfig,FluxSamplingParam(defaults aligned with FLUX.1-dev).examples/inference/basic/flux_dev_t2i.py.fastvideo/tests/transformers/test_flux.py), SSIM (fastvideo/tests/ssim/test_flux_t2i_similarity.py), local pipeline smoke (tests/local_tests/pipelines/test_flux_dev_pipeline_smoke.py), local component loaders (tests/local_tests/flux/test_flux_dev_component_loaders.py).How to test
From repo root (requires CUDA and
official_weights/FLUX.1-devorFLUX_DEV_ROOT/FLUX_TRANSFORMER_PATHwhere noted):