diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 9ad13424327..b6ad1bd8932 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -102,7 +102,7 @@ torchrun --nproc_per_node 2 prune_minitron.py \ --hf_model_name_or_path Qwen/Qwen3-8B \ --prune_target_memory_mb 12288 \ --seq_length 4096 \ - --calib_mbs 1 \ + --calib_batch_size 1 \ --output_hf_path /tmp/Qwen3-8B-Pruned-12GB ``` diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 275a4f93c33..c963c12246c 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -53,10 +53,8 @@ import modelopt.torch.prune as mtp import modelopt.torch.utils.distributed as dist from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0 -from modelopt.torch.utils.plugins.mbridge import ( - get_hf_mbridge_calibration_loop, - load_mbridge_model_from_hf, -) +from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf +from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu @@ -104,11 +102,7 @@ def get_args() -> argparse.Namespace: "--calib_num_samples", type=int, default=1024, help="Number of samples for calibration" ) # TODO: Add support for pre-training dataset (pre-tokenized) - # TODO: only allow mbs>1 for pretraining dataset - parser.add_argument( - "--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size" - ) - parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size") + parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size") parser.add_argument("--seq_length", type=int, default=4096) # Pruning parameters parser.add_argument( @@ -164,8 +158,8 @@ def get_args() -> argparse.Namespace: default=None, help=( "Batch size used only for KV-cache sizing in --prune_target_memory_mb. " - "Defaults to --calib_mbs when not set. " - "Use this to target an inference batch size that differs from the calibration micro-batch size." + "Defaults to --calib_batch_size when not set. " + "Use this to target an inference batch size that differs from the calibration batch size." ), ) @@ -296,16 +290,12 @@ def main(args: argparse.Namespace): init_model_parallel=True, moe_grouped_gemm=False, ) - forward_loop = get_hf_mbridge_calibration_loop( - model=model, - provider=provider, - tokenizer=tokenizer, - hf_model_name_or_path=args.hf_model_name_or_path, - trust_remote_code=args.trust_remote_code, + forward_loop = get_megatron_calibration_forward_loop( + tokenizer, dataset_name=args.calib_dataset_name, num_samples=args.calib_num_samples, - micro_batch_size=args.calib_mbs, - global_batch_size=args.calib_gbs, + seq_length=args.seq_length, + batch_size=args.calib_batch_size, ) pruning_config = { @@ -385,7 +375,9 @@ def score_func(m): pruning_config["top_k"] = args.top_k # memory_mb constraint requires batch_size and seq_length pruning_config["batch_size"] = ( - args.inference_batch_size if args.inference_batch_size is not None else args.calib_mbs + args.inference_batch_size + if args.inference_batch_size is not None + else args.calib_batch_size ) pruning_config["seq_length"] = args.seq_length print_rank_0(f"Pruning constraints: {pruning_constraints}") diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 4616b75fc0b..081f5051c32 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -50,9 +50,9 @@ Please see example snippets of both modes for Minitron pruning on Megatron-Bridg ```python import torch import modelopt.torch.prune as mtp -from modelopt.torch.utils.plugins.mbridge import ( - get_hf_mbridge_calibration_loop, - load_mbridge_model_from_hf, +from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf +from modelopt.torch.utils.plugins.megatron_calibration import ( + get_megatron_calibration_forward_loop, ) # Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint @@ -67,13 +67,11 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf ) # Set up the forward loop to run on 1024 train samples -forward_loop = get_hf_mbridge_calibration_loop( - model=model, - provider=provider, - tokenizer=tokenizer, - hf_model_name_or_path="Qwen/Qwen3-8B", +forward_loop = get_megatron_calibration_forward_loop( + tokenizer, dataset_name="nemotron-post-training-dataset-v2", num_samples=1024, + seq_length=4096, ) # Run pruning on the unwrapped model diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 80ed8f9abdd..3202b127237 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -563,7 +563,7 @@ def get_dataset_dataloader( batch_size: int = 1, num_samples: int | list[int] = 512, max_sample_length: int = 512, - device: torch.device | None = None, + device: torch.device | str | None = None, include_labels: bool = False, apply_chat_template: bool = False, ) -> DataLoader: diff --git a/modelopt/torch/utils/plugins/__init__.py b/modelopt/torch/utils/plugins/__init__.py index fd00e423f05..f2f02852906 100644 --- a/modelopt/torch/utils/plugins/__init__.py +++ b/modelopt/torch/utils/plugins/__init__.py @@ -17,6 +17,9 @@ from modelopt.torch.utils import import_plugin +with import_plugin("megatron_calibration"): + from .megatron_calibration import * + with import_plugin("megatron_generate"): from .megatron_generate import * diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index 06c3466b4ef..cc3ac29f93d 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -14,43 +14,23 @@ # limitations under the License. """Megatron-Bridge plugins for using with Model-Optimizer.""" -from collections.abc import Callable from typing import Any -import torch.nn as nn -from datasets import DatasetDict from megatron.bridge import AutoBridge -from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig -from megatron.bridge.data.loaders import setup_data_iterators -from megatron.bridge.data.utils import get_dataset_provider from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.utils import is_safe_repo from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider -from megatron.bridge.training.config import ( - CheckpointConfig, - ConfigContainer, - LoggerConfig, - OptimizerConfig, - SchedulerConfig, - TrainingConfig, - runtime_config_update, -) -from megatron.bridge.training.eval import evaluate_and_print_results -from megatron.bridge.training.gpt_step import forward_step -from megatron.bridge.training.state import GlobalState -from megatron.bridge.training.tokenizers.config import TokenizerConfig from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.mamba import MambaModel -from megatron.core.parallel_state import get_data_parallel_group from megatron.core.transformer.module import MegatronModule from megatron.core.utils import unwrap_model from transformers import AutoTokenizer from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec -from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0 +from modelopt.torch.utils import print_rank_0 -__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"] +__all__ = ["load_mbridge_model_from_hf"] def load_mbridge_model_from_hf( @@ -118,134 +98,3 @@ def load_mbridge_model_from_hf( ) return bridge, provider, model, unwrapped_model, tokenizer - - -def _get_dataset_cfg( - dataset_name: str, - num_samples: int, - seq_length: int, - apply_chat_template: bool = True, - tokenizer: AutoTokenizer | None = None, -) -> HFDatasetConfig: - """Get a dataset config for the dataset.""" - dataset = get_dataset_samples( - dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer - ) - dataset_cfg = HFDatasetConfig( - dataset_name=f"{dataset_name}_{num_samples}", - dataset_dict=DatasetDict({"train": dataset}), - process_example_fn=lambda example, tokenizer: {"input": example, "output": ""}, - seq_length=seq_length, - dataloader_type="batch", - num_workers=1, - do_validation=False, - do_test=False, - val_proportion=None, - split_val_from_train=False, - rewrite=True, - ) - - return dataset_cfg - - -def get_hf_mbridge_calibration_loop( - *, - model: list[MegatronModule], - provider: GPTModelProvider | MambaModelProvider, - tokenizer: AutoTokenizer, - hf_model_name_or_path: str, - trust_remote_code: bool = False, - dataset_name: str = "nemotron-post-training-dataset-v2", - num_samples: int = 512, - micro_batch_size: int = 1, - global_batch_size: int = 1, -) -> Callable[[nn.Module], None]: - """Get a modelopt calibration loop for a Megatron-Bridge model. - - Args: - model: The model to calibrate. - provider: The provider to use for the model. - tokenizer: The tokenizer to use for the model. - hf_model_name_or_path: The name or path of the HF model. - trust_remote_code: Whether to trust remote code. - dataset_name: The name of the dataset to use for evaluation. - num_samples: The number of samples to use for evaluation. - micro_batch_size: The micro batch size to use for evaluation. - global_batch_size: The global batch size to use for evaluation. - - Returns: - A function that can be used to calibrate the model with a modelopt.torch API. - """ - if global_batch_size < micro_batch_size: - warn_rank_0( - f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}." - ) - global_batch_size = micro_batch_size - num_iters = num_samples // global_batch_size - - cfg = ConfigContainer( - model=provider, - train=TrainingConfig( - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - train_iters=num_iters, - eval_iters=num_iters, - skip_train=True, - ), - # TODO: Replace validation args in train with validation config in nemo:26.04 - # validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True), - dataset=_get_dataset_cfg( - dataset_name, - num_samples, - provider.seq_length, - apply_chat_template=True, - tokenizer=tokenizer, - ), - tokenizer=TokenizerConfig( - tokenizer_type="HuggingFaceTokenizer", - tokenizer_model=hf_model_name_or_path, - # NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR - hf_tokenizer_kwargs={ - "trust_remote_code": trust_remote_code, - "use_fast": tokenizer.is_fast, - }, - ), - # Unused - optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False), - scheduler=SchedulerConfig(lr_decay_style="constant"), - logger=LoggerConfig(), - checkpoint=CheckpointConfig(), - ) - runtime_config_update(cfg) - - state = GlobalState() - state.cfg = cfg - - dataset_provider = get_dataset_provider(cfg.dataset) - - def _train_valid_test_datasets_provider( - train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig - ): - return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer) - - train_data_iterator, _, _ = setup_data_iterators( - cfg=cfg, - train_state=state.train_state, - model_length=len(model), - train_valid_test_datasets_provider=_train_valid_test_datasets_provider, - dp_group=get_data_parallel_group(), - ) - - def forward_loop(m): - evaluate_and_print_results( - state, - prefix="iteration 1", - forward_step_func=forward_step, - data_iterator=train_data_iterator, - model=model, - config=cfg, - verbose=True, - write_to_tensorboard=False, - ) - - return forward_loop diff --git a/modelopt/torch/utils/plugins/megatron_calibration.py b/modelopt/torch/utils/plugins/megatron_calibration.py new file mode 100644 index 00000000000..ae8bbbfd458 --- /dev/null +++ b/modelopt/torch/utils/plugins/megatron_calibration.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared calibration forward-loop builder for Megatron-Core models. + +Drives a prefill pass through the model over a calibration dataset, producing the +``forward_loop`` callable that ``mtq.quantize`` / ``mtp.prune`` / ``mtq.calibrate`` +expect. + +Picks the best primitives from each existing path: +- ``get_dataset_dataloader`` for dataset surface (HF registry + JSONL auto-detection, + multi-source blending, one-sample-per-row with batch-padding) +- Per-row trim + EOS-at-row-end before forward, matching MBridge's + ``GPTSFTDataset(add_eos=True)`` semantics. +- ``megatron_prefill(skip_return_logits=True)`` for the forward primitive — skips + returning logits / loss compute compared to the legacy training-step path; the LM + head still runs and activation hooks still fire on every layer. + +Context parallelism: this loop targets CP=1. Splitting a calibration sequence across +CP ranks doesn't help (calibration sequences are short and we want the same activations +on every rank), and ``megatron_prefill`` builds its causal mask / position_ids over the +local tensor length, which would silently produce wrong activations under CP>1. +""" + +import copy +from collections.abc import Callable +from typing import TYPE_CHECKING + +import torch +from megatron.core import parallel_state as mpu +from tqdm import tqdm + +from modelopt.torch.utils import distributed as dist +from modelopt.torch.utils.dataset_utils import get_dataset_dataloader + +from .megatron_generate import megatron_prefill + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +__all__ = ["get_megatron_calibration_forward_loop"] + + +def get_megatron_calibration_forward_loop( + tokenizer: "PreTrainedTokenizerBase", + *, + dataset_name: str | list[str] = "cnn_dailymail", + num_samples: int | list[int] = 512, + seq_length: int = 512, + batch_size: int = 1, + apply_chat_template: bool = True, + device: torch.device | str | None = "cuda", +) -> Callable[[torch.nn.Module], None]: + """Build a Megatron-Core calibration ``forward_loop(model)``. + + The returned callable iterates a one-sample-per-row dataloader, gathers the real + tokens of each row via boolean indexing on the dataloader's ``attention_mask`` (so + left- and right-padded tokenizers both work), forces EOS at the trimmed row's last + position, and drives a logits-free prefill pass through the model so activation + hooks fire. Padding positions are kept out of the forward entirely — they would + otherwise be hooked into calibration statistics regardless of attention masking. + + Behavior mirrors M-Bridge's ``GPTSFTDataset(add_eos=True)`` for the + truncated-row case: each calibration row is one document, capped at ``seq_length``, + with the last position overwritten by EOS as an explicit end-of-document marker. + Under-cap rows lose their natural last content token in exchange for the marker — + a deliberate trade-off so hooks see a consistent EOS signal at row end. + + Args: + tokenizer: HuggingFace tokenizer. + dataset_name: Dataset key (see :func:`get_supported_datasets`), a ``.jsonl`` + path, or a list mixing the two. + num_samples: Number of raw samples to draw. + seq_length: Truncation / padding target per row. + batch_size: Calibration micro-batch size. With variable-length samples and a + mix of short and long, the forward loop iterates per-row when the batch + contains any padding — so true batched throughput requires uniform-length + samples (or all-long samples where every row is truncated to ``seq_length``). + apply_chat_template: Forwarded to :func:`get_dataset_dataloader`. + device: Forwarded to :func:`get_dataset_dataloader`. + + Returns: + A ``forward_loop(model)`` callable to pass into ``mtq.quantize``, + ``mtp.prune``, or other such APIs. + """ + # Deepcopy before mutating pad_token so the caller's tokenizer isn't silently changed. + if getattr(tokenizer, "pad_token", None) is None: + tokenizer = copy.deepcopy(tokenizer) + tokenizer.pad_token = tokenizer.eos_token + + dataloader = get_dataset_dataloader( + dataset_name=dataset_name, + tokenizer=tokenizer, + batch_size=batch_size, + num_samples=num_samples, + max_sample_length=seq_length, + device=device, + apply_chat_template=apply_chat_template, + ) + + eos_id = getattr(tokenizer, "eos_token_id", None) + + # Sort samples by real length descending so front batches are mostly full-length + # (no padding → batched forward). Back batches end up all-short and fall to the + # per-row path. Calibration statistics are order-invariant aggregates (amax / + # channel importance), so this re-ordering doesn't affect quality, just throughput. + all_ids: list[torch.Tensor] = [] + all_masks: list[torch.Tensor] = [] + for sample in dataloader: + all_ids.append(sample["input_ids"]) + all_masks.append(sample.get("attention_mask", torch.ones_like(sample["input_ids"]))) + cat_ids = torch.cat(all_ids, dim=0) + cat_masks = torch.cat(all_masks, dim=0) + # Pre-compute per-row real lengths once on CPU; sort + per-batch padding check both + # read from this CPU tensor, avoiding a CPU-GPU sync inside the forward hot loop. + lengths_cpu = cat_masks.sum(dim=-1).cpu() + sort_idx = torch.argsort(lengths_cpu, descending=True) + sorted_ids = cat_ids[sort_idx] + sorted_masks = cat_masks[sort_idx] + sorted_lengths = lengths_cpu[sort_idx] + seq_len = sorted_ids.shape[-1] + + def _forward_loop(model: torch.nn.Module) -> None: + # ``megatron_prefill`` builds its causal mask + position_ids over the local input + # tensor length, so splitting a calibration sequence across CP ranks would silently + # produce wrong activations. Calibration sequences are short enough that CP doesn't + # help anyway — fail loud rather than ship broken statistics. + cp_size = mpu.get_context_parallel_world_size() + if cp_size != 1: + raise RuntimeError( + f"get_megatron_calibration_forward_loop requires CP=1, got " + f"context_parallel_world_size={cp_size}. Run calibration without CP." + ) + n = sorted_ids.shape[0] + for start in tqdm(range(0, n, batch_size), disable=not dist.is_master()): + ids = sorted_ids[start : start + batch_size] + mask = sorted_masks[start : start + batch_size] + lens = sorted_lengths[start : start + batch_size] + # If any row in this batch has padding, forward each row at its real length so + # calibration hooks don't fire on padding positions — padding tokens contribute + # their (constant-ish) hidden states to the activation statistics, causing a + # substantial MMLU regression on prune calibration. + if bool((lens < seq_len).any()): + for b in range(ids.shape[0]): + # Boolean-mask gather works for both left- and right-padded sequences: + # we extract exactly the real tokens regardless of which side the + # padding sits on. + row = ids[b][mask[b].bool()].unsqueeze(0).clone() + if row.shape[1] < 1: + continue + if eos_id is not None: + # Overwrites the row's last real token with EOS — matches the + # truncated-row case of M-Bridge's ``GPTSFTDataset(add_eos=True)``. + # For under-cap rows, this loses one content token in exchange for + # an explicit end-of-document marker that hooks see during prune + # importance scoring. + row[0, -1] = eos_id + megatron_prefill(model, row, skip_return_logits=True) + else: + if eos_id is not None: + ids = ids.clone() + ids[:, -1] = eos_id + megatron_prefill(model, ids, skip_return_logits=True) + + return _forward_loop diff --git a/tests/_test_utils/torch/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/tokenizer/special_tokens_map.json index 02ee80b6196..344c8261025 100644 --- a/tests/_test_utils/torch/tokenizer/special_tokens_map.json +++ b/tests/_test_utils/torch/tokenizer/special_tokens_map.json @@ -12,5 +12,12 @@ "normalized": false, "rstrip": false, "single_word": false + }, + "pad_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false } } diff --git a/tests/_test_utils/torch/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/tokenizer/tokenizer_config.json index 66600edeef6..bdd427826a5 100644 --- a/tests/_test_utils/torch/tokenizer/tokenizer_config.json +++ b/tests/_test_utils/torch/tokenizer/tokenizer_config.json @@ -3,6 +3,7 @@ "chat_template": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}Q: {{ message['content'] }}{% elif message['role'] == 'assistant' %}A: {{ message['content'] }}{% endif %}{{ eos_token }}{% endfor %}", "clean_up_tokenization_spaces": true, "eos_token": "<|eot_id|>", + "pad_token": "<|eot_id|>", "extra_special_tokens": {}, "model_input_names": [ "input_ids", diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 812d2cd9c3b..03da6cb8d8a 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -515,30 +515,16 @@ def test_legacy_text_fallback_on_hf_builder_failure(self, tmp_path, monkeypatch) # --------------------------------------------------------------------------- -class _FakeTokenizer: - """Minimal callable tokenizer that mimics the HF tokenizer surface used by the dataloader. +@pytest.fixture +def pad_tokenizer(): + """Real tiny HF tokenizer (vocab=128) shared with other test modules. - Tokenizes by character ordinal and left-pads to the longest sample (capped at max_length). - Avoids a hard dependency on ``transformers`` in the test environment. + Skips the test if ``transformers`` isn't installed. """ + pytest.importorskip("transformers") + from _test_utils.torch.transformers_models import get_tiny_tokenizer - padding_side = "left" - pad_token_id = 0 - - def __call__(self, texts, return_tensors=None, padding=True, truncation=True, max_length=16): - ids = [[ord(c) % 100 + 1 for c in t][:max_length] for t in texts] - n = max(len(x) for x in ids) - input_ids = [[self.pad_token_id] * (n - len(x)) + x for x in ids] - attention = [[0] * (n - len(x)) + [1] * len(x) for x in ids] - return { - "input_ids": torch.tensor(input_ids, dtype=torch.long), - "attention_mask": torch.tensor(attention, dtype=torch.long), - } - - -@pytest.fixture -def pad_tokenizer(): - return _FakeTokenizer() + return get_tiny_tokenizer() class TestGetDatasetDataloaderBlending: diff --git a/tools/launcher/common/megatron_lm/quantize/task.py b/tools/launcher/common/megatron_lm/quantize/task.py index 95833fe3960..7ba99202d1c 100644 --- a/tools/launcher/common/megatron_lm/quantize/task.py +++ b/tools/launcher/common/megatron_lm/quantize/task.py @@ -15,7 +15,21 @@ """Megatron-LM PTQ quantization task with typed configuration. -Example YAML (typed config): +NOTE — currently NOT wired into SandboxPipeline: + Under nemo_run/Fiddle YAML loading, dataclass `__post_init__` runs *before* + nested fields like `config` are populated, so `self.config` is None at that + point and `materialize_from_config()` returns early. The previous fix — + a `materialize_from_config()` hook called explicitly by + `SandboxPipeline.__post_init__` after Fiddle finishes building — was + removed to keep `core.py` minimal. As a result, typed-config YAMLs of the + form `_target_: ...MegatronLMQuantizeTask` no longer materialize. Until + that hook is reinstated, use the raw `script`/`args`/`environment` form in + YAMLs (see `examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml`). + + This module is intentionally retained as a reference for the eventual + re-enablement of typed task configs. + +Example YAML (typed config — currently disabled, see note above): task_0: _target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask @@ -29,7 +43,7 @@ _factory_: "slurm_factory" nodes: 1 -Example YAML (raw SandboxTask — still works): +Example YAML (raw SandboxTask — the supported form today): task_0: script: common/megatron_lm/quantize/quantize.sh @@ -39,7 +53,7 @@ environment: - MLM_MODEL_CFG: Qwen/Qwen3-8B - QUANT_CFG: NVFP4_DEFAULT_CFG - - TP: 4 + - TP: "4" """ from dataclasses import dataclass @@ -57,6 +71,7 @@ class MegatronLMQuantizeConfig: tp: Tensor parallelism degree. calib_dataset: Calibration dataset path or HuggingFace repo ID. calib_size: Number of calibration samples. + calib_max_sequence_length: Maximum sequence length for calibration samples. mmlu_dataset: MMLU evaluation dataset path or HuggingFace repo ID. mmlu_fraction: Fraction of MMLU to evaluate (0.0-1.0). mmlu_lower_bound: Minimum MMLU score to pass. @@ -72,6 +87,7 @@ class MegatronLMQuantizeConfig: extra_args: str = "" calib_dataset: str = "abisee/cnn_dailymail" calib_size: int = 32 + calib_max_sequence_length: int = 512 mmlu_dataset: str = "cais/mmlu" mmlu_fraction: float = 0.01 mmlu_lower_bound: float = 0.38 @@ -82,9 +98,10 @@ class MegatronLMQuantizeConfig: class MegatronLMQuantizeTask(SandboxTask): """PTQ quantization task — converts typed config to args/environment. - Set `config` to use typed fields. The task automatically generates - `script`, `args`, and `environment` from the config. You can still - set `slurm_config` directly. + Set `config` to use typed fields. SandboxPipeline calls materialize_from_config() + after Fiddle has populated all fields, which expands the typed config into the + plain `script`, `args`, and `environment` SandboxTask fields. `slurm_config` + is set directly on the task and is not affected. If both `config` and `args`/`environment` are set, `config` takes precedence. """ @@ -92,25 +109,37 @@ class MegatronLMQuantizeTask(SandboxTask): config: MegatronLMQuantizeConfig = None def __post_init__(self): - """Generate script, args, and environment from typed config.""" - if self.config is not None: - c = self.config - self.script = self.script or "common/megatron_lm/quantize/quantize.sh" - args = [ - f"--calib-dataset-path-or-name {c.hf_local}{c.calib_dataset}", - f"--calib-size {c.calib_size}", - ] - if c.extra_args: - args.append(c.extra_args) - self.args = args - self.environment = [ - {"MLM_MODEL_CFG": c.model}, - {"QUANT_CFG": c.quant_cfg}, - {"HF_MODEL_CKPT": f"{c.hf_local}{c.model}"}, - {"MMLU_DATASET": f"{c.hf_local}{c.mmlu_dataset}"}, - {"TP": str(c.tp)}, - {"PP": str(c.pp)}, - {"EP": str(c.ep)}, - {"ETP": str(c.etp)}, - {"MMLU_LOWER_BOUND": str(c.mmlu_lower_bound)}, - ] + # Idempotent: also materializes for direct (non-Fiddle) Python construction, + # where __post_init__ sees `config` already populated as an __init__ kwarg. + # Under nemo_run/Fiddle YAML loading, `config` may still be None here; the + # pipeline calls materialize_from_config() again once the build completes. + self.materialize_from_config() + + def materialize_from_config(self): + """Expand `self.config` into the plain SandboxTask `script`, `args`, `environment` fields. + + Idempotent. Called by SandboxPipeline.__post_init__ once Fiddle has populated `config`. + """ + if self.config is None: + return + c = self.config + self.script = self.script or "common/megatron_lm/quantize/quantize.sh" + args = [ + f"--calib-dataset-path-or-name {c.hf_local}{c.calib_dataset}", + f"--calib-size {c.calib_size}", + f"--calib-max-sequence-length {c.calib_max_sequence_length}", + ] + if c.extra_args: + args.append(c.extra_args) + self.args = args + self.environment = [ + {"MLM_MODEL_CFG": c.model}, + {"QUANT_CFG": c.quant_cfg}, + {"HF_MODEL_CKPT": f"{c.hf_local}{c.model}"}, + {"MMLU_DATASET": f"{c.hf_local}{c.mmlu_dataset}"}, + {"TP": str(c.tp)}, + {"PP": str(c.pp)}, + {"EP": str(c.ep)}, + {"ETP": str(c.etp)}, + {"MMLU_LOWER_BOUND": str(c.mmlu_lower_bound)}, + ] diff --git a/tools/launcher/examples/Qwen/Qwen3-30B-A3B/megatron_lm_ptq.yaml b/tools/launcher/examples/Qwen/Qwen3-30B-A3B/megatron_lm_ptq.yaml index 0eeca6531c9..7f166c5da0c 100644 --- a/tools/launcher/examples/Qwen/Qwen3-30B-A3B/megatron_lm_ptq.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-30B-A3B/megatron_lm_ptq.yaml @@ -5,6 +5,11 @@ # # Usage: # uv run launch.py --yaml examples/Qwen/Qwen3-30B-A3B/megatron_lm_ptq.yaml --yes +# +# NOTE: This file uses the raw `script`/`args`/`environment` form because the +# typed `MegatronLMQuantizeTask` (common/megatron_lm/quantize/task.py) is +# currently not wired into SandboxPipeline — see the docstring there for +# rationale. The typed class is retained for future re-enablement. job_name: Qwen3-30B-A3B_PTQ pipeline: @@ -13,19 +18,20 @@ pipeline: note: task_0: - _target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask - config: - model: Qwen/Qwen3-30B-A3B - quant_cfg: NVFP4_DEFAULT_CFG - tp: 1 - pp: 1 - ep: 8 - etp: 1 - calib_dataset: abisee/cnn_dailymail - calib_size: 32 - mmlu_dataset: cais/mmlu - mmlu_lower_bound: 0.75 - hf_local: /hf-local/ + script: common/megatron_lm/quantize/quantize.sh + args: + - --calib-dataset-path-or-name /hf-local/abisee/cnn_dailymail + - --calib-size 32 + environment: + - MLM_MODEL_CFG: Qwen/Qwen3-30B-A3B + - QUANT_CFG: NVFP4_DEFAULT_CFG + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-30B-A3B + - MMLU_DATASET: /hf-local/cais/mmlu + - TP: "1" + - PP: "1" + - EP: "8" + - ETP: "1" + - MMLU_LOWER_BOUND: "0.75" slurm_config: _factory_: "slurm_factory" nodes: 1 @@ -33,19 +39,20 @@ pipeline: gpus_per_node: 8 task_1: - _target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask - config: - model: Qwen/Qwen3-30B-A3B - quant_cfg: FP8_DEFAULT_CFG - tp: 1 - pp: 1 - ep: 8 - etp: 1 - calib_dataset: abisee/cnn_dailymail - calib_size: 32 - mmlu_dataset: cais/mmlu - mmlu_lower_bound: 0.75 - hf_local: /hf-local/ + script: common/megatron_lm/quantize/quantize.sh + args: + - --calib-dataset-path-or-name /hf-local/abisee/cnn_dailymail + - --calib-size 32 + environment: + - MLM_MODEL_CFG: Qwen/Qwen3-30B-A3B + - QUANT_CFG: FP8_DEFAULT_CFG + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-30B-A3B + - MMLU_DATASET: /hf-local/cais/mmlu + - TP: "1" + - PP: "1" + - EP: "8" + - ETP: "1" + - MMLU_LOWER_BOUND: "0.75" slurm_config: _factory_: "slurm_factory" nodes: 1 diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml index 6ae64fc1ff4..93ed18e1fca 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml @@ -5,13 +5,15 @@ # task_1: FP8 quantize → MMLU → export # task_2: TRT-LLM eval MMLU on all exported checkpoints # -# Uses MegatronLMQuantizeTask with typed config — see common/megatron_lm/quantize/task.py -# for all available fields. -# # Usage: # uv run launch.py --yaml examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml --yes # # For single-GPU local Docker, use megatron_lm_ptq_local.yaml instead. +# +# NOTE: This file uses the raw `script`/`args`/`environment` form because the +# typed `MegatronLMQuantizeTask` (common/megatron_lm/quantize/task.py) is +# currently not wired into SandboxPipeline — see the docstring there for +# rationale. The typed class is retained for future re-enablement. job_name: Qwen3-8B_PTQ pipeline: @@ -20,16 +22,20 @@ pipeline: note: task_0: - _target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask - config: - model: Qwen/Qwen3-8B - quant_cfg: NVFP4_DEFAULT_CFG - tp: 1 - calib_dataset: abisee/cnn_dailymail - calib_size: 32 - mmlu_dataset: cais/mmlu - mmlu_lower_bound: 0.75 - hf_local: /hf-local/ + script: common/megatron_lm/quantize/quantize.sh + args: + - --calib-dataset-path-or-name /hf-local/abisee/cnn_dailymail + - --calib-size 32 + environment: + - MLM_MODEL_CFG: Qwen/Qwen3-8B + - QUANT_CFG: NVFP4_DEFAULT_CFG + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - MMLU_DATASET: /hf-local/cais/mmlu + - TP: "1" + - PP: "1" + - EP: "1" + - ETP: "1" + - MMLU_LOWER_BOUND: "0.68" slurm_config: _factory_: "slurm_factory" nodes: 1 @@ -37,16 +43,20 @@ pipeline: gpus_per_node: 1 task_1: - _target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask - config: - model: Qwen/Qwen3-8B - quant_cfg: FP8_DEFAULT_CFG - tp: 1 - calib_dataset: abisee/cnn_dailymail - calib_size: 32 - mmlu_dataset: cais/mmlu - mmlu_lower_bound: 0.68 - hf_local: /hf-local/ + script: common/megatron_lm/quantize/quantize.sh + args: + - --calib-dataset-path-or-name /hf-local/abisee/cnn_dailymail + - --calib-size 32 + environment: + - MLM_MODEL_CFG: Qwen/Qwen3-8B + - QUANT_CFG: FP8_DEFAULT_CFG + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - MMLU_DATASET: /hf-local/cais/mmlu + - TP: "1" + - PP: "1" + - EP: "1" + - ETP: "1" + - MMLU_LOWER_BOUND: "0.75" slurm_config: _factory_: "slurm_factory" nodes: 1 diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq_local.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq_local.yaml index 5e852520425..5fd55ee56a8 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq_local.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq_local.yaml @@ -1,31 +1,12 @@ # Local single-GPU variant of megatron_lm_ptq.yaml. # -# Uses MegatronLMQuantizeTask with typed config (tp=1, 1 GPU). -# See common/megatron_lm/quantize/task.py for all available fields. -# # Usage: # uv run launch.py --yaml examples/Qwen/Qwen3-8B/megatron_lm_ptq_local.yaml hf_local=/mnt/hf-local --yes # -# ----------------------------------------------------------------------------------- -# Equivalent raw SandboxTask (for reference — shows what MegatronLMQuantizeTask generates): -# -# task_0: -# script: common/megatron_lm/quantize/quantize.sh -# args: -# - --calib-dataset-path-or-name /hf-local/abisee/cnn_dailymail -# - --calib-size 32 -# environment: -# - MLM_MODEL_CFG: Qwen/Qwen3-8B -# - QUANT_CFG: NVFP4_DEFAULT_CFG -# - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B -# - MMLU_DATASET: /hf-local/cais/mmlu -# - TP: 1 -# slurm_config: -# _factory_: "slurm_factory" -# nodes: 1 -# ntasks_per_node: 1 -# gpus_per_node: 1 -# ----------------------------------------------------------------------------------- +# NOTE: This file uses the raw `script`/`args`/`environment` form because the +# typed `MegatronLMQuantizeTask` (common/megatron_lm/quantize/task.py) is +# currently not wired into SandboxPipeline — see the docstring there for +# rationale. The typed class is retained for future re-enablement. job_name: Qwen3-8B_NVFP4_local pipeline: @@ -34,15 +15,20 @@ pipeline: note: task_0: - _target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask - config: - model: Qwen/Qwen3-8B - quant_cfg: NVFP4_DEFAULT_CFG - tp: 1 - calib_dataset: abisee/cnn_dailymail - calib_size: 32 - mmlu_dataset: cais/mmlu - hf_local: /hf-local/ + script: common/megatron_lm/quantize/quantize.sh + args: + - --calib-dataset-path-or-name /hf-local/abisee/cnn_dailymail + - --calib-size 32 + environment: + - MLM_MODEL_CFG: Qwen/Qwen3-8B + - QUANT_CFG: NVFP4_DEFAULT_CFG + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - MMLU_DATASET: /hf-local/cais/mmlu + - TP: "1" + - PP: "1" + - EP: "1" + - ETP: "1" + - MMLU_LOWER_BOUND: "0.38" slurm_config: _factory_: "slurm_factory" nodes: 1 diff --git a/tools/launcher/tests/conftest.py b/tools/launcher/tests/conftest.py index 072518cc795..1886f9bf9cd 100644 --- a/tools/launcher/tests/conftest.py +++ b/tools/launcher/tests/conftest.py @@ -29,16 +29,12 @@ import pytest - -@pytest.fixture(autouse=True) -def add_launcher_to_path(): - """Add the launcher directory to sys.path so core.py and slurm_config.py can be imported.""" - launcher_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - if launcher_dir not in sys.path: - sys.path.insert(0, launcher_dir) - yield - if launcher_dir in sys.path: - sys.path.remove(launcher_dir) +# Make the launcher dir importable so test modules can `import core`, `import slurm_config`, +# etc. at module-load time. conftest.py is imported by pytest before any test module, so +# this mutation is in effect before the first test-module import resolves. +_LAUNCHER_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _LAUNCHER_DIR not in sys.path: + sys.path.insert(0, _LAUNCHER_DIR) @pytest.fixture diff --git a/tools/launcher/tests/test_core.py b/tools/launcher/tests/test_core.py index 6c7e8f04366..b678af15f5b 100644 --- a/tools/launcher/tests/test_core.py +++ b/tools/launcher/tests/test_core.py @@ -26,14 +26,27 @@ """ import os +from dataclasses import dataclass + +from core import ( + _FACTORY_REGISTRY, + GlobalVariables, + SandboxPipeline, + SandboxTask, + SandboxTask0, + SandboxTask1, + create_task_from_yaml, + get_default_env, + register_factory, + report_versions, + set_slurm_config_type, +) class TestSandboxTask: """Tests for the SandboxTask dataclass.""" def test_defaults(self): - from core import SandboxTask - task = SandboxTask() assert task.script is None assert task.slurm_config is None @@ -42,8 +55,6 @@ def test_defaults(self): assert task.skip is False def test_with_values(self): - from core import SandboxTask - task = SandboxTask( script="test.sh", args=["--foo", "bar"], @@ -60,8 +71,6 @@ class TestSandboxPipeline: """Tests for SandboxPipeline task collection and global_vars interpolation.""" def test_task_slots_collected(self): - from core import SandboxPipeline, SandboxTask0, SandboxTask1 - t0 = SandboxTask0(script="a.sh") t1 = SandboxTask1(script="b.sh") pipeline = SandboxPipeline(task_0=t0, task_1=t1) @@ -70,14 +79,10 @@ def test_task_slots_collected(self): assert pipeline.tasks[1].script == "b.sh" def test_empty_pipeline(self): - from core import SandboxPipeline - pipeline = SandboxPipeline() assert pipeline.tasks == [] def test_global_vars_interpolation_in_environment(self): - from core import GlobalVariables, SandboxPipeline, SandboxTask0 - t0 = SandboxTask0( script="test.sh", environment=[{"MODEL": "<>"}], @@ -89,8 +94,6 @@ def test_global_vars_interpolation_in_environment(self): assert pipeline.tasks[0].environment == [{"MODEL": "/hf-local/Qwen/Qwen3-8B"}] def test_global_vars_interpolation_in_args(self): - from core import GlobalVariables, SandboxPipeline, SandboxTask0 - t0 = SandboxTask0( script="test.sh", args=["--model", "<>"], @@ -102,8 +105,6 @@ def test_global_vars_interpolation_in_args(self): assert pipeline.tasks[0].args == ["--model", "/models/llama"] def test_global_vars_unresolved_passthrough(self): - from core import GlobalVariables, SandboxPipeline, SandboxTask0 - t0 = SandboxTask0( script="test.sh", args=["<>"], @@ -116,8 +117,6 @@ def test_global_vars_unresolved_passthrough(self): assert pipeline.tasks[0].args == ["<>"] def test_skip_and_allow_to_fail(self): - from core import SandboxPipeline - pipeline = SandboxPipeline(skip=True, allow_to_fail=True, note="test note") assert pipeline.skip is True assert pipeline.allow_to_fail is True @@ -128,8 +127,6 @@ class TestFactoryRegistry: """Tests for register_factory and its use in create_task_from_yaml.""" def test_register_and_lookup(self, tmp_yaml): - from core import _FACTORY_REGISTRY, register_factory - # Register a mock factory def mock_factory(nodes=1, **kwargs): return {"nodes": nodes, "factory": "mock"} @@ -139,8 +136,6 @@ def mock_factory(nodes=1, **kwargs): assert _FACTORY_REGISTRY["mock_factory"] is mock_factory def test_create_task_from_yaml_uses_registry(self, tmp_yaml): - from core import create_task_from_yaml, register_factory - def test_factory(nodes=1): return {"nodes": nodes} @@ -161,8 +156,6 @@ def test_factory(nodes=1): assert task.slurm_config == {"nodes": 2} def test_task_configs_resolved_via_registry(self, tmp_yaml): - from core import SandboxPipeline, register_factory - def dummy_factory(nodes=1): return {"nodes": nodes} @@ -187,10 +180,6 @@ class TestSetSlurmConfigType: """Tests for set_slurm_config_type annotation patching.""" def test_patches_annotation(self): - from dataclasses import dataclass - - from core import SandboxTask, set_slurm_config_type - @dataclass class MockSlurmConfig: host: str = "test" @@ -204,8 +193,6 @@ class TestGetDefaultEnv: """Tests for get_default_env utility.""" def test_default_title(self): - from core import get_default_env - slurm_env, local_env = get_default_env() assert slurm_env["TRITON_CACHE_DIR"] == "/cicd/triton-cache" assert slurm_env["HF_HOME"] == "/cicd/hf-cache" @@ -215,8 +202,6 @@ def test_default_title(self): assert "LAUNCH_SCRIPT" not in local_env def test_custom_title(self): - from core import get_default_env - slurm_env, local_env = get_default_env("modelopt") assert slurm_env["TRITON_CACHE_DIR"] == "/modelopt/triton-cache" assert slurm_env["HF_HOME"] == "/modelopt/hf-cache" @@ -227,16 +212,12 @@ class TestReportVersions: """Tests for report_versions git info utility.""" def test_runs_on_repo(self, capsys): - from core import report_versions - # Should not raise — runs git on the current repo report_versions(os.getcwd()) captured = capsys.readouterr() assert "Version Report" in captured.out def test_runs_on_nonexistent_dir(self, capsys): - from core import report_versions - # Should handle gracefully — "unknown" for non-git dirs report_versions("/tmp/nonexistent_dir_12345") captured = capsys.readouterr() diff --git a/tools/launcher/tests/test_core_extended.py b/tools/launcher/tests/test_core_extended.py index 9d4ba560407..698ed0aca4d 100644 --- a/tools/launcher/tests/test_core_extended.py +++ b/tools/launcher/tests/test_core_extended.py @@ -28,14 +28,22 @@ from unittest.mock import MagicMock, patch import pytest +from core import ( + GlobalVariables, + SandboxPipeline, + SandboxTask, + SandboxTask0, + _git_info, + create_task_from_yaml, + get_default_env, + run_jobs, +) class TestCreateTaskFromYamlErrors: """Error handling in create_task_from_yaml.""" def test_missing_factory_raises(self, tmp_yaml): - from core import create_task_from_yaml - yaml_content = """ script: test.sh slurm_config: @@ -47,8 +55,6 @@ def test_missing_factory_raises(self, tmp_yaml): create_task_from_yaml(path, factory_lookup={}) def test_missing_slurm_config_raises(self, tmp_yaml): - from core import create_task_from_yaml - yaml_content = """ script: test.sh """ @@ -57,8 +63,6 @@ def test_missing_slurm_config_raises(self, tmp_yaml): create_task_from_yaml(path, factory_lookup={}) def test_environment_preserved(self, tmp_yaml): - from core import create_task_from_yaml - def factory(nodes=1): return {"nodes": nodes} @@ -81,8 +85,6 @@ class TestSandboxPipelineExtended: def test_dict_environment_interpolation(self): """Global vars resolve in dict-format environment (not list).""" - from core import GlobalVariables, SandboxPipeline, SandboxTask0 - t0 = SandboxTask0( script="test.sh", environment={"MODEL": "<>", "STATIC": "value"}, @@ -98,8 +100,6 @@ def test_dict_environment_interpolation(self): def test_tasks_list_directly(self): """Pipeline can receive tasks as a list directly.""" - from core import SandboxPipeline, SandboxTask - tasks = [ SandboxTask(script="a.sh"), SandboxTask(script="b.sh"), @@ -111,8 +111,6 @@ def test_tasks_list_directly(self): def test_no_global_vars_no_error(self): """Pipeline without global_vars doesn't crash on interpolation.""" - from core import SandboxPipeline, SandboxTask0 - t0 = SandboxTask0( script="test.sh", args=["<>"], @@ -126,23 +124,17 @@ class TestGitInfo: """Direct tests for _git_info helper.""" def test_valid_git_repo(self): - from core import _git_info - commit, branch = _git_info(os.getcwd()) assert commit != "unknown" assert branch != "unknown" assert len(commit) >= 7 # short hash def test_nonexistent_directory(self): - from core import _git_info - commit, branch = _git_info("/tmp/nonexistent_xyz_12345") assert commit == "unknown" assert branch == "unknown" def test_non_git_directory(self): - from core import _git_info - # Use /tmp which is outside any git repo commit, branch = _git_info("/tmp") # /tmp may or may not be inside a git worktree depending on the system @@ -158,8 +150,6 @@ class TestRunJobsExtended: @patch("core.build_docker_executor") def test_environment_list_merged_to_env(self, mock_docker, mock_exp, tmp_path): """List-of-dicts environment is merged into task_env.""" - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_inst = MagicMock() mock_exp_inst._id = "exp_env" mock_exp_inst.__enter__ = MagicMock(return_value=mock_exp_inst) @@ -197,8 +187,6 @@ def test_environment_list_merged_to_env(self, mock_docker, mock_exp, tmp_path): @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_none_env_values_converted_to_empty_string(self, mock_docker, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_inst = MagicMock() mock_exp_inst._id = "exp_none" mock_exp_inst.__enter__ = MagicMock(return_value=mock_exp_inst) @@ -234,8 +222,6 @@ def test_none_env_values_converted_to_empty_string(self, mock_docker, mock_exp, @patch("core.build_docker_executor") def test_test_level_filters_pipeline(self, mock_docker, mock_exp, tmp_path): """Pipelines with test_level > current are skipped.""" - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_inst = MagicMock() mock_exp_inst._id = "exp_lvl" mock_exp_inst.__enter__ = MagicMock(return_value=mock_exp_inst) @@ -267,8 +253,6 @@ def test_test_level_filters_pipeline(self, mock_docker, mock_exp, tmp_path): @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_skipped_pipeline_not_run(self, mock_docker, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - slurm_env, local_env = get_default_env() t0 = SandboxTask0(script="test.sh", slurm_config=MagicMock()) @@ -291,8 +275,6 @@ def test_skipped_pipeline_not_run(self, mock_docker, mock_exp, tmp_path): @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_detach_flag_passed_to_experiment(self, mock_docker, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_inst = MagicMock() mock_exp_inst._id = "exp_detach" mock_exp_inst.__enter__ = MagicMock(return_value=mock_exp_inst) @@ -323,8 +305,6 @@ def test_detach_flag_passed_to_experiment(self, mock_docker, mock_exp, tmp_path) @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_version_report_called(self, mock_docker, mock_exp, tmp_path, capsys): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_inst = MagicMock() mock_exp_inst._id = "exp_ver" mock_exp_inst.__enter__ = MagicMock(return_value=mock_exp_inst) diff --git a/tools/launcher/tests/test_docker_execution.py b/tools/launcher/tests/test_docker_execution.py index 6b38b6ccf2c..01b418125c1 100644 --- a/tools/launcher/tests/test_docker_execution.py +++ b/tools/launcher/tests/test_docker_execution.py @@ -26,13 +26,20 @@ import os from unittest.mock import MagicMock, patch +from core import ( + SandboxPipeline, + SandboxTask0, + SandboxTask1, + build_docker_executor, + get_default_env, + run_jobs, +) + class TestBuildDockerExecutor: """Tests for build_docker_executor mount and directory setup.""" def test_scratch_dir_created(self, tmp_path): - from core import build_docker_executor - job_dir = str(tmp_path / "experiments") build_docker_executor( hf_local="/tmp/hf-local", @@ -55,8 +62,6 @@ def test_scratch_dir_created(self, tmp_path): assert os.path.isdir(scratch_dir) def test_hf_local_mount(self, tmp_path): - from core import build_docker_executor - job_dir = str(tmp_path / "experiments") executor = build_docker_executor( hf_local="/my/hf-local", @@ -79,8 +84,6 @@ def test_hf_local_mount(self, tmp_path): assert any("/my/hf-local:/hf-local" in v for v in volumes) def test_scratchspace_mount(self, tmp_path): - from core import build_docker_executor - job_dir = str(tmp_path / "experiments") executor = build_docker_executor( hf_local="/tmp/hf", @@ -104,8 +107,6 @@ def test_scratchspace_mount(self, tmp_path): assert any(f"{expected_scratch}:/scratchspace" in v for v in volumes) def test_modelopt_mount(self, tmp_path): - from core import build_docker_executor - job_dir = str(tmp_path / "experiments") executor = build_docker_executor( hf_local="/tmp/hf", @@ -128,8 +129,6 @@ def test_modelopt_mount(self, tmp_path): assert any("/custom/modelopt:/opt/modelopt" in v for v in volumes) def test_experiment_title_mount(self, tmp_path): - from core import build_docker_executor - job_dir = str(tmp_path / "experiments") executor = build_docker_executor( hf_local="/tmp/hf", @@ -153,8 +152,6 @@ def test_experiment_title_mount(self, tmp_path): assert any(f"{exp_title_path}:/modelopt" in v for v in volumes) def test_local_slurm_config_mounts_preserved(self, tmp_path): - from core import build_docker_executor - job_dir = str(tmp_path / "experiments") executor = build_docker_executor( hf_local="/tmp/hf", @@ -184,8 +181,6 @@ class TestRunJobsDockerPath: @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_docker_executor_called_with_hf_local(self, mock_docker, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_instance = MagicMock() mock_exp_instance._id = "test_exp_001" mock_exp_instance.__enter__ = MagicMock(return_value=mock_exp_instance) @@ -223,8 +218,6 @@ def test_docker_executor_called_with_hf_local(self, mock_docker, mock_exp, tmp_p @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_metadata_written(self, mock_docker, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_instance = MagicMock() mock_exp_instance._id = "test_exp_meta" mock_exp_instance.__enter__ = MagicMock(return_value=mock_exp_instance) @@ -264,8 +257,6 @@ def test_metadata_written(self, mock_docker, mock_exp, tmp_path): @patch("core.run.Experiment") @patch("core.build_docker_executor") def test_skipped_task_not_submitted(self, mock_docker, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, SandboxTask1, get_default_env, run_jobs - mock_exp_instance = MagicMock() mock_exp_instance._id = "test_exp_skip" mock_exp_instance.__enter__ = MagicMock(return_value=mock_exp_instance) @@ -300,8 +291,6 @@ def test_skipped_task_not_submitted(self, mock_docker, mock_exp, tmp_path): @patch("core.run.Experiment") @patch("core.build_slurm_executor") def test_slurm_executor_called_without_hf_local(self, mock_slurm, mock_exp, tmp_path): - from core import SandboxPipeline, SandboxTask0, get_default_env, run_jobs - mock_exp_instance = MagicMock() mock_exp_instance._id = "test_exp_slurm" mock_exp_instance.__enter__ = MagicMock(return_value=mock_exp_instance) diff --git a/tools/launcher/tests/test_slurm_config.py b/tools/launcher/tests/test_slurm_config.py index b23c46c24b9..20f8e3bdab0 100644 --- a/tools/launcher/tests/test_slurm_config.py +++ b/tools/launcher/tests/test_slurm_config.py @@ -22,13 +22,16 @@ SLURM_HF_LOCAL), return type """ +import importlib + +import slurm_config +from slurm_config import SlurmConfig, slurm_factory + class TestSlurmConfig: """Tests for the SlurmConfig dataclass.""" def test_defaults(self): - from slurm_config import SlurmConfig - cfg = SlurmConfig() assert cfg.host is None assert cfg.port == 22 @@ -44,8 +47,6 @@ def test_defaults(self): assert cfg.array is None def test_custom_values(self): - from slurm_config import SlurmConfig - cfg = SlurmConfig( host="login.example.com", account="my_account", @@ -66,54 +67,38 @@ class TestSlurmFactory: """Tests for the slurm_factory function.""" def test_default_returns_slurm_config(self): - from slurm_config import slurm_factory - cfg = slurm_factory() # slurm_factory with @run.autoconvert returns a nemo-run Config wrapper assert "SlurmConfig" in repr(cfg) def test_default_container(self): - from slurm_config import slurm_factory - cfg = slurm_factory() assert "tensorrt-llm" in cfg.container def test_default_srun_args(self): - from slurm_config import slurm_factory - cfg = slurm_factory() assert cfg.srun_args == ["--no-container-mount-home"] def test_default_container_mounts_from_env(self, monkeypatch): monkeypatch.setenv("SLURM_HF_LOCAL", "/custom/hf-local") - # Need to re-import to pick up the env var in the default - # The factory reads SLURM_HF_LOCAL at call time via the default arg - import importlib - - import slurm_config - + # Reload to pick up the env var — slurm_factory reads SLURM_HF_LOCAL at module-import + # time via a default arg, so the qualified slurm_config.slurm_factory call below is + # required (the top-level `from slurm_config import slurm_factory` still points at + # the pre-reload function). importlib.reload(slurm_config) cfg = slurm_config.slurm_factory() assert any("/custom/hf-local:/hf-local" in m for m in cfg.container_mounts) def test_override_nodes(self): - from slurm_config import slurm_factory - cfg = slurm_factory(nodes=8) assert cfg.nodes == 8 def test_override_partition(self): - from slurm_config import slurm_factory - cfg = slurm_factory(partition="gpu") assert cfg.partition == "gpu" def test_env_var_host(self, monkeypatch): monkeypatch.setenv("SLURM_HOST", "test-host.example.com") - import importlib - - import slurm_config - importlib.reload(slurm_config) cfg = slurm_config.slurm_factory() assert cfg.host == "test-host.example.com" diff --git a/tools/launcher/tests/test_slurm_executor.py b/tools/launcher/tests/test_slurm_executor.py index 5f2d1b8dac8..900616136e3 100644 --- a/tools/launcher/tests/test_slurm_executor.py +++ b/tools/launcher/tests/test_slurm_executor.py @@ -22,6 +22,8 @@ from unittest.mock import MagicMock, patch +from core import build_slurm_executor + class TestBuildSlurmExecutor: """Tests for build_slurm_executor mount construction and executor params.""" @@ -29,8 +31,6 @@ class TestBuildSlurmExecutor: @patch("core.run.SlurmExecutor") @patch("core.run.SSHTunnel") def test_scratch_and_modelopt_mounts(self, mock_tunnel, mock_executor): - from core import build_slurm_executor - mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( @@ -74,8 +74,6 @@ def test_scratch_and_modelopt_mounts(self, mock_tunnel, mock_executor): @patch("core.run.SlurmExecutor") @patch("core.run.SSHTunnel") def test_scratch_path_uses_experiment_title(self, mock_tunnel, mock_executor): - from core import build_slurm_executor - mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( @@ -111,8 +109,6 @@ def test_scratch_path_uses_experiment_title(self, mock_tunnel, mock_executor): @patch("core.run.SlurmExecutor") @patch("core.run.SSHTunnel") def test_tunnel_created_with_correct_params(self, mock_tunnel, mock_executor): - from core import build_slurm_executor - mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( @@ -151,8 +147,6 @@ def test_tunnel_created_with_correct_params(self, mock_tunnel, mock_executor): @patch("core.run.SlurmExecutor") @patch("core.run.SSHTunnel") def test_executor_params(self, mock_tunnel, mock_executor): - from core import build_slurm_executor - mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( @@ -198,8 +192,6 @@ def test_executor_params(self, mock_tunnel, mock_executor): @patch("core.run.SlurmExecutor") @patch("core.run.SSHTunnel") def test_none_container_mounts_handled(self, mock_tunnel, mock_executor): - from core import build_slurm_executor - mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( diff --git a/tools/launcher/tests/test_yaml_formats.py b/tools/launcher/tests/test_yaml_formats.py index 981c3221684..86a4863156f 100644 --- a/tools/launcher/tests/test_yaml_formats.py +++ b/tools/launcher/tests/test_yaml_formats.py @@ -24,6 +24,14 @@ """ import yaml +from core import ( + GlobalVariables, + SandboxPipeline, + SandboxTask, + SandboxTask0, + SandboxTask1, + register_factory, +) class TestYamlFormatParsing: @@ -81,7 +89,6 @@ def test_bare_pipeline_format(self, tmp_yaml): def test_task_configs_format(self, tmp_yaml): """task_configs lists YAML files that are resolved into tasks.""" - from core import SandboxPipeline, register_factory def local_factory(nodes=1): return {"nodes": nodes} @@ -108,8 +115,6 @@ def local_factory(nodes=1): def test_environment_list_of_dicts(self): """Environment as list-of-single-key-dicts (nemo-run format).""" - from core import SandboxTask - task = SandboxTask( script="test.sh", environment=[{"A": "1"}, {"B": "2"}, {"C": "3"}], @@ -119,8 +124,6 @@ def test_environment_list_of_dicts(self): def test_global_vars_across_multiple_tasks(self, tmp_yaml): """Global vars resolve in both task_0 and task_1.""" - from core import GlobalVariables, SandboxPipeline, SandboxTask0, SandboxTask1 - t0 = SandboxTask0( script="quantize.sh", args=["--model", "<>"],