From 8c1f5404c67954a2a1f67bac82401877f972aa37 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Fri, 3 Apr 2026 07:09:40 -0700 Subject: [PATCH 1/3] ci: add FLUX/diffusion support to scripts/performance/run_recipe.py Add DIFFUSION_FAMILIES set and _get_diffusion_step() helper so the performance CI path can train diffusion models (FLUX, WAN) without overriding their dataset config or using the wrong GPT step function. Also add megatron.bridge.recipes.flux subpackage so get_library_recipe() can locate FLUX recipes via the standard family-name import path. Co-Authored-By: Claude Sonnet 4.6 --- scripts/performance/run_recipe.py | 103 ++++++++++++------- src/megatron/bridge/recipes/flux/__init__.py | 15 +++ 2 files changed, 80 insertions(+), 38 deletions(-) create mode 100644 src/megatron/bridge/recipes/flux/__init__.py diff --git a/scripts/performance/run_recipe.py b/scripts/performance/run_recipe.py index 8f586df86a..973756e974 100644 --- a/scripts/performance/run_recipe.py +++ b/scripts/performance/run_recipe.py @@ -32,9 +32,27 @@ from megatron.bridge.utils.common_utils import get_rank_safe +# Diffusion model families manage their own dataset configs and require +# a dedicated forward step function rather than the standard GPT step. +DIFFUSION_FAMILIES = frozenset({"flux", "wan"}) + + +def _get_diffusion_step(model_family_name: str): + """Return the appropriate forward step instance for a diffusion model family.""" + if model_family_name == "flux": + from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep + + return FluxForwardStep() + elif model_family_name == "wan": + from megatron.bridge.diffusion.models.wan.wan_step import WanForwardStep + + return WanForwardStep() + raise ValueError(f"Unknown diffusion model family: {model_family_name!r}") + def set_user_overrides(config, args): """Apply CLI arguments to ConfigContainer fields.""" + is_diffusion = args.model_family_name in DIFFUSION_FAMILIES # Training configuration if args.max_steps: @@ -67,42 +85,46 @@ def set_user_overrides(config, args): config.checkpoint.most_recent_k = args.most_recent_k # Dataset configuration - logging.info(f"Configuring dataset: type={args.data}") - - cp_size = getattr(config.model, "context_parallel_size", 1) or 1 - pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1 - - # Create dataset configuration based on type - if args.data == "mock": - config.dataset = create_mock_dataset_config(seq_length=args.seq_length or 8192) - elif args.data == "rp2": - if not args.dataset_paths or not args.index_mapping_dir: - raise ValueError("--dataset-paths and --index-mapping-dir are required for rp2 dataset") - config.dataset = create_rp2_dataset_config( - dataset_paths=args.dataset_paths, - seq_length=args.seq_length or 8192, - index_mapping_dir=args.index_mapping_dir, - ) - elif args.data == "squad": - if not args.dataset_root: - raise ValueError("--dataset-root is required for squad dataset") - config.dataset = create_squad_dataset_config( - dataset_root=args.dataset_root, - seq_length=args.seq_length or 8192, - packed=False, - pad_seq_to_mult=pad_seq_to_mult, - ) - elif args.data == "squad_packed": - if not args.dataset_root: - raise ValueError("--dataset-root is required for squad_packed dataset") - config.dataset = create_squad_dataset_config( - dataset_root=args.dataset_root, - seq_length=args.seq_length or 8192, - packed=True, - pad_seq_to_mult=pad_seq_to_mult, - ) - else: - raise ValueError(f"Unknown dataset type: {args.data}") + # Diffusion models (FLUX, WAN) configure their own dataset inside the recipe + # (data_paths=None → mock/synthetic data by default). Replacing config.dataset + # with a GPT-style mock config would break them, so skip this block entirely. + if not is_diffusion: + logging.info(f"Configuring dataset: type={args.data}") + + cp_size = getattr(config.model, "context_parallel_size", 1) or 1 + pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1 + + # Create dataset configuration based on type + if args.data == "mock": + config.dataset = create_mock_dataset_config(seq_length=args.seq_length or 8192) + elif args.data == "rp2": + if not args.dataset_paths or not args.index_mapping_dir: + raise ValueError("--dataset-paths and --index-mapping-dir are required for rp2 dataset") + config.dataset = create_rp2_dataset_config( + dataset_paths=args.dataset_paths, + seq_length=args.seq_length or 8192, + index_mapping_dir=args.index_mapping_dir, + ) + elif args.data == "squad": + if not args.dataset_root: + raise ValueError("--dataset-root is required for squad dataset") + config.dataset = create_squad_dataset_config( + dataset_root=args.dataset_root, + seq_length=args.seq_length or 8192, + packed=False, + pad_seq_to_mult=pad_seq_to_mult, + ) + elif args.data == "squad_packed": + if not args.dataset_root: + raise ValueError("--dataset-root is required for squad_packed dataset") + config.dataset = create_squad_dataset_config( + dataset_root=args.dataset_root, + seq_length=args.seq_length or 8192, + packed=True, + pad_seq_to_mult=pad_seq_to_mult, + ) + else: + raise ValueError(f"Unknown dataset type: {args.data}") # Tokenizer configuration from megatron.bridge.training.config import TokenizerConfig @@ -122,7 +144,8 @@ def set_user_overrides(config, args): ) # Model configuration - if args.seq_length: + # Diffusion models use fixed image/latent dimensions; seq_length is not applicable. + if args.seq_length and not is_diffusion: config.model.seq_length = args.seq_length if args.tensor_model_parallel_size: config.model.tensor_model_parallel_size = args.tensor_model_parallel_size @@ -207,9 +230,13 @@ def main(): if args.task == "pretrain": logging.info("Starting pretraining") - from megatron.bridge.training.gpt_step import forward_step from megatron.bridge.training.pretrain import pretrain + if args.model_family_name in DIFFUSION_FAMILIES: + forward_step = _get_diffusion_step(args.model_family_name) + else: + from megatron.bridge.training.gpt_step import forward_step + pretrain(config=recipe, forward_step_func=forward_step) elif args.task in ["sft", "lora"]: logging.info("Starting finetuning") diff --git a/src/megatron/bridge/recipes/flux/__init__.py b/src/megatron/bridge/recipes/flux/__init__.py new file mode 100644 index 0000000000..06f8372f8e --- /dev/null +++ b/src/megatron/bridge/recipes/flux/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.diffusion.recipes.flux import * # noqa: F401, F403 From 2174f86b964f139825a52ff0602c3e5d3853890c Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Sun, 12 Apr 2026 08:40:05 -0700 Subject: [PATCH 2/3] perf: move tokenizer config inside non-diffusion guard Diffusion models (FLUX, WAN) do not use tokenizers, so the tokenizer configuration block should be skipped for them alongside the dataset block. Co-Authored-By: Claude Sonnet 4.6 --- scripts/performance/run_recipe.py | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/scripts/performance/run_recipe.py b/scripts/performance/run_recipe.py index 973756e974..0030b9c347 100644 --- a/scripts/performance/run_recipe.py +++ b/scripts/performance/run_recipe.py @@ -126,22 +126,22 @@ def set_user_overrides(config, args): else: raise ValueError(f"Unknown dataset type: {args.data}") - # Tokenizer configuration - from megatron.bridge.training.config import TokenizerConfig - - if args.tokenizer_type == "NullTokenizer": - config.tokenizer = TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=args.vocab_size) - elif args.tokenizer_type == "HuggingFaceTokenizer": - if not args.tokenizer_model: - raise ValueError("--tokenizer-model is required when using HuggingFaceTokenizer") - tokenizer_model = args.tokenizer_model - config.tokenizer = TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=tokenizer_model) - elif args.tokenizer_type == "SentencePieceTokenizer": - if not args.tokenizer_model: - raise ValueError("--tokenizer-model is required for SentencePieceTokenizer") - config.tokenizer = TokenizerConfig( - tokenizer_type="SentencePieceTokenizer", tokenizer_model=args.tokenizer_model - ) + # Tokenizer configuration + from megatron.bridge.training.config import TokenizerConfig + + if args.tokenizer_type == "NullTokenizer": + config.tokenizer = TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=args.vocab_size) + elif args.tokenizer_type == "HuggingFaceTokenizer": + if not args.tokenizer_model: + raise ValueError("--tokenizer-model is required when using HuggingFaceTokenizer") + tokenizer_model = args.tokenizer_model + config.tokenizer = TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=tokenizer_model) + elif args.tokenizer_type == "SentencePieceTokenizer": + if not args.tokenizer_model: + raise ValueError("--tokenizer-model is required for SentencePieceTokenizer") + config.tokenizer = TokenizerConfig( + tokenizer_type="SentencePieceTokenizer", tokenizer_model=args.tokenizer_model + ) # Model configuration # Diffusion models use fixed image/latent dimensions; seq_length is not applicable. From 79963809d5689f8b4e78ec67d646cdb8ddaf3214 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Sun, 12 Apr 2026 08:49:04 -0700 Subject: [PATCH 3/3] style: fix import sort in run_recipe.py Co-Authored-By: Claude Sonnet 4.6 --- scripts/performance/run_recipe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/performance/run_recipe.py b/scripts/performance/run_recipe.py index 0030b9c347..24a2b37b32 100644 --- a/scripts/performance/run_recipe.py +++ b/scripts/performance/run_recipe.py @@ -32,6 +32,7 @@ from megatron.bridge.utils.common_utils import get_rank_safe + # Diffusion model families manage their own dataset configs and require # a dedicated forward step function rather than the standard GPT step. DIFFUSION_FAMILIES = frozenset({"flux", "wan"})