diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 99150778..a2d19bfb 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -114,7 +114,7 @@ jobs: DATASET_PATH: ${{ github.workspace }}/ShareGPT_V3_unfiltered_cleaned_split.json run: | pytest -s -v -rA tests/utils - pytest -s -v -rA tests/testing + pytest -s -v -rA tests/testing -m "not dpp" # -------------------- Save venv cache -------------------- - name: Save Virtualenv diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 0e343aeb..d1304e88 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -1,143 +1,16 @@ import argparse -from dataclasses import dataclass -import datetime -import itertools -import json -import os -from pathlib import Path -import random -import time -from itertools import dropwhile -import re -from typing import Any, Dict, Iterable, List, Literal, NamedTuple, Optional, Tuple -import torch -from fms.models import get_model -from fms.utils.generation import pad_input_ids -from torch import distributed as dist -from torch.fx.experimental import _config as fx_config -from transformers import AutoTokenizer -from aiu_fms_testing_utils.utils.dpp_config import DPPRunnerConfig -from aiu_fms_testing_utils.utils.env_utils import scoped_environ -from aiu_fms_testing_utils.testing.validation import ( - GoldenTokenHook, - LogitsExtractorHook, - ValidationInfo, - capture_level_1_metrics, - extract_validation_information, - filter_failed_level_1_cases, - find_validation_info_path, - get_validation_info_path, - load_validation_information, - top_k_loss_calculator, +from aiu_fms_testing_utils.testing.dpp.run_drive_paged_programs import run_dpp +from aiu_fms_testing_utils.testing.dpp.program_models import ( + AttnType, + DatasetType, + TestType, ) -from aiu_fms_testing_utils.utils import ( - get_pad_size, - sample_rag_factoid_requests, - sample_sharegpt_requests, - stagger_region, - warmup_model, -) -from aiu_fms_testing_utils.utils.aiu_setup import aiu_dist_setup, dprint, local_rank -from aiu_fms_testing_utils.utils.paged import ( - ProgramCriteria, - get_programs_prompts, -) -from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string - -# Constants -PAD_MULTIPLE = 64 - - -@dataclass -class ProgramInfo: - """Encapsulates program execution criteria. - - Attributes: - program_id: Unique identifier for the program being tested. - batch_size_limit: Numeric threshold for batch size constraint. - batch_size_limit_type: Comparison operator for batch size (e.g., ">=", "<=", "=="). - prompt_length_limit: Numeric threshold for prompt length constraint. - prompt_length_limit_type: Comparison operator for prompt length (e.g., ">=", "<=", "=="). - """ - - program_id: str - batch_size_limit: int - batch_size_limit_type: str - prompt_length_limit: int - prompt_length_limit_type: str - - -class EnvConfig(NamedTuple): - """Represents global configuration derived from environment and CLI. - - Attributes: - attn_name: The internal name of the attention algorithm (e.g., 'spyre_paged_attn'). - cpu_dtype: Data type for CPU validation ('fp8' or 'fp32'). - max_batch_size: Maximum batch size. - max_tkv: Maximum total key-value (context) length. - """ - - attn_name: str - cpu_dtype: str - max_batch_size: int - max_tkv: int - - -class MetricResult(NamedTuple): - """Result of comparing AIU and CPU logit distributions. - - Attributes: - cross_entropy_loss: Cross-entropy loss between the distributions. - mean_abs_diff: Mean absolute difference of softmax probabilities. - """ - - cross_entropy_loss: float - mean_abs_diff: float - - def __str__(self) -> str: - return f"cross_entropy_loss: {self.cross_entropy_loss:.6f}, mean_abs_diff: {self.mean_abs_diff:.6f}" - - -class PreparedInputs(NamedTuple): - """Represents prepared model inputs from dataset sampling. - - Attributes: - input_ids: Padded tensor of tokenized input IDs with shape (batch_size, seq_length). - extra_kwargs: Dictionary with attention mask and other model inputs. - sample_key: String identifier for the sampled prompts. - """ - - input_ids: torch.Tensor - extra_kwargs: Dict[str, Any] - sample_key: str - - -class ValidPrompt(NamedTuple): - """Represents a valid prompt configuration for program execution. - - Attributes: - program_id: ID of the program this prompt will execute. - shape: Tuple of (batch_size, seq_length) for this prompt. - input_ids: Tokenized and padded input tensor. - extra_kwargs: Dictionary with attention mask and other model inputs. - sample_key: String identifier for the sampled prompts. - """ - - program_id: str - shape: Tuple[int, int] - input_ids: torch.Tensor - extra_kwargs: Dict[str, Any] - sample_key: str +from aiu_fms_testing_utils.utils.model_setup import Timing def parse_cli_args() -> argparse.Namespace: - """ - Initializes the argument parser and parses command-line arguments. - - Returns: - argparse.Namespace: An object containing the parsed arguments. - """ + """Initializes the argument parser and parses command-line arguments.""" parser = argparse.ArgumentParser( description="Script which will drive paged programs for debugging" @@ -161,29 +34,27 @@ def parse_cli_args() -> argparse.Namespace: type=int, default=8, help="set this if you want to change the number of tokens generated per sequence (1 prefill + max_new_tokens-1 decodes). Note: If this value is larger than 64, this may result in switching decode programs mid generation", - ) - parser.add_argument( - "--distributed", - action="store_true", - help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", + required=True, ) parser.add_argument( "--model_variant", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b", help="The model id or path to use for this test. Note: must be a huggingface format", + required=True, ) parser.add_argument( "--timing", type=str, - choices=["e2e", "per-token"], - default="", + choices=["none", "e2e", "per-token"], + default="none", help="if set, how to time the generation of tokens, e2e or per-token", ) parser.add_argument( "--program_criteria_json_path", type=str, help="path to json file containing the program criteria list", + required=True, ) parser.add_argument( "--dataset_path", @@ -196,6 +67,7 @@ def parse_cli_args() -> argparse.Namespace: choices=["rag_factoid", "sharegpt", "custom"], default="sharegpt", help="selects the correct dataset type for sampling. Must be one of rag_factoid or sharegpt or custom", + required=True, ) parser.add_argument( "--test_type", @@ -280,879 +152,9 @@ def parse_cli_args() -> argparse.Namespace: return parser.parse_args() -def _prepare_inputs( - batch_size: int, - seq_length: int, - tokenizer: AutoTokenizer, - sampler, - dataset_path: str, - allow_truncation: bool, - enforce_sizes: List[int] = [], - seed: int = 0, -) -> PreparedInputs: - """Prepares and tokenizes input prompts for model inference. - - Samples prompts from a dataset using the provided sampler, tokenizes them, - and pads them to the specified sequence length. Handles cases where fewer - prompts are available than requested by repeating the first prompt. - - Args: - batch_size: Number of prompts to sample for the batch. - seq_length: Target sequence length for padding. - tokenizer: HuggingFace tokenizer for encoding prompts. - sampler: Callable that samples prompts from the dataset. - dataset_path: Path to the dataset file. - allow_truncation: If True, allows truncating prompts longer than seq_length. - enforce_sizes: List of specific sequence lengths to enforce for sampling. - seed: Random seed for reproducible sampling. - - Returns: - Tuple containing: - - input_ids: Padded tensor of tokenized input IDs with shape (batch_size, seq_length). - - extra_kwargs: Dictionary with additional model inputs including attention mask. - - sample_key: String identifier for the sampled prompts. - - Raises: - ValueError: If no valid prompts exist in the dataset for the requested shape. - """ - start = time.time() - prompts_and_sizes, sample_key = sampler( - dataset_path, - batch_size, - tokenizer, - 32, - seq_length * 2 if allow_truncation else seq_length, - seed, - enforce_sizes=enforce_sizes, - truncation=allow_truncation, - return_key=True, - ) - end = time.time() - if local_rank == 0: - dprint(f"extracted prompts in {(end - start):.4f} seconds") - prompt_list = [] - for prompt, size in prompts_and_sizes: - encoded = tokenizer.encode(prompt, return_tensors="pt").squeeze(0) - if size > seq_length: - assert allow_truncation - encoded = encoded[:seq_length] - prompt_list.append(encoded) - - if not prompt_list: - raise ValueError( - f"No valid prompt sample exists in dataset for input shape (Batch Size={batch_size}, Seq Length={seq_length})" - ) - if len(prompt_list) < batch_size: - dprint( - f"You requested {batch_size} prompts but we were only able to get {len(prompt_list)} valid prompts. We will be repeating the first prompt." - ) - prompt_list = [prompt_list[0]] * (batch_size - len(prompt_list)) + prompt_list - - input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) - extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - - return PreparedInputs( - input_ids=input_ids, extra_kwargs=extra_kwargs, sample_key=sample_key - ) - - -def _maybe_prepare_fp8_weights(model: torch.nn.Module, is_fp8: bool) -> None: - """Converts model weights from bfloat16 to float16 for FP8 attention. - - When using FP8 attention variants, this function converts all bfloat16 parameters - to float16. Issues a warning if any parameter values exceed the float16 range, - which may cause accuracy loss. - - Args: - model: PyTorch model whose weights may need conversion. - is_fp8: If True, performs the weight conversion. - """ - if is_fp8: - for name, param in model.named_parameters(): - if param.dtype == torch.bfloat16: - if param.max() > torch.finfo(torch.float16).max: - dprint( - f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended." - ) - param.data = param.data.to(dtype=torch.float16) - - -def _load_validation_info( - model_variant, - batch_size, - seq_length, - max_new_tokens, - tokenizer, - seed, - cpu_dtype: str, - attn_type: str, - validation_info_outputs_dir: str, - sample_key: str | None = None, -) -> ValidationInfo | None: - """Loads pre-computed CPU validation information from disk if available. - - Searches for a previously saved validation info file matching the specified - parameters. If found, loads and returns the validation information to avoid - redundant CPU computation. - - Args: - model_variant: Model identifier or path (HuggingFace format). - batch_size: Batch size used for validation. - seq_length: Sequence length used for validation. - max_new_tokens: Number of tokens to generate during validation. - tokenizer: HuggingFace tokenizer for the model. - seed: Random seed used for validation. - cpu_dtype: Data type string for CPU validation ("fp8" or "fp32"). - attn_type: Attention algorithm type used. - validation_info_outputs_dir: Directory containing saved validation outputs. - sample_key: Optional identifier for the specific prompt sample used. - - Returns: - ValidationInfo object if a matching file is found, None otherwise. - """ - full_path = find_validation_info_path( - validation_info_dir=validation_info_outputs_dir, - model_variant=model_variant, - batch_size=batch_size, - seq_length=seq_length, - max_new_tokens=max_new_tokens, - seed=seed, - attn_type=attn_type, - version_allow_decrement=True, - dtype=cpu_dtype, - sample_key=sample_key, - ) - if full_path is not None: - dprint(f"cpu validation info found for seed={seed} -- loading it") - return load_validation_information(full_path, "logits", batch_size, tokenizer) - else: - return None - - -def parse_program_limit(limit_str: str) -> tuple[int, str | None]: - """Parses a program limit string into a numeric value and comparison operator. - - Accepts either a plain integer (defaults to ">=" for backward compatibility) - or a string with a comparison operator prefix (e.g., ">=10", "<5", "==8"). - - Args: - limit_str: String representation of the limit, either a number or - operator+number (e.g., "10", ">=10", "<5"). - - Returns: - Tuple containing: - - limit_val: The numeric limit value. - - limit_type: The comparison operator string (">=", "<=", "<", ">", "=="). - - Raises: - ValueError: If the limit string format is invalid. - """ - matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)") - - # Default limit to min to maintain backwards compat - try: - limit_type = ">=" - limit_val = int(limit_str) - except ValueError: - limit_type = None - match = matcher.fullmatch(limit_str) - if match is None: - raise ValueError("Program not well formatted, wrong limit type") - limit_type = match.group(1) - limit_val = int(match.group(2)) - return limit_val, limit_type - - -def _metric_calculator(r: torch.Tensor, t: torch.Tensor): - """Calculates cross-entropy and mean absolute difference between logit distributions. - - Args: - r: Reference logits tensor from CPU validation. - t: Test logits tensor from AIU inference. - - Returns: - MetricResult: A named tuple containing the calculated metrics. - """ - cross_entropy_loss = torch.nn.CrossEntropyLoss()( - r, t.softmax(dim=1).to(dtype=torch.float32) - ) - mean_abs_diff = torch.mean( - torch.abs( - r.softmax(dim=1).to(dtype=torch.float32) - - t.softmax(dim=1).to(dtype=torch.float32) - ) - ) - return MetricResult( - cross_entropy_loss=cross_entropy_loss.item(), mean_abs_diff=mean_abs_diff.item() - ) - - -def _get_model_kwargs(model_variant: str) -> Dict[str, Any]: - """Constructs model loading kwargs based on whether variant is a path or ID. - - Determines if the model_variant is a local filesystem path or a HuggingFace - model identifier, and returns the appropriate keyword arguments for model loading. - - Args: - model_variant: Either a local path to model files or a HuggingFace model ID. - - Returns: - Dictionary with either "model_path" (for local paths) or "variant" - (for HuggingFace IDs) as the key. - """ - model_kwargs = {} - if os.path.exists(model_variant): - model_kwargs["model_path"] = model_variant - else: - model_kwargs["variant"] = model_variant - - return model_kwargs - - -def _get_distributed_kwargs( - is_distributed: bool, - dist_timeout: str, -) -> Dict[str, Any]: - """Initializes distributed training configuration and returns kwargs. - - Sets up PyTorch distributed process group with tensor parallelism strategy - if distributed mode is enabled. Configures custom timeout if specified. - - Args: - is_distributed: If True, initializes distributed training setup. - dist_timeout: Timeout in minutes for distributed operations (0 uses default). - - Returns: - Dictionary containing distributed configuration with keys: - - "distributed_strategy": Set to "tp" (tensor parallelism) if distributed. - - "group": PyTorch distributed group (WORLD) if distributed. - Returns empty dict if not distributed. - """ - distributed_kwargs = {} - if is_distributed: - if dist_timeout > 0: - # Default timeout: - # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group - dist.init_process_group(timeout=datetime.timedelta(minutes=dist_timeout)) - dprint(f"NOTICE: init_process_group timeout set to {dist_timeout} minutes") - else: - dist.init_process_group() - - aiu_dist_setup(dist.get_rank(), dist.get_world_size()) - distributed_kwargs["distributed_strategy"] = "tp" - distributed_kwargs["group"] = dist.group.WORLD - - return distributed_kwargs - - -def get_sampler(dataset_type: str, dataset_path: str, tokenizer: AutoTokenizer): - """Selects and configures the sampler based on type. - - Returns a sampler function and configuration for the specified dataset type. - - Args: - dataset_type: Type of dataset ("custom", "rag_factoid", or "sharegpt"). - dataset_path: Path to the dataset file or directory. - tokenizer: HuggingFace tokenizer for encoding prompts. - - Returns: - Tuple containing: - - sampler: Callable function for sampling prompts from the dataset. - - allow_truncation: Boolean indicating if prompt truncation is allowed. - - custom_shape: Tuple of (batch_size, max_seq_length) for custom datasets, - None for other dataset types. - - Raises: - ValueError: If dataset_type is not one of the supported types. - SystemExit: If custom dataset path is not a directory or file reading fails. - """ - custom_shape = None - if dataset_type == "custom": - if local_rank == 0: - dprint( - "Using custom prompts from user, programs parameter will be ignored as it will be determined by user prompt" - ) - directory = Path(dataset_path) - if not directory.is_dir(): - dprint("when using a custom dataset, you must provide a directory") - exit() - - result = [] - for fp in directory.iterdir(): - if fp.is_file(): - try: - content = fp.read_text() - result.append( - (content, get_pad_size(len(tokenizer.encode(content)))) - ) - except Exception as e: - print(f"Error while reading {fp} for custom dataset: {e}") - exit() - - custom_shape = (len(result), max([_[1] for _ in result])) - - def _custom_line_sampler(**kwargs): - """Custom sampler for user-provided text files. - - Returns pre-loaded prompts from custom dataset files without - additional sampling logic. Supports optional sample key return. - - Args: - **kwargs: Keyword arguments, supports "return_key" flag. - - Returns: - List of (prompt, padded_size) tuples, or tuple of (list, sample_key) - if return_key=True. - """ - return_key = kwargs.get("return_key", False) - sample_key = format_kwargs_to_string(**kwargs) - if return_key: - return result, sample_key - return result - - sampler = _custom_line_sampler - allow_truncation = False - elif dataset_type == "rag_factoid": - sampler = sample_rag_factoid_requests - allow_truncation = False - elif dataset_type == "sharegpt": - sampler = sample_sharegpt_requests - allow_truncation = True - else: - raise ValueError("dataset_type must be one of rag_factoid or sharegpt") - - return sampler, allow_truncation, custom_shape - - -def load_model( - device_type: Literal["cpu", "spyre"], - is_fp8: bool, - model_kwargs: Dict[str, Any], - distributed_kwargs: Dict[str, Any], - stagger_load: int, - model_config: DPPRunnerConfig, -): - """Loads and optionally compiles a model for inference or validation. - - Loads a model with the specified configuration. For Spyre/AIU models, - compiles the model using the sendnn backend with dynamic compilation enabled. - The scoped_environ context manager temporarily sets environment variables - from model_config during compilation to configure the compiler's behavior (e.g., - program criteria, batch sizes, context lengths). - - Args: - device_type: Target device for model execution. Options: - - "cpu": Load on CPU for validation (fp32, no compilation) - - "spyre": Load on CPU, compile for Spyre/AIU execution (fp16, with sendnn compilation) - is_fp8: If True, uses FP8 quantization (dtype=None for auto-detection). - model_kwargs: Dictionary with model loading parameters (variant or path). - distributed_kwargs: Dictionary with distributed training configuration. - stagger_load: Number of concurrent processes allowed during loading (0=unlimited). - model_config: DPPRunnerConfig instance with environment variable updates. - - Returns: - torch.nn.Module: Loaded model in evaluation mode. Spyre models are compiled - with sendnn backend and may have FP8 weight conversion applied. - """ - - if device_type not in ["cpu", "spyre"]: - raise ValueError( - f"device_type must be 'cpu' or 'spyre' for DPP, got '{device_type}'" - ) - - dtype = ( - (torch.float32 if device_type == "cpu" else torch.float16) - if not is_fp8 - else None - ) - - with stagger_region(stagger_load): - model = get_model( - architecture="hf_pretrained", - device_type="cpu", - data_type=dtype, - fused_weights=False, - **model_kwargs, - **distributed_kwargs, - ) - - model.eval() - - if device_type == "spyre": - with scoped_environ(model_config.env_updates()): - # Temporarily set environment variables needed for compile - model.compile(backend="sendnn", options={"sendnn.dynamic": True}) - - _maybe_prepare_fp8_weights(model, is_fp8) - - return model - - -def get_programs_to_test(programs, program_criteria_list) -> list[ProgramInfo]: - """Parses program specifications into ProgramInfo objects for testing. - - Converts command-line program specifications into structured ProgramInfo objects. - Supports three formats: - - Empty list: Tests all programs with any valid prompt. - - "program_id": Tests specific program with any valid prompt. - - "program_id:batch_constraint,prompt_constraint": Tests program with specific constraints. - - Args: - programs: List of program specification strings from command line. - program_criteria_list: List of ProgramCriteria objects defining available programs. - - Returns: - List of ProgramInfo objects representing programs to test with their constraints. - """ - programs_to_test = [] - for program_str in programs: - enforce_prompt_split = program_str.split(":") - program_id = enforce_prompt_split[0] - if len(enforce_prompt_split) == 1: - programs_to_test.append( - ProgramInfo(program_id, 0, ">=", 0, ">=") - ) # this will always satisfy - else: - enforce_batch_size, enforce_prompt_length = ( - _ for _ in enforce_prompt_split[1].split(",") - ) - - # Default limit to min to maintain backwards compat - enforce_batch_size_val, enforce_batch_size_type = parse_program_limit( - enforce_batch_size - ) - enforce_prompt_length_val, enforce_prompt_length_type = parse_program_limit( - enforce_prompt_length - ) - - programs_to_test.append( - ProgramInfo( - program_id, - enforce_batch_size_val, - enforce_batch_size_type, - enforce_prompt_length_val, - enforce_prompt_length_type, - ) - ) - - if len(programs_to_test) == 0: - programs_to_test = [ - ProgramInfo(str(p.program_id), 0, ">=", 0, ">=") - for p in program_criteria_list - ] - - return programs_to_test - - -def get_valid_prompts( - program_map, - dataset_path: str, - enforce_homogeneous_prompt_programs: bool, - programs_to_test: List[ProgramInfo], - program_criteria_list: List[ProgramCriteria], - tokenizer: AutoTokenizer, - sampler, - allow_truncation: bool, - custom_shape: Optional[Tuple[int, int]], - pad_multiple: int, -): - """Generator that yields valid prompts matching program criteria and constraints. - - Iterates through programs to test and finds prompts from the dataset that satisfy - the program's batch size and prompt length constraints. For custom datasets, uses - the provided shape directly. For other datasets, samples prompts matching the - program criteria. When enforce_homogeneous_prompt_programs is True, generates - multiple sequence lengths within a batch to ensure all prompts hit the same program. - - Args: - program_map: Dictionary mapping program sequences to valid prompt shapes. - dataset_path: Path to the dataset for sampling prompts. - enforce_homogeneous_prompt_programs: If True, ensures all prompts in a batch - use the same decode program. - programs_to_test: List of ProgramInfo objects specifying programs and constraints. - program_criteria_list: List of ProgramCriteria defining program boundaries. - tokenizer: HuggingFace tokenizer for encoding prompts. - sampler: Callable for sampling prompts from the dataset. - allow_truncation: If True, allows truncating prompts exceeding max length. - custom_shape: Optional tuple of (batch_size, seq_length) for custom datasets. - pad_multiple: Padding granularity for sequence lengths (typically 64). - - Yields: - ValidPrompt: A named tuple containing program_id, shape, input_ids, - extra_kwargs, and sample_key. - """ - # select prompts that fit the batch size criteria - if custom_shape: - prompt_found = 0 - for program_criteria_seq, valid_prompt_shapes in program_map.items(): - for valid_prompt_shape in valid_prompt_shapes: - if valid_prompt_shape == custom_shape: - enforce_sizes = [valid_prompt_shape[1]] - input_ids, extra_kwargs, sample_key = _prepare_inputs( - batch_size=valid_prompt_shape[0], - seq_length=valid_prompt_shape[1], - tokenizer=tokenizer, - sampler=sampler, - dataset_path=dataset_path, - allow_truncation=allow_truncation, - enforce_sizes=enforce_sizes, - ) - prompt_found = 1 - yield ValidPrompt( - program_id=program_criteria_seq[0].program_id, - shape=custom_shape, - input_ids=input_ids, - extra_kwargs=extra_kwargs, - sample_key=sample_key, - ) - break - if prompt_found: - break - else: - for program_info in programs_to_test: - program_id = program_info.program_id - - filtered_program_map = program_map - if program_id.isnumeric(): - filtered_program_map = { - k: v - for k, v in program_map.items() - if k[0] == program_criteria_list[int(program_id)] - } - used_keys = set() - # for each program, we need to check if we have a shape that satisfies the --programs request - for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): - # if ? or numeric => we need to check if we have found at least one valid key to stop - if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: - break - # if * => we need to see if we have found the first key to see if we should skip - elif program_id == "*" and program_seq_key[0] in used_keys: - continue - - for valid_prompt_shape in valid_prompt_shapes: - # make sure the criteria for batch limit and prompt limit is satisfied - # eval is safe here because we have limited what type and limit can be before - - batch_check = eval( - f"valid_prompt_shape[0] {program_info.batch_size_limit_type} {program_info.batch_size_limit}" - ) - prompt_check = eval( - f"valid_prompt_shape[1] {program_info.prompt_length_limit_type} {program_info.prompt_length_limit}" - ) - if batch_check and prompt_check: - # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length - # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning - # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user - enforce_sizes = [valid_prompt_shape[1]] - if enforce_homogeneous_prompt_programs: - # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length - tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) - possible_seq_lengths = [ - _ - for _ in range( - tkv_cutoff, valid_prompt_shape[1], pad_multiple - ) - ] - # favor sequences that are close to the valid prompt length - possible_seq_lengths.reverse() - enforce_sizes = enforce_sizes + list( - itertools.islice( - itertools.cycle(possible_seq_lengths), - valid_prompt_shape[0] - 1, - ) - ) - try: - input_ids, extra_kwargs, sample_key = _prepare_inputs( - batch_size=valid_prompt_shape[0], - seq_length=valid_prompt_shape[1], - tokenizer=tokenizer, - sampler=sampler, - dataset_path=dataset_path, - allow_truncation=allow_truncation, - enforce_sizes=enforce_sizes, - ) - used_keys.add(program_seq_key[0]) - yield ValidPrompt( - program_id=program_seq_key[0], - shape=valid_prompt_shape, - input_ids=input_ids, - extra_kwargs=extra_kwargs, - sample_key=sample_key, - ) - break - except ValueError: - dprint( - f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" - ) - - if len(used_keys) == 0 and local_rank == 0: - dprint( - f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{program_info.batch_size_limit_type}{program_info.batch_size_limit} and prompt_length{program_info.prompt_length_limit_type}{program_info.prompt_length_limit}" - ) - - -def generate_cpu_validation( - model_variant: str, - max_new_tokens: int, - validation_info_outputs_dir: str, - save_validation_info_outputs: bool, - validation_model: torch.nn.Module, - valid_prompt, - input_ids: torch.Tensor, - extra_kwargs: Dict[str, Any], - sample_key: str, - attn_name: str, - cpu_dtype: str, - tokenizer: AutoTokenizer, -) -> ValidationInfo: - """Generates or loads CPU validation information for reference comparison. - - Attempts to load pre-computed CPU validation data from disk. If not found, - runs CPU inference to generate reference outputs (tokens and logits). - Optionally saves the validation info for future use. - - Args: - model_variant: Model identifier or path. - max_new_tokens: Maximum number of tokens to generate. - validation_info_outputs_dir: Directory for validation info outputs. - save_validation_info_outputs: Whether to save validation info to disk. - validation_model: CPU model for generating validation data. - valid_prompt: Tuple of (batch_size, seq_length) for the prompt shape. - input_ids: Tokenized input tensor. - extra_kwargs: Dictionary with attention mask and other model inputs. - sample_key: String identifier for the sampled prompts. - attn_name: Name of the attention algorithm used. - cpu_dtype: Data type string for CPU validation ("fp8" or "fp32"). - tokenizer: HuggingFace tokenizer for the model. - - Returns: - ValidationInfo: ValidationInfo object containing CPU reference outputs - (tokens and logits). - """ - # attempt to load the cpu validation info if it is already computed - cpu_validation_info = _load_validation_info( - model_variant=model_variant, - batch_size=valid_prompt[0], - seq_length=valid_prompt[1], - max_new_tokens=max_new_tokens, - tokenizer=tokenizer, - seed=0, - cpu_dtype=cpu_dtype, - attn_type=attn_name, - validation_info_outputs_dir=validation_info_outputs_dir, - sample_key=sample_key, - ) - if cpu_validation_info is None: - cpu_validation_info = extract_validation_information( - model=validation_model, - input_ids=input_ids, - max_new_tokens=max_new_tokens, - post_iteration_hook=LogitsExtractorHook(), - attn_algorithm="math", - **extra_kwargs, - ) - if save_validation_info_outputs: - cpu_validation_info.save( - get_validation_info_path( - validation_info_dir=validation_info_outputs_dir, - model_variant=model_variant, - batch_size=valid_prompt[0], - seq_length=valid_prompt[1], - max_new_tokens=max_new_tokens, - seed=0, - attn_type=attn_name, - dtype=cpu_dtype, - sample_key=sample_key, - ) - ) - - return cpu_validation_info - - -def generate_aiu_validation( - test_type: str, - max_new_tokens: int, - timing: str, - prefill_chunk_size: int, - model: torch.nn.Module, - input_ids: torch.Tensor, - cpu_validation_info: Optional[ValidationInfo], - extra_kwargs: Dict[str, Any], -) -> ValidationInfo: - """Generates AIU validation information by running inference on the compiled model. - - Executes the AIU-compiled model to generate tokens and extract logits. If CPU - validation info is available and test_type is "metrics", injects golden tokens - from CPU validation to ensure consistent decode paths for metric comparison. - - Args: - test_type: Type of test being run ("metrics" or "tokens"). - max_new_tokens: Maximum number of tokens to generate. - timing: Whether to collect timing information. - prefill_chunk_size: Chunk size for prefill operations. - model: Compiled AIU model for inference. - input_ids: Tokenized input tensor. - cpu_validation_info: Optional CPU validation data for golden token injection. - extra_kwargs: Dictionary with attention mask and other model inputs. - - Returns: - ValidationInfo: ValidationInfo object containing AIU outputs (tokens, logits, - and optional timing information). - """ - golden_hook = None - if test_type == "metrics" and cpu_validation_info: - golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens")) - - aiu_validation_info = extract_validation_information( - model=model, - input_ids=input_ids, - max_new_tokens=max_new_tokens, - post_iteration_hook=golden_hook, - last_n_tokens=64, - timing=timing, - prefill_chunk_size=prefill_chunk_size, - **extra_kwargs, - ) - - return aiu_validation_info - - -def evaluate_cross_entropy_metrics( - cross_entropy_threshold: float, - aiu_validation_info: ValidationInfo, - cpu_validation_info: ValidationInfo, - program_id: str, - prompt_shape: Tuple[int, int], - tokenizer: AutoTokenizer, -) -> float: - """Evaluates cross-entropy metrics between AIU and CPU outputs. - - Computes cross-entropy and mean difference metrics between AIU and CPU logits - for each generated token. Prints detailed comparison including token IDs and - decoded strings. Calculates failure rate based on cross-entropy threshold. - - Args: - cross_entropy_threshold: Maximum acceptable cross-entropy for a passing token. - aiu_validation_info: ValidationInfo from AIU inference. - cpu_validation_info: ValidationInfo from CPU reference. - program_id: ID of the program being tested. - prompt_shape: Tuple of (batch_size, seq_length). - tokenizer: HuggingFace tokenizer for decoding tokens. - - Returns: - float: Failure rate (number of failed tokens / total tokens). - """ - level_1_metrics = capture_level_1_metrics( - cpu_validation_info.get_info("logits"), - aiu_validation_info.get_info("logits"), - top_k_loss_calculator(20, _metric_calculator), - ) - - if local_rank == 0: - cpu_tokens = cpu_validation_info.get_info("tokens") - for sentence_idx, token_idx, metrics_value in level_1_metrics: - aiu_token = torch.argmax( - aiu_validation_info.get_info("logits")[sentence_idx][token_idx], dim=-1 - ) - cpu_token = cpu_tokens[sentence_idx][prompt_shape[1] + token_idx] - aiu_str = tokenizer.decode(aiu_token).replace( - "\n", "" - ) # remove newlines for readability - cpu_str = tokenizer.decode(cpu_token).replace( - "\n", "" - ) # remove newlines for readability - dprint( - f'For Program {program_id} in sentence {sentence_idx + 1}: the metric for token {token_idx} is {metrics_value}, AIU ID="{aiu_token.item()}" | STR="{aiu_str}" -- CPU ID="{cpu_token.item()}" | CPU STR="{cpu_str}"' - ) - - ce_fail_responses = filter_failed_level_1_cases( - level_1_metrics, lambda m: m[0] >= cross_entropy_threshold - ) - failure_rate = len(ce_fail_responses) / len(level_1_metrics) - - return failure_rate - - -def report_token_comparison( - max_new_tokens: int, - aiu_validation_info: ValidationInfo, - cpu_validation_info: ValidationInfo, - program_id: str, - tokenizer: AutoTokenizer, -) -> None: - """Reports side-by-side comparison of AIU and CPU generated token sequences. - - Prints detailed comparison of generated tokens between AIU and CPU models, - including the original prompt, token IDs, and decoded text. Only executes - on rank 0 in distributed settings. Used for qualitative analysis of model - outputs rather than quantitative metrics. - - Args: - max_new_tokens: Number of tokens generated after the prompt. - aiu_validation_info: ValidationInfo from AIU inference. - cpu_validation_info: ValidationInfo from CPU reference. - program_id: ID of the program being tested. - tokenizer: HuggingFace tokenizer for decoding tokens. - """ - if local_rank != 0: - return - - for sentence_idx, (reference_sentence, test_sentence) in enumerate( - zip( - cpu_validation_info.get_info("tokens"), - aiu_validation_info.get_info("tokens"), - ) - ): - tokens_prompt = [t.item() for t in reference_sentence[:-max_new_tokens]] - cpu_tokens_generated = [t.item() for t in reference_sentence[-max_new_tokens:]] - aiu_tokens_generated = [t.item() for t in test_sentence[-max_new_tokens:]] - tokens_prompt_without_pad = list( - dropwhile(lambda x: x == tokenizer.pad_token_id, tokens_prompt) - ) - prompt_length = len([token_id for token_id in tokens_prompt_without_pad]) - dprint(f"Prompt Length: {prompt_length}") - dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:") - dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt_without_pad)}") - dprint(f"CPU tokens:\n{cpu_tokens_generated}") - dprint(f"AIU tokens:\n{aiu_tokens_generated}") - dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}") - dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") - - -def setup_environment( - program_criteria_json_path: str, attention_type: str -) -> EnvConfig: - """Set up global process state and environment variables. - - Args: - program_criteria_json_path: Path to the JSON file containing program criteria definitions. - attention_type: Type of attention mechanism to use. Must be one of sdpa, paged, math_fp8, paged_fp8. - - Returns: - EnvConfig: Immutable configuration containing: - - attn_name: Mapped attention implementation name - - cpu_dtype: Data type for CPU operations ("fp8" or "fp32") - - max_batch_size: Maximum batch size from VLLM_DT_MAX_BATCH_SIZE - - max_tkv: Maximum token-key-value context length from VLLM_DT_MAX_CONTEXT_LEN - - Raises: - SystemExit: If required environment variables VLLM_DT_MAX_CONTEXT_LEN or - VLLM_DT_MAX_BATCH_SIZE are not set. - """ - os.environ["COMPILATION_MODE"] = "offline_decoder" - os.environ["DT_PROG_CRITERIA_FILEPATH"] = program_criteria_json_path - - if ( - "VLLM_DT_MAX_CONTEXT_LEN" not in os.environ - or "VLLM_DT_MAX_BATCH_SIZE" not in os.environ - ): - if local_rank == 0: - dprint("Missing required VLLM environment variables.") - exit(1) - - torch.manual_seed(42) - torch.set_grad_enabled(False) - fx_config.backed_size_oblivious = True +def main() -> None: + # Environment Setup + args = parse_cli_args() attention_map = { "sdpa": "sdpa_causal", @@ -1161,347 +163,36 @@ def setup_environment( "paged_fp8": "spyre_paged_attn_fp8", } - return EnvConfig( - attn_name=attention_map[attention_type], - cpu_dtype="fp8" if "fp8" in attention_type else "fp32", - max_batch_size=int(os.environ["VLLM_DT_MAX_BATCH_SIZE"]), - max_tkv=int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]), - ) - - -def prepare_test_prompts( - program_criteria_json_path: str, - programs: List[str], - max_new_tokens: int, - prioritize_large_batch_sizes: bool, - enforce_homogeneous_prompt_programs: bool, - max_batch_size: int, - max_tkv: int, - tkv_limit: int | None, - tokenizer: AutoTokenizer, - sampler: Any, - allow_truncation: bool, - custom_shape: Optional[Tuple[int, int]], - dataset_path: str, -): - """Parses program criteria and generates the sequence of valid test prompts. - - This function unrolls the necessary arguments to decouple the logic from - the argparse namespace. - """ - - with open(program_criteria_json_path, "r") as f: - program_criteria_json_list = json.load(f)["programs"] - program_criteria_list = [] - for i, d in enumerate(program_criteria_json_list): - program_criteria_list.append( - ProgramCriteria( - i, - d["max_batch"], - d["max_tkv"], - d["batch_granularity"], - d["tkv_granularity"], - ) - ) - - programs_to_test = get_programs_to_test(programs, program_criteria_list) - - # FIXME: filter condition for this on prompt and batch - program_map = get_programs_prompts( - program_criteria_list=program_criteria_list, - multiple=PAD_MULTIPLE, - max_batch_size=max_batch_size, - max_tkv=max_tkv, - program_cycles=max_new_tokens, - tkv_limit=tkv_limit, - prioritize_large_batch_sizes=prioritize_large_batch_sizes, - ) - for v in program_map.values(): - random.Random(42).shuffle(v) - - # Select concrete prompts and program associations - return get_valid_prompts( - program_map=program_map, - dataset_path=dataset_path, - enforce_homogeneous_prompt_programs=enforce_homogeneous_prompt_programs, - programs_to_test=programs_to_test, - program_criteria_list=program_criteria_list, - tokenizer=tokenizer, - sampler=sampler, - allow_truncation=allow_truncation, - custom_shape=custom_shape, - pad_multiple=PAD_MULTIPLE, - ) - - -def generate_validation_info_and_test( - valid_prompts: Iterable[ValidPrompt], - model: torch.nn.Module, - validation_model: Optional[torch.nn.Module], - tokenizer: AutoTokenizer, - env_config: EnvConfig, - model_config: DPPRunnerConfig, - test_type: str, - max_new_tokens: int, - skip_validation: bool, - save_validation_info_outputs: bool, - validation_info_outputs_dir: str, - cross_entropy_threshold: float, - failure_rate_threshold: float, - timing: str, - prefill_chunk_size: int, - model_variant: str, -) -> list[Any]: - """Generates tokens using AIU and CPU models and validates the results. - - This function iterates through prepared prompts, executes the generation - cycle for both hardware targets, and evaluates whether the AIU outputs - match the golden reference. - """ - - failed_cases = [] - # for each program and valid prompt (batch size, sequence length) - for valid_prompt in valid_prompts: - valid_prompt.extra_kwargs["attn_name"] = env_config.attn_name - valid_prompt.extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks - - if local_rank == 0: - dprint(f"*** testing program {valid_prompt.program_id} ***") - dprint( - f"program id: {valid_prompt.program_id}, valid prompt: {valid_prompt.shape}, input shape: {valid_prompt.input_ids.shape}" - ) - - if not skip_validation: - # Generate or load CPU validation info - cpu_validation_info = generate_cpu_validation( - model_variant=model_variant, - max_new_tokens=max_new_tokens, - validation_info_outputs_dir=validation_info_outputs_dir, - save_validation_info_outputs=save_validation_info_outputs, - validation_model=validation_model, - valid_prompt=valid_prompt.shape, - input_ids=valid_prompt.input_ids, - extra_kwargs=valid_prompt.extra_kwargs, - sample_key=valid_prompt.sample_key, - attn_name=env_config.attn_name, - cpu_dtype=env_config.cpu_dtype, - tokenizer=tokenizer, - ) - - aiu_validation_info = generate_aiu_validation( - test_type=test_type, - max_new_tokens=max_new_tokens, - timing=timing, - prefill_chunk_size=prefill_chunk_size, - model=model, - input_ids=valid_prompt.input_ids, - cpu_validation_info=cpu_validation_info, - extra_kwargs=valid_prompt.extra_kwargs, - ) - - if test_type == "metrics": - failure_rate = evaluate_cross_entropy_metrics( - cross_entropy_threshold=cross_entropy_threshold, - aiu_validation_info=aiu_validation_info, - cpu_validation_info=cpu_validation_info, - program_id=valid_prompt.program_id, - prompt_shape=valid_prompt.shape, - tokenizer=tokenizer, - ) - if failure_rate > failure_rate_threshold: - failed_cases.append( - (valid_prompt.program_id, valid_prompt.shape, failure_rate) - ) - - elif test_type == "tokens": - report_token_comparison( - max_new_tokens=max_new_tokens, - aiu_validation_info=aiu_validation_info, - cpu_validation_info=cpu_validation_info, - program_id=valid_prompt.program_id, - tokenizer=tokenizer, - ) - - else: - raise ValueError("test type must be one of metrics or tokens") - else: - aiu_validation_info = generate_aiu_validation( - test_type=test_type, - max_new_tokens=max_new_tokens, - timing=timing, - prefill_chunk_size=prefill_chunk_size, - model=model, - input_ids=valid_prompt.input_ids, - cpu_validation_info=None, - extra_kwargs=valid_prompt.extra_kwargs, - ) - - if local_rank == 0: - for sentence_idx, test_sentence in enumerate( - aiu_validation_info.get_info("tokens") - ): - tokens_prompt = [t.item() for t in test_sentence[:-max_new_tokens]] - aiu_tokens_generated = [ - t.item() for t in test_sentence[-max_new_tokens:] - ] - dprint( - f"For Program {valid_prompt.program_id} in sentence {sentence_idx + 1}:" - ) - dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}") - dprint(f"AIU tokens:\n{aiu_tokens_generated}") - dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") - - return failed_cases - - -def main() -> None: - """Main execution function for driving paged program validation tests. - - Workflow: - 1. Sets and configures environment. - 2. Loads models (both AIU-compiled and CPU validation). - 3. Warms up the model. - 4. Selects programs and prompts to test based on criteria. - 5. For each program/prompt combination: - - Generates CPU validation data (or loads from cache). - - Runs AIU inference. - - Compares outputs using metrics or token-based validation. - 6. Prints results and failure cases. - - Raises: - Various exceptions may be raised during: - - Model loading (e.g., OOM, invalid model variant) - - Distributed initialization (e.g., timeout, network issues) - - File I/O (e.g., missing program criteria JSON) - - Validation (e.g., shape mismatches) - """ - - # Environment Setup - args = parse_cli_args() - is_fp8: bool = "fp8" in args.attention_type - if args.skip_validation and args.test_type == "metrics": - dprint("When skipping validation, only test_type will be ignored") - env_config: EnvConfig = setup_environment( - program_criteria_json_path=args.program_criteria_json_path, - attention_type=args.attention_type, - ) - tokenizer = AutoTokenizer.from_pretrained(args.model_variant) - sampler, allow_truncation, custom_shape = get_sampler( - dataset_type=args.dataset_type, - dataset_path=args.dataset_path, - tokenizer=tokenizer, - ) - - # Model Loading - model_kwargs: Dict[str, Any] = _get_model_kwargs(model_variant=args.model_variant) - distributed_kwargs: Dict[str, Any] = _get_distributed_kwargs( - is_distributed=args.distributed, dist_timeout=args.dist_timeout - ) - args.save_validation_info_outputs = args.save_validation_info_outputs and ( - dist.get_rank() == 0 - ) - model_config: DPPRunnerConfig = DPPRunnerConfig() - world_size = ( - dist.get_world_size() if args.distributed and dist.is_initialized() else 1 - ) - model_config.setup_config( - model_variant=args.model_variant, - use_distributed=args.distributed, - world_size=world_size, - prefill_chunk_size=args.prefill_chunk_size, - ) - model = load_model( - device_type="spyre", - is_fp8=is_fp8, - model_kwargs=model_kwargs, - distributed_kwargs=distributed_kwargs, - stagger_load=args.stagger_load, - model_config=model_config, - ) - validation_model = None - if not args.skip_validation: - validation_model = load_model( - device_type="cpu", - is_fp8=is_fp8, - model_kwargs=model_kwargs, - distributed_kwargs=distributed_kwargs, - stagger_load=args.stagger_load, - model_config=model_config, + try: + attention_type = attention_map[args.attention_type] + except KeyError: + raise ValueError( + f"Invalid attention type: {args.attention_type}. Expected one of {list(attention_map.keys())}." ) - # Model Warmup - ## warmup with any input so compiler produces criteria json - ## TODO: Swap this with _prepare_inputs once fix for shape_id is available - ## input_ids, extra_kwargs, sample_key = _prepare_inputs(2, max_tkv, tokenizer) - prompt_list = [torch.arange(0, PAD_MULTIPLE, dtype=torch.int64)] - # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 - if is_fp8: - prompt_list = prompt_list * 2 - input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) - extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - extra_kwargs["attn_name"] = env_config.attn_name - extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks - warmup_model( - model=model, - input_ids=input_ids, - max_new_tokens=args.max_new_tokens, - compile_dynamic_sendnn=True, - stagger_update_lazyhandle=args.stagger_update_lazyhandle, - prefill_chunk_size=args.prefill_chunk_size, - **extra_kwargs, - ) - if args.distributed: - # wait for rank0 to be finished as it is the only one generating the criteria json - # this is needed since otherwise we may run into a race condition - torch.distributed.barrier() - - # Prompt Preparation - valid_prompts = prepare_test_prompts( + run_dpp( program_criteria_json_path=args.program_criteria_json_path, - programs=args.programs, + dataset_type=DatasetType(args.dataset_type), max_new_tokens=args.max_new_tokens, - prioritize_large_batch_sizes=args.prioritize_large_batch_sizes, - enforce_homogeneous_prompt_programs=args.enforce_homogeneous_prompt_programs, - max_batch_size=env_config.max_batch_size, - max_tkv=env_config.max_tkv, - tkv_limit=model_config.tkv_limit, - tokenizer=tokenizer, - sampler=sampler, - allow_truncation=allow_truncation, - custom_shape=custom_shape, + model_variant=args.model_variant, dataset_path=args.dataset_path, - ) - - # Validation and Testing - failed_cases = generate_validation_info_and_test( - valid_prompts=valid_prompts, - model=model, - validation_model=validation_model, - tokenizer=tokenizer, - env_config=env_config, - model_config=model_config, - test_type=args.test_type, - max_new_tokens=args.max_new_tokens, - skip_validation=args.skip_validation, - save_validation_info_outputs=args.save_validation_info_outputs, - validation_info_outputs_dir=args.validation_info_outputs_dir, + programs=args.programs, + timing=Timing(args.timing), + test_type=TestType(args.test_type), cross_entropy_threshold=args.cross_entropy_threshold, failure_rate_threshold=args.failure_rate_threshold, - timing=args.timing, + attention_type=AttnType(attention_type), prefill_chunk_size=args.prefill_chunk_size, - model_variant=args.model_variant, + stagger_load=args.stagger_load, + stagger_update_lazyhandle=args.stagger_update_lazyhandle, + dist_timeout=args.dist_timeout, + run_cpu_validation=not args.skip_validation, + validation_info_outputs_dir=args.validation_info_outputs_dir, + save_validation_info_outputs=args.save_validation_info_outputs, + prioritize_large_batch_sizes=args.prioritize_large_batch_sizes, + enforce_homogeneous_prompt_programs=args.enforce_homogeneous_prompt_programs, ) - if not args.skip_validation and local_rank == 0: - if len(failed_cases) != 0: - dprint("The test failed with the following cases:") - for failed_case in failed_cases: - dprint( - f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}" - ) - else: - dprint("all tests passed") - if __name__ == "__main__": main() diff --git a/aiu_fms_testing_utils/scripts/generate_metrics.py b/aiu_fms_testing_utils/scripts/generate_metrics.py index f65149fa..df11aae6 100644 --- a/aiu_fms_testing_utils/scripts/generate_metrics.py +++ b/aiu_fms_testing_utils/scripts/generate_metrics.py @@ -19,6 +19,7 @@ from fms.models import get_model from fms.utils.generation import pad_input_ids from transformers import AutoTokenizer +from aiu_fms_testing_utils.testing.dpp.program_models import AttnType parser = argparse.ArgumentParser( description="Script to determine a reasonable logits loss threshold when testing with aiu" @@ -244,7 +245,7 @@ def write_csv(metrics, path, metric_name): ids, args.max_new_tokens, LogitsExtractorHook(), - attn_algorithm="math", + attn_algorithm=AttnType.MATH, **padding_kwargs, ) cpu_static_tokens = cpu_validation_info.get_info("tokens") @@ -324,7 +325,7 @@ def write_csv(metrics, path, metric_name): ids, args.max_new_tokens, LogitsExtractorHook(), - attn_algorithm="math", + attn_algorithm=AttnType.MATH, **padding_kwargs, ) diff --git a/aiu_fms_testing_utils/scripts/save_cpu_data.py b/aiu_fms_testing_utils/scripts/save_cpu_data.py index 0df0ecba..15420934 100644 --- a/aiu_fms_testing_utils/scripts/save_cpu_data.py +++ b/aiu_fms_testing_utils/scripts/save_cpu_data.py @@ -1,4 +1,5 @@ import json +from aiu_fms_testing_utils.testing.dpp.program_models import AttnType from aiu_fms_testing_utils.testing.validation import ( LogitsExtractorHook, extract_validation_information, @@ -96,7 +97,7 @@ def process_row(row): torch.tensor(input_ids).unsqueeze(0), max_new_tokens, LogitsExtractorHook(), - attn_algorithm="math", + attn_algorithm=AttnType.MATH, ) return {"id": id, "input_ids": input_ids, "validation": cpu_validation_info} diff --git a/aiu_fms_testing_utils/scripts/validation.py b/aiu_fms_testing_utils/scripts/validation.py index ce55953d..c55155ce 100644 --- a/aiu_fms_testing_utils/scripts/validation.py +++ b/aiu_fms_testing_utils/scripts/validation.py @@ -14,6 +14,7 @@ from fms.utils import generation from fms.utils.generation import pad_input_ids from torch import distributed as dist +from aiu_fms_testing_utils.testing.dpp.program_models import AttnType from aiu_fms_testing_utils.utils import warmup_model from aiu_fms_testing_utils.testing.validation import ( LogitsExtractorHook, @@ -690,7 +691,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): ids.to(validation_device), args.max_new_tokens, LogitsExtractorHook(), - attn_algorithm="math", + attn_algorithm=AttnType.MATH, **padding_kwargs, ) diff --git a/aiu_fms_testing_utils/testing/dpp/generation.py b/aiu_fms_testing_utils/testing/dpp/generation.py new file mode 100644 index 00000000..3eb9de36 --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/generation.py @@ -0,0 +1,322 @@ +from typing import Any, Iterable, Optional + +import torch +from transformers import AutoTokenizer + +from aiu_fms_testing_utils.testing.dpp.metrics_validation import ( + evaluate_cross_entropy_metrics, + evaluate_token_accuracy, +) +from aiu_fms_testing_utils.testing.dpp.program_models import ( + AttnType, + EnvConfig, + TestType, + ValidPrompt, +) +from aiu_fms_testing_utils.testing.validation import ( + GoldenTokenHook, + LogitsExtractorHook, + ValidationInfo, + extract_validation_information, + find_validation_info_path, + get_validation_info_path, + load_validation_information, +) +from aiu_fms_testing_utils.utils.aiu_setup import dprint, local_rank, r0dprint +from aiu_fms_testing_utils.utils.dpp_config import DPPRunnerConfig +from aiu_fms_testing_utils.utils.model_setup import Timing + + +def _generate_aiu_validation( + test_type: TestType, + max_new_tokens: int, + timing: Timing, + prefill_chunk_size: int, + model: torch.nn.Module, + valid_prompt: ValidPrompt, + cpu_validation_info: Optional[ValidationInfo] = None, +) -> ValidationInfo: + """Generates AIU validation information by running inference on the compiled model. + + Executes the AIU-compiled model to generate tokens and extract logits. If CPU + validation info is available and test_type is "metrics", injects golden tokens + from CPU validation to ensure consistent decode paths for metric comparison. + + Args: + test_type: Type of test being run. + max_new_tokens: Maximum number of tokens to generate. + timing: Whether to collect timing information. + prefill_chunk_size: Chunk size for prefill operations. + model: Compiled AIU model for inference. + valid_prompt: ValidPrompt object containing prompt input IDs and extra kwargs for model execution. + cpu_validation_info: Optional CPU validation data for golden token injection. + + Returns: + ValidationInfo: ValidationInfo object containing AIU outputs (tokens, logits, + and optional timing information).""" + + golden_hook = None + if test_type == TestType.METRICS and cpu_validation_info is not None: + golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens")) + + aiu_validation_info = extract_validation_information( + model=model, + input_ids=valid_prompt.input_ids, + max_new_tokens=max_new_tokens, + post_iteration_hook=golden_hook, + last_n_tokens=64, + timing=timing, + prefill_chunk_size=prefill_chunk_size, + **valid_prompt.extra_kwargs, + ) + + return aiu_validation_info + + +def _generate_cpu_validation( + model_variant: str, + max_new_tokens: int, + validation_info_outputs_dir: str, + save_validation_info_outputs: bool, + validation_model: torch.nn.Module, + valid_prompt: ValidPrompt, + env_config: EnvConfig, + tokenizer: AutoTokenizer, +) -> ValidationInfo: + """Generates or loads CPU validation information for reference comparison. + + Attempts to load pre-computed CPU validation data from disk. If not found, + runs CPU inference to generate reference outputs (tokens and logits). + Optionally saves the validation info for future use. + + Args: + model_variant: Model identifier or path. + max_new_tokens: Maximum number of tokens to generate. + validation_info_outputs_dir: Directory for validation info outputs. + save_validation_info_outputs: Whether to save validation info to disk. + validation_model: CPU model for generating validation data. + valid_prompt: ValidPrompt object containing prompt input IDs and extra kwargs for model execution. + env_config: Environment configuration with attention settings and CPU dtype. + tokenizer: HuggingFace tokenizer for the model. + + Returns: + ValidationInfo: ValidationInfo object containing CPU reference outputs (tokens and logits).""" + + # attempt to load the cpu validation info if it is already computed + cpu_validation_path = find_validation_info_path( + validation_info_dir=validation_info_outputs_dir, + model_variant=model_variant, + batch_size=valid_prompt.shape[0], + seq_length=valid_prompt.shape[1], + max_new_tokens=max_new_tokens, + seed=0, + attn_type=env_config.attn_type, + version_allow_decrement=True, + dtype=env_config.cpu_dtype, + sample_key=valid_prompt.sample_key, + ) + + if cpu_validation_path is not None: + # Skip CPU generation if validation info is already available + dprint( + f"Loaded CPU validation info for program {model_variant} with prompt shape {valid_prompt.shape} and sample key {valid_prompt.sample_key}" + ) + return load_validation_information( + cpu_validation_path, "logits", valid_prompt.shape[0], tokenizer + ) + + dprint( + f"No pre-computed CPU validation info found for program {model_variant} with prompt shape {valid_prompt.shape} and sample key {valid_prompt.sample_key}" + ) + + cpu_validation_info = extract_validation_information( + model=validation_model, + input_ids=valid_prompt.input_ids, + max_new_tokens=max_new_tokens, + post_iteration_hook=LogitsExtractorHook(), + attn_algorithm=AttnType.MATH, + **valid_prompt.extra_kwargs, + ) + + if save_validation_info_outputs: + validation_info_path = get_validation_info_path( + validation_info_dir=validation_info_outputs_dir, + model_variant=model_variant, + batch_size=valid_prompt.shape[0], + seq_length=valid_prompt.shape[1], + max_new_tokens=max_new_tokens, + seed=0, + attn_type=env_config.attn_type, + dtype=env_config.cpu_dtype, + sample_key=valid_prompt.sample_key, + ) + cpu_validation_info.save(validation_info_path) + + return cpu_validation_info + + +def generate_aiu_cpu_test( + valid_prompts: Iterable[ValidPrompt], + model: torch.nn.Module, + validation_model: torch.nn.Module, + tokenizer: AutoTokenizer, + env_config: EnvConfig, + model_config: DPPRunnerConfig, + test_type: TestType, + max_new_tokens: int, + save_validation_info_outputs: bool, + validation_info_outputs_dir: str, + cross_entropy_threshold: float, + failure_rate_threshold: float, + timing: Timing, + prefill_chunk_size: int, + model_variant: str, +) -> list[Any]: + """Generates tokens using AIU and CPU models and validates the results. + + This function iterates through prepared prompts, executes the generation + cycle for both hardware targets, and evaluates whether the AIU outputs + match the golden reference. + + Args: + valid_prompts: Iterable of ValidPrompt objects containing input prompts and metadata. + model: Compiled AIU model for inference. + validation_model: CPU model for generating reference validation data. + tokenizer: HuggingFace tokenizer for decoding token outputs. + env_config: Environment configuration with attention settings. + model_config: Model configuration with architecture details. + test_type: Type of test being run. + max_new_tokens: Maximum number of tokens to generate. + save_validation_info_outputs: Whether to save CPU validation info to disk. + validation_info_outputs_dir: Directory for saving/loading CPU validation info. + cross_entropy_threshold: Threshold for cross-entropy difference to consider a token generation as failed. + failure_rate_threshold: Threshold for the failure rate to consider the test case as failed. + timing: Whether to collect timing information. + prefill_chunk_size: Chunk size for prefill operations. + model_variant: Model identifier or path for naming validation info files. + + Returns: + List of failed cases with program ID, prompt shape, and failure rate.""" + + failed_cases = [] + # for each program and valid prompt (batch size, sequence length) + for valid_prompt in valid_prompts: + valid_prompt.extra_kwargs["attn_name"] = env_config.attn_type + valid_prompt.extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks + + r0dprint(f"*** testing program {valid_prompt.program_id} ***") + r0dprint( + f"program id: {valid_prompt.program_id}, valid prompt: {valid_prompt.shape}, input shape: {valid_prompt.input_ids.shape}" + ) + + # Generate or load CPU validation info + cpu_validation_info = _generate_cpu_validation( + model_variant, + max_new_tokens, + validation_info_outputs_dir, + save_validation_info_outputs, + validation_model, + valid_prompt.shape, + valid_prompt.input_ids, + valid_prompt.extra_kwargs, + valid_prompt.sample_key, + env_config.attn_type, + env_config.cpu_dtype, + tokenizer, + ) + + aiu_validation_info = _generate_aiu_validation( + test_type, + max_new_tokens, + timing, + prefill_chunk_size, + model, + valid_prompt, + cpu_validation_info=cpu_validation_info, + ) + + if test_type == TestType.METRICS: + failure_rate = evaluate_cross_entropy_metrics( + cross_entropy_threshold, + aiu_validation_info, + cpu_validation_info, + valid_prompt.program_id, + valid_prompt.shape, + tokenizer, + ) + if failure_rate > failure_rate_threshold: + failed_cases.append( + (valid_prompt.program_id, valid_prompt.shape, failure_rate) + ) + + elif test_type == TestType.TOKENS: + failure_rate = evaluate_token_accuracy( + max_new_tokens, + aiu_validation_info, + cpu_validation_info, + valid_prompt.program_id, + tokenizer, + ) + if failure_rate > failure_rate_threshold: + failed_cases.append( + (valid_prompt.program_id, valid_prompt.shape, failure_rate) + ) + + return failed_cases + + +def generate_aiu_test( + valid_prompts: Iterable[ValidPrompt], + model: torch.nn.Module, + tokenizer: AutoTokenizer, + env_config: EnvConfig, + model_config: DPPRunnerConfig, + test_type: TestType, + max_new_tokens: int, + timing: Timing, + prefill_chunk_size: int, +) -> None: + """Generates tokens using the AIU model and prints the outputs. + + Args: + valid_prompts: Iterable of ValidPrompt objects containing input prompts and metadata. + model: Compiled AIU model for inference. + tokenizer: HuggingFace tokenizer for decoding token outputs. + env_config: Environment configuration with attention settings. + model_config: Model configuration with architecture details. + test_type: Type of test being run. + max_new_tokens: Maximum number of tokens to generate. + timing: Whether to collect timing information. + prefill_chunk_size: Chunk size for prefill operations.""" + + # for each program and valid prompt (batch size, sequence length) + for valid_prompt in valid_prompts: + valid_prompt.extra_kwargs["attn_name"] = env_config.attn_type.value + valid_prompt.extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks + + r0dprint(f"*** testing program {valid_prompt.program_id} ***") + r0dprint( + f"program id: {valid_prompt.program_id}, valid prompt: {valid_prompt.shape}, input shape: {valid_prompt.input_ids.shape}" + ) + + aiu_tokens = _generate_aiu_validation( + test_type, + max_new_tokens, + timing, + prefill_chunk_size, + model, + valid_prompt, + ).get_info("tokens") + + if local_rank != 0: + return + + for sentence_idx, test_sentence in enumerate(aiu_tokens): + tokens_prompt = [t.item() for t in test_sentence[:-max_new_tokens]] + aiu_tokens_generated = [t.item() for t in test_sentence[-max_new_tokens:]] + dprint( + f"For Program {valid_prompt.program_id} in sentence {sentence_idx + 1}:" + ) + dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}") + dprint(f"AIU tokens:\n{aiu_tokens_generated}") + dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") diff --git a/aiu_fms_testing_utils/testing/dpp/metrics_validation.py b/aiu_fms_testing_utils/testing/dpp/metrics_validation.py new file mode 100644 index 00000000..1c68e37e --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/metrics_validation.py @@ -0,0 +1,162 @@ +from typing import Tuple + +import torch +from transformers import AutoTokenizer + +from aiu_fms_testing_utils.testing.dpp.program_models import MetricResult +from aiu_fms_testing_utils.testing.validation import ( + ValidationInfo, + capture_level_1_metrics, + filter_failed_level_1_cases, + top_k_loss_calculator, +) +from aiu_fms_testing_utils.utils.aiu_setup import local_rank, r0dprint + + +def _metric_calculator(reference_tensor: torch.Tensor, test_tensor: torch.Tensor): + """Calculates cross-entropy and mean absolute difference between logit distributions. + + Args: + reference_tensor: Reference logits tensor from CPU validation. + test_tensor: Test logits tensor from AIU inference. + + Returns: + MetricResult: A named tuple containing the calculated metrics.""" + + cross_entropy_loss = torch.nn.CrossEntropyLoss()( + reference_tensor, test_tensor.softmax(dim=1).to(dtype=torch.float32) + ) + mean_abs_diff = torch.mean( + torch.abs( + reference_tensor.softmax(dim=1).to(dtype=torch.float32) + - test_tensor.softmax(dim=1).to(dtype=torch.float32) + ) + ) + return MetricResult( + cross_entropy_loss=cross_entropy_loss.item(), mean_abs_diff=mean_abs_diff.item() + ) + + +def evaluate_cross_entropy_metrics( + cross_entropy_threshold: float, + aiu_validation_info: ValidationInfo, + cpu_validation_info: ValidationInfo, + program_id: str, + prompt_shape: Tuple[int, int], + tokenizer: AutoTokenizer, +) -> float: + """Evaluates cross-entropy metrics between AIU and CPU outputs. + + Computes cross-entropy and mean difference metrics between AIU and CPU logits + for each generated token. Prints detailed comparison including token IDs and + decoded strings. Calculates failure rate based on cross-entropy threshold. + + Args: + cross_entropy_threshold: Maximum acceptable cross-entropy for a passing token. + aiu_validation_info: ValidationInfo from AIU inference. + cpu_validation_info: ValidationInfo from CPU reference. + program_id: ID of the program being tested. + prompt_shape: Tuple of (batch_size, seq_length). + tokenizer: HuggingFace tokenizer for decoding tokens. + + Returns: + float: Failure rate (number of failed tokens / total tokens).""" + + level_1_metrics = capture_level_1_metrics( + cpu_validation_info.get_info("logits"), + aiu_validation_info.get_info("logits"), + top_k_loss_calculator(20, _metric_calculator), + ) + + if local_rank == 0: + cpu_tokens = cpu_validation_info.get_info("tokens") + for sentence_idx, token_idx, metrics_value in level_1_metrics: + aiu_token = torch.argmax( + aiu_validation_info.get_info("logits")[sentence_idx][token_idx], dim=-1 + ) + cpu_token = cpu_tokens[sentence_idx][prompt_shape[1] + token_idx] + aiu_str = tokenizer.decode(aiu_token).replace( + "\n", "" + ) # remove newlines for readability + cpu_str = tokenizer.decode(cpu_token).replace( + "\n", "" + ) # remove newlines for readability + r0dprint( + f'For Program {program_id} in sentence {sentence_idx + 1}: the metric for token {token_idx} is {metrics_value}, AIU ID="{aiu_token.item()}" | STR="{aiu_str}" -- CPU ID="{cpu_token.item()}" | CPU STR="{cpu_str}"' + ) + + ce_fail_responses = filter_failed_level_1_cases( + level_1_metrics, lambda m: m[0] >= cross_entropy_threshold + ) + failure_rate = len(ce_fail_responses) / len(level_1_metrics) + + return failure_rate + + +def evaluate_token_accuracy( + max_new_tokens: int, + aiu_validation_info: ValidationInfo, + cpu_validation_info: ValidationInfo, + program_id: str, + tokenizer: AutoTokenizer, +) -> float: + """Reports side-by-side comparison of AIU and CPU generated token sequences. + + Prints detailed comparison of generated tokens between AIU and CPU models, + including the original prompt, token IDs, and decoded text. Only logs results + on rank 0 in distributed settings. Used for qualitative analysis of model + outputs rather than quantitative metrics. + + Args: + max_new_tokens: Number of tokens generated after the prompt. + aiu_validation_info: ValidationInfo from AIU inference. + cpu_validation_info: ValidationInfo from CPU reference. + program_id: ID of the program being tested. + tokenizer: HuggingFace tokenizer for decoding tokens. + + Returns: + float: Failure rate (percentage of mismatched tokens).""" + + total_mismatches = 0 + total_tokens = 0 + for sentence_idx, (reference_sentence, test_sentence) in enumerate( + zip( + cpu_validation_info.get_info("tokens"), + aiu_validation_info.get_info("tokens"), + ) + ): + tokens_prompt = reference_sentence[:-max_new_tokens] + cpu_tokens_generated = reference_sentence[-max_new_tokens:] + aiu_tokens_generated = test_sentence[-max_new_tokens:] + + # Calculate token mismatch failure rate + num_mismatches = (cpu_tokens_generated != aiu_tokens_generated).sum().item() + prompt_tokens = cpu_tokens_generated.size(0) + total_mismatches += num_mismatches + total_tokens += prompt_tokens + + if local_rank != 0: + continue # Only print for rank 0 to avoid clutter in distributed settings + + # Remove leading padding tokens using torch operations + pad_mask = tokens_prompt != tokenizer.pad_token_id + first_non_pad = pad_mask.nonzero(as_tuple=True)[0] + if len(first_non_pad) > 0: + tokens_prompt_without_pad = tokens_prompt[first_non_pad[0] :] + else: + tokens_prompt_without_pad = tokens_prompt + prompt_length = tokens_prompt_without_pad.size(0) + + r0dprint(f"Prompt Length: {prompt_length}") + r0dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:") + r0dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt_without_pad)}") + r0dprint(f"CPU tokens:\n{cpu_tokens_generated}") + r0dprint(f"AIU tokens:\n{aiu_tokens_generated}") + r0dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}") + r0dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") + + failure_rate = total_mismatches / total_tokens if total_tokens > 0 else 0.0 + r0dprint( + f"Token Failure Rate: {failure_rate:.2%} ({num_mismatches}/{prompt_tokens} mismatches)" + ) + return failure_rate diff --git a/aiu_fms_testing_utils/testing/dpp/prepare_data.py b/aiu_fms_testing_utils/testing/dpp/prepare_data.py new file mode 100644 index 00000000..423fb90f --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/prepare_data.py @@ -0,0 +1,464 @@ +import itertools +import json +import os +import random +import time +from typing import Callable, Generator, List, Optional, Tuple + +import torch +from fms.utils.generation import pad_input_ids +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer + +from aiu_fms_testing_utils.testing.dpp.prepare_programs import get_programs_to_test +from aiu_fms_testing_utils.testing.dpp.program_models import ( + DatasetType, + PreparedInputs, + ProgramInfo, + ValidPrompt, +) +from aiu_fms_testing_utils.utils.aiu_setup import dprint, r0dprint, local_rank +from aiu_fms_testing_utils.utils.paged import ProgramCriteria, get_programs_prompts + +PAD_MULTIPLE = 64 + +SHARE_GPT_DATASET = ( + "anon8231489123/ShareGPT_Vicuna_unfiltered", + "ShareGPT_V3_unfiltered_cleaned_split.json", +) +RAG_FACTOID_DATASET = ("", "") + + +def _prepare_inputs( + batch_size: int, + seq_length: int, + tokenizer: AutoTokenizer, + sampler: Callable[..., tuple[list[tuple[str, int]], str]], + dataset_path: str, + allow_truncation: bool, + enforce_sizes: List[int] = [], + seed: int = 0, +) -> PreparedInputs: + """Prepares and tokenizes input prompts for model inference. + + Samples prompts from a dataset using the provided sampler, tokenizes them, + and pads them to the specified sequence length. Handles cases where fewer + prompts are available than requested by repeating the first prompt. + + Args: + batch_size: Number of prompts to sample for the batch. + seq_length: Target sequence length for padding. + tokenizer: HuggingFace tokenizer for encoding prompts. + sampler: Callable that samples prompts from the dataset. + dataset_path: Path to the dataset file. + allow_truncation: If True, allows truncating prompts longer than seq_length. + enforce_sizes: List of specific sequence lengths to enforce for sampling. + seed: Random seed for reproducible sampling. + + Returns: + Tuple containing: + - input_ids: Padded tensor of tokenized input IDs with shape (batch_size, seq_length). + - extra_kwargs: Dictionary with additional model inputs including attention mask. + - sample_key: String identifier for the sampled prompts. + + Raises: + ValueError: If no valid prompts exist in the dataset for the requested shape.""" + + start = time.time() + prompts_and_sizes, sample_key = sampler( + dataset_path, + batch_size, + tokenizer, + 32, + seq_length * 2 if allow_truncation else seq_length, + seed, + enforce_sizes=enforce_sizes, + truncation=allow_truncation, + return_key=True, + ) + end = time.time() + + r0dprint( + f"Extracted {len(prompts_and_sizes)} prompts in {(end - start):.4f} seconds" + ) + + prompt_list = [] + for prompt, size in prompts_and_sizes: + encoded = tokenizer.encode(prompt, return_tensors="pt").squeeze(0) + if size > seq_length: + assert allow_truncation + encoded = encoded[:seq_length] + prompt_list.append(encoded) + + if not prompt_list: + raise ValueError( + f"No valid prompt sample exists in dataset for input shape (Batch Size={batch_size}, Seq Length={seq_length})" + ) + if len(prompt_list) < batch_size: + dprint( + f"You requested {batch_size} prompts but we were only able to get {len(prompt_list)} valid prompts. We will be repeating the first prompt." + ) + prompt_list = [prompt_list[0]] * (batch_size - len(prompt_list)) + prompt_list + + input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) + extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + + return PreparedInputs( + input_ids=input_ids, extra_kwargs=extra_kwargs, sample_key=sample_key + ) + + +def _get_valid_prompts_by_custom_shape( + program_map: dict, + custom_shape: Tuple[int, int], + tokenizer: AutoTokenizer, + sampler: Callable[..., tuple[list[tuple[str, int]], str]], + dataset_path: str, + allow_truncation: bool, +) -> Generator[ValidPrompt, None, None]: + """Selects prompts matching a custom shape for user-provided datasets. + + Args: + program_map: Dictionary mapping program sequences to valid prompt shapes. + custom_shape: Tuple of (batch_size, seq_length) specified by the user for custom datasets. + tokenizer: HuggingFace tokenizer for encoding prompts. + sampler: Callable for sampling prompts from the dataset. + dataset_path: Path to the dataset for sampling prompts. + allow_truncation: If True, allows truncating prompts that exceed max sequence length. + Yields: + ValidPrompt: A named tuple for prompts matching the custom shape.""" + + found_valid_shape = False + for program_criteria_seq, valid_prompt_shapes in program_map.items(): + for valid_prompt_shape in valid_prompt_shapes: + if valid_prompt_shape != custom_shape: + continue + + input_ids, extra_kwargs, sample_key = _prepare_inputs( + valid_prompt_shape[0], + valid_prompt_shape[1], + tokenizer, + sampler, + dataset_path, + allow_truncation, + enforce_sizes=[valid_prompt_shape[1]], + ) + yield ValidPrompt( + program_id=program_criteria_seq[0].program_id, + shape=custom_shape, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + sample_key=sample_key, + ) + found_valid_shape = True + break + + if found_valid_shape: + break + + if not found_valid_shape: + r0dprint( + f"No valid prompt shape was found which would result in program {program_criteria_seq[0].program_id} that satisfied the custom shape {custom_shape}" + ) + + +def _get_valid_prompts_by_shape( + program_map: dict, + program_info: ProgramInfo, + tokenizer: AutoTokenizer, + sampler: Callable[..., tuple[list[tuple[str, int]], str]], + dataset_path: str, + allow_truncation: bool, + pad_multiple: int, + enforce_homogeneous_prompt_programs: bool, +) -> Generator[ValidPrompt, None, None]: + """Selects valid prompts matching program criteria and constraints. + + Args: + program_map: Dictionary mapping program sequences to valid prompt shapes. + program_info: ProgramInfo object specifying the program and its constraints. + tokenizer: HuggingFace tokenizer for encoding prompts. + sampler: Callable for sampling prompts from the dataset. + dataset_path: Path to the dataset for sampling prompts. + allow_truncation: If True, allows truncating prompts that exceed max sequence length. + pad_multiple: Padding granularity for sequence lengths (typically 64). + enforce_homogeneous_prompt_programs: If True, ensures all prompts in a batch use the same decode program. + Yields: + ValidPrompt: A named tuple matching the program criteria and constraints for testing.""" + + used_keys = set() + # for each program, we need to check if we have a shape that satisfies the --programs request + for program_criteria_seq, valid_prompt_shapes in program_map.items(): + # if ? or numeric => we need to check if we have found at least one valid key to stop + if ( + program_info.program_id == "?" or program_info.program_id.isnumeric() + ) and len(used_keys) > 0: + break + # if * => we need to see if we have found the first key to see if we should skip + elif program_info.program_id == "*" and program_criteria_seq[0] in used_keys: + continue + + for valid_prompt_shape in valid_prompt_shapes: + # make sure the criteria for batch limit and prompt limit is satisfied + # eval is safe here because we have limited what type and limit can be before + + batch_check = eval( + f"valid_prompt_shape[0] {program_info.batch_size_limit_type} {program_info.batch_size_limit}" + ) + prompt_check = eval( + f"valid_prompt_shape[1] {program_info.prompt_length_limit_type} {program_info.prompt_length_limit}" + ) + if not batch_check or not prompt_check: + continue + + # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length + # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning + # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user + enforce_sizes = [valid_prompt_shape[1]] + if enforce_homogeneous_prompt_programs: + # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) + possible_seq_lengths = [ + _ for _ in range(tkv_cutoff, valid_prompt_shape[1], pad_multiple) + ] + # favor sequences that are close to the valid prompt length + possible_seq_lengths.reverse() + enforce_sizes = enforce_sizes + list( + itertools.islice( + itertools.cycle(possible_seq_lengths), + valid_prompt_shape[0] - 1, + ) + ) + + try: + input_ids, extra_kwargs, sample_key = _prepare_inputs( + batch_size=valid_prompt_shape[0], + seq_length=valid_prompt_shape[1], + tokenizer=tokenizer, + sampler=sampler, + dataset_path=dataset_path, + allow_truncation=allow_truncation, + enforce_sizes=enforce_sizes, + ) + used_keys.add(program_criteria_seq[0]) + yield ValidPrompt( + program_id=program_criteria_seq[0], + shape=valid_prompt_shape, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + sample_key=sample_key, + ) + break + except ValueError as e: + dprint(f"Failed to prepare inputs for shape {valid_prompt_shape}: {e}") + + if len(used_keys) == 0: + r0dprint( + f"No valid prompt shape was found which would result in program {program_info.program_id} that satisfied batch{program_info.batch_size_limit_type}{program_info.batch_size_limit} and prompt_length{program_info.prompt_length_limit_type}{program_info.prompt_length_limit}" + ) + + +def _get_valid_prompts( + program_map: dict, + dataset_path: str, + enforce_homogeneous_prompt_programs: bool, + programs_to_test: List[ProgramInfo], + program_criteria_list: List[ProgramCriteria], + tokenizer: AutoTokenizer, + sampler: Callable[..., tuple[list[tuple[str, int]], str]], + allow_truncation: bool, + pad_multiple: int, +) -> Generator[ValidPrompt, None, None]: + """Generator that yields valid prompts matching program criteria and constraints. + + Iterates through programs to test and finds prompts from the dataset that satisfy + the program's batch size and prompt length constraints. For custom datasets, uses + the provided shape directly. For other datasets, samples prompts matching the + program criteria. When enforce_homogeneous_prompt_programs is True, generates + multiple sequence lengths within a batch to ensure all prompts hit the same program. + + Args: + program_map: Dictionary mapping program sequences to valid prompt shapes. + dataset_path: Path to the dataset for sampling prompts. + enforce_homogeneous_prompt_programs: If True, ensures all prompts in a batch + use the same decode program. + programs_to_test: List of ProgramInfo objects specifying programs and constraints. + program_criteria_list: List of ProgramCriteria defining program boundaries. + tokenizer: HuggingFace tokenizer for encoding prompts. + sampler: Callable for sampling prompts from the dataset. + allow_truncation: If True, allows truncating prompts exceeding max length. + pad_multiple: Padding granularity for sequence lengths (typically 64). + + Yields: + ValidPrompt: A named tuple matching the program criteria and constraints for testing.""" + + for program_info in programs_to_test: + program_id = program_info.program_id + + filtered_program_map = program_map + if program_id.isnumeric(): + filtered_program_map = { + k: v + for k, v in program_map.items() + if k[0] == program_criteria_list[int(program_id)] + } + + yield from _get_valid_prompts_by_shape( + filtered_program_map, + program_info, + tokenizer, + sampler, + dataset_path, + allow_truncation, + pad_multiple, + enforce_homogeneous_prompt_programs, + ) + + +def prepare_test_prompts( + program_criteria_json_path: str, + programs: List[str], + max_new_tokens: int, + prioritize_large_batch_sizes: bool, + enforce_homogeneous_prompt_programs: bool, + max_batch_size: int, + max_tkv: int, + tkv_limit: Optional[int], + tokenizer: AutoTokenizer, + sampler: Callable[..., tuple[list[tuple[str, int]], str]], + allow_truncation: bool, + custom_shape: Optional[Tuple[int, int]], + dataset_path: str, +) -> Generator[ValidPrompt, None, None]: + """Parses program criteria and generates the sequence of valid test prompts. + + Args: + program_criteria_json_path: Path to JSON file containing program criteria definitions. + programs: List of program specifications from command line arguments. + max_new_tokens: Maximum number of tokens to generate for each prompt. + prioritize_large_batch_sizes: If True, prioritizes larger batch sizes when selecting prompts. + enforce_homogeneous_prompt_programs: If True, ensures all prompts in a batch use the same decode program. + max_batch_size: Maximum batch size to consider when selecting prompts. + max_tkv: Maximum total key-value size to consider when selecting prompts. + tkv_limit: Optional limit on total key-value size for prompts. + tokenizer: HuggingFace tokenizer for encoding prompts. + sampler: Callable for sampling prompts from the dataset. + allow_truncation: If True, allows truncating prompts that exceed max sequence length. + custom_shape: Optional tuple of (batch_size, seq_length) for custom datasets. + dataset_path: Path to the dataset for sampling prompts.""" + + with open(program_criteria_json_path, "r") as f: + program_criteria_json_list = json.load(f)["programs"] + program_criteria_list = [] + for i, d in enumerate(program_criteria_json_list): + program_criteria_list.append( + ProgramCriteria( + i, + d["max_batch"], + d["max_tkv"], + d["batch_granularity"], + d["tkv_granularity"], + ) + ) + + programs_to_test = get_programs_to_test(programs, program_criteria_list) + + # FIXME: filter condition for this on prompt and batch + program_map = get_programs_prompts( + program_criteria_list=program_criteria_list, + multiple=PAD_MULTIPLE, + max_batch_size=max_batch_size, + max_tkv=max_tkv, + program_cycles=max_new_tokens, + tkv_limit=tkv_limit, + prioritize_large_batch_sizes=prioritize_large_batch_sizes, + ) + + for v in program_map.values(): + random.Random(42).shuffle(v) + + if custom_shape: + # Exit early if the user has selected a custom shape + return _get_valid_prompts_by_custom_shape( + program_map, + custom_shape, + tokenizer, + sampler, + dataset_path, + allow_truncation, + ) + + # Select concrete prompts and program associations + return _get_valid_prompts( + program_map=program_map, + dataset_path=dataset_path, + enforce_homogeneous_prompt_programs=enforce_homogeneous_prompt_programs, + programs_to_test=programs_to_test, + program_criteria_list=program_criteria_list, + tokenizer=tokenizer, + sampler=sampler, + allow_truncation=allow_truncation, + pad_multiple=PAD_MULTIPLE, + ) + + +def resolve_dataset_path( + dataset_type: DatasetType, dataset_path: Optional[str] = None +) -> str: + """Resolves the dataset path based on the specified dataset type and optional user-provided path. + + Args: + dataset_type: The type of dataset to resolve. + dataset_path: Optional path to a custom dataset. If not provided, the default path for the given dataset_type is used. + A dataset path must be provided if dataset_type is CUSTOM. For other dataset types, if dataset_path is not provided, + the function will attempt to download the dataset from HuggingFace or fetch a cached download. + Returns: + The local file path to the dataset.""" + + # If the user manually provided a dataset path to use, use the provided path. + if dataset_path is not None: + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"Dataset file not found at {dataset_path}") + + r0dprint( + f"Using provided dataset path {dataset_path} for dataset type {dataset_type}" + ) + return dataset_path + + # Dataset is not provided + if dataset_type == DatasetType.CUSTOM: + raise ValueError("dataset_path must be provided when dataset_type is CUSTOM") + + def _resolve_remote_dataset(): + if dataset_type == DatasetType.SHAREGPT: + r0dprint("Using ShareGPT dataset from HuggingFace") + # Fetch from HuggingFace or use cached download + hf_dataset_path = hf_hub_download( + repo_id=SHARE_GPT_DATASET[0], + filename=SHARE_GPT_DATASET[1], + repo_type="dataset", + ) + elif dataset_type == DatasetType.RAG_FACTOID: + r0dprint("Using RAG Factoid dataset from HuggingFace") + # Fetch from HuggingFace or use cached download + hf_dataset_path = hf_hub_download( + repo_id=RAG_FACTOID_DATASET[0], + filename=RAG_FACTOID_DATASET[1], + repo_type="dataset", + ) + + return hf_dataset_path + + # Initially download only for rank 0 to avoid redundant downloads in distributed settings + if local_rank == 0: + local_dataset_path = _resolve_remote_dataset() + + # Synchronize all ranks to ensure the dataset is downloaded before any rank attempts to access it + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # Resolve the paths for non-zero ranks after download is complete + if local_rank != 0: + local_dataset_path = _resolve_remote_dataset() + + return local_dataset_path diff --git a/aiu_fms_testing_utils/testing/dpp/prepare_model.py b/aiu_fms_testing_utils/testing/dpp/prepare_model.py new file mode 100644 index 00000000..62ae8be6 --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/prepare_model.py @@ -0,0 +1,112 @@ +import os +from typing import Any, Dict + +import torch +from fms.models import get_model + +from aiu_fms_testing_utils.testing.dpp.program_models import DeviceType +from aiu_fms_testing_utils.utils import stagger_region +from aiu_fms_testing_utils.utils.aiu_setup import dprint +from aiu_fms_testing_utils.utils.dpp_config import DPPRunnerConfig +from aiu_fms_testing_utils.utils.env_utils import scoped_environ + + +def _prepare_fp8_weights(model: torch.nn.Module) -> None: + """Converts model weights from bfloat16 to float16 for FP8 attention. + + When using FP8 attention variants, this function converts all bfloat16 parameters + to float16. Issues a warning if any parameter values exceed the float16 range, + which may cause accuracy loss. + + Args: + model: PyTorch model whose weights may need conversion.""" + + for name, param in model.named_parameters(): + if param.dtype == torch.bfloat16: + if param.max() > torch.finfo(torch.float16).max: + dprint( + f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy." + ) + param.data = param.data.to(dtype=torch.float16) + + +def get_model_kwargs(model_variant: str) -> Dict[str, Any]: + """Constructs model loading kwargs based on whether variant is a path or ID. + + Determines if the model_variant is a local filesystem path or a HuggingFace + model identifier, and returns the appropriate keyword arguments for model loading. + + Args: + model_variant: Either a local path to model files or a HuggingFace model ID. + + Returns: + Dictionary with either "model_path" (for local paths) or "variant" + (for HuggingFace IDs) as the key.""" + + model_kwargs = {} + if os.path.exists(model_variant): + model_kwargs["model_path"] = model_variant + else: + model_kwargs["variant"] = model_variant + + return model_kwargs + + +def load_model( + device_type: DeviceType, + is_fp8: bool, + model_kwargs: Dict[str, Any], + distributed_kwargs: Dict[str, Any], + stagger_load: int, + model_config: DPPRunnerConfig, +): + """Loads and optionally compiles a model for inference or validation. + + Loads a model with the specified configuration. For Spyre/AIU models, + compiles the model using the sendnn backend with dynamic compilation enabled. + The scoped_environ context manager temporarily sets environment variables + from model_config during compilation to configure the compiler's behavior (e.g., + program criteria, batch sizes, context lengths). + + Args: + device_type: Target device for model execution. + is_fp8: If True, uses FP8 quantization (dtype=None for auto-detection). + model_kwargs: Dictionary with model loading parameters (variant or path). + distributed_kwargs: Dictionary with distributed training configuration. + stagger_load: Number of concurrent processes allowed during loading (0=unlimited). + model_config: DPPRunnerConfig instance with environment variable updates. + + Returns: + torch.nn.Module: Loaded model in evaluation mode. Spyre models are compiled + with sendnn backend and may have FP8 weight conversion applied.""" + + dtype = torch.float32 if device_type == DeviceType.CPU else torch.float16 + if is_fp8: + dtype = None # Let the model loading logic decide the appropriate FP8 dtype + + with stagger_region(stagger_load): + model = get_model( + architecture="hf_pretrained", + device_type="cpu", + data_type=dtype, + fused_weights=False, + **model_kwargs, + **distributed_kwargs, + ) + + model.eval() + + if device_type != DeviceType.SPYRE: + return model + + with scoped_environ(model_config.env_updates()): + # Temporarily set environment variables needed for compile + dprint(f"Compiling model for Spyre execution with DPP config: {model_config}") + # Ignore autograd warning; this is needed + model.compile(backend="sendnn", options={"sendnn.dynamic": True}) + + if is_fp8: + dprint("Converting model weights for FP8 attention...") + _prepare_fp8_weights(model) + + return model diff --git a/aiu_fms_testing_utils/testing/dpp/prepare_programs.py b/aiu_fms_testing_utils/testing/dpp/prepare_programs.py new file mode 100644 index 00000000..efb1486c --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/prepare_programs.py @@ -0,0 +1,106 @@ +import re + +from aiu_fms_testing_utils.testing.dpp.program_models import ProgramInfo +from aiu_fms_testing_utils.utils.paged import ProgramCriteria + + +def _parse_program_limit(limit_str: str) -> tuple[int, str | None]: + """Parses a program limit string into a numeric value and comparison operator. + + Accepts either a plain integer (defaults to ">=" for backward compatibility) + or a string with a comparison operator prefix (e.g., ">=10", "<5", "==8"). + + Args: + limit_str: String representation of the limit, either a number or + operator+number (e.g., "10", ">=10", "<5"). + + Returns: + Tuple containing: + - limit_val: The numeric limit value. + - limit_type: The comparison operator string (">=", "<=", "<", ">", "=="). + + Raises: + ValueError: If the limit string format is invalid.""" + + matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)") + + # Default limit to min to maintain backwards compat + try: + limit_type = ">=" + limit_val = int(limit_str) + except ValueError: + limit_type = None + match = matcher.fullmatch(limit_str) + if match is None: + raise ValueError("Program not well formatted, wrong limit type") + limit_type = match.group(1) + limit_val = int(match.group(2)) + return limit_val, limit_type + + +def get_programs_to_test( + programs: list[str], program_criteria_list: list[ProgramCriteria] +) -> list[ProgramInfo]: + """Parses program specifications into ProgramInfo objects for testing. + + Converts command-line program specifications into structured ProgramInfo objects. + Supports three formats: + - Empty list: Tests all programs with any valid prompt. + - "program_id": Tests specific program with any valid prompt. + - "program_id:batch_constraint,prompt_constraint": Tests program with specific constraints. + + Args: + programs: List of program specification strings from command line. + program_criteria_list: List of ProgramCriteria objects defining available programs. + + Returns: + List of ProgramInfo objects representing programs to test with their constraints.""" + + if not isinstance(programs, list): + raise ValueError( + "Programs argument must be a list of program criteria strings." + ) + + if not isinstance(program_criteria_list, list): + raise ValueError( + "Program criteria list must be a list of ProgramCriteria objects." + ) + + programs_to_test = [] + for program_str in programs: + enforce_prompt_split = program_str.split(":") + program_id = enforce_prompt_split[0] + if len(enforce_prompt_split) == 1: + programs_to_test.append( + ProgramInfo(program_id, 0, ">=", 0, ">=") + ) # this will always satisfy + else: + enforce_batch_size, enforce_prompt_length = ( + _ for _ in enforce_prompt_split[1].split(",") + ) + + # Default limit to min to maintain backwards compat + enforce_batch_size_val, enforce_batch_size_type = _parse_program_limit( + enforce_batch_size + ) + enforce_prompt_length_val, enforce_prompt_length_type = ( + _parse_program_limit(enforce_prompt_length) + ) + + programs_to_test.append( + ProgramInfo( + program_id, + enforce_batch_size_val, + enforce_batch_size_type, + enforce_prompt_length_val, + enforce_prompt_length_type, + ) + ) + + if not programs_to_test: + # If no programs specified, test all programs with any valid prompt + for program_criteria in program_criteria_list: + pid = str(program_criteria.program_id) + programs_to_test.append(ProgramInfo(pid, 0, ">=", 0, ">=")) + + return programs_to_test diff --git a/aiu_fms_testing_utils/testing/dpp/program_models.py b/aiu_fms_testing_utils/testing/dpp/program_models.py new file mode 100644 index 00000000..7e47aca6 --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/program_models.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, NamedTuple, Tuple + +import torch + + +class DeviceType(Enum): + CPU = "cpu" + SPYRE = "spyre" + + +class TestType(Enum): + METRICS = "metrics" + TOKENS = "tokens" + + +class AttnType(Enum): + SDPA = "sdpa_causal" + MATH = "math" + PAGED = "spyre_paged_attn" + MATH_FP8 = "math_fp8" + PAGED_FP8 = "spyre_paged_attn_fp8" + + +class DatasetType(Enum): + CUSTOM = "custom" + RAG_FACTOID = "rag_factoid" + SHAREGPT = "sharegpt" + + +@dataclass +class ProgramInfo: + """Encapsulates program execution criteria. + + Attributes: + program_id: Unique identifier for the program being tested. + batch_size_limit: Numeric threshold for batch size constraint. + batch_size_limit_type: Comparison operator for batch size (e.g., ">=", "<=", "=="). + prompt_length_limit: Numeric threshold for prompt length constraint. + prompt_length_limit_type: Comparison operator for prompt length (e.g., ">=", "<=", "==").""" + + program_id: str + batch_size_limit: int + batch_size_limit_type: str + prompt_length_limit: int + prompt_length_limit_type: str + + +class EnvConfig(NamedTuple): + """Represents global configuration derived from environment and CLI. + + Attributes: + attn_name: The internal name of the attention algorithm. + cpu_dtype: Data type for CPU validation ('fp8' or 'fp32'). + max_batch_size: Maximum batch size. + max_tkv: Maximum total key-value (context) length.""" + + attn_type: AttnType + cpu_dtype: str + max_batch_size: int + max_tkv: int + + +class MetricResult(NamedTuple): + """Result of comparing AIU and CPU logit distributions. + + Attributes: + cross_entropy_loss: Cross-entropy loss between the distributions. + mean_abs_diff: Mean absolute difference of softmax probabilities.""" + + cross_entropy_loss: float + mean_abs_diff: float + + def __str__(self) -> str: + return f"cross_entropy_loss: {self.cross_entropy_loss:.6f}, mean_abs_diff: {self.mean_abs_diff:.6f}" + + +class PreparedInputs(NamedTuple): + """Represents prepared model inputs from dataset sampling. + + Attributes: + input_ids: Padded tensor of tokenized input IDs with shape (batch_size, seq_length). + extra_kwargs: Dictionary with attention mask and other model inputs. + sample_key: String identifier for the sampled prompts.""" + + input_ids: torch.Tensor + extra_kwargs: Dict[str, Any] + sample_key: str + + +class ValidPrompt(NamedTuple): + """Represents a valid prompt configuration for program execution. + + Attributes: + program_id: ID of the program this prompt will execute. + shape: Tuple of (batch_size, seq_length) for this prompt. + input_ids: Tokenized and padded input tensor. + extra_kwargs: Dictionary with attention mask and other model inputs. + sample_key: String identifier for the sampled prompts.""" + + program_id: str + shape: Tuple[int, int] + input_ids: torch.Tensor + extra_kwargs: Dict[str, Any] + sample_key: str diff --git a/aiu_fms_testing_utils/testing/dpp/run_drive_paged_programs.py b/aiu_fms_testing_utils/testing/dpp/run_drive_paged_programs.py new file mode 100644 index 00000000..6f42c1cf --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/run_drive_paged_programs.py @@ -0,0 +1,360 @@ +import datetime +import os +from typing import Any, Dict, Iterable, List, Optional + +import torch +from fms.utils.generation import pad_input_ids +from torch import distributed as dist +from torch.fx.experimental import _config as fx_config +from transformers import AutoTokenizer + +from aiu_fms_testing_utils.testing.dpp.generation import ( + generate_aiu_cpu_test, + generate_aiu_test, +) +from aiu_fms_testing_utils.testing.dpp.prepare_data import ( + prepare_test_prompts, + resolve_dataset_path, + PAD_MULTIPLE, +) +from aiu_fms_testing_utils.testing.dpp.prepare_model import ( + get_model_kwargs, + load_model, +) +from aiu_fms_testing_utils.testing.dpp.program_models import ( + AttnType, + DatasetType, + DeviceType, + EnvConfig, + TestType, + ValidPrompt, +) +from aiu_fms_testing_utils.testing.dpp.sample_prompts import get_sampler +from aiu_fms_testing_utils.utils import warmup_model +from aiu_fms_testing_utils.utils.aiu_setup import ( + aiu_dist_setup, + dprint, + is_distributed, + local_rank, + r0dprint, + world_size, +) +from aiu_fms_testing_utils.utils.dpp_config import DPPRunnerConfig +from aiu_fms_testing_utils.utils.model_setup import Timing + +DEFAULT_CE_THRESHOLD = 2.5 +DEFAULT_FAILURE_RATE_THRESHOLD = 0.1 + + +def _get_distributed_kwargs(dist_timeout: str) -> Dict[str, Any]: + """Initializes distributed training configuration and returns kwargs. + + Sets up PyTorch distributed process group with tensor parallelism strategy. Configures custom timeout if specified. + + Args: + dist_timeout: Timeout in minutes for distributed operations (0 uses default). + + Returns: + Dictionary containing distributed configuration with keys: + - "distributed_strategy": Set to "tp" (tensor parallelism) if distributed. + - "group": PyTorch distributed group (WORLD) if distributed.""" + + if dist_timeout > 0: + # Default timeout: + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group + dist.init_process_group(timeout=datetime.timedelta(minutes=dist_timeout)) + dprint(f"NOTICE: init_process_group timeout set to {dist_timeout} minutes") + else: + dist.init_process_group() + + aiu_dist_setup(dist.get_rank(), dist.get_world_size()) + return { + "distributed_strategy": "tp", + "group": dist.group.WORLD, + } + + +def _setup_environment( + program_criteria_json_path: str, attention_type: AttnType +) -> EnvConfig: + """Set up global process state and environment variables. + + Args: + program_criteria_json_path: Path to the JSON file containing program criteria definitions. + attention_type: Type of attention mechanism to use. + + Returns: + EnvConfig: Immutable configuration containing: + - attn_type: Mapped attention implementation name + - cpu_dtype: Data type for CPU operations ("fp8" or "fp32") + - max_batch_size: Maximum batch size from VLLM_DT_MAX_BATCH_SIZE + - max_tkv: Maximum token-key-value context length from VLLM_DT_MAX_CONTEXT_LEN + + Raises: + ValueError: If required environment variables VLLM_DT_MAX_CONTEXT_LEN or + VLLM_DT_MAX_BATCH_SIZE are not set.""" + + os.environ["COMPILATION_MODE"] = "offline_decoder" + os.environ["DT_PROG_CRITERIA_FILEPATH"] = program_criteria_json_path + + if ( + "VLLM_DT_MAX_CONTEXT_LEN" not in os.environ + or "VLLM_DT_MAX_BATCH_SIZE" not in os.environ + ): + r0dprint("Missing required VLLM environment variables.") + raise ValueError( + "Environment variables VLLM_DT_MAX_CONTEXT_LEN and VLLM_DT_MAX_BATCH_SIZE must be set before running DPP." + ) + + torch.manual_seed(42) + torch.set_grad_enabled(False) + fx_config.backed_size_oblivious = True + + return EnvConfig( + attn_type=attention_type, + cpu_dtype="fp8" if "fp8" in attention_type.value else "fp32", + max_batch_size=int(os.environ["VLLM_DT_MAX_BATCH_SIZE"]), + max_tkv=int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]), + ) + + +def _run_aiu_cpu_tests( + model: torch.nn.Module, + validation_model: torch.nn.Module, + tokenizer: AutoTokenizer, + valid_prompts: Iterable[ValidPrompt], + env_config: EnvConfig, + model_config: DPPRunnerConfig, + test_type: TestType, + max_new_tokens: int, + validation_info_outputs_dir: str, + cross_entropy_threshold: float, + failure_rate_threshold: float, + timing: Timing, + prefill_chunk_size: int, + model_variant: str, + save_validation_info_outputs: bool = False, +): + """Runs tests comparing AIU and CPU outputs for given prompts.""" + + # Validation and Testing + failed_cases = generate_aiu_cpu_test( + valid_prompts, + model, + validation_model, + tokenizer, + env_config, + model_config, + test_type, + max_new_tokens, + save_validation_info_outputs and dist.get_rank() == 0, + validation_info_outputs_dir, + cross_entropy_threshold, + failure_rate_threshold, + timing, + prefill_chunk_size, + model_variant, + ) + + if local_rank != 0: + return + + if len(failed_cases) != 0: + dprint("The test failed with the following cases:") + for failed_case in failed_cases: + dprint( + f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}" + ) + else: + dprint("All tests passed") + + +def run_dpp( + program_criteria_json_path: str, + dataset_type: DatasetType, + max_new_tokens: int, + model_variant: str, + dataset_path: Optional[str] = None, + programs: List[str] = None, + timing: Timing = Timing.NONE, + test_type: TestType = TestType.METRICS, + cross_entropy_threshold: float = DEFAULT_CE_THRESHOLD, + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD, + attention_type: AttnType = AttnType.PAGED, + prefill_chunk_size: int = 0, + stagger_load: int = 0, + stagger_update_lazyhandle: int = 0, + dist_timeout: int = 0, + run_cpu_validation: bool = True, + validation_info_outputs_dir: str = None, + save_validation_info_outputs: bool = False, + prioritize_large_batch_sizes: bool = False, + enforce_homogeneous_prompt_programs: bool = False, +): + """Main execution function for driving paged program validation tests. + + Workflow: + 1. Sets and configures environment. + 2. Loads models (both AIU-compiled and CPU validation). + 3. Warms up the model. + 4. Selects programs and prompts to test based on criteria. + 5. For each program/prompt combination: + - Generates CPU validation data (or loads from cache). + - Runs AIU inference. + - Compares outputs using metrics or token-based validation. + 6. Prints results and failure cases.""" + + if programs is None: + programs = [] + + if not validation_info_outputs_dir and run_cpu_validation: + raise ValueError( + "validation_info_outputs_dir must be specified if run_cpu_validation is True" + ) + + if not os.path.exists(program_criteria_json_path): + raise FileNotFoundError( + f"Program criteria JSON file not found at {program_criteria_json_path}" + ) + + if not os.path.exists(validation_info_outputs_dir): + os.makedirs(validation_info_outputs_dir, exist_ok=True) + + if run_cpu_validation: + r0dprint( + f"Running DPP on model {model_variant} and validating against CPU reference outputs in {validation_info_outputs_dir}" + ) + else: + r0dprint( + f"Running DPP on {model_variant} without CPU validation (outputs will not be compared against reference)" + ) + + if is_distributed: + r0dprint(f"Running DPP in distributed mode with {world_size} ranks") + else: + r0dprint("Running DPP in single-process mode") + + dataset_path = resolve_dataset_path(dataset_type, dataset_path) + + is_fp8 = attention_type == AttnType.PAGED_FP8 + if not run_cpu_validation and test_type == TestType.METRICS: + r0dprint("When skipping validation, only test_type will be ignored") + + # Environment Setup + env_config = _setup_environment( + program_criteria_json_path=program_criteria_json_path, + attention_type=attention_type, + ) + tokenizer = AutoTokenizer.from_pretrained(model_variant) + sampler, allow_truncation, custom_shape = get_sampler( + dataset_type=dataset_type, + dataset_path=dataset_path, + tokenizer=tokenizer, + ) + + # Model Loading + model_kwargs = get_model_kwargs(model_variant) + distributed_kwargs = _get_distributed_kwargs(dist_timeout) if is_distributed else {} + + # Setup model config + model_config = DPPRunnerConfig() + model_config.setup_config( + model_variant=model_variant, + use_distributed=is_distributed, + prefill_chunk_size=prefill_chunk_size, + ) + + # Prompt Preparation + valid_prompts = prepare_test_prompts( + program_criteria_json_path, + programs, + max_new_tokens, + prioritize_large_batch_sizes, + enforce_homogeneous_prompt_programs, + env_config.max_batch_size, + env_config.max_tkv, + model_config.tkv_limit, + tokenizer, + sampler, + allow_truncation, + custom_shape, + dataset_path, + ) + + model = load_model( + device_type=DeviceType.SPYRE, + is_fp8=is_fp8, + model_kwargs=model_kwargs, + distributed_kwargs=distributed_kwargs, + stagger_load=stagger_load, + model_config=model_config, + ) + + # Model Warmup + ## warmup with any input so compiler produces criteria json + ## TODO: Swap this with _prepare_inputs once fix for shape_id is available + ## input_ids, extra_kwargs, sample_key = _prepare_inputs(2, max_tkv, tokenizer) + prompt_list = [torch.arange(0, PAD_MULTIPLE, dtype=torch.int64)] + + # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 + if is_fp8: + prompt_list = prompt_list * 2 + + input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) + extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + extra_kwargs["attn_name"] = env_config.attn_type.value + extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks + warmup_model( + model, + input_ids, + max_new_tokens, + compile_dynamic_sendnn=True, + stagger_update_lazyhandle=stagger_update_lazyhandle, + prefill_chunk_size=prefill_chunk_size, + **extra_kwargs, + ) + + if is_distributed: + # wait for rank0 to be finished as it is the only one generating the criteria json + # this is needed since otherwise we may run into a race condition + torch.distributed.barrier() + + # Test Execution + if run_cpu_validation: + validation_model = load_model( + device_type=DeviceType.CPU, + is_fp8=is_fp8, + model_kwargs=model_kwargs, + distributed_kwargs=distributed_kwargs, + stagger_load=stagger_load, + model_config=model_config, + ) + _run_aiu_cpu_tests( + model, + validation_model, + tokenizer, + valid_prompts, + env_config, + model_config, + test_type, + max_new_tokens, + validation_info_outputs_dir, + cross_entropy_threshold, + failure_rate_threshold, + timing, + prefill_chunk_size, + model_variant, + save_validation_info_outputs, + ) + else: + generate_aiu_test( + valid_prompts, + model, + tokenizer, + env_config, + model_config, + test_type, + max_new_tokens, + timing, + prefill_chunk_size, + ) diff --git a/aiu_fms_testing_utils/testing/dpp/sample_prompts.py b/aiu_fms_testing_utils/testing/dpp/sample_prompts.py new file mode 100644 index 00000000..790f2317 --- /dev/null +++ b/aiu_fms_testing_utils/testing/dpp/sample_prompts.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Callable, Optional + +from transformers import AutoTokenizer + +from aiu_fms_testing_utils.testing.dpp.program_models import DatasetType +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string +from aiu_fms_testing_utils.utils import ( + get_pad_size, + sample_rag_factoid_requests, + sample_sharegpt_requests, +) +from aiu_fms_testing_utils.utils.aiu_setup import dprint, r0dprint + + +def _custom_line_sampler(result: list[tuple[str, int]], **kwargs): + """Custom sampler for user-provided text files. + + Returns pre-loaded prompts from custom dataset files without + additional sampling logic. Supports optional sample key return. + + Args: + result: List of (prompt, padded_size) tuples. + **kwargs: Keyword arguments, supports "return_key" flag. + + Returns: + List of (prompt, padded_size) tuples, or tuple of (list, sample_key) + if return_key=True.""" + + return_key = kwargs.get("return_key", False) + sample_key = format_kwargs_to_string(**kwargs) + + if return_key: + return result, sample_key + + return result + + +def get_sampler( + dataset_type: DatasetType, dataset_path: str, tokenizer: AutoTokenizer +) -> tuple[ + Callable[..., tuple[list[tuple[str, int]], str]], bool, Optional[tuple[int, int]] +]: + """Selects and configures the sampler based on type. + + Returns a sampler function and configuration for the specified dataset type. + + Args: + dataset_type: Type of dataset. + dataset_path: Path to the dataset file or directory. + tokenizer: HuggingFace tokenizer for encoding prompts. + + Returns: + Tuple containing: + - sampler: Callable function for sampling prompts from the dataset. + - allow_truncation: Boolean indicating if prompt truncation is allowed. + - custom_shape: Tuple of (batch_size, max_seq_length) for custom datasets, + None for other dataset types. + + Raises: + ValueError: If dataset_type is not one of the supported types.""" + + custom_shape = None + if dataset_type == DatasetType.CUSTOM: + r0dprint( + "Using custom prompts from user, programs parameter will be ignored as it will be determined by user prompt" + ) + directory = Path(dataset_path) + if not directory.is_dir(): + raise NotADirectoryError( + f"Custom dataset path {dataset_path} is not a directory" + ) + + result = [] + for fp in directory.iterdir(): + if not fp.is_file(): + continue + + try: + content = fp.read_text() + pad_size = get_pad_size(len(tokenizer.encode(content))) + result.append((content, pad_size)) + except Exception as e: + dprint(f"Error while reading {fp} for custom dataset: {e}") + raise + + custom_shape = (len(result), max([_[1] for _ in result])) + + sampler = _custom_line_sampler + allow_truncation = False + elif dataset_type == DatasetType.RAG_FACTOID: + sampler = sample_rag_factoid_requests + allow_truncation = False + elif dataset_type == DatasetType.SHAREGPT: + sampler = sample_sharegpt_requests + allow_truncation = True + else: + raise ValueError("dataset_type must be one of rag_factoid or sharegpt") + + return sampler, allow_truncation, custom_shape diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 8174eca7..12feaa58 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -2,10 +2,12 @@ from typing import List, Tuple, Callable, MutableMapping, Any, Optional import torch -from aiu_fms_testing_utils.utils.aiu_setup import dprint -from aiu_fms_testing_utils._version import version_tuple +from aiu_fms_testing_utils.utils.aiu_setup import dprint, r0dprint import os from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string +from aiu_fms_testing_utils.utils.model_setup import Timing +from aiu_fms_testing_utils.testing.dpp.program_models import DeviceType, AttnType +from aiu_fms_testing_utils._version import version_tuple import hashlib @@ -42,10 +44,14 @@ class StaticTokenInjectorHook( Tuple[torch.Tensor, MutableMapping[str, Any]], ] ): - def __init__(self, static_tokens: List[torch.Tensor], device_type: str = "cpu"): + def __init__( + self, + static_tokens: List[torch.Tensor], + device_type: DeviceType = DeviceType.CPU, + ): super().__init__() self.static_tokens = torch.tensor( - static_tokens, device=device_type + static_tokens, device=device_type.value ).t() # transposing so batch tokens per token_position def __call__( @@ -61,7 +67,7 @@ class GoldenTokenHook( Tuple[torch.Tensor, MutableMapping[str, Any]], ] ): - def __init__(self, static_tokens: torch.Tensor, device_type: str = "cpu"): + def __init__(self, static_tokens: torch.Tensor, device_type: str = DeviceType.CPU): super().__init__() self.logits_extractor = LogitsExtractorHook() self.extracted_logits = None @@ -260,10 +266,10 @@ def extract_validation_information( input_ids, max_new_tokens, post_iteration_hook, - attn_algorithm=None, + attn_algorithm: Optional[AttnType] = None, eos_token_id=None, last_n_tokens=0, - timing="", + timing=Timing.NONE, prefill_chunk_size=0, **extra_kwargs, ): @@ -284,8 +290,10 @@ def extract_validation_information( if last_n_tokens != 0: extra_generation_kwargs["last_n_tokens"] = last_n_tokens if attn_algorithm is not None: - extra_generation_kwargs["attn_algorithm"] = attn_algorithm + extra_generation_kwargs["attn_algorithm"] = attn_algorithm.value + r0dprint(f"Calling generate with timing: {timing}") + r0dprint(f"Calling generate with kwargs: {extra_generation_kwargs}") result = generate( model, input_ids, @@ -294,34 +302,42 @@ def extract_validation_information( do_sample=False, post_iteration_hook=post_iteration_hook, eos_token_id=eos_token_id, - timing=timing, + timing=timing.value if timing != Timing.NONE else "", extra_kwargs=extra_generation_kwargs, **attention_specific_kwargs, ) - if timing != "": + r0dprint(f"Model generate result ({type(result)}): {result}") + + # Split result into model output and timings (empty list if none) + if isinstance(result, tuple): + model_output, timing_results = result + else: + model_output, timing_results = result, [] # If the result is just a tensor + + if timing != Timing.NONE and timing_results: dprint( "=== This timing information might be inaccurate due to extra work being done in generate() for validation" ) - result, timings = result - if timing == "e2e": - dprint(f"E2E timing information: {timings[0]:.3f}s") - elif timing == "per-token": - timings = [f"{t * 1000:.3f}" for t in timings] - dprint(f"Per-token timing information: {', '.join(timings)} ms") + if timing == Timing.E2E: + dprint(f"E2E timing information: {timing_results[0]:.3f}s") + elif timing == Timing.PER_TOKEN: + timing_results = [f"{t * 1000:.3f}" for t in timing_results] + dprint(f"Per-token timing information: {', '.join(timing_results)} ms") - if len(result.shape) == 1: - result = result.unsqueeze(0) + if len(model_output.shape) == 1: + model_output = model_output.unsqueeze(0) if hasattr(post_iteration_hook, "extracted_logits"): validation_info = [ {"tokens": t.to("cpu"), "logits": logits.to("cpu")} for t, logits in zip( - torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits) + torch.unbind(model_output), + torch.unbind(post_iteration_hook.extracted_logits), ) ] else: - validation_info = [{"tokens": t.to("cpu")} for t in torch.unbind(result)] + validation_info = [{"tokens": t.to("cpu")} for t in torch.unbind(model_output)] return ValidationInfo(validation_info) @@ -409,7 +425,7 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer): aiu_str = tokenizer.decode(aiu_token) validation_str = tokenizer.decode(validation_token) - print( + dprint( f"In sentence {sentence_index + 1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}" ) @@ -421,9 +437,9 @@ def get_validation_info_path( seq_length: int, max_new_tokens: int, seed: int, - attn_type: str, + attn_type: AttnType, aftu_version: Optional[Tuple[int, int, int]] = None, - device_type: str = "cpu", + device_type: DeviceType = DeviceType.CPU, dtype: str = "fp16", **kwargs, ): @@ -432,7 +448,27 @@ def get_validation_info_path( sample_key = kwargs.get("sample_key", None) - validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + attn_type_map = { + AttnType.SDPA: "sdpa", + AttnType.PAGED: "paged", + AttnType.MATH: "math", + AttnType.MATH_FP8: "math-fp8", + AttnType.PAGED_FP8: "paged-fp8", + } + + val_prefix = get_default_validation_prefix( + aftu_version=".".join([str(_) for _ in aftu_version[:3]]), + model_id=model_variant, + max_new_tokens=max_new_tokens, + batch_size=batch_size, + seq_length=seq_length, + dtype=dtype, + attn_type=attn_type_map[attn_type], + sample_key=sample_key, + ) + validation_file_name = ( + f"{val_prefix}.{device_type.value}_validation_info.{seed}.out" + ) full_path = os.path.join(validation_info_dir, validation_file_name) return full_path @@ -459,10 +495,10 @@ def find_validation_info_path( seq_length: int, max_new_tokens: int, seed: int, - attn_type: str, + attn_type: AttnType, aftu_version: Optional[Tuple[int, int, int]] = None, version_allow_decrement: bool = False, - device_type: str = "cpu", + device_type: DeviceType = DeviceType.CPU, dtype: str = "fp16", **kwargs, ): diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 1bbb82ec..92f1d46c 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -281,9 +281,10 @@ def __sample_requests( List[Tuple[str, int]] """ - assert prompt_length_max >= prompt_length_min, ( - "Please enter valid prompt length max/min values" - ) + if prompt_length_max < prompt_length_min: + raise ValueError( + f"Max prompt length ({prompt_length_max}) should be larger than min prompt length ({prompt_length_min})" + ) if enforce_sizes is None: enforce_sizes = [] @@ -489,7 +490,7 @@ def sample_rag_factoid_requests( return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): - print("error dataset does not exist") + raise FileNotFoundError(f"Dataset path {dataset_path} does not exist") dataset = [] # Load the dataset. @@ -543,26 +544,16 @@ def sample_sharegpt_requests( pad_multiple: int = 64, return_key: bool = False, ) -> List[Tuple[str, int]]: - if not os.path.exists(dataset_path): - print("downloading share-gpt dataset as it does not exist") - is_distributed_initialized = torch.distributed.is_initialized() - if not is_distributed_initialized or rank < 1: - __download_file( - "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", - dataset_path, - ) - else: - print("waiting for rank0 to complete download") - - if is_distributed_initialized: - torch.distributed.barrier() - if enforce_sizes is None: enforce_sizes = [] + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"Dataset path {dataset_path} does not exist") + # Load the dataset. with open(dataset_path, "r", encoding="utf-8") as f: dataset = json.load(f) + # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset: List[str] = [data["conversations"][0]["value"] for data in dataset] diff --git a/aiu_fms_testing_utils/utils/aiu_setup.py b/aiu_fms_testing_utils/utils/aiu_setup.py index fb2dedf2..62a8a065 100644 --- a/aiu_fms_testing_utils/utils/aiu_setup.py +++ b/aiu_fms_testing_utils/utils/aiu_setup.py @@ -13,14 +13,20 @@ rank = int(os.getenv("RANK", 0)) world_rank = rank world_size = int(os.getenv("WORLD_SIZE", 1)) +is_distributed = "LOCAL_RANK" in os.environ and "RANK" in os.environ -def dprint_str(text): +def dprint_str(text: str) -> str: return f"[{rank:2d}/{world_size:2d}]: {text}" -def dprint(text): - print(dprint_str(text)) +def dprint(*text: str): + print(dprint_str(" ".join(text))) + + +def r0dprint(*text: str): + if rank == 0: + dprint(*text) # ============================================================== diff --git a/aiu_fms_testing_utils/utils/dpp_config.py b/aiu_fms_testing_utils/utils/dpp_config.py index 815675f8..52ff6468 100644 --- a/aiu_fms_testing_utils/utils/dpp_config.py +++ b/aiu_fms_testing_utils/utils/dpp_config.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from aiu_fms_testing_utils.utils.aiu_setup import dprint +from torch import distributed as dist @dataclass @@ -50,17 +51,29 @@ def _configure_granite_3_8b(self, use_distributed, world_size, prefill_chunk_siz context=context, ) - # these values are to be consistent with vllm for granite 3.3 8b instruct - blocks_override = 8192 if prefill_chunk_size > 0 else 2080 - - self.num_blocks = self._get_int_env( - key="AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", - default=blocks_override, + elif use_distributed and world_size == 1: + ##Only set defaults for TP=1 + context = ( + "Model granite-3.3-8b (or compatible) " + "with tensor parallel size 1 detected" + ) + self.tkv_limit = self._get_int_env( + key="VLLM_DT_MAX_BATCH_TKV_LIMIT", + default=131072, context=context, ) + # these values are to be consistent with vllm for granite 3.3 8b instruct + blocks_override = 8192 if prefill_chunk_size > 0 else 2080 + + self.num_blocks = self._get_int_env( + key="AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", + default=blocks_override, + context=context, + ) + def setup_config( - self, model_variant, use_distributed, world_size, prefill_chunk_size + self, model_variant: str, use_distributed: bool, prefill_chunk_size: int ): """Set up environment variables and default values if not specified""" @@ -69,27 +82,32 @@ def setup_config( "granite-3.3-8b-instruct" in model_variant or "granite-4.0-8b" in model_variant ): + world_size = ( + dist.get_world_size() + if use_distributed and dist.is_initialized() + else 1 + ) self._configure_granite_3_8b( use_distributed, world_size, prefill_chunk_size ) - ## global defaults (fallback) - ## TODO: IN future we may remove defaults for unknown configurations \ - ## and require users to set the environment variables - ## num_blocks is set in generate if not set here if self.tkv_limit is None: - self.tkv_limit = self._get_int_env( - key="VLLM_DT_MAX_BATCH_TKV_LIMIT", - default=524288, - context="Unknown model configuration", + raise RuntimeError( + f"Could not determine tkv_limit for model variant '{model_variant}'. " + "Please set the environment variable VLLM_DT_MAX_BATCH_TKV_LIMIT or " + "run this program in distributed mode." ) def env_updates(self) -> dict[str, str]: """Returns a key/value of environment variables needed for model compile""" if self.tkv_limit is None: raise RuntimeError( - "ModelConfig.env_updates() called before setup_config(). " - "Call setup_config(...) first." + "ModelConfig.env_updates() called before setup_config(). Call setup_config(...) first." ) return {"VLLM_DT_MAX_BATCH_TKV_LIMIT": str(self.tkv_limit)} + + def __str__(self) -> str: + return ( + f"DPPRunnerConfig(num_blocks={self.num_blocks}, tkv_limit={self.tkv_limit})" + ) diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index 669749b3..65b029ed 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -13,6 +13,21 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size from aiu_fms_testing_utils.utils import aiu_setup +from enum import Enum + + +class Timing(Enum): + """Enum for timing options in generation function.""" + + NONE = "none" + """No timing measurements will be taken.""" + + E2E = "e2e" + """Measures end-to-end generation time for the entire generation loop.""" + + PER_TOKEN = "per-token" + """Measures time taken for each token generation step.""" + def get_default_dtype(args: argparse.Namespace) -> torch.dtype | None: """Return default_dtype for non-quantized models, otherwise None. diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 49807dc1..843f3b40 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -6,6 +6,7 @@ import torch import fms.utils.spyre.paged # noqa from aiu_fms_testing_utils.utils import get_pad_size +from aiu_fms_testing_utils.utils.model_setup import Timing def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): @@ -39,7 +40,7 @@ def generate( use_cache: bool = False, prefill_chunk_size: int = 0, eos_token_id: Optional[int] = None, - timing: str = "", + timing: Timing = Timing.NONE, post_iteration_hook: Optional[ Callable[ [int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], @@ -47,7 +48,7 @@ def generate( ] ] = None, extra_kwargs: Optional[MutableMapping[str, Any]] = None, -): +) -> Union[torch.Tensor, Tuple[torch.Tensor, List[float]]]: """ A trivial generate function that can be used for validation/testing in cases where HF is not available. @@ -80,7 +81,10 @@ def generate( extra_kwargs: an optional mapping of additional kwargs to pass to the model. For example: if extra_kwargs contains position_ids and mask keys, these model parameters will be updated as-appropriate for each token generated. - """ + Returns: + A tensor of shape (batch x seq + max_new_tokens) with the generated tokens. + If timing is not Timing.NONE, also returns an array of timing measurements as described above.""" + random.seed(0) if num_beams != 1: raise NotImplementedError("generate() does yet not support beam search") @@ -259,7 +263,7 @@ def generate( prompt_length = input_ids.shape[1] - if timing != "": + if timing != Timing.NONE: times: List[float] = [] start_time = time.time() @@ -613,14 +617,14 @@ def generate( else: next_input = result - if timing == "per-token": + if timing == Timing.PER_TOKEN: if input_ids.device.type == "cuda": torch.cuda.synchronize() current_token_time = time.time() - start_time times.append(current_token_time) start_time = time.time() - if timing == "e2e": + if timing == Timing.E2E: if input_ids.device.type == "cuda": torch.cuda.synchronize() e2e_time = time.time() - start_time @@ -629,8 +633,9 @@ def generate( if not is_batch: result = result[0] - if timing != "": + if timing != Timing.NONE: return result, times + return result @@ -661,6 +666,9 @@ def calculate_padding(self, batch_size, tkv): def __str__(self): return f"ProgramCriteria(program_id={self.program_id})" + def __repr__(self): + return f"ProgramCriteria(program_id={self.program_id}, max_batch={self.max_batch}, max_tkv={self.max_tkv}, batch_granularity={self.batch_granularity}, tkv_granularity={self.tkv_granularity})" + def __eq__(self, other): if not isinstance(other, ProgramCriteria): return NotImplemented diff --git a/tests/fixtures/dpp-all-criterion.json b/tests/fixtures/dpp-all-criterion.json new file mode 100644 index 00000000..59d0ae3e --- /dev/null +++ b/tests/fixtures/dpp-all-criterion.json @@ -0,0 +1,220 @@ +{ + "programs": [ + { + "max_batch": 4, + "max_tkv": 32768, + "batch_granularity": 4, + "tkv_granularity": 32768 + }, + { + "max_batch": 8, + "max_tkv": 16384, + "batch_granularity": 8, + "tkv_granularity": 16384 + }, + { + "max_batch": 16, + "max_tkv": 8192, + "batch_granularity": 16, + "tkv_granularity": 8192 + }, + { + "max_batch": 32, + "max_tkv": 4096, + "batch_granularity": 32, + "tkv_granularity": 4096 + }, + { + "max_batch": 2, + "max_tkv": 32768, + "batch_granularity": 2, + "tkv_granularity": 32768 + }, + { + "max_batch": 4, + "max_tkv": 16384, + "batch_granularity": 4, + "tkv_granularity": 16384 + }, + { + "max_batch": 8, + "max_tkv": 8192, + "batch_granularity": 8, + "tkv_granularity": 8192 + }, + { + "max_batch": 16, + "max_tkv": 4096, + "batch_granularity": 16, + "tkv_granularity": 4096 + }, + { + "max_batch": 32, + "max_tkv": 2048, + "batch_granularity": 32, + "tkv_granularity": 2048 + }, + { + "max_batch": 1, + "max_tkv": 32768, + "batch_granularity": 1, + "tkv_granularity": 32768 + }, + { + "max_batch": 2, + "max_tkv": 16384, + "batch_granularity": 2, + "tkv_granularity": 16384 + }, + { + "max_batch": 4, + "max_tkv": 8192, + "batch_granularity": 4, + "tkv_granularity": 8192 + }, + { + "max_batch": 8, + "max_tkv": 4096, + "batch_granularity": 8, + "tkv_granularity": 4096 + }, + { + "max_batch": 16, + "max_tkv": 2048, + "batch_granularity": 16, + "tkv_granularity": 2048 + }, + { + "max_batch": 1, + "max_tkv": 16384, + "batch_granularity": 1, + "tkv_granularity": 16384 + }, + { + "max_batch": 2, + "max_tkv": 8192, + "batch_granularity": 2, + "tkv_granularity": 8192 + }, + { + "max_batch": 4, + "max_tkv": 4096, + "batch_granularity": 4, + "tkv_granularity": 4096 + }, + { + "max_batch": 8, + "max_tkv": 2048, + "batch_granularity": 8, + "tkv_granularity": 2048 + }, + { + "max_batch": 1, + "max_tkv": 8192, + "batch_granularity": 1, + "tkv_granularity": 8192 + }, + { + "max_batch": 2, + "max_tkv": 4096, + "batch_granularity": 2, + "tkv_granularity": 4096 + }, + { + "max_batch": 4, + "max_tkv": 2048, + "batch_granularity": 4, + "tkv_granularity": 2048 + }, + { + "max_batch": 1, + "max_tkv": 4096, + "batch_granularity": 1, + "tkv_granularity": 4096 + }, + { + "max_batch": 2, + "max_tkv": 2048, + "batch_granularity": 2, + "tkv_granularity": 2048 + }, + { + "max_batch": 6, + "max_tkv": 32768, + "batch_granularity": 6, + "tkv_granularity": 32768 + }, + { + "max_batch": 12, + "max_tkv": 16384, + "batch_granularity": 12, + "tkv_granularity": 16384 + }, + { + "max_batch": 24, + "max_tkv": 8192, + "batch_granularity": 24, + "tkv_granularity": 8192 + }, + { + "max_batch": 32, + "max_tkv": 6144, + "batch_granularity": 32, + "tkv_granularity": 6144 + }, + { + "max_batch": 16, + "max_tkv": 12288, + "batch_granularity": 16, + "tkv_granularity": 12288 + }, + { + "max_batch": 8, + "max_tkv": 24576, + "batch_granularity": 8, + "tkv_granularity": 24576 + }, + { + "max_batch": 6, + "max_tkv": 16384, + "batch_granularity": 6, + "tkv_granularity": 16384 + }, + { + "max_batch": 12, + "max_tkv": 8192, + "batch_granularity": 12, + "tkv_granularity": 8192 + }, + { + "max_batch": 24, + "max_tkv": 4096, + "batch_granularity": 24, + "tkv_granularity": 4096 + }, + { + "max_batch": 32, + "max_tkv": 3072, + "batch_granularity": 32, + "tkv_granularity": 3072 + }, + { + "max_batch": 16, + "max_tkv": 6144, + "batch_granularity": 16, + "tkv_granularity": 6144 + }, + { + "max_batch": 8, + "max_tkv": 12288, + "batch_granularity": 8, + "tkv_granularity": 12288 + }, + { + "max_batch": 4, + "max_tkv": 24576, + "batch_granularity": 4, + "tkv_granularity": 24576 + } + ] +} diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 4f95e61e..71a75748 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -30,6 +30,7 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup import os +from aiu_fms_testing_utils.testing.dpp.program_models import AttnType try: from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa: F401 @@ -600,7 +601,7 @@ def _get_device_validation_information( # overrides for validation info that are device specific device_dependent_kwargs = {} if device == "cpu": - device_dependent_kwargs["attn_algorithm"] = "math" + device_dependent_kwargs["attn_algorithm"] = AttnType.MATH if device == "aiu": device_dependent_kwargs["last_n_tokens"] = 64 if "paged" in ATTN_NAME else 1 @@ -630,7 +631,7 @@ def _get_device_validation_information( seq_length, max_new_tokens, token_iter, - ATTN_NAME, + AttnType(ATTN_NAME), device_type=device, **kwargs, ) diff --git a/tests/testing/test_drive_paged_programs.py b/tests/testing/test_drive_paged_programs.py new file mode 100644 index 00000000..fbdff748 --- /dev/null +++ b/tests/testing/test_drive_paged_programs.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path + +import pytest + +from aiu_fms_testing_utils.testing.dpp.program_models import ( + DatasetType, +) +from aiu_fms_testing_utils.testing.dpp.run_drive_paged_programs import run_dpp +from aiu_fms_testing_utils.utils.aiu_setup import r0dprint, world_size + + +@pytest.fixture(scope="module") +def dpp_criterion_json_path(): + test_path = Path(__file__).parent.parent / "fixtures" / "dpp-all-criterion.json" + return str(test_path) + + +def setup_environment(): + """Sets up the testing environment for driving paged programs.""" + + r0dprint( + f"Setting up environment for driving paged programs for world size {world_size}..." + ) + + if world_size == 4: + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = os.environ.get( + "VLLM_DT_MAX_BATCH_TKV_LIMIT", "131072" + ) + os.environ["VLLM_DT_MAX_BATCH_SIZE"] = os.environ.get( + "VLLM_DT_MAX_BATCH_SIZE", "32" + ) + os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = os.environ.get( + "VLLM_DT_MAX_CONTEXT_LEN", "32768" + ) + os.environ["VLLM_DT_CHUNK_LEN"] = os.environ.get("VLLM_DT_CHUNK_LEN", "1024") + elif world_size == 1: + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = os.environ.get( + "VLLM_DT_MAX_BATCH_TKV_LIMIT", "131072" + ) + os.environ["VLLM_DT_MAX_BATCH_SIZE"] = os.environ.get( + "VLLM_DT_MAX_BATCH_SIZE", "16" + ) + os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = os.environ.get( + "VLLM_DT_MAX_CONTEXT_LEN", "3072" + ) + os.environ["VLLM_DT_CHUNK_LEN"] = os.environ.get("VLLM_DT_CHUNK_LEN", "1024") + else: + r0dprint( + f"Non-default world size {world_size} detected. Unable to assume default environment setup" + ) + + r0dprint("Batch TKV Limit:", os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"]) + r0dprint("Max Batch Size:", os.environ["VLLM_DT_MAX_BATCH_SIZE"]) + r0dprint("Max Context Length:", os.environ["VLLM_DT_MAX_CONTEXT_LEN"]) + r0dprint("Chunk Length:", os.environ["VLLM_DT_CHUNK_LEN"]) + + +@pytest.mark.dpp +def test_drive_paged_programs(dpp_criterion_json_path: str): + """Test driving paged programs with specified configurations.""" + + setup_environment() + + programs = ["2:0,<8192"] + max_new_tokens = 32 + model_variant = "ibm-granite/granite-3.3-8b-instruct" + validation_info_outputs_dir = os.getenv( + "VALIDATION_INFO_OUTPUTS_DIR", "/home/senuser/models/validation_info" + ) + dataset_type = DatasetType.SHAREGPT + cross_entropy_threshold = 2.6 + failure_rate_threshold = 0.1 + + r0dprint(f"Loading criteria from path: {dpp_criterion_json_path}") + + run_dpp( + program_criteria_json_path=dpp_criterion_json_path, + dataset_type=dataset_type, + max_new_tokens=max_new_tokens, + model_variant=model_variant, + programs=programs, + cross_entropy_threshold=cross_entropy_threshold, + failure_rate_threshold=failure_rate_threshold, + prefill_chunk_size=1024, + run_cpu_validation=True, + prioritize_large_batch_sizes=True, + enforce_homogeneous_prompt_programs=True, + validation_info_outputs_dir=validation_info_outputs_dir, + ) diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index dc832055..40b5c98d 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -15,8 +15,9 @@ from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string from aiu_fms_testing_utils.utils import sample_sharegpt_requests from transformers import AutoTokenizer - +from aiu_fms_testing_utils.testing.dpp.program_models import AttnType from aiu_fms_testing_utils._version import version_tuple + from fms.models import get_model from fms.utils.generation import pad_input_ids from pathlib import Path @@ -57,7 +58,7 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): input_ids, max_new_tokens, post_iteration_hook, - attn_algorithm="math", + attn_algorithm=AttnType.MATH, **padding_kwargs, ) @@ -86,7 +87,13 @@ def test_get_validation_info_path(tmp_path): assert ( get_validation_info_path( - tmp_path, "ibm-granite/granite-3.3-8b-instruct", 4, 64, 128, 0, "sdpa" + tmp_path, + "ibm-granite/granite-3.3-8b-instruct", + 4, + 64, + 128, + 0, + AttnType.SDPA, ) == f"{tmp_path}/{hex_digest}_{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" ) @@ -103,7 +110,7 @@ def test_get_validation_info_path(tmp_path): 64, 128, 0, - "sdpa", + AttnType.SDPA, aftu_version=(1, 2, 3), ) == f"{tmp_path}/{hex_digest}_1.2.3.cpu_validation_info.0.out" @@ -187,7 +194,7 @@ def test_find_validation_info_path( 64, 128, 0, - "sdpa", + AttnType.SDPA, (10, 10, 10), ) ) @@ -202,7 +209,7 @@ def test_find_validation_info_path( 64, 128, 0, - "sdpa", + AttnType.SDPA, save_version, ) ) @@ -216,7 +223,7 @@ def test_find_validation_info_path( 64, 128, 0, - "sdpa", + AttnType.SDPA, current_version, version_allow_decrement=version_allow_decrement, ) @@ -225,6 +232,9 @@ def test_find_validation_info_path( assert found_path is None else: match = re.search(r"(\d+)\.(\d+)\.(\d+)", found_path) + assert match is not None, ( + f"Expected to find a version in the found path: {found_path}" + ) found_version = (int(match.group(1)), int(match.group(2)), int(match.group(3))) assert found_version == expected_version