-
Notifications
You must be signed in to change notification settings - Fork 300
ci: Add FLUX/diffusion support to scripts/performance/run_recipe.py #3176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,8 +33,27 @@ | |
| from megatron.bridge.utils.common_utils import get_rank_safe | ||
|
|
||
|
|
||
| # Diffusion model families manage their own dataset configs and require | ||
| # a dedicated forward step function rather than the standard GPT step. | ||
| DIFFUSION_FAMILIES = frozenset({"flux", "wan"}) | ||
|
|
||
|
|
||
| def _get_diffusion_step(model_family_name: str): | ||
| """Return the appropriate forward step instance for a diffusion model family.""" | ||
| if model_family_name == "flux": | ||
| from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep | ||
|
|
||
| return FluxForwardStep() | ||
| elif model_family_name == "wan": | ||
| from megatron.bridge.diffusion.models.wan.wan_step import WanForwardStep | ||
|
|
||
| return WanForwardStep() | ||
| raise ValueError(f"Unknown diffusion model family: {model_family_name!r}") | ||
|
|
||
|
|
||
| def set_user_overrides(config, args): | ||
| """Apply CLI arguments to ConfigContainer fields.""" | ||
| is_diffusion = args.model_family_name in DIFFUSION_FAMILIES | ||
|
|
||
| # Training configuration | ||
| if args.max_steps: | ||
|
|
@@ -67,62 +86,67 @@ def set_user_overrides(config, args): | |
| config.checkpoint.most_recent_k = args.most_recent_k | ||
|
|
||
| # Dataset configuration | ||
| logging.info(f"Configuring dataset: type={args.data}") | ||
|
|
||
| cp_size = getattr(config.model, "context_parallel_size", 1) or 1 | ||
| pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1 | ||
|
|
||
| # Create dataset configuration based on type | ||
| if args.data == "mock": | ||
| config.dataset = create_mock_dataset_config(seq_length=args.seq_length or 8192) | ||
| elif args.data == "rp2": | ||
| if not args.dataset_paths or not args.index_mapping_dir: | ||
| raise ValueError("--dataset-paths and --index-mapping-dir are required for rp2 dataset") | ||
| config.dataset = create_rp2_dataset_config( | ||
| dataset_paths=args.dataset_paths, | ||
| seq_length=args.seq_length or 8192, | ||
| index_mapping_dir=args.index_mapping_dir, | ||
| ) | ||
| elif args.data == "squad": | ||
| if not args.dataset_root: | ||
| raise ValueError("--dataset-root is required for squad dataset") | ||
| config.dataset = create_squad_dataset_config( | ||
| dataset_root=args.dataset_root, | ||
| seq_length=args.seq_length or 8192, | ||
| packed=False, | ||
| pad_seq_to_mult=pad_seq_to_mult, | ||
| ) | ||
| elif args.data == "squad_packed": | ||
| if not args.dataset_root: | ||
| raise ValueError("--dataset-root is required for squad_packed dataset") | ||
| config.dataset = create_squad_dataset_config( | ||
| dataset_root=args.dataset_root, | ||
| seq_length=args.seq_length or 8192, | ||
| packed=True, | ||
| pad_seq_to_mult=pad_seq_to_mult, | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown dataset type: {args.data}") | ||
|
|
||
| # Tokenizer configuration | ||
| from megatron.bridge.training.config import TokenizerConfig | ||
|
|
||
| if args.tokenizer_type == "NullTokenizer": | ||
| config.tokenizer = TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=args.vocab_size) | ||
| elif args.tokenizer_type == "HuggingFaceTokenizer": | ||
| if not args.tokenizer_model: | ||
| raise ValueError("--tokenizer-model is required when using HuggingFaceTokenizer") | ||
| tokenizer_model = args.tokenizer_model | ||
| config.tokenizer = TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=tokenizer_model) | ||
| elif args.tokenizer_type == "SentencePieceTokenizer": | ||
| if not args.tokenizer_model: | ||
| raise ValueError("--tokenizer-model is required for SentencePieceTokenizer") | ||
| config.tokenizer = TokenizerConfig( | ||
| tokenizer_type="SentencePieceTokenizer", tokenizer_model=args.tokenizer_model | ||
| ) | ||
| # Diffusion models (FLUX, WAN) configure their own dataset inside the recipe | ||
| # (data_paths=None → mock/synthetic data by default). Replacing config.dataset | ||
| # with a GPT-style mock config would break them, so skip this block entirely. | ||
| if not is_diffusion: | ||
| logging.info(f"Configuring dataset: type={args.data}") | ||
|
|
||
| cp_size = getattr(config.model, "context_parallel_size", 1) or 1 | ||
| pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1 | ||
|
|
||
| # Create dataset configuration based on type | ||
| if args.data == "mock": | ||
| config.dataset = create_mock_dataset_config(seq_length=args.seq_length or 8192) | ||
| elif args.data == "rp2": | ||
| if not args.dataset_paths or not args.index_mapping_dir: | ||
| raise ValueError("--dataset-paths and --index-mapping-dir are required for rp2 dataset") | ||
| config.dataset = create_rp2_dataset_config( | ||
| dataset_paths=args.dataset_paths, | ||
| seq_length=args.seq_length or 8192, | ||
| index_mapping_dir=args.index_mapping_dir, | ||
| ) | ||
| elif args.data == "squad": | ||
| if not args.dataset_root: | ||
| raise ValueError("--dataset-root is required for squad dataset") | ||
| config.dataset = create_squad_dataset_config( | ||
| dataset_root=args.dataset_root, | ||
| seq_length=args.seq_length or 8192, | ||
| packed=False, | ||
| pad_seq_to_mult=pad_seq_to_mult, | ||
| ) | ||
| elif args.data == "squad_packed": | ||
| if not args.dataset_root: | ||
| raise ValueError("--dataset-root is required for squad_packed dataset") | ||
| config.dataset = create_squad_dataset_config( | ||
| dataset_root=args.dataset_root, | ||
| seq_length=args.seq_length or 8192, | ||
| packed=True, | ||
| pad_seq_to_mult=pad_seq_to_mult, | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown dataset type: {args.data}") | ||
|
|
||
| # Tokenizer configuration | ||
| from megatron.bridge.training.config import TokenizerConfig | ||
|
|
||
| if args.tokenizer_type == "NullTokenizer": | ||
| config.tokenizer = TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=args.vocab_size) | ||
| elif args.tokenizer_type == "HuggingFaceTokenizer": | ||
| if not args.tokenizer_model: | ||
| raise ValueError("--tokenizer-model is required when using HuggingFaceTokenizer") | ||
| tokenizer_model = args.tokenizer_model | ||
| config.tokenizer = TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=tokenizer_model) | ||
| elif args.tokenizer_type == "SentencePieceTokenizer": | ||
| if not args.tokenizer_model: | ||
| raise ValueError("--tokenizer-model is required for SentencePieceTokenizer") | ||
| config.tokenizer = TokenizerConfig( | ||
| tokenizer_type="SentencePieceTokenizer", tokenizer_model=args.tokenizer_model | ||
| ) | ||
|
|
||
| # Model configuration | ||
| if args.seq_length: | ||
| # Diffusion models use fixed image/latent dimensions; seq_length is not applicable. | ||
| if args.seq_length and not is_diffusion: | ||
| config.model.seq_length = args.seq_length | ||
| if args.tensor_model_parallel_size: | ||
| config.model.tensor_model_parallel_size = args.tensor_model_parallel_size | ||
|
|
@@ -207,9 +231,13 @@ def main(): | |
|
|
||
| if args.task == "pretrain": | ||
| logging.info("Starting pretraining") | ||
| from megatron.bridge.training.gpt_step import forward_step | ||
| from megatron.bridge.training.pretrain import pretrain | ||
|
|
||
| if args.model_family_name in DIFFUSION_FAMILIES: | ||
| forward_step = _get_diffusion_step(args.model_family_name) | ||
| else: | ||
| from megatron.bridge.training.gpt_step import forward_step | ||
|
|
||
|
Comment on lines
+236
to
+240
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Diffusion routing still misses the This new branch fixes diffusion Based on learnings: when a feature is not supported, raise an explicit error instead of silently ignoring the input to fail fast with a clear message. 🤖 Prompt for AI Agents
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this comment seems relevant @suiyoubi .
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| pretrain(config=recipe, forward_step_func=forward_step) | ||
| elif args.task in ["sft", "lora"]: | ||
| logging.info("Starting finetuning") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from megatron.bridge.diffusion.recipes.flux import * # noqa: F401, F403 |
Uh oh!
There was an error while loading. Please reload this page.