diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5c189bd28b..cd8e48e085 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,7 @@ Changelog **Bug Fixes** +- Fix Megatron utility functions for generation (with pipeline parallelism) and ~10x speedup in MMLU score evaluation (by batching prefill passes). - Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this. - Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend. diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 445fdea863..0fa9a658ff 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -18,7 +18,7 @@ while skipping pruning of num_attention_heads using following defaults: 1024 samples from nemotron-post-training-dataset-v2 for calibration, at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...), - top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. + top-10 candidates are evaluated for MMLU score (10% sampled data) to select the best model. torchrun --nproc_per_node 2 prune_minitron.py \ --hf_model_name_or_path Qwen/Qwen3-8B \ @@ -140,11 +140,11 @@ def get_args() -> argparse.Namespace: parser.add_argument( "--prune_score_func", type=str, - default="mmlu_5pct", + default="mmlu_10pct", help=( "Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. " "Format: mmlu_pct where is the percentage of MMLU data to sample per subject " - "(e.g. mmlu_5pct for 5%, mmlu_100pct for full eval)." + "(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)." ), ) parser.add_argument( @@ -299,16 +299,14 @@ def main(args: argparse.Namespace): match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func) if not match: raise ValueError( - f"Invalid score function: {args.prune_score_func}. " - "Expected format: mmlu_pct (e.g. mmlu_5pct)" + f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_pct (e.g. mmlu_10pct)" ) - mmlu_pct = int(match.group(1)) - if not 0 < mmlu_pct <= 100: - raise ValueError("--prune_score_func percentage must be in the range [1, 100].") - _mmlu_pct = mmlu_pct / 100.0 + mmlu_frac = float(match.group(1)) / 100.0 def score_func(m): - return megatron_mmlu(m, tokenizer, percentage=_mmlu_pct) + return megatron_mmlu( + m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs + ) pruning_config["score_func"] = score_func pruning_config["max_width_pruning"] = args.max_width_pruning diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 930e9c6d25..9e84622269 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -124,7 +124,7 @@ This mode can be useful when you don't know the exact dimensions you want to pru from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu def score_func(m): - return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval + return megatron_mmlu(m, tokenizer, fraction=0.1, batch_size=4) # 10% sampled data for faster eval # Specify target parameter count and configure the auto pruning algorithm # Save minitron scores at checkpoint so we can resume pruning without running the forward loop again @@ -147,7 +147,7 @@ mtp.prune(...) 1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model) 2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`) -3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning) +3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~5 min per candidate for an 8B model MMLU score with 10% sampled data) 4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures 5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index a156f2cd8c..b1d37c1ad9 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -747,6 +747,17 @@ def _import_state_dict(self): if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) + # For PP with shared embedding/output weights, re-sync the output layer on the last + # pipeline stage from stage 0's (now HF-loaded) embedding. At model init, + # setup_embeddings_and_output_layer() zeros out the last stage's weight and all-reduces + # from stage 0. After importing HF weights into stage 0's embedding, that sync is stale, + # so we re-run it here. + if ( + model.share_embeddings_and_output_weights + and model.config.pipeline_model_parallel_size > 1 + ): + model.setup_embeddings_and_output_layer() + # MTP if hasattr(model, "mtp"): layer_pbar.set_description("Importing MTP") diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 204d30603e..e99a44a791 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -171,6 +171,9 @@ class CandidateSubnet: score: float | None +torch.serialization.add_safe_globals([CandidateSubnet]) + + class MCoreMinitronSearcher(BaseSearcher): """Searcher for Minitron pruning algorithm. diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index ada1b53612..f5aba0d1a1 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -105,8 +105,9 @@ def no_stdout(): def print_rank_0(*args, **kwargs): """Prints only on the master process.""" + kwargs.setdefault("flush", True) if dist.is_master(): - print(*args, **kwargs, flush=True) + print(*args, **kwargs) def warn_rank_0(message, *args, **kwargs): diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index 0891f58b5b..5625013bb4 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -17,11 +17,15 @@ import torch from megatron.core import mpu -from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage +from megatron.core.inference.communication_utils import ( + broadcast_from_last_pipeline_stage, + recv_from_prev_pipeline_rank_, + send_to_next_pipeline_rank, +) from megatron.core.inference.contexts import StaticInferenceContext -from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.timers import Timer from megatron.core.transformer import MegatronModule +from megatron.core.utils import get_attr_wrapped_model from tqdm import tqdm __all__ = ["megatron_generate", "megatron_prefill"] @@ -48,81 +52,105 @@ def megatron_prefill( image_sizes: torch.LongTensor | None = None, skip_return_logits: bool = False, ) -> torch.Tensor: - """A simple prefill function for Megatron Core V(LM) models.""" + """A simple prefill function for Megatron Core V(LM) models. + + Supports TP, PP, SP, and combinations thereof. For PP, activations are communicated + explicitly between pipeline stages (rather than through get_forward_backward_func) + so that the training pipeline scheduler does not interfere with inference. + """ if not isinstance(model, MegatronModule): raise ValueError("megatron_prefill only supports Megatron Core models.") model.eval() - # Create a static inference context if KV-cache is enabled. - max_batch_size = input_ids.shape[0] + batch_size = input_ids.shape[0] seq_length = input_ids.shape[-1] + device = input_ids.device - def _dummy_loss_func(output_tensor, non_loss_data=True): - """Need a dummy loss function.""" - return output_tensor - - def _forward_step_func(data, model): - """Forward step function.""" - batch_size = data["tokens"].shape[0] - seq_len = data["tokens"].shape[-1] - device = data["tokens"].device - - # ModelOpt transoformer_spec by default use arbitrary attention mask type; hence we need to - # compute the attention_mask for prefilling. Alternatively, if "causal" attention mask type - # is used, the attention_mask is not needed. During generation, the attn_mask_type is overridden - # to "no_mask" by SelfAttention.forward() if inference_context is provided. - attention_mask = ( - torch.triu(torch.ones((batch_size, seq_len, seq_len), device=device), diagonal=1) - .bool() - .view(batch_size, 1, seq_len, seq_len) - ) - - position_ids = ( - torch.arange(seq_len, dtype=torch.long, device=device) - .unsqueeze(0) - .expand(batch_size, -1) - ) - - output_tensor = model( - data["tokens"], - position_ids, - attention_mask, - runtime_gather_output=True, - ) - return output_tensor, _dummy_loss_func + pp_first = mpu.is_pipeline_first_stage() + pp_last = mpu.is_pipeline_last_stage() + is_pp = not (pp_first and pp_last) + pp_dtype = model.config.pipeline_dtype or ( + torch.bfloat16 if model.config.bf16 else torch.float32 + ) if model.config.sequence_parallel: tp = model.config.tensor_model_parallel_size - num_pad_tokens = (tp - input_ids.shape[-1] % tp) % tp + num_pad_tokens = (tp - seq_length % tp) % tp else: num_pad_tokens = 0 if num_pad_tokens > 0: - padding_shape = (input_ids.shape[0], num_pad_tokens) - padded_tokens = torch.full(padding_shape, 0, dtype=input_ids.dtype, device=input_ids.device) - tokens = torch.cat((input_ids, padded_tokens), dim=-1) + tokens = torch.cat( + [ + input_ids, + torch.zeros(batch_size, num_pad_tokens, dtype=input_ids.dtype, device=device), + ], + dim=-1, + ) else: tokens = input_ids - list_of_logits = get_forward_backward_func()( - forward_step_func=_forward_step_func, - data_iterator=[{"tokens": tokens}], - model=model, - num_microbatches=1, - seq_length=tokens.shape[-1], - micro_batch_size=max_batch_size, - decoder_seq_length=tokens.shape[-1], - forward_only=True, - collect_non_loss_data=True, + padded_seq_len = tokens.shape[-1] + + # ModelOpt transformer_spec uses arbitrary attention mask type by default; the causal mask + # must be supplied explicitly for prefill. + attention_mask = ( + torch.triu( + torch.ones((batch_size, padded_seq_len, padded_seq_len), device=device), diagonal=1 + ) + .bool() + .view(batch_size, 1, padded_seq_len, padded_seq_len) + ) + position_ids = ( + torch.arange(padded_seq_len, dtype=torch.long, device=device) + .unsqueeze(0) + .expand(batch_size, -1) ) - if skip_return_logits: - return None - if mpu.is_pipeline_last_stage(): - logits = list_of_logits[0][:, :seq_length, :].detach() + # For PP, receive activations from the previous stage before calling forward. + if is_pp and not pp_first: + pp_dtype = model.config.pipeline_dtype or ( + torch.bfloat16 if model.config.bf16 else torch.float32 + ) + recv_buffer = torch.empty( + (padded_seq_len, batch_size, model.config.hidden_size), + dtype=pp_dtype, + device=device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + get_attr_wrapped_model(model, "set_input_tensor")(recv_buffer) + + has_vision_inputs = ( + pixel_values is not None or image_grid_thw is not None or image_sizes is not None + ) + if has_vision_inputs: + forward_kwargs: dict = { + "input_ids": tokens, + "position_ids": position_ids, + "attention_mask": torch.ones( + (batch_size, padded_seq_len), dtype=torch.bool, device=device + ), + "runtime_gather_output": True, + } + if pixel_values is not None: + forward_kwargs["pixel_values"] = pixel_values + if image_grid_thw is not None: + forward_kwargs["image_grid_thw"] = image_grid_thw + if image_sizes is not None: + forward_kwargs["image_sizes"] = image_sizes + output = model(**forward_kwargs) else: - logits = None + output = model(tokens, position_ids, attention_mask, runtime_gather_output=True) + + # For PP non-last stages, forward activations to the next stage and return early. + if is_pp and not pp_last: + pp_dtype = model.config.pipeline_dtype or ( + torch.bfloat16 if model.config.bf16 else torch.float32 + ) + send_to_next_pipeline_rank(output.to(dtype=pp_dtype)) + + logits = output[:, :seq_length, :].detach() if pp_last else None if model.config.bf16: logits_dtype = torch.bfloat16 @@ -130,11 +158,12 @@ def _forward_step_func(data, model): logits_dtype = torch.float16 else: logits_dtype = torch.float32 - logits = broadcast_from_last_pipeline_stage( - [max_batch_size, seq_length, model.vocab_size], logits_dtype, logits - ) - return logits + # All PP ranks must participate in the broadcast to stay in sync. + result = broadcast_from_last_pipeline_stage( + [batch_size, seq_length, model.vocab_size], logits_dtype, logits + ) + return None if skip_return_logits else result def megatron_generate( @@ -182,6 +211,13 @@ def megatron_generate( model.eval() + pp_first = mpu.is_pipeline_first_stage() + pp_last = mpu.is_pipeline_last_stage() + is_pp = not (pp_first and pp_last) + pp_dtype = model.config.pipeline_dtype or ( + torch.bfloat16 if model.config.bf16 else torch.float32 + ) + # Create a static inference context if KV-cache is enabled. max_batch_size = input_ids.shape[0] max_seq_len = input_ids.shape[-1] + osl @@ -189,20 +225,45 @@ def megatron_generate( StaticInferenceContext(max_batch_size, max_seq_len) if enable_kv_cache else None ) - def _dummy_loss_func(output_tensor, non_loss_data=True): - """Need a dummy loss function.""" - return output_tensor + disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0 + + output_ids = torch.tensor([]) + step_pbar = tqdm(range(osl), disable=disable_tqdm, leave=False) + + time_ttft = 0 + time_remaining_outputs = 0 + timer = Timer("generate") + timer.start(barrier=True) + + for step in step_pbar: + step_pbar.set_description(get_current_memory_info()) - def _forward_step_func(data, model): - """Forward step function.""" - batch_size = data["tokens"].shape[0] - seq_len = data["tokens"].shape[-1] - device = data["tokens"].device + if model.config.sequence_parallel: + tp = model.config.tensor_model_parallel_size + num_pad_tokens = (tp - input_ids.shape[-1] % tp) % tp + else: + num_pad_tokens = 0 - # ModelOpt transoformer_spec by default use arbitrary attention mask type; hence we need to - # compute the attention_mask for prefilling. Alternatively, if "causal" attention mask type - # is used, the attention_mask is not needed. During generation, the attn_mask_type is overridden - # to "no_mask" by SelfAttention.forward() if inference_context is provided. + if inference_context is not None and step > 0: + tokens = input_ids[:, -1:] + inference_context.enable_decode_mode() + num_pad_tokens = 0 + elif num_pad_tokens > 0: + padding_shape = (input_ids.shape[0], num_pad_tokens) + padded_tokens = torch.full( + padding_shape, 0, dtype=input_ids.dtype, device=input_ids.device + ) + tokens = torch.cat((input_ids, padded_tokens), dim=-1) + else: + tokens = input_ids + + batch_size = tokens.shape[0] + seq_len = tokens.shape[-1] + device = tokens.device + + # ModelOpt transformer_spec uses arbitrary attention mask type by default; compute causal + # mask for prefill. During decode, attn_mask_type is overridden to "no_mask" by + # SelfAttention.forward() when inference_context is provided. if seq_len > 1: attention_mask = ( torch.triu(torch.ones((batch_size, seq_len, seq_len), device=device), diagonal=1) @@ -218,109 +279,57 @@ def _forward_step_func(data, model): .expand(batch_size, -1) ) - # Check if this is a VLM model (has vision inputs) - _has_pixel_values = data.get("pixel_values") is not None - _has_image_grid_thw = data.get("image_grid_thw") is not None - _has_image_sizes = data.get("image_sizes") is not None + # Check if this is a VLM model (vision inputs only passed at step 0 / prefill) + _has_pixel_values = step == 0 and pixel_values is not None + _has_image_grid_thw = step == 0 and image_grid_thw is not None + _has_image_sizes = step == 0 and image_sizes is not None has_vision_inputs = _has_pixel_values or _has_image_grid_thw or _has_image_sizes - if has_vision_inputs: - # For VLM models: - # - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions) - # - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal) - vlm_position_ids = ( - torch.arange(seq_len, dtype=torch.long, device=device) - .unsqueeze(0) - .expand(batch_size, -1) + # For PP, receive activations from the previous stage before calling forward. + if is_pp and not pp_first: + recv_buffer = torch.empty( + (seq_len, batch_size, model.config.hidden_size), + dtype=pp_dtype, + device=device, ) - vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device) + recv_from_prev_pipeline_rank_(recv_buffer) + get_attr_wrapped_model(model, "set_input_tensor")(recv_buffer) + if has_vision_inputs: forward_args = { - "input_ids": data["tokens"], - "position_ids": vlm_position_ids, - "attention_mask": vlm_attention_mask, + "input_ids": tokens, + "position_ids": position_ids, + "attention_mask": torch.ones( + (batch_size, seq_len), dtype=torch.bool, device=device + ), "inference_context": inference_context, "runtime_gather_output": True, } - # Add vision inputs if _has_pixel_values: - forward_args["pixel_values"] = data["pixel_values"] + forward_args["pixel_values"] = pixel_values if _has_image_grid_thw: - forward_args["image_grid_thw"] = data["image_grid_thw"] + forward_args["image_grid_thw"] = image_grid_thw if _has_image_sizes: - forward_args["image_sizes"] = data["image_sizes"] - - output_tensor = model(**forward_args) + forward_args["image_sizes"] = image_sizes + output = model(**forward_args) else: - # For text-only LLM models - output_tensor = model( - data["tokens"], + output = model( + tokens, position_ids, attention_mask, inference_context=inference_context, runtime_gather_output=True, ) - return output_tensor, _dummy_loss_func - - disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0 - - output_ids = torch.tensor([]) - step_pbar = tqdm(range(osl), disable=disable_tqdm, leave=False) - - time_ttft = 0 - time_remaining_outputs = 0 - timer = Timer("generate") - timer.start(barrier=True) - - for step in step_pbar: - step_pbar.set_description(get_current_memory_info()) - - if model.config.sequence_parallel: - tp = model.config.tensor_model_parallel_size - num_pad_tokens = (tp - input_ids.shape[-1] % tp) % tp - else: - num_pad_tokens = 0 - - if inference_context is not None and step > 0: - tokens = input_ids[:, -1:] - inference_context.enable_decode_mode() - elif num_pad_tokens > 0: - padding_shape = (input_ids.shape[0], num_pad_tokens) - padded_tokens = torch.full( - padding_shape, 0, dtype=input_ids.dtype, device=input_ids.device - ) - tokens = torch.cat((input_ids, padded_tokens), dim=-1) - else: - tokens = input_ids - - data_dict = {"tokens": tokens} - # Vision inputs should only be passed during prefill (step 0), not during decode steps - if pixel_values is not None: - data_dict["pixel_values"] = pixel_values - if image_grid_thw is not None: - data_dict["image_grid_thw"] = image_grid_thw - if image_sizes is not None: - data_dict["image_sizes"] = image_sizes - - list_of_logits = get_forward_backward_func()( - forward_step_func=_forward_step_func, - data_iterator=[data_dict], - model=model, - num_microbatches=1, - seq_length=tokens.shape[-1], - micro_batch_size=max_batch_size, - decoder_seq_length=tokens.shape[-1], - forward_only=True, - collect_non_loss_data=True, - ) if inference_context is not None: - inference_context.sequence_len_offset += tokens.shape[-1] + inference_context.sequence_len_offset += seq_len - if mpu.is_pipeline_last_stage(): - eager_ids = ( - list_of_logits[0][:, -(num_pad_tokens + 1), :].argmax(dim=-1, keepdim=True).detach() - ) + # For PP non-last stages, forward activations to the next stage. + if is_pp and not pp_last: + send_to_next_pipeline_rank(output.to(dtype=pp_dtype)) + + if pp_last: + eager_ids = output[:, -(num_pad_tokens + 1), :].argmax(dim=-1, keepdim=True).detach() else: eager_ids = None diff --git a/modelopt/torch/utils/plugins/megatron_mmlu.py b/modelopt/torch/utils/plugins/megatron_mmlu.py index fe71bc6ecc..4a07405caf 100644 --- a/modelopt/torch/utils/plugins/megatron_mmlu.py +++ b/modelopt/torch/utils/plugins/megatron_mmlu.py @@ -40,62 +40,60 @@ """A simple MMLU evaluation for Megatron LM models.""" -import requests import torch -import transformers from datasets import load_dataset +from tqdm import tqdm +from transformers import PreTrainedTokenizer -from .megatron_generate import megatron_generate +from .. import distributed as dist +from .. import print_rank_0 +from .megatron_generate import megatron_prefill __all__ = ["megatron_mmlu"] - -def _get_all_subjects(): - """All subjects (anatomy, ...) can be acquired from querying all subsets and splits.""" - response = requests.get( - "https://datasets-server.huggingface.co/splits?dataset=cais/mmlu", timeout=10 - ) - data = response.json() - all_subjects = set() - for split in data["splits"]: - all_subjects.add(split["config"]) - for name in ["all", "auxiliary_train"]: - all_subjects.discard(name) - return sorted(all_subjects) +_CHOICES = ["A", "B", "C", "D"] def megatron_mmlu( model, - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: PreTrainedTokenizer, few_shots: int = 0, - percentage: float = 0.05, - enable_kv_cache: bool = False, + fraction: float = 0.05, + batch_size: int = 1, ) -> float: - """Evaluate the model on MMLU. + """Evaluate the model on MMLU using log-likelihood scoring over batched prefill passes. + + Instead of autoregressively generating tokens, a single prefill forward pass is run per + batch and the answer is selected as argmax over the four choice token logits at the last + prompt position. This is the same approach used by lm-evaluation-harness. Args: model: The model to evaluate. tokenizer: The tokenizer to use. few_shots: The number of few-shot examples to use. - percentage: The percentage of the test set to evaluate on. - enable_kv_cache: Whether to disable KV-cache. + fraction: The fraction of the test set to evaluate on. + batch_size: Number of examples to process in one forward pass. """ - all_correct = {} - all_subjects = _get_all_subjects() + print_rank_0( + f"\nMMLU ({fraction * 100}%, {few_shots}-shot, Batch Size: {batch_size}) evaluation started...\n" + "First batch may take longer to evaluate for Pipeline Parallel models." + ) + assert 0 < fraction <= 1, "Fraction must be between 0 and 1" + + # Token IDs for " A", " B", " C", " D" — the last subword handles edge cases. + choice_ids = [tokenizer.encode(f" {c}", add_special_tokens=False)[-1] for c in _CHOICES] def _format_example(example, include_answer: bool = True): - """Format an example into a multi-choices problem.""" prompt = example["question"] - for choice, answer in zip(["A", "B", "C", "D"], example["choices"]): + for choice, answer in zip(_CHOICES, example["choices"]): prompt += f"\n{choice}. {answer}" if include_answer: - prompt += "Answer: {}\n\n".format(example["answer"]) + prompt += "Answer: {}\n\n".format(_CHOICES[example["answer"]]) else: prompt += "\nAnswer:" return prompt def _generate_prompt(test_example, dev_examples, few_shots=0): - """Generating few-shot prompts.""" prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( " ".join(test_example["subject"].split("_")) ) @@ -104,51 +102,92 @@ def _generate_prompt(test_example, dev_examples, few_shots=0): prompt += _format_example(test_example, include_answer=False) return prompt - if torch.distributed.get_rank() == 0: - print(f"\nMMLU ({percentage * 100}%, {few_shots}-shot) evaluation started...\n", flush=True) - print("{:48} | (ACC) | Count/Total".format("Subject"), flush=True) - print("{:48} | {:5} | {:11}".format("-" * 48, "-" * 5, "-" * 11), flush=True) - - for subject in all_subjects: - test_data = load_dataset("cais/mmlu", subject, split="test") - dev_data = load_dataset("cais/mmlu", subject, split="dev") - - correct = [] - for idx, test_example in enumerate(test_data): - if idx > percentage * len(test_data): - break - prompt = _generate_prompt(test_example, dev_data, few_shots=few_shots) - label = ["A", "B", "C", "D"][test_example["answer"]] - tokens = tokenizer(prompt, return_tensors="pt") - generated_ids = megatron_generate( - model, - tokens.input_ids.cuda(), - osl=2, - disable_tqdm=True, - enable_kv_cache=enable_kv_cache, - ) - predict = tokenizer.batch_decode(generated_ids)[0].strip() - correct += [True] if predict.startswith(label) else [False] - all_correct[subject] = correct - - if torch.distributed.get_rank() == 0: - print( - f"{subject:48} | {sum(correct) / len(correct):.3f} | {sum(correct):5}/{len(correct):5}", - flush=True, - ) - - avg_correct = [] - - for subject, correct in all_correct.items(): - avg_correct += correct - - if torch.distributed.get_rank() == 0: - print("{:48} | {:5} | {:11}".format("-" * 48, "-" * 5, "-" * 11), flush=True) - print( - "{:48} | {:.3f} | {:5}/{:5}".format( - "average", sum(avg_correct) / len(avg_correct), sum(avg_correct), len(avg_correct) - ), - flush=True, - ) - - return sum(avg_correct) / len(avg_correct) + # Load all subjects in two dataset calls instead of 2x num_subjects calls. + # The "all" config includes a "subject" field for per-subject reporting. + test_dataset = load_dataset("cais/mmlu", "all", split="test") + dev_dataset = load_dataset("cais/mmlu", "all", split="dev") if few_shots > 0 else None + + # Group dev examples by subject for few-shot prompt construction. + dev_by_subject: dict = {} + if dev_dataset is not None: + for ex in dev_dataset: + dev_by_subject.setdefault(ex["subject"], []).append(ex) + + # Collect all examples, tracking subject membership for per-subject reporting. + all_subjects_seen: list[str] = [] + all_prompts: list[str] = [] + all_labels: list[str] = [] + + # Count test examples per subject to apply the fraction cutoff correctly. + subject_counts: dict[str, int] = {} + for ex in test_dataset: + subject_counts[ex["subject"]] = subject_counts.get(ex["subject"], 0) + 1 + + subject_idx: dict[str, int] = {} + for ex in test_dataset: + subj = ex["subject"] + idx = subject_idx.get(subj, 0) + if idx >= fraction * subject_counts[subj]: + continue + subject_idx[subj] = idx + 1 + prompt = _generate_prompt(ex, dev_by_subject.get(subj, []), few_shots=few_shots) + all_prompts.append(prompt) + all_labels.append(_CHOICES[ex["answer"]]) + all_subjects_seen.append(subj) + + # Tokenize all prompts and sort by length to minimise padding waste within batches. + encoded = [tokenizer(p, return_tensors="pt").input_ids[0] for p in all_prompts] + lengths = [e.shape[0] for e in encoded] + order = sorted(range(len(encoded)), key=lambda i: lengths[i], reverse=True) + + sorted_encoded = [encoded[i] for i in order] + sorted_lengths = [lengths[i] for i in order] + + # Run inference in global batches. + predictions: list[str] = [""] * len(encoded) + n_batches = (len(sorted_encoded) + batch_size - 1) // batch_size + pbar = tqdm( + range(0, len(sorted_encoded), batch_size), + total=n_batches, + desc="MMLU", + unit="batch", + disable=not dist.is_master(), + ) + for batch_start in pbar: + batch_enc = sorted_encoded[batch_start : batch_start + batch_size] + batch_len = sorted_lengths[batch_start : batch_start + batch_size] + max_len = max(batch_len) + + # Right-pad to max_len; causal mask means the last real token is unaffected by padding. + padded = torch.zeros(len(batch_enc), max_len, dtype=torch.long) + for i, (e, seq_len) in enumerate(zip(batch_enc, batch_len)): + padded[i, :seq_len] = e + + logits = megatron_prefill(model, padded.cuda()) # [B, max_len, vocab] + + for i, seq_len in enumerate(batch_len): + answer_logits = logits[i, seq_len - 1, choice_ids] + predictions[order[batch_start + i]] = _CHOICES[answer_logits.argmax().item()] + + examples_done = min(batch_start + batch_size, len(sorted_encoded)) + pbar.set_postfix(examples=f"{examples_done}/{len(sorted_encoded)}") + + # Compute per-subject accuracy and overall average. + subject_correct: dict[str, list[bool]] = {} + for pred, label, subj in zip(predictions, all_labels, all_subjects_seen): + subject_correct.setdefault(subj, []).append(pred == label) + + all_correct = [pred == label for pred, label in zip(predictions, all_labels)] + n_total = len(all_correct) + avg = sum(all_correct) / n_total + + print_rank_0("{:48} | (ACC) | Count/Total".format("Subject")) + print_rank_0("{:48} | {:5} | {:11}".format("-" * 48, "-" * 5, "-" * 11)) + for subj in sorted(subject_correct): + correct = subject_correct[subj] + n = len(correct) + print_rank_0(f"{subj:48} | {sum(correct) / n:.3f} | {sum(correct):5}/{n:5}") + print_rank_0("{:48} | {:5} | {:11}".format("-" * 48, "-" * 5, "-" * 11)) + print_rank_0("{:48} | {:.3f} | {:5}/{:5}".format("average", avg, sum(all_correct), n_total)) + + return avg diff --git a/tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py b/tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py index 63abe1723f..81fca8ed96 100644 --- a/tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py +++ b/tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import pytest from _test_utils.torch.megatron.models import get_mcore_qwen3_600m from _test_utils.torch.megatron.utils import initialize_for_megatron from transformers import AutoTokenizer @@ -22,12 +22,18 @@ SEED = 1234 +# TODO: move to regression test folder -def _test_megatron_generate_and_mmlu(rank, size): - initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED) - - model = get_mcore_qwen3_600m(tensor_model_parallel_size=size).cuda().eval() +def _test_megatron_generate_and_mmlu(rank, size, parallelism): + if parallelism == "tp": + initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED) + model = get_mcore_qwen3_600m(tensor_model_parallel_size=size).cuda().eval() + elif parallelism == "pp": + initialize_for_megatron(pipeline_model_parallel_size=size, seed=SEED) + model = get_mcore_qwen3_600m(pipeline_model_parallel_size=size).cuda().eval() + else: + raise ValueError(f"Invalid parallelism: {parallelism}") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") messages = [ @@ -42,10 +48,13 @@ def _test_megatron_generate_and_mmlu(rank, size): model_inputs = tokenizer([text], return_tensors="pt").to(device="cuda") output_ids = megatron_generate(model, model_inputs["input_ids"]) output_text = tokenizer.batch_decode(output_ids) - print(output_text) + print(rank, output_text) - assert megatron_mmlu(model, tokenizer) > 0.24 + assert 0.36 < megatron_mmlu(model, tokenizer, fraction=0.1, batch_size=16) < 0.39 -def test_megatron_generate_and_mmlu(dist_workers): - dist_workers.run(_test_megatron_generate_and_mmlu) +@pytest.mark.parametrize("parallelism", ["tp", "pp"]) +def test_megatron_generate_and_mmlu(dist_workers, parallelism, num_gpus): + if num_gpus == 1 and parallelism == "pp": + pytest.skip("Skipping as redundant test on 1 GPU") + dist_workers.run(_test_megatron_generate_and_mmlu, parallelism=parallelism)