Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions examples/inference/basic/basic_flux_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import argparse
import contextlib
Comment on lines +1 to +6
Copy link

Copilot AI Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says the example is examples/inference/basic/flux_dev_t2i.py, but the added example file here is named basic_flux_dev.py. Please align the filename/path in the PR description (or rename/move the script) so contributors can find it easily.

Copilot uses AI. Check for mistakes.
import os
import re

DEFAULT_PROMPTS = [
"a photo of a cat",
(
"a cinematic photo of a red panda wearing a tiny backpack, standing on a "
"rainy neon-lit street at night, shallow depth of field, sharp focus, "
"35mm, bokeh"
),
]


def _safe_filename(text: str, max_len: int = 100) -> str:
"""Make a stable, filesystem-friendly filename base."""
s = text[:max_len].strip()
s = s.replace(os.sep, "_")
if os.altsep:
s = s.replace(os.altsep, "_")
s = re.sub(r"\s+", " ", s)
s = re.sub(r"[^A-Za-z0-9 .,_-]", "_", s)
s = s.strip(" .")
return s or "prompt"


def _remove_existing_outputs(out_dir: str, filename_base: str) -> None:
"""Delete prior outputs so reruns do not get _1, _2 suffixes."""
if not os.path.isdir(out_dir):
return

pattern = re.compile(rf"^{re.escape(filename_base)}(_\d+)?\.(mp4|png)$")
for fn in os.listdir(out_dir):
if pattern.match(fn):
with contextlib.suppress(FileNotFoundError):
os.remove(os.path.join(out_dir, fn))


def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Run FLUX.1-dev text-to-image with FastVideo VideoGenerator.",
)
p.add_argument(
"--model-path",
default="official_weights/FLUX.1-dev",
help="Local Diffusers checkpoint dir or HF repo id.",
)
p.add_argument(
"--out-dir",
"--outdir",
default="outputs/flux_dev/samples",
help="Directory for saved PNG outputs.",
)
p.add_argument(
"--prompt",
action="append",
default=None,
help="Prompt. Repeat for multiple images.",
)
p.add_argument(
"--backend",
default=None,
help="Set FASTVIDEO_ATTENTION_BACKEND (e.g. TORCH_SDPA).",
)
p.add_argument("--seed", type=int, default=42, help="Base seed; each prompt uses seed + index.")
p.add_argument("--height", type=int, default=1024, help="Output height.")
p.add_argument("--width", type=int, default=1024, help="Output width.")
p.add_argument("--steps", type=int, default=28, help="Number of inference steps.")
p.add_argument("--guidance", type=float, default=3.5, help="Guidance scale.")
p.add_argument("--num-gpus", type=int, default=1, help="GPU count.")
return p.parse_args()


def main() -> None:
args = parse_args()
prompts: list[str] = args.prompt if args.prompt else DEFAULT_PROMPTS

if args.backend:
os.environ["FASTVIDEO_ATTENTION_BACKEND"] = args.backend

from fastvideo import VideoGenerator

os.makedirs(args.out_dir, exist_ok=True)

init_kwargs = {
"num_gpus": args.num_gpus,
"workload_type": "t2i",
"sp_size": 1,
"tp_size": 1,
"dit_cpu_offload": False,
"dit_layerwise_offload": False,
"text_encoder_cpu_offload": False,
"vae_cpu_offload": False,
"image_encoder_cpu_offload": False,
"pin_cpu_memory": False,
"use_fsdp_inference": False,
}

generator = VideoGenerator.from_pretrained(
model_path=args.model_path,
**init_kwargs,
)
try:
for i, prompt in enumerate(prompts):
seed = args.seed + i
filename_base = (
f"flux_dev_{i:02d}_seed{seed}_{_safe_filename(prompt, max_len=80)}"
)
_remove_existing_outputs(args.out_dir, filename_base)
output_path = os.path.join(args.out_dir, f"{filename_base}.png")
print(f"[flux] prompt_idx={i} seed={seed} output_path={output_path}")

generation_kwargs = {
"output_path": output_path,
"height": args.height,
"width": args.width,
"num_frames": 1,
"fps": 1,
"num_inference_steps": args.steps,
"guidance_scale": args.guidance,
"use_embedded_guidance": True,
"true_cfg_scale": 1.0,
"seed": seed,
"save_video": True,
}

generator.generate_video(prompt, **generation_kwargs)

print(f"[flux] done. outputs written to: {args.out_dir}")
finally:
generator.shutdown()


if __name__ == "__main__":
main()
27 changes: 27 additions & 0 deletions fastvideo/configs/models/dits/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field

from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig


@dataclass
class FluxTransformer2DArchConfig(DiTArchConfig):

patch_size: int = 1
in_channels: int = 64
out_channels: int | None = None
num_layers: int = 19
num_single_layers: int = 38
attention_head_dim: int = 128
num_attention_heads: int = 24
joint_attention_dim: int = 4096
pooled_projection_dim: int = 768
guidance_embeds: bool = True
axes_dims_rope: tuple[int, int, int] = (16, 56, 56)


@dataclass
class FluxDiTConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=FluxTransformer2DArchConfig)
prefix: str = "flux"
74 changes: 74 additions & 0 deletions fastvideo/configs/pipelines/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field

import torch

from fastvideo.configs.models import EncoderConfig
from fastvideo.configs.models.dits.flux import FluxDiTConfig
from fastvideo.configs.models.encoders import (
BaseEncoderOutput,
CLIPTextConfig,
T5LargeConfig,
)
from fastvideo.configs.models.vaes.autoencoder_kl import AutoencoderKLVAEConfig
from fastvideo.configs.pipelines.base import PipelineConfig, preprocess_text


def _flux_clip_pooled_postprocess(outputs: BaseEncoderOutput) -> torch.Tensor:
"""CLIP branch for FLUX: Diffusers uses pooled prompt embeddings only."""
if outputs.pooler_output is None:
raise RuntimeError(
"FLUX CLIP conditioning requires pooler_output. Ensure the CLIP text encoder returns pooled features.")
return outputs.pooler_output


def _flux_t5_sequence_postprocess(outputs: BaseEncoderOutput) -> torch.Tensor:
if outputs.last_hidden_state is None:
raise RuntimeError("FLUX T5 conditioning requires last_hidden_state.")
return outputs.last_hidden_state


@dataclass
class FluxPipelineConfig(PipelineConfig):
"""Pipeline layout for Diffusers FLUX.1-dev (CLIP + T5 + packed DiT + FlowMatch)."""

scheduler_arch: str = "FlowMatchEulerDiscreteScheduler"
transformer_arch: str = "FluxTransformer2DModel"
vae_arch: str = "AutoencoderKL"
text_encoder_archs: tuple[str, ...] = ("CLIPTextModel", "T5EncoderModel")
tokenizer_archs: tuple[str, ...] = ("CLIPTokenizer", "T5TokenizerFast")

dit_config: FluxDiTConfig = field(default_factory=FluxDiTConfig)
vae_config: AutoencoderKLVAEConfig = field(default_factory=AutoencoderKLVAEConfig)

embedded_cfg_scale: float = 3.5
flow_shift: float | None = None

text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (CLIPTextConfig(), T5LargeConfig()))
preprocess_text_funcs: tuple[Callable[[str], str],
...] = field(default_factory=lambda: (preprocess_text, preprocess_text))
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = field(
default_factory=lambda: (_flux_clip_pooled_postprocess, _flux_t5_sequence_postprocess))

dit_precision: str = "bf16"
vae_precision: str = "fp32"
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32", "bf16"))

def __post_init__(self) -> None:
te_cfgs = list(self.text_encoder_configs)
if len(te_cfgs) >= 1:
te_cfgs[0].tokenizer_kwargs.setdefault("padding", "max_length")
te_cfgs[0].tokenizer_kwargs.setdefault("max_length", 77)
te_cfgs[0].tokenizer_kwargs.setdefault("truncation", True)
te_cfgs[0].tokenizer_kwargs.setdefault("return_tensors", "pt")
if len(te_cfgs) >= 2:
cap = 512
te_cfgs[1].tokenizer_kwargs["max_length"] = min(int(te_cfgs[1].tokenizer_kwargs.get("max_length", cap)),
cap)
te_cfgs[1].tokenizer_kwargs.setdefault("padding", "max_length")
te_cfgs[1].tokenizer_kwargs.setdefault("truncation", True)
te_cfgs[1].tokenizer_kwargs.setdefault("return_tensors", "pt")
4 changes: 4 additions & 0 deletions fastvideo/configs/sample/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class SamplingParam:
num_inference_steps: int = 50
num_inference_steps_sr: int = 50
guidance_scale: float = 1.0
# Embedded guidance (FLUX): do not treat ``guidance_scale > 1`` as classic CFG.
use_embedded_guidance: bool = False
# Diffusers-style true CFG for FLUX when > 1 (requires negative prompt encoding).
true_cfg_scale: float = 1.0
guidance_rescale: float = 0.0
boundary_ratio: float | None = None
sigmas: list[float] | None = None
Expand Down
27 changes: 27 additions & 0 deletions fastvideo/configs/sample/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass

from fastvideo.configs.sample.base import SamplingParam


@dataclass
class FluxSamplingParam(SamplingParam):

prompt: str | None = "a photo of a cat"
negative_prompt: str = ""

num_videos_per_prompt: int = 1
seed: int = 0

num_frames: int = 1
height: int = 1024
width: int = 1024
fps: int = 1

num_inference_steps: int = 28
guidance_scale: float = 3.5
use_embedded_guidance: bool = True
true_cfg_scale: float = 1.0
Loading
Loading