Skip to content

perf(diffusion): improve Flux training throughput#2251

Open
pthombre wants to merge 4 commits into
mainfrom
pthombre/am-358-improve-flux1-dev-diffusion-training-performance
Open

perf(diffusion): improve Flux training throughput#2251
pthombre wants to merge 4 commits into
mainfrom
pthombre/am-358-improve-flux1-dev-diffusion-training-performance

Conversation

@pthombre
Copy link
Copy Markdown
Contributor

@pthombre pthombre commented May 15, 2026

What does this PR do ?

Improves FLUX.1-dev diffusion training throughput for both full fine-tuning and LoRA by adding measured performance controls to the diffusion recipe and promoting the best validated Flux configs.

Changelog

  • Add diffusion recipe support for Transformer Engine Linear/FP8 autocast, QKV projection fusion controls, and optimizer _target_/fused AdamW kwargs.
  • Add FP8-safe Transformer Engine Linear conversion in NeMoAutoDiffusionPipeline, including safe-subset filtering for batch-only Flux conditioning/output projections.
  • Fix Flux LoRA DDP handling by freezing base weights before DDP wrapping and setting attention backend through the wrapped transformer module.
  • Extend per-layer compile discovery to Flux transformer_blocks and single_transformer_blocks.
  • Make Flux LoRA generation loading more robust to PEFT alias fields and checkpoint key prefixes.
  • Update Flux full fine-tune defaults to TE FP8 delayed scaling, fused AdamW, FSDP2 prefetch 3/2, local/global batch 4/32, and compile disabled.
  • Update Flux LoRA defaults to replicated DDP, fused AdamW, local/global batch 6/48, and TE FP8 disabled.
  • Remove experiment-only torchao FP8 and training profiler plumbing from the production recipe path.

Experiment summary

The selected defaults come from bounded 8x H100 performance sweeps for FLUX.1-dev full fine-tuning and LoRA fine-tuning. Each promoted setting was required to complete the bounded run, keep loss and gradient norms finite, and avoid regressing generation/checkpoint validation where applicable.

Full fine-tuning

  • Baseline FSDP2 with local/global batch 1/8 measured 16.795 samples/s.
  • Fused AdamW was compatible and gave a small throughput improvement, so the config uses foreach=false and fused=true.
  • FSDP2 prefetch 3/2 was the best low-risk prefetch setting. Deeper 4/3 prefetch regressed and used more memory, so the config keeps 3/2.
  • Larger local batch was the main throughput win. Batch 4/32 without compile completed at 34.70 samples/s; batch 6/48 OOMed and batch 5/40 was slower, so the full fine-tune config uses 4/32.
  • torch.compile remains disabled by default because it helped small-batch runs but failed for batch 4/32.
  • Default attention remains flash: FlashAttention 3 was unavailable in the measured environment, while flex and flash_varlen regressed.
  • QKV fusion remains opt-in because both built-in and compact QKV fusion were stable but did not beat the best default path.
  • Transformer Engine FP8 safe-subset with delayed scaling was the best validated full fine-tune path at about 39.6 samples/s. It beat current scaling, tied deeper prefetch, completed a 500-step larger-data validation run, and passed generation validation from the final checkpoint.

LoRA fine-tuning

  • Baseline FSDP2 LoRA with local/global batch 1/8 measured 21.83 samples/s and was dominated by FSDP all-gather.
  • Replicated DDP removed the dominant all-gather path and improved throughput to 36.39 samples/s in the initial DDP check.
  • DDP batch 6/48 was the best stable batch-size point; batch 7/56 OOMed.
  • Fused AdamW on DDP batch 6/48 remained valid and produced a small/noisy improvement, so it is enabled for LoRA too.
  • Transformer Engine FP8 is disabled for LoRA because DDP batch 6 OOMed and DDP batch 4 was slower than the non-FP8 DDP config.
  • The selected LoRA config was confirmed with checkpointing at 53.57 samples/s; a comparable FSDP checkpoint run measured 22.07 samples/s. Image verification from the baseline and optimized adapters produced coherent, comparable outputs.

Validation

  • uv run ruff format nemo_automodel/_diffusers/auto_diffusion_pipeline.py nemo_automodel/components/distributed/parallelizer.py nemo_automodel/recipes/diffusion/train.py examples/diffusion/generate/generate.py
  • uv run ruff check --fix nemo_automodel/_diffusers/auto_diffusion_pipeline.py nemo_automodel/components/distributed/parallelizer.py nemo_automodel/recipes/diffusion/train.py examples/diffusion/generate/generate.py
  • uv run python -m py_compile nemo_automodel/_diffusers/auto_diffusion_pipeline.py nemo_automodel/components/distributed/parallelizer.py nemo_automodel/recipes/diffusion/train.py examples/diffusion/generate/generate.py
  • Follow-up torchao removal: targeted ruff format, targeted ruff check --fix, py_compile, and git diff --cached --check for the changed diffusion files.
  • Follow-up profiler removal: targeted ruff format, targeted ruff check --fix, py_compile, and git diff --cached --check for the changed diffusion recipe/config files.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests? Performance validation is covered by the 8x H100 experiment sweep summarized above.
  • Did you add or update any necessary documentation? PR includes the experiment-backed config rationale.

Additional Information

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
@pthombre pthombre force-pushed the pthombre/am-358-improve-flux1-dev-diffusion-training-performance branch from 98bc816 to 0e61f24 Compare May 15, 2026 21:40
@pthombre
Copy link
Copy Markdown
Contributor Author

/claude review

@pthombre
Copy link
Copy Markdown
Contributor Author

/ok to test 0e61f24

Comment thread nemo_automodel/_diffusers/auto_diffusion_pipeline.py Outdated
Comment thread nemo_automodel/recipes/diffusion/train.py
Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
pthombre added 2 commits May 15, 2026 15:20
Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
@pthombre
Copy link
Copy Markdown
Contributor Author

/claude review

@pthombre
Copy link
Copy Markdown
Contributor Author

/ok to test ac3388d

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant