From 211c6b3413c52b2fd7a671eda3ce8b2139a4fedd Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 14 May 2026 13:08:04 -0700 Subject: [PATCH 1/7] support NVFP4 MSE and mixed precision in mcore; super nvfp4 recipe Signed-off-by: Jennifer Chen --- CHANGELOG.rst | 6 + .../specdec_bench/datasets/speed.py | 6 +- .../export/plugins/hf_checkpoint_utils.py | 45 ++- .../torch/export/plugins/mcore_nemotron.py | 5 +- modelopt/torch/export/quant_utils.py | 20 +- .../torch/export/unified_export_megatron.py | 99 +++++- modelopt/torch/quantization/config.py | 10 +- modelopt/torch/quantization/conversion.py | 15 +- modelopt/torch/quantization/model_calib.py | 124 +++++-- .../nn/modules/tensor_quantizer.py | 12 +- modelopt/torch/quantization/plugins/custom.py | 24 +- .../torch/quantization/plugins/megatron.py | 42 ++- .../quantization/qtensor/nvfp4_tensor.py | 30 +- .../super-nvfp4-max-calib.yaml | 131 +++++++ .../super-nvfp4.yaml | 134 ++++++++ .../test_nvfp4_static_quantizer_cuda.py | 17 + .../export/test_unified_export_megatron.py | 59 +++- .../torch/export/test_hf_checkpoint_utils.py | 57 +++- .../test_nvfp4_static_export_cpu.py | 320 ++++++++++++++++++ 19 files changed, 1045 insertions(+), 111 deletions(-) create mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml create mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml create mode 100644 tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 62f2b0041cb..d213757cb6a 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,12 @@ Changelog - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. +- Support Megatron-Core checkpoint restore and export for MSE ``NVFP4StaticQuantizer``. +- Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. +- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. + +**Bug Fixes** +- In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False 0.44 (2026-05-18) ^^^^^^^^^^^^^^^^^ diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py index 3552d71a1ad..651ee6b1e73 100644 --- a/examples/specdec_bench/specdec_bench/datasets/speed.py +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -730,11 +730,7 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data # Strip HF metadata from the schema to avoid Feature parsing errors schema = table.schema if schema.metadata and b"huggingface" in schema.metadata: - new_meta = { - k: v - for k, v in schema.metadata.items() - if k != b"huggingface" - } + new_meta = {k: v for k, v in schema.metadata.items() if k != b"huggingface"} table = table.replace_schema_metadata(new_meta or None) dataset = HFDataset(table) if self.num_samples is not None and self.num_samples < len(dataset): diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py index 4d9bc6fc292..63ac3393486 100644 --- a/modelopt/torch/export/plugins/hf_checkpoint_utils.py +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -22,9 +22,21 @@ import torch from huggingface_hub import snapshot_download +from huggingface_hub.errors import LocalEntryNotFoundError from safetensors.torch import safe_open from tqdm import tqdm +_HF_HUB_OFFLINE_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + + +def _is_hf_hub_offline() -> bool: + return os.environ.get("HF_HUB_OFFLINE", "").strip().upper() in _HF_HUB_OFFLINE_TRUE_VALUES + + +def _copy_python_files(source_dir: Path, save_dir: Path) -> None: + for py_file in source_dir.glob("*.py"): + shutil.copy2(py_file, save_dir / py_file.name) + def copy_hf_ckpt_remote_code( pretrained_model_path: str | os.PathLike, save_directory: str | os.PathLike @@ -36,7 +48,10 @@ def copy_hf_ckpt_remote_code( frameworks. If ``pretrained_model_path`` is a local directory, Python files are copied directly. - If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), files are downloaded from the Hub. + If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), the Hub + snapshot is resolved first and Python files are copied from that snapshot. When + ``HF_HUB_OFFLINE`` is set, the snapshot must already be available in the local + Hugging Face cache. Args: pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID. @@ -47,14 +62,28 @@ def copy_hf_ckpt_remote_code( save_dir.mkdir(parents=True, exist_ok=True) if hf_checkpoint_path.is_dir(): - for py_file in hf_checkpoint_path.glob("*.py"): - shutil.copy2(py_file, save_dir / py_file.name) + _copy_python_files(hf_checkpoint_path, save_dir) else: - snapshot_download( - repo_id=str(pretrained_model_path), - local_dir=str(save_dir), - allow_patterns=["*.py"], - ) + local_files_only = _is_hf_hub_offline() + try: + source_dir = Path( + snapshot_download( + repo_id=str(pretrained_model_path), + allow_patterns=["*.py"], + local_files_only=local_files_only, + ) + ) + except LocalEntryNotFoundError as exc: + if local_files_only: + raise RuntimeError( + f"Could not copy Python sidecar files for {pretrained_model_path!r} because " + "HF_HUB_OFFLINE is enabled and the files are not available in the local " + "Hugging Face cache. Populate the cache with the model's *.py files or pass " + "a local pretrained model directory." + ) from exc + raise + + _copy_python_files(source_dir, save_dir) def load_multimodal_components( diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index e9d1b4e1e9b..24bd8144055 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -131,7 +131,10 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), - "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), + "core_attention": SelfAttentionScaling( + "backbone.layers.{}.mixer.", + func_kwargs={"k_scale_name": "k_proj.k_scale", "v_scale_name": "v_proj.v_scale"}, + ), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 76f304a478a..3e488c821f8 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -287,9 +287,25 @@ def _ensure_weight_quantizer_calibrated( module_name: Optional module name for better warning messages """ if isinstance(weight_quantizer, NVFP4StaticQuantizer): - need_per_block = not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None + + def _amax_is_invalid(t: torch.Tensor | None) -> bool: + # MCore distcp may register but not fill amax — treat missing/non-finite/negative as recompute. + if t is None: + return True + t = t.detach() + if not torch.is_floating_point(t): + return False + return bool(torch.any(~torch.isfinite(t)).item() or torch.any(t < 0).item()) + + need_per_block = ( + not hasattr(weight_quantizer, "_amax") + or weight_quantizer._amax is None + or _amax_is_invalid(weight_quantizer._amax) + ) need_global = ( - not hasattr(weight_quantizer, "_global_amax") or weight_quantizer.global_amax is None + not hasattr(weight_quantizer, "_global_amax") + or weight_quantizer.global_amax is None + or _amax_is_invalid(weight_quantizer.global_amax) ) if not (need_per_block or need_global): return diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 23b8cfd1630..f7c227055d0 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -61,6 +61,7 @@ get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, + process_layer_quant_config, to_quantized_weight, ) @@ -169,6 +170,7 @@ def __init__( self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] self.exclude_modules = [] + self.layer_config_dict = {} if not hasattr(model, "_modelopt_state"): return @@ -324,22 +326,32 @@ def save_pretrained( print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") combined_exclude_modules = self._gather_exclude_modules() + combined_layer_config_dict = self._gather_layer_config_dict() + # kv_cache_dtype is only set on attention-owning ranks; writer rank may not be one. + gathered_kv_cache_dtype = self._gather_kv_cache_dtype() if is_last_stage_main_rank and quantization is not None: + if combined_layer_config_dict: + quantization_config = process_layer_quant_config(combined_layer_config_dict) + quantization_config["exclude_modules"] = combined_exclude_modules + else: + quantization_config = { + "quant_algo": quantization, + "exclude_modules": combined_exclude_modules, + } + if quantization == "NVFP4": # update block size + quantization_config["group_size"] = 16 + + if gathered_kv_cache_dtype is not None: + quantization_config["kv_cache_quant_algo"] = gathered_kv_cache_dtype + self._hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, - "quantization": { - "quant_algo": quantization, - "exclude_modules": combined_exclude_modules, - }, + "quantization": quantization_config, } - if quantization == "NVFP4": # update block size - self._hf_quant_config["quantization"]["group_size"] = 16 - if hasattr(self, "kv_cache_dtype"): - self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) @@ -359,10 +371,11 @@ def save_pretrained( # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" if self._hf_quant_config and os.path.exists(config_json_file): - converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) with open(config_json_file) as f: config_dict = json.load(f) - config_dict["quantization_config"] = converted_quant_config + config_dict["quantization_config"] = convert_hf_quant_config_format( + self._hf_quant_config + ) with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) @@ -803,9 +816,7 @@ def _get_quantized_state( name_to_value = {} qformat: str = self._get_quantization_format(module) if qformat is None and "norm" not in prefix: - # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty - # string then it usually ends with "." which needs to be removed. - self.exclude_modules.append(prefix.removesuffix(".")) + self._record_excluded_module(prefix) block_size = get_weight_block_size(module) name_to_value = self._get_weight_bias(module, dtype, name_to_value) @@ -850,6 +861,27 @@ def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str): return weight_scale, weight_scale_2 + def _record_layer_quant_config(self, prefix: str, qformat: str | None, block_size: int): + """Record per-HF-layer quantization metadata for mixed precision exports.""" + if qformat in (None, QUANTIZATION_NONE): + return + + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + self.layer_config_dict[layer_name + ".quantization"] = qformat + self.layer_config_dict[layer_name + ".awq_block_size"] = block_size + + def _record_excluded_module(self, prefix: str): + """Record an unquantized HF module prefix for hf_quant_config.""" + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + if layer_name not in self.exclude_modules: + self.exclude_modules.append(layer_name) + def _name_remapping( self, module: torch.nn.Module | torch.Tensor, @@ -866,6 +898,7 @@ def _name_remapping( return name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) + self._record_layer_quant_config(prefix, qformat, block_size) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -906,6 +939,8 @@ def _gated_mlp_slicing( gate_proj_prefix = prefix + gate_proj_name + "." up_proj_prefix = prefix + up_proj_name + "." + self._record_layer_quant_config(gate_proj_prefix, qformat, block_size) + self._record_layer_quant_config(up_proj_prefix, qformat, block_size) ffn_hidden_size = module.config.ffn_hidden_size gate_proj_weight = weight[:ffn_hidden_size, :] @@ -986,6 +1021,7 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): for expert_id in range(num_experts): expert_prefix = prefix.format(expert_id) + "." + self._record_layer_quant_config(expert_prefix, qformat, block_size) weight_key = f"weight{expert_id}" if weight_key not in state_dict: @@ -1030,6 +1066,16 @@ def _qkv_slicing( q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." v_proj_prefix = prefix + v_proj_name + "." + self._record_layer_quant_config(q_proj_prefix, qformat, block_size) + self._record_layer_quant_config(k_proj_prefix, qformat, block_size) + self._record_layer_quant_config(v_proj_prefix, qformat, block_size) + if qformat in (None, QUANTIZATION_NONE): + # Split fused linear_qkv exclude into per-HF-name q/k/v_proj entries. + fused_prefix = prefix.removesuffix(".") + self.exclude_modules = [m for m in self.exclude_modules if m != fused_prefix] + self._record_excluded_module(q_proj_prefix) + self._record_excluded_module(k_proj_prefix) + self._record_excluded_module(v_proj_prefix) config = module.config hidden_size = config.hidden_size @@ -1179,6 +1225,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): weight_scale_list.append(weight_scale) weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1247,6 +1294,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) bias_list.append(bias) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1349,6 +1397,31 @@ def _gather_exclude_modules(self): combined_exclude_modules.update(modules) return sorted(combined_exclude_modules) + def _gather_layer_config_dict(self): + """Get per-layer quantization metadata from all ranks for hf_quant_config.""" + if not torch.distributed.is_initialized(): + return dict(sorted(self.layer_config_dict.items())) + + all_layer_config_dicts = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_layer_config_dicts, self.layer_config_dict) + combined_layer_config_dict = {} + for layer_config_dict in all_layer_config_dicts: + if layer_config_dict: + combined_layer_config_dict.update(layer_config_dict) + return dict(sorted(combined_layer_config_dict.items())) + + def _gather_kv_cache_dtype(self): + """Return first non-None kv_cache_dtype across ranks (only attention ranks set it).""" + local = getattr(self, "kv_cache_dtype", None) + if not torch.distributed.is_initialized(): + return local + all_dtypes = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_dtypes, local) + for dt in all_dtypes: + if dt is not None: + return dt + return None + def export_mcore_gpt_to_hf( model: torch.nn.Module, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index dfed54cc991..c4c20139052 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -636,12 +636,12 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): sync_expert_weight_amax: bool = ModeloptField( default=False, - title="Sync weight quantizer amax across MoE experts", + title="Share one weight amax across local experts in a SequentialMLP MoE layer.", description=( - "If True, the weight quantizer amax values are synchronized (max) across " - "local experts in SequentialMLP layers during calibration. This matches " - "TEGroupedMLP behavior where all experts share a single weight quantizer. " - "Only affects MoE models with SequentialMLP experts." + "If True, max-calibration synchronizes the weight quantizer amax across local " + "experts within each SequentialMLP layer, so all experts in that layer share " + "one effective weight amax. TEGroupedMLP already fuses experts into a single " + "GEMM with one weight quantizer, so this flag is irrelevant there." ), ) diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3f97f8380be..190138f0971 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -98,6 +98,13 @@ def restore_quantized_model( return restore_quantizer_state(model, config, metadata) +def maybe_promote_nvfp4_static_quantizer(module: nn.Module, quantizer_state: dict) -> None: + if quantizer_state.get("_is_nvfp4_static_quantizer") and not isinstance( + module, NVFP4StaticQuantizer + ): + NVFP4StaticQuantizer.from_tensor_quantizer(module) + + def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: MetadataDict): """Restore the quantizer states from the given state dict. @@ -131,12 +138,8 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: if isinstance(module, TensorQuantizer): name = get_unwrapped_name(name, model) state = quantizer_state_dict[name] - # TODO: Add a registry for TensorQuantizers and avoid this manual conversion. - if state.get("_is_nvfp4_static_quantizer") and not isinstance( - module, NVFP4StaticQuantizer - ): - NVFP4StaticQuantizer.from_tensor_quantizer(module) - module.set_from_modelopt_state(quantizer_state_dict[name]) + maybe_promote_nvfp4_static_quantizer(module, state) + module.set_from_modelopt_state(state) for name, module in model.named_modules(): if isinstance(module, QuantModule): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index bce49786077..36bac966291 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -55,6 +55,14 @@ ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper +try: + from .plugins.megatron import _check_static_block_tp_supported +except ImportError: + + def _check_static_block_tp_supported(model: nn.Module) -> None: # no-op without megatron + return + + __all__ = [ "CalibratorFactory", "awq", @@ -99,7 +107,7 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: @torch.no_grad() -def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: +def _bootstrap_uncalibrated_static_weight_quantizers(model: nn.Module) -> int: """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE @@ -210,26 +218,54 @@ def _has_expert_parallelism(module: nn.Module) -> bool: return ps is not None and ps.expert_model_parallel_group.is_initialized() -def _check_moe_calibration_complete(quantizer, parallel_state): - """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" +def _iter_leaf_quantizers(quantizer): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - _check_moe_calibration_complete(_q, parallel_state) + yield from _iter_leaf_quantizers(_q) return - for group in [ - parallel_state.data_parallel_group, - parallel_state.expert_model_parallel_group, - parallel_state.tensor_parallel_group, - ]: - if not group.is_initialized(): + yield quantizer + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete across distributed MoE ranks.""" + for leaf_quantizer in _iter_leaf_quantizers(quantizer): + if leaf_quantizer.is_block_quant: continue - has_amax = getattr(quantizer, "_amax", None) is not None - amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) - if any(amax_states) and not all(amax_states): - raise RuntimeError( - "MoE calibration incomplete: some experts received no tokens during calibration. " - "Increase --calib-size to ensure all experts see calibration data." + + has_amax = getattr(leaf_quantizer, "_amax", None) is not None + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: + if not group.is_initialized(): + continue + amax_states = DistributedProcessGroup.get_dist_syncd_obj( + has_amax, group, lambda objs: objs ) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during " + "calibration. Increase --calib-size to ensure all experts see calibration " + "data." + ) + + +def _is_routed_expert(parent_name: str) -> bool: + """Routed-expert FQN contains ``experts`` but not ``shared_experts`` (covers SequentialMLP and TEGroupedMLP).""" + return "experts" in parent_name and "shared_experts" not in parent_name + + +def _should_sync_amax_across_ep( + parent_name: str, child_name: str, sync_expert_weight_amax: bool +) -> bool: + """Skip EP sync for routed-expert weights (per-rank shards differ). + + SequentialMLP opts in via sync_expert_weight_amax. + """ + if "weight_quantizer" in child_name and _is_routed_expert(parent_name): + return sync_expert_weight_amax + return True @torch.no_grad() @@ -246,7 +282,8 @@ def max_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync input_quantizer amax across distributed processes. - sync_expert_weight_amax: Whether to sync weight quantizer amax across MoE experts. + sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts + in a MoE layer (within-rank sync + EP all-reduce when EP>1). See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -263,13 +300,14 @@ def max_calibrate( if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) - # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer - # so the static blockwise fake-quant path is used in forward and the export - # picks up the two-level (per-block + global) scaling. Run before the - # ``distributed_sync`` early return so single-process callers also get the - # promotion. ``promote_nvfp4_static_quantizers`` only promotes when - # ``is_static_block_quant`` is True and the per-block ``_amax`` buffer is - # populated, so it's a no-op for dynamic-block / non-NVFP4 configs. + # Fail fast on static-block under TP>1 (sharded_state_dict treats _amax as replicated). + _check_static_block_tp_supported(model) + + # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so + # the static blockwise fake-quant path is used in forward and export picks up the + # two-level (per-block + global) scaling. Run before the ``distributed_sync`` early + # return so single-process callers also get the promotion. No-op for dynamic-block + # / non-NVFP4 configs. promote_nvfp4_static_quantizers(model) if not distributed_sync: @@ -282,23 +320,24 @@ def max_calibrate( if isinstance(child, (TensorQuantizer, SequentialQuantizer)): _check_moe_calibration_complete(child, module.parallel_state) - def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, parent_name, child_name): + """Sync amax across DP (always) and EP (filtered — see _should_sync_amax_across_ep).""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_ep(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state, parent_name, child_name) + return + if getattr(quantizer, "_amax", None) is None: return - if getattr(quantizer, "_amax", None) is not None: - quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + if _should_sync_amax_across_ep(parent_name, child_name, sync_expert_weight_amax): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) - # TODO: create sync_bias_across_distributed_group # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): - for child in module.children(): + for child_name, child in module.named_children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp_ep(child, module.parallel_state) + sync_quantizer_amax_across_dp_ep(child, module.parallel_state, name, child_name) # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -321,7 +360,6 @@ def sync_quantizer_amax_across_tp( # Syncing amax across TP for sequential quantizer if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - # Syncing amax across TP for sequential quantizer sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) @@ -449,13 +487,23 @@ def mse_calibrate( # Step 1: max calibrate, bootstrap dead-expert weight quantizers, # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. max_calibrate(model, forward_loop, distributed_sync) - _bootstrap_uncalibrated_weight_quantizers(model) + _bootstrap_uncalibrated_static_weight_quantizers(model) _sync_grouped_weight_global_amax(model) # Step 2: replace calibrators with MseCalibrator for enabled quantizers. + skipped_non_nvfp4: dict[str, int] = {} for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + bs = getattr(module, "block_sizes", None) + if not ( + getattr(module, "_num_bits", None) == (2, 1) + and bs is not None + and bs.get("scale_bits") == (4, 3) + ): + fmt = f"num_bits={module._num_bits} block_sizes={bs}" + skipped_non_nvfp4[fmt] = skipped_non_nvfp4.get(fmt, 0) + 1 + continue initial_amax = module._amax.clone().detach() is_nvfp4_static = module.is_nvfp4_static @@ -503,6 +551,14 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) + if skipped_non_nvfp4: + formats = ", ".join(f"{n}x [{fmt}]" for fmt, n in skipped_non_nvfp4.items()) + warnings.warn( + f"MSE calibration only meaningful for NVFP4; skipped {sum(skipped_non_nvfp4.values())} " + f"non-NVFP4 quantizer(s) — keeping max-calibrated amax: {formats}", + stacklevel=2, + ) + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. name_to_module = dict(model.named_modules()) seen_modules: set[int] = set() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index fa540b8fdf5..aadb5ccfdc7 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -547,15 +547,25 @@ def is_mxfp(self, bits): else: raise NotImplementedError() + @property + def is_block_quant(self): + """Check if is block quantization (static or dynamic).""" + return self.block_sizes is not None + @property def is_static_block_quant(self): """Check if is static block quantization.""" return ( - self.block_sizes is not None + self.is_block_quant and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ) + @property + def is_dynamic_block_quant(self): + """Check if is dynamic block quantization.""" + return self.is_block_quant and self.block_sizes.get("type", None) == "dynamic" + @property def rotate_is_enabled(self): """Check if rotate is enabled in quant config.""" diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4200aadc73a..f480d245daa 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -24,7 +24,7 @@ from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, SequentialQuantizer, TensorQuantizer +from ..nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear from ..utils import multi_context, replace_function @@ -126,7 +126,7 @@ def modelopt_post_restore(self, prefix: str = ""): def _check_unsupported_states(quantizer: TensorQuantizer): for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: + if k not in ["_amax", "_pre_quant_scale", "_global_amax"]: warnings.warn( f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " f"incorrect. Please implement a custom restore for {k}." @@ -137,6 +137,21 @@ def _has_state(quantizer, name): quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer return hasattr(quantizer, name) + def _has_complete_static_nvfp4_weight_state(quantizer, weight): + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer + if not isinstance(quantizer, NVFP4StaticQuantizer): + return False + amax = getattr(quantizer, "_amax", None) + global_amax = getattr(quantizer, "global_amax", None) + if amax is None or global_amax is None: + return False + block_sizes = getattr(quantizer, "block_sizes", None) + block_size = block_sizes.get(-1) if isinstance(block_sizes, dict) else None + if block_size is None or weight.shape[-1] % block_size != 0: + return False + expected_blocks = weight.numel() // block_size + return amax.numel() == expected_blocks and global_amax.numel() == 1 + if self.weight is None: return @@ -144,7 +159,10 @@ def _has_state(quantizer, name): _check_unsupported_states( quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] ) - if _has_state(self.weight_quantizer, "_amax"): + # Skip max_calibrate when saved static NVFP4 state is intact; else MSE scales get overwritten. + if _has_state( + self.weight_quantizer, "_amax" + ) and not _has_complete_static_nvfp4_weight_state(self.weight_quantizer, self.weight): self.weight_quantizer.reset_amax() max_calibrate(self.weight_quantizer, lambda wq: wq(self.weight), distributed_sync=False) if _has_state(self.input_quantizer, "_pre_quant_scale"): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 0b50fd937ae..33f11a4491b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -40,7 +40,8 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer +from ..conversion import maybe_promote_nvfp4_static_quantizer +from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from ..utils import sync_moe_expert_amax @@ -68,6 +69,35 @@ __all__ = [] +def _check_static_block_tp_supported(model: torch.nn.Module) -> None: + """Raise under TP>1: static-block _amax is shard-local but sharded_state_dict treats it as replicated.""" + offending = [] + for name, module in model.named_modules(): + if not isinstance(module, QuantModule): + continue + parallel_state = getattr(module, "parallel_state", None) + if parallel_state is None: + continue + tp_group = getattr(parallel_state, "tensor_parallel_group", None) + if tp_group is None or not tp_group.is_initialized() or tp_group.world_size() <= 1: + continue + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is None: + continue + leaves = ( + list(weight_quantizer) + if isinstance(weight_quantizer, SequentialQuantizer) + else [weight_quantizer] + ) + if any(leaf.is_static_block_quant for leaf in leaves): + offending.append((name, tp_group.world_size())) + if offending: + raise NotImplementedError( + "Static-block NVFP4 weight quantization (e.g. MSE) is not supported with TP > 1. Please re-run with TP=1. " + f"Offending modules (showing first 5 of {len(offending)}): {offending[:5]}" + ) + + def real_quant_module_get_extra_state(self) -> dict: """Populating real_quantizer_state and q_tensor_state.""" extra_state = {} @@ -190,7 +220,9 @@ def quant_module_set_extra_state(self, state: Any): if quantizer_state is not None: for name, module in self.named_modules(): if isinstance(module, TensorQuantizer): - module.set_from_modelopt_state(quantizer_state[name], properties_only=False) + quantizer_substate = quantizer_state[name] + maybe_promote_nvfp4_static_quantizer(module, quantizer_substate) + module.set_from_modelopt_state(quantizer_substate, properties_only=False) self.modelopt_post_restore() # Handle real_quantizer_state and q_tensor_state @@ -399,6 +431,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: + # Static NVFP4 _global_amax is a replicated scalar; only per-block _amax shards. + if k.endswith("_global_amax"): + continue if "weight_quantizer." in k: weight_quantizer_axis = self.get_submodule(k.rsplit(".", 1)[0]).axis if weight_quantizer_axis is not None: @@ -427,6 +462,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: + # Static NVFP4 _global_amax is a replicated scalar; only per-block _amax shards. + if k.endswith("_global_amax"): + continue if "weight_quantizer." in k: weight_quantizer_axis = None if isinstance(self.weight_quantizer, TensorQuantizer): diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index bb39c8a81e3..15dbd8e2c1a 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -55,7 +55,16 @@ def get_e2m1_bounds(cls, device): @classmethod def _is_static_quantizer(cls, weight_quantizer) -> bool: """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax.""" - return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None + global_amax = cls._get_static_global_amax(weight_quantizer) + return global_amax is not None + + @classmethod + def _get_static_global_amax(cls, weight_quantizer): + """Return global amax from live or restored static NVFP4 quantizers.""" + global_amax = getattr(weight_quantizer, "global_amax", None) + if global_amax is None: + global_amax = getattr(weight_quantizer, "_global_amax", None) + return global_amax @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): @@ -70,8 +79,9 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): Returns: The global scaling factor as a float tensor. """ - if cls._is_static_quantizer(weight_quantizer): - return weight_quantizer.global_amax.float() / (6.0 * 448.0) + global_amax = cls._get_static_global_amax(weight_quantizer) + if global_amax is not None: + return global_amax.float() / (6.0 * 448.0) else: assert hasattr(weight_quantizer, "_amax"), ( "Weight quantizer does not have attribute amax" @@ -109,7 +119,7 @@ def get_weights_scaling_factor_from_quantizer( if cls._is_static_quantizer(weight_quantizer): # Static path: use pre-computed per-block amax values from quantizer - global_amax = weight_quantizer.global_amax.float() + global_amax = cls._get_static_global_amax(weight_quantizer).float() per_block_amax = weight_quantizer._amax.float() # Compute scales in float @@ -122,15 +132,11 @@ def get_weights_scaling_factor_from_quantizer( expected_shape = (*weight.shape[:-1], num_blocks_per_row) per_block_scale = per_block_scale.view(expected_shape) - # Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the - # cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero - # for an all-zero weight block) and global_amax is small, the pre-cast value - # explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any - # value >= 480 casts to NaN — clamp first to keep the stored byte finite. + # Clamp to fp8_e4m3fn range: upper avoids NaN cast, lower avoids 0x00 underflow. if not keep_high_precision: per_block_scale = ( (per_block_scale * 448.0 / per_block_scale_max) - .clamp_(max=448.0) + .clamp_(min=2**-9, max=448.0) .to(torch.float8_e4m3fn) ) return per_block_scale, weights_scaling_factor_2 @@ -171,9 +177,9 @@ def get_weights_scaling_factor( ) # Set all zero values in scale to 1.0 per_block_scale[per_block_scale == 0] = 1.0 - # Convert to torch.float8_e4m3fn + # Clamp at fp8_e4m3fn subnormal min so tiny-amax blocks don't underflow to 0. if not keep_high_precision: - per_block_scale = per_block_scale.to(torch.float8_e4m3fn) + per_block_scale = per_block_scale.clamp_(min=2**-9).to(torch.float8_e4m3fn) return per_block_scale, weights_scaling_factor_2 @classmethod diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml new file mode 100644 index 00000000000..df2b30b8ed8 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# but with ONE major difference: use max calibration instead of MSE +# - MoE routed experts: NVFP4 W4A4 weight, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: amax/max calibration comparison variant +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj FP8 per-tensor; FP8 KV cache; + everything else(lm_head/MTP/Latent MOE) stay BF16. Amax calibration comparison variant. +quantize: + algorithm: + method: max + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml new file mode 100644 index 00000000000..729dcd12d52 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: weight MSE with FP8-scale sweep over the 128 e4m3 scale values +# (NVFP4 weights use static block scales selected by MSE; FP8 per-tensor scales +# are also chosen via MSE search instead of plain amax). +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj FP8 per-tensor; FP8 KV cache; + everything else(lm_head/MTP/latent MOE) stay BF16. Weight-MSE calibration with FP8 scale sweep. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index 2b5caea16f0..773a2a47405 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -138,6 +138,23 @@ def test_fake_quantize_with_both_amaxs(self, device): assert torch.allclose(output, expected) + def test_static_export_clamps_overflowing_fp8_block_scales(self, device): + """Static export should match fake quant clamping and never write NaN FP8 scales.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + quantizer.amax = torch.tensor([8.0], device=device) + quantizer.global_amax = torch.tensor(1.0, device=device) + weight = torch.ones(1, 16, device=device) + + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(quantizer, weight) + + assert weight_scale.dtype == torch.float8_e4m3fn + assert torch.isfinite(weight_scale.float()).all() + assert torch.equal(weight_scale.float(), torch.full_like(weight_scale.float(), 448.0)) + @pytest.mark.parametrize("device", ["cuda"]) class TestNVFP4MSECalibrator: diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 3fac8269ccd..8dfbc0323c2 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -29,7 +29,7 @@ import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.export.unified_export_megatron import GPTModelExporter from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel @@ -42,15 +42,8 @@ def _verify_model_quant_config( """Verify config.json and hf_quant_config.json""" config_dict = json.load(open(export_dir / "config.json")) hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) - # Make sure config.json and hf_quant_config.json are consistent - assert ( - config_dict["quantization_config"]["quant_algo"] - == hf_quant_config_dict["quantization"]["quant_algo"] - ) - assert ( - config_dict["quantization_config"]["ignore"] - == hf_quant_config_dict["quantization"]["exclude_modules"] - ) + # Make sure config.json and hf_quant_config.json use the same serving config. + assert config_dict["quantization_config"] == hf_quant_config_dict # Verify config.json if kv_cache_quant_cfg: @@ -58,17 +51,17 @@ def _verify_model_quant_config( # Verify hf_quant_config.json if quant_config: - quant_config_dict = hf_quant_config_dict["quantization"] + quant_config_dict = hf_quant_config_dict quant_type = quant_config_dict["quant_algo"] assert ( quant_type in quant_config ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG - assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + assert len(quant_config_dict["ignore"]) > 1 # Dynamically added exclude modules if quant_type == "NVFP4": - assert quant_config_dict["group_size"] == 16 + assert quant_config_dict["config_groups"]["group_0"]["weights"]["group_size"] == 16 if kv_cache_quant_cfg: - assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 + assert quant_config_dict["kv_cache_scheme"]["num_bits"] == 8 def _test_unified_export_megatron( @@ -295,6 +288,44 @@ def test_qkv_slicing_gqa_tp2(dist_workers_size_2, tmp_path): dist_workers_size_2.run(partial(_test_qkv_slicing_gqa_tp2, tmp_path)) +def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): + """Unquantized fused MCore linear_qkv should become HF q/k/v excludes.""" + exporter = object.__new__(GPTModelExporter) + exporter.dtype = torch.bfloat16 + exporter.exclude_modules = ["backbone.layers.0.mixer"] + exporter.layer_config_dict = {} + exporter._state_dict = {} + + hidden_size = 8 + head_size = 4 + num_attention_heads = 2 + num_query_groups = 1 + qkv_dim = num_attention_heads + 2 * num_query_groups + weight = torch.arange(qkv_dim * head_size * hidden_size, dtype=torch.bfloat16).reshape( + qkv_dim * head_size, hidden_size + ) + + module = torch.nn.Module() + module.config = type( + "Config", + (), + { + "hidden_size": hidden_size, + "num_query_groups": num_query_groups, + "num_attention_heads": num_attention_heads, + "kv_channels": head_size, + }, + )() + exporter._get_quantized_state = lambda *args, **kwargs: ({"weight": weight}, None, 0) + + exporter._qkv_slicing(module, "backbone.layers.0.mixer.") + + assert "backbone.layers.0.mixer" not in exporter.exclude_modules + assert "backbone.layers.0.mixer.q_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.k_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.v_proj" in exporter.exclude_modules + + def _make_exporter_for_mtp(model_dir: Path) -> GPTModelExporter: """Create a minimal GPTModelExporter instance for testing _get_mtp_state_dict.""" exporter = object.__new__(GPTModelExporter) diff --git a/tests/unit/torch/export/test_hf_checkpoint_utils.py b/tests/unit/torch/export/test_hf_checkpoint_utils.py index f83cb355749..33d17eebb3d 100644 --- a/tests/unit/torch/export/test_hf_checkpoint_utils.py +++ b/tests/unit/torch/export/test_hf_checkpoint_utils.py @@ -20,6 +20,8 @@ import pytest pytest.importorskip("huggingface_hub") +hf_hub_errors = pytest.importorskip("huggingface_hub.errors") +LocalEntryNotFoundError = hf_hub_errors.LocalEntryNotFoundError from modelopt.torch.export import copy_hf_ckpt_remote_code @@ -59,15 +61,60 @@ def test_copy_hf_ckpt_remote_code_local_dir_no_py_files(tmp_path): assert list(dst_dir.iterdir()) == [], "no files should be copied" -def test_copy_hf_ckpt_remote_code_hub_id(tmp_path): - """copy_hf_ckpt_remote_code delegates to snapshot_download for a Hub model ID.""" +def test_copy_hf_ckpt_remote_code_hub_id(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code copies .py files from the resolved Hub snapshot.""" dst_dir = tmp_path / "dst" - - with patch("modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download") as mock_sd: + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + (snapshot_dir / "modeling_custom.py").write_text("# custom model") + (snapshot_dir / "not_python.txt").write_text("not python") + + monkeypatch.delenv("HF_HUB_OFFLINE", raising=False) + with patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + return_value=str(snapshot_dir), + ) as mock_sd: copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-Nano-12B-v2", dst_dir) mock_sd.assert_called_once_with( repo_id="nvidia/NVIDIA-Nemotron-Nano-12B-v2", - local_dir=str(dst_dir), allow_patterns=["*.py"], + local_files_only=False, + ) + assert (dst_dir / "modeling_custom.py").read_text() == "# custom model" + assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied" + + +def test_copy_hf_ckpt_remote_code_hub_id_offline_uses_cache(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code resolves cached Hub snapshots when HF_HUB_OFFLINE is set.""" + dst_dir = tmp_path / "dst" + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + (snapshot_dir / "nemotron_reasoning_parser.py").write_text("# parser") + + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + with patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + return_value=str(snapshot_dir), + ) as mock_sd: + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", dst_dir) + + mock_sd.assert_called_once_with( + repo_id="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + allow_patterns=["*.py"], + local_files_only=True, ) + assert (dst_dir / "nemotron_reasoning_parser.py").read_text() == "# parser" + + +def test_copy_hf_ckpt_remote_code_hub_id_offline_missing_cache_raises(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code raises a clear error when offline cache is missing.""" + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + with ( + patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + side_effect=LocalEntryNotFoundError("missing"), + ), + pytest.raises(RuntimeError, match="HF_HUB_OFFLINE"), + ): + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", tmp_path / "dst") diff --git a/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py new file mode 100644 index 00000000000..dfb776a0484 --- /dev/null +++ b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU round-trip tests for NVFP4 static export with extreme per-block amax (underflow/overflow).""" + +from __future__ import annotations + +import pytest +import torch + +from modelopt.torch.export.quant_utils import QUANTIZATION_NVFP4, to_quantized_weight +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer +from modelopt.torch.quantization.qtensor import NVFP4QTensor + +BLOCK_SIZE = 16 +FP4_VALUES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6]) + + +def _make_static_quantizer( + per_block_amax: torch.Tensor, global_amax: torch.Tensor +) -> NVFP4StaticQuantizer: + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: BLOCK_SIZE, "type": "static", "scale_bits": (4, 3)}, + ) + q = NVFP4StaticQuantizer(quant_attribute_cfg=cfg) + q.amax = per_block_amax.clone() + q.global_amax = global_amax.clone() + return q + + +def _export_round_trip( + weight: torch.Tensor, quantizer: NVFP4StaticQuantizer +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the full export path and dequantize it back, mimicking vLLM serving. + + Returns (weight_scale_fp8, weight_scale_2_fp32, dequantized_weight_bf16). + """ + weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(quantizer) + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + quantizer, weight, weight_scale_2.to(weight.device) + ) + packed = to_quantized_weight( + weight, + weight_scale, + QUANTIZATION_NVFP4, + weight_scale_2, + BLOCK_SIZE, + ) + qtensor = NVFP4QTensor(weight.shape, weight.dtype, packed) + dequant = qtensor.dequantize( + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: BLOCK_SIZE}, + dtype=weight.dtype, + ) + return weight_scale, weight_scale_2, dequant + + +def _layer1_routed_expert_like( + out_dim: int, in_dim: int, *, n_outliers: int, seed: int +) -> torch.Tensor: + """Synthesize a tensor whose block-amax distribution matches the failure case. + + The vast majority of blocks have tiny absolute values (~1e-7), and a handful + of rows carry an outlier magnitude (~1e-1) that drives the global tensor + amax to the FP8-normal regime. This is the distribution that produces 81% + raw-1 FP8 block scales in the broken Ultra V3 MSE export. + """ + g = torch.Generator().manual_seed(seed) + weight = torch.randn(out_dim, in_dim, generator=g, dtype=torch.bfloat16) * 1e-7 + # Inject outliers in n_outliers distinct (row, col) positions. + n = max(1, n_outliers) + rows = torch.randint(0, out_dim, (n,), generator=g) + cols = torch.randint(0, in_dim, (n,), generator=g) + weight[rows, cols] = torch.randn(n, generator=g, dtype=torch.bfloat16) * 0.1 + return weight + + +def _per_block_max(weight: torch.Tensor) -> torch.Tensor: + """Per-(out, num_blocks) absolute max, mirroring NVFP4 block-amax reduction.""" + blocks = weight.float().view(*weight.shape[:-1], -1, BLOCK_SIZE) + return blocks.abs().amax(dim=-1) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestNVFP4StaticExportFiniteAndBounded: + """The static export path must never emit NaN/Inf scales or dequant values. + + These tests fail fast if any saved bit pattern (weight_scale, packed weight, + dequantized result) contains a NaN or Inf, no matter what amax the caller + set on the static quantizer. + """ + + def test_layer1_like_distribution_no_nan(self): + weight = _layer1_routed_expert_like(64, 256, n_outliers=4, seed=0) + block_max = _per_block_max(weight) + global_amax = block_max.max() + # MSE-style: a mix of "shrunk" (multiplier 0.5) and "expanded" + # (multiplier 2.5) per-block amax around the actual block maxima. + mult = torch.where( + torch.arange(block_max.numel()) % 2 == 0, + torch.tensor(0.5), + torch.tensor(2.5), + ).view_as(block_max) + amax = (block_max * mult).clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all(), "weight_scale (FP8) must be finite" + assert torch.isfinite(ws2).all(), "weight_scale_2 (FP32) must be finite" + assert torch.isfinite(deq.float()).all(), "dequantized weight must be finite" + + def test_overflow_amax_saturates_no_nan(self): + """When _amax > _global_amax (MSE multiplier > 1 on the global-max + block), the FP8 cast must saturate to 448, not produce NaN.""" + weight = torch.randn(8, 16, dtype=torch.bfloat16) * 1e-2 + block_max = _per_block_max(weight) + global_amax = block_max.max() + # Force the first block to have amax 4x the global max. + amax = block_max.clone() + amax[0] = global_amax * 4.0 + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all(), "FP8 weight_scale must saturate, not NaN" + # FP8 e4m3fn max is 448. The byte for the overflowing block should be 448. + assert ws[0].float().max().item() == pytest.approx(448.0) + assert torch.isfinite(deq.float()).all() + + def test_underflow_amax_no_inf_in_dequant(self): + """When per_block_amax / global_amax is below FP8 representable range, + the static export must not emit Inf or NaN in the *dequantized* tensor. + Whether the FP8 byte is 0 (natural underflow) or a clamped subnormal, + the dequant of every weight in the affected blocks must be finite.""" + weight = torch.randn(8, 16, dtype=torch.bfloat16) * 1e-7 + block_max = _per_block_max(weight) + global_amax = block_max.max() * 1e6 # _global_amax much larger than blocks + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, torch.tensor(global_amax)) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all() + assert torch.isfinite(ws2).all() + assert torch.isfinite(deq.float()).all(), ( + "dequantized weight has NaN/Inf in underflow regime — this is the " + "failure pattern that breaks vLLM serving" + ) + + +class TestNVFP4StaticExportRoundTripBound: + """Round-trip dequant magnitude must stay bounded by the encoded amax. + + For every block, ``|dequant| <= 6 * weight_scale_FP8 * weight_scale_2``, + and that product must be bounded above by ``max(_amax_block, 448 * scale_2)`` + (FP8 saturation). If any block's dequant exceeds that bound, an + out-of-distribution outlier was synthesized by the export path itself. + """ + + def test_dequant_magnitude_within_amax(self): + weight = _layer1_routed_expert_like(32, 128, n_outliers=4, seed=1) + block_max = _per_block_max(weight) + global_amax = block_max.max() + # Use _amax = block_max (no shrinking, no expansion). + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + # The maximum representable magnitude per block is 6 * dequant_scale. + # Allow a small relative tolerance for the bf16 quantization that + # NVFP4QTensor.dequantize applies to its output (~0.4% per element). + dequant_scale_per_block = ws.float() * ws2.float() + expected_block_bound = 6.0 * dequant_scale_per_block # shape (out, num_blocks) + deq_block_max = deq.float().view(*deq.shape[:-1], -1, BLOCK_SIZE).abs().amax(dim=-1) + violation = (deq_block_max - expected_block_bound).clamp(min=0) + # Reject any per-block dequant magnitude that exceeds the FP4 saturation + # bound by more than 1% (well above bf16 round-up noise) — that would + # indicate the export synthesized out-of-distribution outliers. + relative = violation / expected_block_bound.clamp(min=1e-30) + max_relative = relative.max().item() + assert max_relative <= 1e-2, ( + f"dequant block max exceeds the FP4 saturation bound by " + f"{max_relative:.2%}. Worst block index: " + f"{tuple(int(i) for i in (relative == relative.max()).nonzero()[0].tolist())}" + ) + + +class TestNVFP4StaticVsDynamicEquivalence: + """When _amax = per_block_amax (no MSE shrink/expand), the static and + dynamic export paths must produce bit-identical FP8 weight_scale bytes. + Both paths apply the same lower clamp at the fp8 subnormal min (2**-9) + so tiny-amax blocks land on 0x01 instead of underflowing to 0x00.""" + + def test_static_matches_dynamic_when_amax_is_block_max(self): + weight = _layer1_routed_expert_like(16, 64, n_outliers=2, seed=2) + block_max = _per_block_max(weight) + global_amax = block_max.max() + + # Static path + q = _make_static_quantizer(block_max.clamp(min=1e-30), global_amax) + static_ws_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(q) + static_ws, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + q, weight, static_ws_2 + ) + + # Dynamic path + dynamic_ws, dynamic_ws_2 = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, weights_scaling_factor_2=static_ws_2.clone() + ) + + # Both paths should yield identical FP8 byte patterns when amax matches. + assert torch.equal(static_ws.view(torch.uint8), dynamic_ws.view(torch.uint8)), ( + "static and dynamic export paths produced different FP8 " + "weight_scale bytes for the same per-block amax — this means " + "the static path's scale computation diverges from the dynamic path" + ) + assert torch.allclose(static_ws_2, dynamic_ws_2) + + +class TestNVFP4StaticManualRoundTrip: + """Cross-check the export path against a manual per-block computation. + + For each block: ``dequant_scale = FP8(amax * 448 / global_amax) * (global_amax / (6 * 448))``, + and each FP4 value lies in {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}. The + dequantized weight for any element should be the FP4 value at the rounded + ordinal, multiplied by the block's dequant_scale. + """ + + def test_single_block_matches_manual(self): + # One block of 16 elements, mid-magnitude (FP8-normal regime). + torch.manual_seed(3) + weight = (torch.rand(1, BLOCK_SIZE, dtype=torch.bfloat16) - 0.5) * 0.05 + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + # Manual per-block dequant scale: FP8-quantized ratio * scale_2. + dequant_scale = ws.float()[0, 0].item() * ws2.float().item() + + # Every dequant element must be FP4 value * dequant_scale, allowing for + # the bf16 round-trip applied by NVFP4QTensor.dequantize on its output + # (~0.4% relative). Use a tolerance that's loose enough for bf16 and + # tight enough to catch a real off-grid value. + deq_vals = deq.float().reshape(-1) + grid = torch.tensor( + [v.item() * dequant_scale for v in FP4_VALUES.float()], + dtype=torch.float32, + ) + for v in deq_vals.tolist(): + distance = (grid - v).abs().min().item() + tolerance = max(abs(v) * 1e-2, 1e-12) + assert distance <= tolerance, ( + f"dequant value {v} is not on the FP4 grid (min distance {distance:g}, " + f"tolerance {tolerance:g}); grid = {sorted(grid.tolist())}" + ) + + +class TestNVFP4StaticCornerCases: + """Edge cases that have historically caused trouble in MSE static export.""" + + def test_zero_amax_block_does_not_explode(self): + """If MSE selects amax=0 for a block (e.g., dead expert), the export + must not emit NaN/Inf or amplify dequant magnitude.""" + weight = torch.zeros(2, BLOCK_SIZE, dtype=torch.bfloat16) + weight[1, :] = 0.05 # one block with real values + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clone() + amax[0] = 0.0 # explicit zero amax for the dead block + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all() + assert torch.isfinite(deq.float()).all() + # The dead block must dequantize to all zeros (no leakage from the + # special amax==0 substitution). + assert torch.equal(deq[0].float(), torch.zeros_like(deq[0].float())), ( + "amax==0 block leaked nonzero dequant values" + ) + + def test_ultra_v3_layer1_distribution_byte_distribution_sane(self): + """Sanity: in the Ultra-V3 layer-1-like regime, the export's FP8 byte + distribution does not contain raw NaN bytes (0x7F or 0xFF).""" + weight = _layer1_routed_expert_like(128, 512, n_outliers=8, seed=4) + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, _, _ = _export_round_trip(weight, q) + ws_bytes = ws.view(torch.uint8).reshape(-1) + + # FP8 e4m3fn NaN bytes are 0x7F (127) and 0xFF (255). + nan_count = int(((ws_bytes == 127) | (ws_bytes == 255)).sum().item()) + assert nan_count == 0, f"static export emitted {nan_count} NaN FP8 weight_scale bytes" From 8dff087201b7887c2677778575e39d9c89bdd773 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 15 May 2026 12:21:11 -0700 Subject: [PATCH 2/7] minor fixes; check block quant moe completeness Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/backends/utils.py | 4 ++++ modelopt/torch/quantization/model_calib.py | 3 --- .../torch/quantization/plugins/test_fused_experts.py | 4 ++-- tests/unit/torch/quantization/test_mse_calibrator.py | 11 ++++++++--- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/quantization/backends/utils.py b/modelopt/torch/quantization/backends/utils.py index 838951bef9d..6ed133f5d6c 100644 --- a/modelopt/torch/quantization/backends/utils.py +++ b/modelopt/torch/quantization/backends/utils.py @@ -20,9 +20,13 @@ def fp8_compatible(): """Check if the current device supports FP8.""" + if not torch.cuda.is_available(): + return False return torch.cuda.get_device_capability(0) >= (8, 9) def fp4_compatible(): """Check if the current device supports FP4.""" + if not torch.cuda.is_available(): + return False return torch.cuda.get_device_capability(0) >= (10, 0) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 36bac966291..01c6be15308 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -229,9 +229,6 @@ def _iter_leaf_quantizers(quantizer): def _check_moe_calibration_complete(quantizer, parallel_state): """Raise error if MoE calibration is incomplete across distributed MoE ranks.""" for leaf_quantizer in _iter_leaf_quantizers(quantizer): - if leaf_quantizer.is_block_quant: - continue - has_amax = getattr(leaf_quantizer, "_amax", None) is not None for group in [ parallel_state.data_parallel_group, diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 19e1ed49197..2f64dab6cea 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -645,7 +645,7 @@ def test_bootstrap_populates_dead_expert_quantizers(self): """ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.model_calib import ( - _bootstrap_uncalibrated_weight_quantizers, + _bootstrap_uncalibrated_static_weight_quantizers, ) model = _TinyMoEModel() @@ -699,7 +699,7 @@ def partial_forward(m): f"Dead expert {idx} down_proj should be uncalibrated pre-bootstrap" ) - n_bootstrapped = _bootstrap_uncalibrated_weight_quantizers(model) + n_bootstrapped = _bootstrap_uncalibrated_static_weight_quantizers(model) assert n_bootstrapped >= 2 * len(dead), ( f"Expected ≥{2 * len(dead)} bootstrapped (gate_up + down per dead expert), " f"got {n_bootstrapped}" diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 4332b093861..04266447a61 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -558,14 +558,19 @@ def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True): from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend register_quant_backend(backend_name, lambda x, tq: x) - model = torch.nn.Linear(8, 8, bias=False) - inputs = torch.randn(1, 8) + model = torch.nn.Linear(16, 8, bias=False) + inputs = torch.randn(1, 16) config = { "quant_cfg": [ {"quantizer_name": "*", "enable": False}, { "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": 8, "axis": None, "backend": backend_name}, + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "backend": backend_name, + }, }, ], "algorithm": "max", From cc4a57023ac5a08c08668ffadc8e3ed17fdb5a09 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 15 May 2026 14:01:15 -0700 Subject: [PATCH 3/7] add mcore autoquant MOE rule Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/algorithms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 992717983db..28086a1e96e 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -362,6 +362,8 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): # gate_proj, up_proj, down_proj for Qwen3 like MoE models r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$", r"^(.*?\.mixer\.experts)\.\d+\.(up_proj|down_proj)$", # NemotronH MoE experts + # NemotronH MoE experts in MCore naming (linear_fc1=gate+up fused, linear_fc2=down) + r"^(.*?\.mlp\.experts\.local_experts)\.\d+\.(linear_fc1|linear_fc2)$", r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts From 60ca49cda7cfc5b54a561eba83f3deae7c74cc97 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 18 May 2026 09:12:00 -0700 Subject: [PATCH 4/7] support EP in autoquant scoring Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/algorithms.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 28086a1e96e..094b7a6d5de 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -291,13 +291,14 @@ def get_score(self, recipe: QuantRecipe) -> float: total_score += importance.cpu().item() continue - if parallel_state.expert_model_parallel_group.is_initialized(): - # TODO: Support expert model parallelism for score estimation - warnings.warn("AutoQuantize does not support expert model parallelism yet.") importance = importance.cpu() importance = DistributedProcessGroup.get_dist_syncd_obj( importance, - [parallel_state.tensor_parallel_group, parallel_state.data_parallel_group], + [ + parallel_state.tensor_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.data_parallel_group, + ], sum, ) total_score += importance.item() @@ -318,13 +319,12 @@ def get_cost(self, recipe: QuantRecipe) -> float: cost += weight_size * recipe.compression continue - if parallel_state.expert_model_parallel_group.is_initialized(): - # TODO: Support expert model parallelism - warnings.warn("AutoQuantize does not support expert model parallelism yet.") - weight_size = DistributedProcessGroup.get_dist_syncd_obj( weight_size, - [parallel_state.tensor_parallel_group], + [ + parallel_state.tensor_parallel_group, + parallel_state.expert_model_parallel_group, + ], sum, ) From bd2e8e9a10506ccb2ae0a452d002c7b2d3e125a8 Mon Sep 17 00:00:00 2001 From: realAsma <86726418+realAsma@users.noreply.github.com> Date: Mon, 18 May 2026 11:43:12 -0700 Subject: [PATCH 5/7] Support auto_quantize for Megatron expert parallelism (#1513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: Bug fix This PR enables `auto_quantize` for Megatron expert parallel MoE flows by including the expert model parallel group when aggregating scores and costs and when synchronizing selected recipes. It also derives the search budget from the no-quant candidate costs in `candidate_stats`, so sharded expert layers use global candidate costs instead of local module weights. ### Usage ```python model, search_state = mtq.auto_quantize( model, constraints={"effective_bits": 8.0}, quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG], data_loader=data_loader, forward_step=forward_step, ) ``` ### Testing - Focused Megatron EP test from local log: `python -m pytest tests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_auto_quantize_moe_ep -xvs` in NGC PyTorch 26.01 (`1 passed` in 134.37s). - Added unit coverage for deriving the auto_quantize budget from no-quant candidate costs. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A - Did you get Claude approval on this PR?: N/A ### Additional Information Base branch: `jennifchen/super_nvfp4_recipe`. Signed-off-by: realAsma Signed-off-by: Jenny Chen Co-authored-by: Jenny Chen --- modelopt/torch/quantization/algorithms.py | 27 ++++++++++--- .../torch/quantization/plugins/megatron.py | 31 ++++++++++++++ .../torch/quantization/quantize_common.py | 24 ++++++++--- .../quantization/plugins/test_megatron.py | 40 +++++++++++++++++++ .../unit/torch/quantization/test_autoquant.py | 35 ++++++++++++++++ 5 files changed, 146 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 094b7a6d5de..99cee86dcb2 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -296,8 +296,8 @@ def get_score(self, recipe: QuantRecipe) -> float: importance, [ parallel_state.tensor_parallel_group, - parallel_state.expert_model_parallel_group, parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, ], sum, ) @@ -722,6 +722,15 @@ def _get_total_weight_size(modules): for module in modules ) + @staticmethod + def _get_total_weight_size_from_candidate_stats(candidate_stats): + no_quant_recipe = QuantRecipe(quant_cfg=None) + total_weight_size = 0 + for candidate_stat in candidate_stats.values(): + no_quant_idx = candidate_stat["formats"].index(no_quant_recipe) + total_weight_size += candidate_stat["costs"][no_quant_idx] + return total_weight_size + def _get_constraints_for_search(self, max_weight_size, lower_bound=None): constraints = { "weight_size_after_compression": ( @@ -744,7 +753,7 @@ def run_search(self): ) compression = self._get_formatted_weight_compression_constraint() - total_weight_size = self._get_total_weight_size(self.model.modules()) + total_weight_size = self._get_total_weight_size_from_candidate_stats(self.candidate_stats) max_weight_size = total_weight_size * compression # Run the search with stats to get the best recipe and whether the constraints are satisfied @@ -754,12 +763,16 @@ def run_search(self): best_recipe = {} best_constraints, best_scores = 0, 0 for name, best_hparam_recipe_info in best_recipe_info.items(): - # Solvers could give different solutions for the same layer across DP/TP groups even though - # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP + # Solvers could give different solutions for the same layer across DP/TP/EP groups even though + # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP/EP _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state best_format = DistributedProcessGroup.get_dist_syncd_obj( best_hparam_recipe_info["format"], - [_ps.data_parallel_group, _ps.tensor_parallel_group], + [ + _ps.data_parallel_group, + _ps.tensor_parallel_group, + _ps.expert_model_parallel_group, + ], lambda a: a[0], ) @@ -1379,7 +1392,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): effective_bits = constraints["effective_bits"] compression = effective_bits / 16.0 candidate_stats = search_state["candidate_stats"] - total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values()) + total_weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size_from_candidate_stats( + candidate_stats + ) max_weight_size = total_weight_size * compression method = search_state["method"] diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 33f11a4491b..f90d2862aef 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -18,6 +18,7 @@ import logging import types import warnings +from contextlib import contextmanager from typing import Any import megatron.core.parallel_state as mcore_parallel @@ -775,3 +776,33 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Affine KVCache Quant bias vector. state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) + + +def _is_supported_megatron_model(model: torch.nn.Module) -> bool: + return isinstance(model, MegatronModule) + + +@contextmanager +def _megatron_grad_ckpt_context(model: torch.nn.Module): + # Megatron configures activation recompute at model build time via TransformerConfig, + # so there is no runtime flag to flip here. + yield + + +def _is_param_grad_enabled_for_megatron(pname: str, model: torch.nn.Module) -> bool: + return "embed" in pname + + +def _register_auto_quantize_support() -> None: + # Local import breaks the circular path where algorithms imports model_calib, + # which imports _check_static_block_tp_supported from this plugin. + from ..algorithms import AutoQuantizeGradientSearcher + + AutoQuantizeGradientSearcher.register_custom_support( + _is_supported_megatron_model, + _megatron_grad_ckpt_context, + _is_param_grad_enabled_for_megatron, + ) + + +_register_auto_quantize_support() diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index 7af33fa599f..46259203b24 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -249,14 +249,28 @@ def forward_loop(model): ) -def auto_quantize_helper(model): +def auto_quantize_helper( + model, + data_loader=None, + forward_step=None, + forward_backward_step=None, + quantization_formats=None, +): + if data_loader is None: + data_loader = [model.get_dummy_input().cuda() for _ in range(2)] + if forward_step is None: + forward_step = lambda model, batch: model(batch) # noqa: E731 + if forward_backward_step is None: + forward_backward_step = lambda model, batch: model(batch).sum().backward() # noqa: E731 + if quantization_formats is None: + quantization_formats = [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG] model, search_state = mtq.auto_quantize( model, constraints={"effective_bits": 8.0}, - quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], - data_loader=[model.get_dummy_input().cuda() for _ in range(2)], - forward_step=lambda model, batch: model(batch), - forward_backward_step=lambda model, batch: model(batch).sum().backward(), + quantization_formats=quantization_formats, + data_loader=data_loader, + forward_step=forward_step, + forward_backward_step=forward_backward_step, num_calib_steps=2, num_score_steps=2, verbose=True, diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index c34fb2df376..3b7985c817b 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -26,6 +26,7 @@ from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, + get_batch, get_forward, initialize_for_megatron, run_mcore_inference, @@ -694,6 +695,45 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg): ) +def _test_auto_quantize_moe_ep_helper(rank, size): + initialize_for_megatron( + tensor_model_parallel_size=1, + expert_model_parallel_size=size, + seed=SEED, + ) + model = _gpt_model_provider( + tp_size=1, + ep_size=size, + hidden_size=32, + num_moe_experts=4, + moe_grouped_gemm=False, + transformer_impl="modelopt", + ) + + def forward_step(model, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + auto_quantize_helper( + model, + data_loader=[get_batch(model, batch_size=2) for _ in range(2)], + forward_step=forward_step, + forward_backward_step=lambda m, b: forward_step(m, b).mean().backward(), + quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG], + ) + + +def test_auto_quantize_moe_ep(dist_workers_size_2): + """auto_quantize must sum score/cost across EP ranks and pick a consistent recipe.""" + dist_workers_size_2.run(_test_auto_quantize_moe_ep_helper) + + @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("sync_weight_amax", [True, False]) def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax): diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..f7b4965cf8f 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -24,8 +24,10 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.algorithms import ( + AutoQuantizeGradientSearcher, QuantRecipe, QuantRecipeHparam, + _AutoQuantizeBaseSearcher, estimate_quant_compression, ) from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg @@ -305,6 +307,39 @@ def test_data_parallel_auto_quantize(skip_on_windows): spawn_multiprocess_job(4, _test_data_parallel_auto_quantize, backend="gloo") +def test_auto_quantize_budget_uses_no_quant_candidate_cost(monkeypatch): + class _BudgetCaptureSearcher(AutoQuantizeGradientSearcher): + def run_search_with_stats(self, max_weight_size, verbose=False): + self.max_weight_size = max_weight_size + return {}, True + + def _raise_local_total_weight_size(modules): + pytest.fail("run_search should derive total weight size from candidate costs") + + monkeypatch.setattr( + _AutoQuantizeBaseSearcher, + "_get_total_weight_size", + staticmethod(_raise_local_total_weight_size), + ) + + searcher = _BudgetCaptureSearcher() + searcher.reset_search() + searcher.model = torch.nn.Module() + searcher.config = {"verbose": False} + searcher.constraints = {"effective_bits": 8.0} + searcher.candidate_stats = { + "local_expert.quant_recipe": { + "formats": [QuantRecipe(mtq.NVFP4_DEFAULT_CFG), QuantRecipe(None)], + "scores": [1.0, 0.0], + "costs": [25.0, 100.0], + } + } + + searcher.run_search() + + assert searcher.max_weight_size == 50.0 + + def test_estimate_quant_compression(): nvfp4_affine_kv_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_AFFINE_KV_CFG) assert estimate_quant_compression(nvfp4_affine_kv_cfg) == 0.25 From 3e85e70111b88555746924cd5f417e91402827f9 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 19 May 2026 22:38:41 -0700 Subject: [PATCH 6/7] merge main in, lm head quant in mcore, w4a16 mcore Signed-off-by: Jennifer Chen --- .agents/developer-guidelines.md | 58 -- .coderabbit.yaml | 17 + .github/workflows/claude_review.yml | 44 +- AGENTS.md | 3 +- CHANGELOG.rst | 7 +- CONTRIBUTING.md | 70 ++ README.md | 1 - docs/source/guides/10_recipes.rst | 36 +- docs/source/guides/11_config_system.rst | 573 +++++++++++++ docs/source/guides/_quant_cfg.rst | 25 +- examples/llm_ptq/hf_ptq.py | 16 +- examples/specdec_bench/run.py | 2 +- .../specdec_bench/models/sglang.py | 83 +- .../specdec_bench/models/vllm.py | 6 + examples/speculative_decoding/main.py | 347 +++----- .../llm_export_utils/quantization_utils.py | 4 +- modelopt/recipe/config.py | 172 +++- modelopt/recipe/loader.py | 171 ++-- modelopt/torch/export/convert_hf_config.py | 13 + modelopt/torch/export/model_config.py | 1 + .../torch/export/plugins/mcore_deepseek.py | 9 + modelopt/torch/export/plugins/mcore_gptoss.py | 6 + modelopt/torch/export/plugins/mcore_llama.py | 11 + modelopt/torch/export/plugins/mcore_qwen.py | 11 + .../torch/export/plugins/megatron_importer.py | 74 +- modelopt/torch/export/quant_utils.py | 12 + modelopt/torch/export/unified_export_hf.py | 3 + .../torch/export/unified_export_megatron.py | 58 +- modelopt/torch/nas/plugins/megatron.py | 17 + modelopt/torch/opt/config.py | 52 +- modelopt/torch/opt/config_loader.py | 75 +- .../torch/prune/plugins/mcore_minitron.py | 59 +- modelopt/torch/quantization/algorithms.py | 6 +- modelopt/torch/quantization/config.py | 279 ++++--- modelopt/torch/quantization/conversion.py | 6 +- modelopt/torch/quantization/model_calib.py | 2 +- modelopt/torch/quantization/model_quant.py | 1 + .../nn/modules/tensor_quantizer.py | 4 +- .../torch/quantization/plugins/megatron.py | 14 +- .../quantization/qtensor/nvfp4_tensor.py | 26 +- modelopt/torch/speculative/config.py | 82 +- .../torch/speculative/plugins/hf_eagle.py | 32 +- .../speculative/plugins/hf_training_args.py | 82 ++ modelopt/torch/utils/dataset_utils.py | 168 +++- modelopt/torch/utils/logging.py | 8 +- .../torch/utils/plugins/megatron_generate.py | 4 +- modelopt/torch/utils/plugins/megatron_mmlu.py | 7 +- .../configs/ptq/units/w4_nvfp4.yaml | 24 + .../ptq/nvfp4_weight_only-kv_fp16.yaml | 29 + .../general/speculative_decoding/dflash.yaml | 8 +- .../general/speculative_decoding/eagle3.yaml | 8 +- ...unified_hf_export_and_check_safetensors.py | 1 + tests/unit/recipe/test_loader.py | 337 +++++++- tests/unit/torch/opt/test_config.py | 2 +- .../plugins/test_fused_experts.py | 57 +- .../quantization/test_config_validation.py | 88 +- .../torch/quantization/test_nvfp4_tensor.py | 112 +++ .../speculative/plugins/test_hf_dflash.py | 8 - .../plugins/test_hf_dflash_offline.py | 30 +- .../torch/speculative/test_eagle_config.py | 67 +- tests/unit/torch/utils/test_dataset_utils.py | 138 ++++ .../Qwen/Qwen3-8B/megatron_lm_ptq.yaml | 2 +- uv.lock | 759 +++++++++--------- 63 files changed, 3179 insertions(+), 1248 deletions(-) delete mode 100644 .agents/developer-guidelines.md create mode 100644 docs/source/guides/11_config_system.rst create mode 100644 modelopt/torch/speculative/plugins/hf_training_args.py create mode 100644 modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml create mode 100644 modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml create mode 100644 tests/unit/torch/quantization/test_nvfp4_tensor.py diff --git a/.agents/developer-guidelines.md b/.agents/developer-guidelines.md deleted file mode 100644 index 10b3d901499..00000000000 --- a/.agents/developer-guidelines.md +++ /dev/null @@ -1,58 +0,0 @@ -# Coding Principles - -Guidelines for production code in ModelOpt. Key values: simplicity, modularity, -and conciseness. - -## Principles - -- **Prefer simple, surgical changes.** Touch only what the task requires. Avoid speculative - refactors, broad rewrites, and "while we're here" cleanups. -- **Design for simplicity and readability.** Choose the design that is easiest to understand and maintain. - Code is read top to bottom: put high-level behavior first, hide lower-level details behind well-named helpers, - and treat heavy branching as a signal to reconsider the design. -- **Prefer modular, composable solutions.** Avoid input-specific or case-specific hard-coding. - Use existing extension points when they fit. If none fit, add a simple, focused helper, - class, or plugin that cleanly captures the new behavior. Keep scope limited to known cases. -- **Respect inheritance boundaries.** Parent abstractions should define shared contracts and - shared behavior, not child-specific special cases. -- **Don't repeat yourself; keep a single source of truth.** Consolidate repeated logic or intent with a shared helper, API, - or abstraction when doing so keeps the design simpler. Avoid duplication that can drift out of sync. -- **Comment cautiously.** Comments should add context, not translate code into English. - Prefer making the code self-explanatory first. Use comments only for non-obvious - intent or constraints that remain unclear from the code. Apply this guidance to new - comments only; do not rewrite or delete existing comments just for style. -- **Document public APIs.** Public and higher-level APIs should have docstrings, including examples when useful. - Internal helpers should usually be self-documenting through clear names and structure. -- **Fix the bug cause, not the side effect.** For bug fixes, find the root cause instead of patching for its side effect. -- **Validate external input once.** Check types and values at the interface boundary. Internal code can trust those - checks and avoid redundant assertions. -- **Remove dead code.** Delete unused imports, unreachable branches, and obsolete helpers. -- **Use relative paths** from the repo root in commands and file references. - -## Testing - -- **Develop with focused tests.** During development, write as many focused - tests as needed, including lower-level unit tests or internal probes, to - understand and harden behavior. -- **Curate production tests and keep them lean.** Before staging or committing, - decide which tests should be checked in. Checked-in tests should document - expected behavior, protect against regressions, or flag backward-incompatible - behavior changes. Remove redundant lower-level tests when a higher-level test - already covers the same behavior, keeping CI/CD fast and lean. - -## Performant AI Code - -- **Keep tensor work on the GPU and avoid unnecessary CPU-GPU syncs.** Reading metadata such as `tensor.shape` is fine. - Avoid Python scalar extraction and operators such as `tensor.item()`, `float(tensor)`, or `min(tensor)` because they - can trigger CPU-GPU syncs. Use PyTorch tensor ops such as `tensor.min()` by default, and only extract Python scalars - when the CPU needs the value. Tensor-value-based Python branching can also break CUDA graphs. -- **Develop with distributed processing in mind.** Examples: Use `print_rank_0` or `warn_rank_0` - when possible to avoid noisy logs. Guard shared side effects, such as - file writes or shared state updates, against race conditions between ranks. - -## Compatibility - -- **Preserve config and checkpoint backward compatibility.** ModelOpt checkpoints include serialized - `ModeloptBaseConfig` instances such as `QuantizeConfig`. If these Pydantic-based configs change - without backward compatibility handling, older checkpoints may no longer load. Make breaking changes - explicit and intentional. diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 825f3ec04bf..dfffd116b35 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -4,6 +4,8 @@ reviews: profile: chill collapse_walkthrough: true poem: false + # Allow CodeRabbit to formally approve once its comments are resolved and pre-merge checks pass + request_changes_workflow: true path_instructions: - path: "modelopt/**/*.py" instructions: &security_instructions | @@ -25,6 +27,21 @@ reviews: @NVIDIA/modelopt-setup-codeowners with an explicit justification in the PR description. - path: "examples/**/*.py" instructions: *security_instructions + - path: "tests/**/*.py" + instructions: | + Verify tests follow the conventions in CONTRIBUTING.md. Flag the following as + IMPORTANT issues: + 1. Imports inside functions or test methods without explicit justification. + Imports belong at the top of the file so import errors surface at collection + time, not mid-test. The only acceptable in-function imports are for circular + imports or optional dependencies (e.g., TensorRT-LLM, Megatron-Core), and + those should carry a brief comment naming the reason. + 2. Redundant lower-level tests that duplicate behavior already covered by a + higher-level test — checked-in tests should be lean and document expected + behavior, protect against regressions, or flag backward-incompatible changes. + 3. Tests placed in the wrong directory for their cost profile (e.g., multi-minute + tests under tests/unit, which targets a few-seconds budget; GPU-requiring + tests under tests/unit instead of tests/gpu*). auto_review: auto_incremental_review: true drafts: false diff --git a/.github/workflows/claude_review.yml b/.github/workflows/claude_review.yml index a41c7571acc..38ef03effc2 100644 --- a/.github/workflows/claude_review.yml +++ b/.github/workflows/claude_review.yml @@ -23,7 +23,7 @@ jobs: contains(github.event.comment.body, '/claude review') && contains(fromJson('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association) runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 15 permissions: contents: read pull-requests: write @@ -80,13 +80,18 @@ jobs: BASE REF: origin/${{ steps.pr-info.outputs.base_ref }} Mandatory workflow — never skip or reorder: - 1. Read the PR diff first (gh pr diff). - 2. Read AGENTS.md, .agents/developer-guidelines.md, - and CONTRIBUTING.md for project conventions, coding principles, and architecture. - 3. For changed files under `modelopt/torch//`, read the sub-package's + 1. Read prior Claude activity on the PR so you don't duplicate already-raised + comments and can track which prior issues are now resolved: + `gh pr view $PR_NUMBER --repo $REPO --json comments,reviews` + Treat prior findings as context, not a ceiling — if you spot a genuinely new + issue this round, flag it. + 2. Read the PR diff (gh pr diff). + 3. Read AGENTS.md and CONTRIBUTING.md (including the Coding standards section) + for project conventions, coding principles, and architecture. + 4. For changed files under `modelopt/torch//`, read the sub-package's `__init__.py` plus any `mode.py` / `config.py` to understand mode registration and config schema. - 4. Only then perform the review using that context. + 5. Only then perform the review using that context. You are performing a deep code review on a **NVIDIA Model Optimizer (ModelOpt)** PR. ModelOpt is NVIDIA's open-source library for model optimization (quantization, pruning, @@ -118,6 +123,19 @@ jobs: ## Review Procedure + **Aim for one pass.** Surface meaningful issues in this review so the author gets + one consolidated set of fixes. + + **Cover each changed file across categories.** For each non-trivial changed file, + consider the categories below (Algorithm Correctness, Mode/State, Export, Backward + Compatibility, Performance) before moving on. + + **Trace public symbols across files.** For new or modified public symbols + (functions, arguments, config fields, exported names), grep call sites in + `modelopt/`, `tests/`, and `examples/` before commenting. Many bugs here only + surface where the symbol meets its caller — mode registration, export paths, + restore logic. + 1. Get PR metadata: `gh pr view $PR_NUMBER --repo $REPO --json title,body,baseRefName,headRefName,files,additions,deletions,changedFiles,author` 2. Get the full diff: `gh pr diff $PR_NUMBER --repo $REPO` - For large PRs (>50 files), prioritize source code over config/lock/auto-generated files. @@ -182,6 +200,10 @@ jobs: - Memory regressions: double-allocating weights, holding tensors past their lifetime. ## Suggestions (Nice to Have) + + SUGGESTIONs document non-blocking improvements and never block approval (see + Completion below). Raise them when genuinely useful; skip nits that aren't. + - Stale, imprecise, or misleading comments/docstrings — a wrong docstring is worse than none. - Missing shape/dtype assertions at module/parallelism boundaries where they would @@ -208,5 +230,11 @@ jobs: - Highlight the most impactful findings - Overall assessment of the PR's risk level - If no significant issues are found, approve the PR: - gh pr review $PR_NUMBER --repo $REPO --approve --body "Claude review passed — no significant issues found. LGTM" + **Approval decision — use exact counts from your findings (no other thresholds):** + + - `0 CRITICAL AND 0 IMPORTANT` → **approve**, regardless of SUGGESTION count. + SUGGESTIONs never block approval. + `gh pr review $PR_NUMBER --repo $REPO --approve --body "Claude review passed — no blocking issues found. LGTM"` + - `≥1 CRITICAL OR ≥1 IMPORTANT` → **post a comment review** summarizing the + issues found. + `gh pr review $PR_NUMBER --repo $REPO --comment --body ""` diff --git a/AGENTS.md b/AGENTS.md index 3000fce922b..5e1d4eb20a1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,8 +11,9 @@ These instructions apply to AI-assisted work in this repository. ## Coding guidelines - **Coding guide:** Code development and review require reading and following - [.agents/developer-guidelines.md](.agents/developer-guidelines.md); + the [coding standards in CONTRIBUTING.md](CONTRIBUTING.md#-coding-standards); do not skip this step. +- **Use relative paths** from the repo root in commands and file references. ## Iterative development diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d213757cb6a..5fd1341ddd0 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,11 +27,16 @@ Changelog - Support Megatron-Core checkpoint restore and export for MSE ``NVFP4StaticQuantizer``. - Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. +- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. +- Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. +- The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically. **Bug Fixes** + - In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False +- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. -0.44 (2026-05-18) +0.44 (2026-05-14) ^^^^^^^^^^^^^^^^^ **New Features** diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f7debbbc6ee..bcc70b64d65 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -42,6 +42,65 @@ To run the pre-commit hooks without committing, use: pre-commit run --all-files ``` +## 📐 Coding standards + +Guidelines for production code in ModelOpt. Key values: simplicity, modularity, +and conciseness. + +### Principles + +- **Prefer simple, surgical changes.** Touch only what the task requires. Avoid speculative + refactors, broad rewrites, and "while we're here" cleanups. +- **Design for simplicity and readability.** Choose the design that is easiest to understand and maintain. + Code is read top to bottom: put high-level behavior first, hide lower-level details behind well-named helpers, + and treat heavy branching as a signal to reconsider the design. +- **Prefer modular, composable solutions.** Avoid input-specific or case-specific hard-coding. + Use existing extension points when they fit. If none fit, add a simple, focused helper, + class, or plugin that cleanly captures the new behavior. Keep scope limited to known cases. +- **Respect inheritance boundaries.** Parent abstractions should define shared contracts and + shared behavior, not child-specific special cases. +- **Don't repeat yourself; keep a single source of truth.** Consolidate repeated logic or intent with a shared helper, API, + or abstraction when doing so keeps the design simpler. Avoid duplication that can drift out of sync. +- **Comment cautiously.** Comments should add context, not translate code into English. + Prefer making the code self-explanatory first. Use comments only for non-obvious + intent or constraints that remain unclear from the code. Apply this guidance to new + comments only; do not rewrite or delete existing comments just for style. +- **Document public APIs.** Public and higher-level APIs should have docstrings, including examples when useful. + Internal helpers should usually be self-documenting through clear names and structure. +- **Fix the bug cause, not the side effect.** For bug fixes, find the root cause instead of patching for its side effect. +- **Validate external input once.** Check types and values at the interface boundary. Internal code can trust those + checks and avoid redundant assertions. +- **Remove dead code.** Delete unused imports, unreachable branches, and obsolete helpers. +- **Keep imports at the top of the file.** Place all imports at module top in both source + and test files so import errors surface at module load time rather than at runtime or + during a specific test. Put an import inside a function only when there is a concrete + reason: resolving a circular import that cannot be restructured, guarding an optional + dependency (e.g., TensorRT-LLM, Megatron-Core), or deferring an unusually heavy import + with explicit justification. Add a brief comment in those cases naming the reason. +- **Define the public API with `__all__` and re-export via `from .module import *`.** + Each module declares its public surface with `__all__ = [...]` at the top of the file. + Package `__init__.py` files re-export submodules with `from .module import *`. This + keeps the public API explicit at the source (next to the definitions), avoids + hand-maintained import lists in `__init__.py` drifting out of sync, and makes + star-imports safe by limiting them to the curated `__all__` names. + +### Performant AI code + +- **Keep tensor work on the GPU and avoid unnecessary CPU-GPU syncs.** Reading metadata such as `tensor.shape` is fine. + Avoid Python scalar extraction and operators such as `tensor.item()`, `float(tensor)`, or `min(tensor)` because they + can trigger CPU-GPU syncs. Use PyTorch tensor ops such as `tensor.min()` by default, and only extract Python scalars + when the CPU needs the value. Tensor-value-based Python branching can also break CUDA graphs. +- **Develop with distributed processing in mind.** Examples: Use `print_rank_0` or `warn_rank_0` + when possible to avoid noisy logs. Guard shared side effects, such as + file writes or shared state updates, against race conditions between ranks. + +### Compatibility + +- **Preserve config and checkpoint backward compatibility.** ModelOpt checkpoints include serialized + `ModeloptBaseConfig` instances such as `QuantizeConfig`. If these Pydantic-based configs change + without backward compatibility handling, older checkpoints may no longer load. Make breaking changes + explicit and intentional. + ## Adding a new PIP dependency Currently we have 2 places where we mention pip dependencies: [pyproject.toml](./pyproject.toml) for dependencies that are required for the ModelOpt library and `examples//requirements.txt` for dependencies that are required for the specific examples. @@ -101,6 +160,17 @@ For broader repo validation and dependency setup, use [noxfile.py](./noxfile.py) nox -s "unit-3.12(torch_211, tf_latest)" ``` +### Test design principles + +- **Develop with focused tests.** During development, write as many focused + tests as needed, including lower-level unit tests or internal probes, to + understand and harden behavior. +- **Curate production tests and keep them lean.** Before staging or committing, + decide which tests should be checked in. Checked-in tests should document + expected behavior, protect against regressions, or flag backward-incompatible + behavior changes. Remove redundant lower-level tests when a higher-level test + already covers the same behavior, keeping CI/CD fast and lean. + ## ✍️ Signing your work - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original diff --git a/README.md b/README.md index 6a4f023e4f0..c15a4e80872 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,6 @@ You can also directly use NVIDIA container images, which have Model Optimizer pr - `nvcr.io/nvidia/pytorch:-py3` - `nvcr.io/nvidia/nemo:` - `nvcr.io/nvidia/tensorrt-llm/release:` -- `nvcr.io/nvidia/tensorrt:-py3` Before pulling and using the container images, please review their respective license terms. Make sure to upgrade Model Optimizer to the latest version as described above. diff --git a/docs/source/guides/10_recipes.rst b/docs/source/guides/10_recipes.rst index 4a1da0a3150..b8f1aa50bce 100644 --- a/docs/source/guides/10_recipes.rst +++ b/docs/source/guides/10_recipes.rst @@ -316,8 +316,13 @@ snippet payload after any imports have been expanded: type: dynamic scale_bits: e4m3 -The schema comment is metadata only; it is not returned as part of the loaded -config, and validation does not expand Pydantic defaults into the snippet. +The schema comment itself is not returned as part of the loaded config. The +declared schema is the validation contract: after imports are resolved, the +loader validates the payload against that schema and returns the result -- +a Pydantic model instance for ``BaseModel`` schemas (with defaults populated) +or a validated ``dict``/``list`` for ``TypedDict`` schemas. The schema can +also be supplied at the call site via ``load_config(path, schema_type=...)``, +which takes precedence over an in-file comment when both are present. Top-level recipe files are validated by :func:`~modelopt.recipe.load_recipe`; they do not need ``modelopt-schema`` comments. The comments are the contract @@ -334,9 +339,10 @@ List imports are schema-driven. When a typed list field such as ``quant_cfg: list[QuantizerCfgEntry]`` contains a bare import entry, the imported snippet must declare its own ``modelopt-schema``: -* If the snippet schema is the same list type, e.g. ``QuantizerCfgListConfig``, - the imported entries are spliced into the containing list. -* If the snippet schema is the element type, e.g. ``QuantizerCfgEntry``, the +* If the snippet schema matches the containing list type + (``QuantizerCfgListConfig``, i.e. ``list[QuantizerCfgEntry]``), the imported + entries are spliced into the containing list. +* If the snippet schema matches the element type (``QuantizerCfgEntry``), the imported entry is appended as a single list item. * If the containing list or imported snippet has no schema, or the snippet schema is neither the list type nor the element type, loading raises @@ -413,9 +419,11 @@ PTQ recipes contain a ``quantize`` mapping with: - Description * - ``quant_cfg`` - Yes - - An ordered list of :class:`~modelopt.torch.quantization.config.QuantizerCfgEntry` - dicts. See :ref:`quant-cfg` for the full specification of entries, ordering - semantics, and atomicity rules. + - An ordered list of + :class:`~modelopt.torch.quantization.config.QuantizerCfgEntry` entries. + In YAML each entry is authored as a mapping; after loading they are + validated Pydantic instances. See :ref:`quant-cfg` for the full + specification of entries, ordering semantics, and atomicity rules. * - ``algorithm`` - No - The calibration algorithm: ``"max"`` (default), ``"mse"``, ``"smoothquant"``, @@ -691,13 +699,15 @@ Recipe data model Recipes are validated at load time using Pydantic models: :class:`~modelopt.recipe.config.ModelOptRecipeBase` - Base class for all recipe types. Contains ``metadata`` as a - :class:`~modelopt.recipe.config.RecipeMetadataConfig` mapping, with - ``recipe_type`` and ``description`` convenience properties. + Base class for all recipe types. Contains a required ``metadata`` field + typed as :class:`~modelopt.recipe.config.RecipeMetadataConfig` -- a + :class:`~modelopt.torch.opt.config.ModeloptBaseConfig` subclass exposing + ``recipe_type`` and ``description`` as Pydantic fields. :class:`~modelopt.recipe.config.ModelOptPTQRecipe` - PTQ-specific recipe. Adds the ``quantize`` field (a dict with ``quant_cfg`` and - ``algorithm``). + PTQ-specific recipe. Adds a required ``quantize`` field typed as + :class:`~modelopt.torch.quantization.config.QuantizeConfig` (also a + ``ModeloptBaseConfig`` subclass, containing ``quant_cfg`` and ``algorithm``). :class:`~modelopt.recipe.config.RecipeType` Enum of supported recipe types. diff --git a/docs/source/guides/11_config_system.rst b/docs/source/guides/11_config_system.rst new file mode 100644 index 00000000000..cb659c2482a --- /dev/null +++ b/docs/source/guides/11_config_system.rst @@ -0,0 +1,573 @@ +.. _modelopt-config-system: + +ModelOpt Config System +###################### + +ModelOpt configs use Python types as the contract and YAML as the portable data +representation. A YAML file is loaded into ordinary Python ``dict``/``list`` +data, optional YAML composition is resolved, and the result is validated by the +owning Pydantic-compatible schema. + +The config system is intentionally general. Quantization configs, reusable YAML +snippets, and recipes are all consumers of the same lower-level semantics. +Recipes are one of the main applications; for the recipe-specific authoring +workflow, see :ref:`recipes`. + +.. contents:: On this page + :local: + :depth: 2 + + +Requirements +============ + +The core configuration system has four required properties and one optional +authoring feature: + +* **Typed / schematized**: each config surface has an explicit Python type + contract. Concrete model configs inherit from + :class:`~modelopt.torch.opt.config.ModeloptBaseConfig`; reusable container + shapes can use Pydantic-compatible type aliases such as + ``list[QuantizerCfgEntry]``. +* **Validated**: invalid values fail at load or schema-construction time. + Type errors, range violations, and unknown fields surface as Pydantic + validation errors instead of being silently ignored. +* **Persistent**: a resolved config can be serialized as plain YAML/JSON data, + and the same plain data can be embedded in a PyTorch checkpoint and restored + against the schema. +* **Composable YAML**: shared fragments such as numeric formats and list units + can be defined once and referenced from multiple YAML files. This is optional + authoring convenience, not a correctness requirement. + +These requirements split the system into three layers: + +* Python/Pydantic-compatible schemas define what is valid. +* YAML stores the user-facing config data. +* The loader resolves YAML conveniences, returns plain data, and invokes schema + validation where the file itself declares a schema. + + +Schema layer +============ + +``ModeloptBaseConfig`` is the common base class for structured ModelOpt config +objects: + +.. code-block:: python + + class ModeloptBaseConfig(BaseModel): + model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) + +The base class adds ModelOpt-specific behavior on top of Pydantic: + +* ``extra="forbid"`` rejects unknown keys by default. +* ``validate_assignment=True`` revalidates field updates after construction. +* ``ModeloptField(...)`` is a thin wrapper over Pydantic ``Field`` that asserts a + default value is supplied, so every config field is constructible without + explicit arguments. +* ``model_dump()`` and ``model_dump_json()`` default to ``by_alias=True`` and + ``warnings=False``, so serialized output uses the documented field aliases + and Pydantic serializer warnings are suppressed. +* ``ModeloptBaseConfig`` inherits from ``collections.abc.MutableMapping``, so + config objects can be used wherever dict-style access is expected: + ``cfg["field"]`` / ``cfg["field"] = value``, ``cfg.get("field")``, + ``key in cfg``, ``len(cfg)``, ``iter(cfg)``, ``cfg.keys()``, ``cfg.values()``, + ``cfg.items()``, ``cfg.update({...})``, and ``cfg.setdefault("field", ...)`` + all work. Keys use aliases when defined. Schema fields are not removable, so + ``del cfg["field"]`` raises ``TypeError`` and the ``MutableMapping`` mixins + that delete (``pop(existing_key)``, ``popitem``, ``clear``) inherit that + failure mode. ``cfg["unknown"] = ...`` raises ``KeyError`` rather than + silently adding a new key. +* ``__init_subclass__`` registers each config subclass with PyTorch safe + globals so config objects can be deserialized by ``torch.load`` with + ``weights_only=True``. + +A typical config schema is a regular Pydantic model with field validators: + +.. code-block:: python + + class QuantizeConfig(ModeloptBaseConfig): + quant_cfg: QuantizeQuantCfgType = ModeloptField( + default=[{"quantizer_name": "*", "cfg": {"num_bits": 8, "axis": None}}], + title="Quantization configuration", + validate_default=True, + ) + algorithm: QuantizeAlgoCfgType = ModeloptField( + default="max", + title="Calibration algorithm", + validate_default=True, + ) + + @field_validator("quant_cfg", mode="before") + @classmethod + def normalize_quant_cfg(cls, v): + return normalize_quant_cfg_list(v) if isinstance(v, (list, dict)) else v + +Not every reusable config shape needs its own top-level config class. Any +type that Pydantic's ``TypeAdapter`` can validate is acceptable as a snippet +schema: + +* Pydantic model classes (``ModeloptBaseConfig`` subclasses or other + ``BaseModel`` subclasses) for object snippets such as a single quantizer + rule (``QuantizerCfgEntry``) or a numeric format + (``QuantizerAttributeConfig``). +* ``list[T]`` aliases for list snippets. For example, + ``QuantizerCfgListConfig`` is defined as ``list[QuantizerCfgEntry]``. +* ``TypedDict`` and ``list[TypedDict]`` shapes when a plain dict layout is the + natural representation. These return validated dict/list data rather than + model instances. +* Unions and other ``TypeAdapter``-compatible annotations when the reusable + data shape is a typed container rather than a standalone model class. + +The important invariant is that the schema lives in Python, while YAML remains +data. Snippet schemas are validation contracts; they are not arbitrary Python +execution hooks. + + +Validation model +================ + +Validation happens at two boundaries. + +Imported snippets +----------------- + +Every file referenced by a YAML ``imports`` block is a reusable snippet. It must +include a ``# modelopt-schema: ...`` comment in the initial comment preamble: + +.. code-block:: yaml + + # modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig + num_bits: e4m3 + axis: + +The loader resolves the schema path, validates the resolved snippet payload with +Pydantic ``TypeAdapter``, and only then exposes that snippet to the importing +file. This makes snippets independently reviewable and prevents a malformed +shared fragment from being copied into many configs silently. + +Schema paths are intentionally restricted: + +* they must resolve under the ``modelopt.`` package; +* they must point at a Pydantic-compatible type; +* they are validation contracts, not arbitrary Python execution hooks. + +Top-level configs +----------------- + +Top-level user configs do not always need a ``modelopt-schema`` comment. The +owning API often supplies schema context directly through ``schema_type=``: + +.. code-block:: python + + from modelopt.recipe import load_config + from modelopt.torch.quantization.config import QuantizeConfig + + cfg = load_config("configs/ptq/presets/model/fp8", schema_type=QuantizeConfig) + # cfg is a validated QuantizeConfig instance. + +An *effective schema* is selected from the explicit ``schema_type`` argument +and the file's ``# modelopt-schema: ...`` comment, with ``schema_type`` +winning when both are present. When an effective schema exists, it serves +two purposes: + +* It guides import resolution, especially deciding whether a list import + should append one element or splice several elements. +* It validates the resolved payload and returns it as an instance of that + schema — a Pydantic model instance for ``BaseModel`` schemas, or a + validated ``dict`` / ``list`` for ``TypedDict`` and ``list[TypedDict]`` + schemas. + +When neither a ``schema_type`` argument nor a schema comment is supplied, +``load_config()`` returns the resolved payload as plain ``dict`` or ``list`` +data without validation. + + +YAML loading +============ + +The general loader lives in ``modelopt.torch.opt.config_loader`` and is exported +through ``modelopt.recipe.load_config``. It is intentionally below the recipe +layer so quantization and other core config modules can use it without depending +on recipes. + +``load_config(path, schema_type=...)`` performs this flow: + +1. Locate the YAML file. Filesystem paths are checked first; if the path is + relative and not found locally, the built-in ``modelopt_recipes`` package is + checked. ``.yml`` and ``.yaml`` suffixes may be omitted. +2. Read the optional ``# modelopt-schema: ...`` comment preamble. +3. Parse one YAML document, or two documents when a list-valued snippet also + needs an ``imports`` declaration. +4. Convert ``eXmY`` strings in ``num_bits`` and ``scale_bits`` fields into + ``(X, Y)`` tuples. +5. Resolve a file-local ``imports`` mapping. +6. Recursively resolve nested imports, detect circular imports, and validate + imported snippets against their declared schemas. +7. Walk the YAML tree and replace ``$import`` references. +8. Select the effective top-level schema (``schema_type=`` argument wins over + ``# modelopt-schema:`` comment when both are present). +9. If an effective schema exists, validate the resolved payload and return a + schema instance (a Pydantic model, or a validated ``dict`` / ``list`` for + ``TypedDict``-shaped schemas); otherwise return the plain resolved data. + +The loader is not a general templating engine. It only understands YAML data, +``imports``, ``$import``, schema comments, and the ``eXmY`` numeric shorthand. +``load_config()`` itself does not apply CLI or environment overrides; +higher-level wrappers may add them on top (for example, ``load_recipe()`` +accepts an ``overrides=`` dotlist that is merged before final validation). + + +Self-contained YAML +=================== + +The simplest YAML config is self-contained and has no cross-file composition: + +.. code-block:: yaml + + algorithm: max + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*weight_quantizer' + cfg: + num_bits: e2m1 + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + +This is the baseline format. YAML stores values; Python schemas define and +validate the allowed shape. + +Self-contained YAML is the right choice when a config is small, used once, or +clearer without indirection. Composable YAML is for repeated fragments and large +families of related configs. + + +YAML persistence +================ + +A loaded config should round-trip through plain data. After loading and +validation, serialize the resolved config rather than the authoring-time YAML: + +.. code-block:: python + + import yaml + + from modelopt.recipe import load_config + from modelopt.torch.quantization.config import QuantizeConfig + + cfg = load_config("configs/ptq/presets/model/fp8", schema_type=QuantizeConfig) + + with open("resolved_quantize.yaml", "w", encoding="utf-8") as f: + yaml.safe_dump(cfg.model_dump(), f) + +The output is fully materialized plain data. YAML comments, ``imports`` blocks, +``$import`` markers, and schema comments are authoring metadata; they do not +survive in the resolved dump. This is intentional. Resolved dumps are suitable +for bug reports, reproducibility artifacts, and diffs across runs. + +Reloading a resolved dump is the same operation as any other load: parse plain +YAML data and validate it against the schema. + + +Checkpoint persistence +====================== + +Configs embedded in checkpoints should use the same plain-data contract. Store +``cfg.model_dump()`` in the checkpoint and restore it with the owning schema: + +.. code-block:: python + + import torch + + state = { + "model": model.state_dict(), + "modelopt_state": { + "quantize_config": cfg.model_dump(), + }, + } + torch.save(state, "checkpoint.pt") + + loaded = torch.load("checkpoint.pt", weights_only=True) + restored_cfg = QuantizeConfig.model_validate( + loaded["modelopt_state"]["quantize_config"] + ) + +Persisting plain data keeps checkpoints independent of the original YAML files +and of the authoring-time import graph. Future readers need the schema, not the +source snippets. + +``ModeloptBaseConfig`` also registers subclasses as PyTorch safe globals, which +allows config objects to participate in safe deserialization. Plain-data +persistence remains the most portable form because it is easy to inspect, diff, +and migrate. + + +Composable YAML +=============== + +Python already has composition through variables, functions, imports, and +mutation. YAML does not. ModelOpt's YAML composition layer exists so repeated +YAML fragments can be shared without moving the canonical config into Python. + +Typical repeated fragments include: + +* one numeric format used by several quantizer entries; +* one complete quantizer-entry snippet reused in many configs; +* a list of quantizer entries reused as a unit; +* a snippet that depends on another snippet; +* related variants such as dynamic and static numeric formats. + +The chosen design is a small YAML-native DSL: a file-local ``imports`` mapping +binds names to YAML files, and inline ``$import`` references insert those +resolved snippets into the data tree. Python remains responsible for schema +validation; YAML remains data. + + +Alternatives considered +----------------------- + +Several other approaches can give YAML configs some form of composability. +Each was considered and rejected for ModelOpt's library-of-configs use case: + +* **Plain YAML anchors and aliases** reuse data inside one file but do not + compose across files and do not validate fragments independently. +* **Hard-coded Python registries** map well-known names like ``nvfp4`` to + Python-side constants. Adding a new fragment requires a Python edit, and + YAML can only reference what Python has pre-declared. +* **YAML files with Python-side name-to-file mappings** keep fragment data in + YAML, but the registration of each fragment still lives in Python. Adding a + new fragment requires both a YAML file and a Python edit. +* **General config frameworks such as OmegaConf and Hydra** provide deep merge + and ``${...}`` interpolation, but there is no native cross-file include + keyword, no native list-concatenation primitive, and the list + append-vs-splice rule must still come from somewhere ModelOpt-specific. + OmegaConf can be useful at the edges (for example for CLI dotted overrides + or environment-variable substitution applied after import resolution) but + is not sufficient as the composition primitive. +* **Python factory systems such as Fiddle or nemo_run** ``_factory_`` make + Python callables the canonical config representation. They are a good fit + when the audience is exclusively Python engineers and configs primarily + build runnable objects. They are a poor fit for ModelOpt because reusable + fragments are typically small typed values (numeric formats, quantizer-list + entries), persisting a factory-based config loses provenance unless the + on-disk format ties to Python qualified names, and Fiddle-style + ``@auto_config`` cannot return bare ``dict`` or ``list`` values without a + wrapper class that duplicates the Pydantic schema. + +ModelOpt uses a small YAML DSL instead: each file declares its own imports, +references them with ``$import``, and resolves to plain data before validation. +This keeps the import graph self-describing, lets config authors add reusable +fragments as YAML without Python edits, and still validates every resolved +value against Python schemas. The on-disk representation is plain YAML data, +so persisted configs do not depend on Python qualified names. + + +Import declarations +------------------- + +Imports are declared once per YAML file: + +.. code-block:: yaml + + imports: + nvfp4: configs/numerics/nvfp4 + kv_fp8: configs/ptq/units/kv_fp8 + +The names are scoped to that file. An imported snippet may declare its own +``imports`` block, and those names are scoped to the snippet file. Recursive +imports are resolved depth-first. Circular imports are detected using canonical +resolved paths and fail with ``ValueError``. + +A file that declares no ``imports`` may not contain ``$import`` markers. This +keeps authoring mistakes explicit: an unknown reference fails instead of being +left as literal data. + + +Dict imports +------------ + +When ``$import`` appears inside a mapping, the imported mapping is copied into +the current mapping. Inline keys override imported keys at that same mapping +level: + +.. code-block:: yaml + + cfg: + $import: nvfp4 + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + +Multiple imports are applied in order, then inline keys are applied last: + +.. code-block:: yaml + + cfg: + $import: [base_format, override_format] + axis: 0 + +The merge is shallow at the mapping where ``$import`` appears. If one nested +leaf changes, provide the complete nested value inline or define a named snippet +for that variant. This avoids hidden deep-merge rules that are hard to review. + + +List imports +------------ + +List imports are type-directed. For a containing list with schema ``list[T]``: + +* importing a snippet with schema ``list[T]`` splices all imported entries into + the containing list; +* importing a snippet with schema ``T`` appends the imported object as a single + list element; +* importing any other schema raises an error; +* importing into an untyped list raises an error. + +Example: + +.. code-block:: yaml + + quant_cfg: + - $import: base_disable_all # QuantizerCfgEntry, appended + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 # QuantizerAttributeConfig, dict import + - $import: kv_fp8 # QuantizerCfgListConfig, spliced + +A list-entry import must be a mapping whose only key is ``$import``. If an entry +needs local changes, either write that entry inline or create a snippet for the +variant. + + +Multi-document list snippets +---------------------------- + +A YAML file has one root node per document. A list-valued snippet that also +needs an ``imports`` block therefore uses two YAML documents: the first document +holds import declarations, and the second document holds the list payload. + +.. code-block:: yaml + + # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig + imports: + fp8: configs/numerics/fp8 + --- + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: fp8 + +Only ``imports`` from the first document is meaningful for a list snippet. The +loader resolves imports in the second document and returns the resolved list. + + +Composition error model +----------------------- + +The loader raises ``ValueError`` for invalid input. The full set of conditions +covers file-shape, schema declaration, and composition rules: + +File-shape errors: + +* the YAML file cannot be located on the filesystem or in built-in + ``modelopt_recipes``; +* a YAML file contains more than two documents; +* the root of a single-document file is not a mapping or a list; +* in a two-document file, the first document is not a mapping or the second + document is neither a mapping nor a list; +* multiple ``# modelopt-schema:`` comments are present in the preamble. + +Schema-declaration errors: + +* a schema path does not start with ``modelopt.``; +* a schema path is missing a module or attribute component, or it fails to + resolve to a real Python object; +* an imported snippet does not declare ``modelopt-schema``; +* an imported snippet does not validate against its declared schema. + +Composition errors: + +* ``imports`` is present but is not a mapping; +* an import path is empty; +* a ``$import`` reference appears in a file that declares no ``imports``; +* a ``$import`` name is not listed in the file-local ``imports`` mapping; +* a dict-form ``$import`` resolves to something other than a dict; +* a list import is used without a typed containing list; +* a list import schema is neither the containing list schema nor its element + schema; +* a circular import is detected (reported with the import chain). + +These failures are load-time errors by design. A composed config should either +resolve to valid plain data or fail before the owning optimization pass starts. + + +Consumers of the config system +============================== + +The config system is shared infrastructure. Current consumers include: + +* lower-level optimization configs such as PTQ ``QuantizeConfig``; +* built-in YAML config snippets under ``modelopt_recipes/configs`` (numeric + formats, reusable quantizer-entry units, model-level presets); +* higher-level recipes under ``modelopt_recipes/general`` and + ``modelopt_recipes/models``, which package metadata together with one or + more type-specific config sections. + +Recipes do not define separate config semantics. ``load_recipe()`` is a +consumer-specific wrapper that uses ``load_config()`` to resolve YAML, dispatches +on ``metadata.recipe_type`` to select the right recipe schema (PTQ today, plus +Eagle / DFlash / Medusa speculative-decoding variants), and returns a validated +``ModelOptRecipeBase`` subclass instance. The required body section depends on +the recipe type (``quantize`` for PTQ, ``eagle`` / ``dflash`` / ``medusa`` for +the speculative-decoding variants); ``metadata`` is required for all types. + +* A **file recipe** is a single YAML file with ``metadata`` and the + algorithm-specific body section. ``load_recipe()`` peeks at + ``metadata.recipe_type``, picks the matching recipe schema, and calls + ``load_config(file, schema_type=schema)`` so list-typed ``$import`` resolution + knows the element types. The returned object is a validated recipe instance + (for example a ``ModelOptPTQRecipe``). +* A **directory recipe** is a directory containing ``metadata.yml`` / + ``metadata.yaml`` and ``quantize.yml`` / ``quantize.yaml``. Each file is + loaded with its own schema (``RecipeMetadataConfig`` and ``QuantizeConfig``, + both ``ModeloptBaseConfig`` subclasses), and the recipe is assembled from the + validated sections. The directory form is currently PTQ-only; + speculative-decoding recipes use the single-file form. + +``load_recipe()`` also accepts an optional ``overrides`` argument: a list of +``key.path=value`` dotlist strings applied on top of the resolved YAML before +final Pydantic validation. Values are parsed with ``yaml.safe_load`` so +``foo.bar=true`` becomes a ``bool`` and ``axis=[0,1]`` becomes a ``list``. The +merge uses OmegaConf and is supported only for single-file recipes. + +The general contract remains the same: YAML authoring data resolves to plain +Python data, Python schemas validate the result, and validated configs are +returned as schema instances. Callers can move between dict and model views +through ``cfg.model_dump()`` and ``Schema.model_validate(data)``. + + +Authoring guidelines +==================== + +When adding config schemas or YAML files: + +* Put the canonical schema in Python, not in YAML comments or loader logic. +* Use ``ModeloptBaseConfig`` for structured config objects that need methods, + defaults, and validators. +* Use ``ModeloptBaseConfig`` subclasses or typed aliases for reusable snippets. +* Prefer self-contained YAML unless a fragment is reused or factoring materially + improves reviewability. +* Add ``# modelopt-schema: ...`` to every file that can be referenced from an + ``imports`` block. +* Keep top-level user config files free of schema comments unless they are also + intended to be imported as snippets. +* Use a concrete typed list schema for list snippets so append-vs-splice + behavior is unambiguous. +* Serialize resolved configs with ``model_dump()`` for long-term artifacts. +* Store plain config data, not authoring-time YAML paths, in checkpoints. +* Do not parse ModelOpt config YAML with raw YAML APIs in application code. Use + ``load_config()`` or a higher-level API built on it so imports, schema checks, + and ``eXmY`` conversion are applied consistently. diff --git a/docs/source/guides/_quant_cfg.rst b/docs/source/guides/_quant_cfg.rst index 6a027b74b00..5c740c453ab 100644 --- a/docs/source/guides/_quant_cfg.rst +++ b/docs/source/guides/_quant_cfg.rst @@ -18,25 +18,32 @@ patterns for composing quantization configurations. Overview ======== -A quantization config is a Python dictionary with two top-level keys: +A quantization config is a :class:`QuantizeConfig +` Pydantic model with two +top-level fields, typically authored as YAML or a Python ``dict``: .. code-block:: python config = { - "quant_cfg": [...], # ordered list of QuantizerCfgEntry dicts + "quant_cfg": [...], # ordered list of QuantizerCfgEntry entries "algorithm": "max", # calibration algorithm } The ``quant_cfg`` value is an **ordered list** of :class:`QuantizerCfgEntry -` dicts. Each entry targets a set of -quantizer modules in the model and specifies their configuration. +` entries. Each entry +targets a set of quantizer modules in the model and specifies their +configuration. Dict input is accepted and normalized to ``QuantizerCfgEntry`` +instances during validation; the result -- whether the input was YAML, dicts, +or instances -- is always a list of validated ``QuantizerCfgEntry`` objects. ---------- Entry Format ============ -Each entry in the list is a dictionary with the following fields: +Each entry is a :class:`QuantizerCfgEntry +` with the following +fields (authored as a YAML/dict mapping; validated into a Pydantic instance): .. list-table:: :header-rows: 1 @@ -55,9 +62,11 @@ Each entry in the list is a dictionary with the following fields: (e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class. * - ``cfg`` - No - - A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig - `, or a list of such dicts - for sequential quantization (see :ref:`sequential-quantizers`). + - Quantizer attributes typed as :class:`QuantizerAttributeConfig + `, or a + list of such for sequential quantization (see + :ref:`sequential-quantizers`). Authored as a mapping; validated into a + ``QuantizerAttributeConfig`` instance. * - ``enable`` - No - ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``. diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 875e78ceea6..0c17558d2e3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -113,6 +113,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, "fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG, "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, + "w4a16_nvfp4": mtq.W4A16_NVFP4_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, @@ -530,9 +531,10 @@ def load_model(args: argparse.Namespace): language_model = full_model else: if args.dataset is None: - args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] + args.dataset = ["cnn_nemotron_v2_mix"] warnings.warn( - "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." + "No dataset specified. Defaulting to the 'cnn_nemotron_v2_mix' combo " + "(cnn_dailymail + nemotron-post-training-dataset-v2)." ) # Adjust calib_size to match dataset length by extending or truncating as needed args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ @@ -785,6 +787,12 @@ def export_quantized( extra_state_dict=mtp_state_dict, ) + if args.qformat == "w4a16_nvfp4": + warnings.warn( + "TensorRT-LLM and SGLang do not support this format. " + "vLLM deployment support is in progress." + ) + # Restore default padding and export the tokenizer as well. if tokenizer is not None: tokenizer.padding_side = default_padding_side @@ -1147,7 +1155,7 @@ def _is_layerwise(obj): quant_cfg = copy.deepcopy(quant_cfg) force_weight_quantizers_static(quant_cfg["quant_cfg"]) - if args.qformat in QUANT_CFG_CHOICES: + if quant_cfg: mono_quantize( args, quant_cfg, @@ -1231,7 +1239,7 @@ def parse_args() -> argparse.Namespace: "This argument will be parsed and converted as a list of ints." ), type=str, - default="512", + default="1024", ) parser.add_argument( "--calib_seq", diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index f4fbf06c0e8..94932c787b8 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -265,7 +265,7 @@ def run_simple(args): type=str, required=False, default="EAGLE3", - choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"], + choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "DFLASH", "NONE"], help="Speculative algorithm to use", ) parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory") diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index d5ff890ffd7..2133ec937eb 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -43,44 +43,55 @@ def __init__( speculative_algorithm = "LOOKAHEAD" elif speculative_algorithm == "NONE": speculative_algorithm = None + + engine_kwargs = { + "model_path": model_dir, + "skip_tokenizer_init": True, + "trust_remote_code": kwargs.get("trust_remote_code", False), + "mem_fraction_static": kwargs.get("mem_fraction_static", 0.8), + "disable_overlap_schedule": kwargs.get("disable_overlap_schedule", False), + "tp_size": kwargs.get("tensor_parallel_size", 1), + "ep_size": kwargs.get("moe_expert_parallel_size", 1), + "torch_compile_max_bs": max_concurrent_requests, + "max_running_requests": max_concurrent_requests, + "attention_backend": kwargs.get("attention_backend"), + "enable_torch_compile": kwargs.get("enable_torch_compile", False), + "cuda_graph_max_bs": max_concurrent_requests, + "disable_cuda_graph": False, + } if speculative_algorithm is not None: # https://github.com/sgl-project/sglang/pull/3582 - self.model = sgl.Engine( - model_path=model_dir, - skip_tokenizer_init=True, - trust_remote_code=kwargs.get("trust_remote_code", False), - mem_fraction_static=0.8, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), - tp_size=kwargs.get("tensor_parallel_size", 1), - ep_size=kwargs.get("moe_expert_parallel_size", 1), - speculative_algorithm=speculative_algorithm, - speculative_num_steps=kwargs.get("speculative_num_steps", 3), - speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), - speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), - speculative_draft_model_path=kwargs.get("draft_model_dir"), - torch_compile_max_bs=max_concurrent_requests, - max_running_requests=max_concurrent_requests, - attention_backend=kwargs.get("attention_backend"), - enable_torch_compile=kwargs.get("enable_torch_compile", False), - cuda_graph_max_bs=max_concurrent_requests, - disable_cuda_graph=False, - ) - else: - self.model = sgl.Engine( - model_path=model_dir, - skip_tokenizer_init=True, - trust_remote_code=kwargs.get("trust_remote_code", False), - mem_fraction_static=0.8, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), - tp_size=kwargs.get("tensor_parallel_size", 1), - ep_size=kwargs.get("moe_expert_parallel_size", 1), - torch_compile_max_bs=max_concurrent_requests, - max_running_requests=max_concurrent_requests, - attention_backend=kwargs.get("attention_backend"), - enable_torch_compile=kwargs.get("enable_torch_compile", False), - cuda_graph_max_bs=max_concurrent_requests, - disable_cuda_graph=False, - ) + engine_kwargs["speculative_algorithm"] = speculative_algorithm + engine_kwargs["speculative_draft_model_path"] = kwargs.get("draft_model_dir") + if speculative_algorithm == "DFLASH": + # Avoid CUDA-graph bucket-padding mismatches during DFLASH replay. + engine_kwargs["disable_cuda_graph_padding"] = True + engine_kwargs["speculative_num_draft_tokens"] = kwargs.get( + "speculative_num_draft_tokens", 8 + ) + if "speculative_dflash_draft_window_size" in kwargs: + engine_kwargs["speculative_dflash_draft_window_size"] = kwargs[ + "speculative_dflash_draft_window_size" + ] + print( + f"[specdec_bench] DFLASH ignores --draft_length / speculative_num_steps / " + f"speculative_eagle_topk; effective draft block = " + f"speculative_num_draft_tokens={engine_kwargs['speculative_num_draft_tokens']}. " + f"To override, set `speculative_num_draft_tokens` under engine_args in the " + f"--runtime_params YAML (no CLI flag)." + ) + else: + engine_kwargs["speculative_num_draft_tokens"] = kwargs.get( + "speculative_num_draft_tokens", 4 + ) + engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3) + engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1) + + # extra engine arg needed for qwen3.5 + if "mamba_scheduler_strategy" in kwargs: + engine_kwargs["mamba_scheduler_strategy"] = kwargs["mamba_scheduler_strategy"] + + self.model = sgl.Engine(**engine_kwargs) self.sampling_config = sampling_kwargs diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index 2e312e7aec8..fc595c1d579 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -63,6 +63,12 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs "method": "mtp", "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), } + elif kwargs.get("speculative_algorithm") == "DFLASH": + specdec = { + "method": "dflash", + "model": kwargs.get("draft_model_dir"), + "num_speculative_tokens": kwargs.get("speculative_num_draft_tokens", 8), + } elif kwargs.get("speculative_algorithm") == "NONE": specdec = None diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 7e399cf9603..1a23bb56f10 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -30,13 +30,11 @@ # limitations under the License. import argparse +import dataclasses import os -from dataclasses import dataclass, field -from typing import Literal import torch import transformers -from accelerate import ParallelismConfig from eagle_utils import ( EagleTrainerWithAccLog, EagleTrainingPlot, @@ -44,200 +42,126 @@ make_speculative_data_module, patch_ring_attention_for_ttt, ) -from omegaconf import OmegaConf +from rich.pretty import pprint from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.config import DFlashConfig, EagleConfig +from modelopt.recipe import load_recipe +from modelopt.recipe.config import ( + ModelOptDFlashRecipe, + ModelOptEagleRecipe, + ModelOptMedusaRecipe, + ModelOptSpeculativeRecipeBase, +) +from modelopt.torch.speculative.plugins.hf_training_args import ( + TrainingArguments as SpecTrainingArgs, +) from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.distributed import is_master torch.manual_seed(0) mto.enable_huggingface_checkpointing() -@dataclass -class ModelArguments: - model_name_or_path: str | None = field( - default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - metadata={"help": "HuggingFace model ID or local path to the base model."}, - ) - use_fake_base_for_offline: bool = field( - default=False, - metadata={ - "help": "Load model architecture without real base weights. Offline training only." - }, - ) - trust_remote_code: bool = field( - default=False, metadata={"help": "Trust remote code when loading model."} - ) - - -@dataclass -class DataArguments: - data_path: str = field( - default=None, - metadata={"help": "Path to the online training data."}, - ) - offline_data_path: str = field( - default=None, - metadata={ - "help": "Path to offline training data directory (.pt files). This argument enables offline mode.", - }, - ) - lazy_preprocess: bool = True - draft_vocab_cache: str | None = field( - default=None, - metadata={"help": "Path to draft vocabulary cache file."}, - ) - chat_template: str = field( - default=None, - metadata={ - "help": "Jinja chat template with {% generation %} tags for answer_only_loss. " - "If not set, the tokenizer's built-in template is used (must already have generation tags)." - }, - ) - vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) - vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) - sample_size: int = field( - default=-1, - metadata={"help": "Number of samples to use for training. Use -1 to use all samples."}, - ) - - def __post_init__(self): - if self.sample_size == 0 or self.sample_size < -1: - raise ValueError("sample_size must be -1 (use all samples) or a positive integer") - - -@dataclass -class TrainingArguments(transformers.TrainingArguments): - training_seq_len: int = field( - default=2048, - metadata={ - "help": ( - "Training sequence length. Sequences will be right padded or truncated to this length." - ) - }, - ) - mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" - estimate_ar: bool = field( - default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} - ) - ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."}) - answer_only_loss: bool = field( - default=False, - metadata={ - "help": "Mask loss on non-assistant tokens. Requires a chat_template with generation tags." - }, - ) - cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) - dp_shard_size: int | None = field( - default=None, - metadata={"help": "Data parallelism shard size. None = auto (total_gpu / cp_size)."}, - ) - - -@dataclass -class MedusaArguments: - medusa_num_heads: int | None = field(default=1) - medusa_num_layers: int | None = field(default=1) +# HF-compatible TrainingArguments with our speculative-decoding extensions, auto-derived +# from :class:`SpecTrainingArgs` so its field set can't drift from the Pydantic recipe schema. +# Used at runtime as ``HfTrainingArguments(**recipe.training.model_dump())`` to obtain a +# ``transformers.Trainer``-compatible dataclass. +HfTrainingArguments = dataclasses.make_dataclass( + "HfTrainingArguments", + [ + (name, fi.annotation, dataclasses.field(default=fi.default)) + for name, fi in SpecTrainingArgs.model_fields.items() + ], + bases=(transformers.TrainingArguments,), +) def _parse_cli() -> tuple[str, list[str]]: - """Parse --config (required) from argv; return remaining args as config overrides. + """Parse --config (required) from argv; return remaining args as dotlist overrides. - Extra arguments use OmegaConf dotlist syntax, e.g. + Extra positional args use dotlist syntax, e.g. ``model.model_name_or_path=meta-llama/Llama-3.2-1B training.output_dir=ckpts/test``. """ p = argparse.ArgumentParser(add_help=False) - p.add_argument("--config", required=True, help="Path to the YAML config file.") + p.add_argument( + "--config", + required=True, + help=( + "Path to a modelopt speculative-decoding recipe YAML " + "(speculative_eagle / speculative_dflash / speculative_medusa)." + ), + ) args, overrides = p.parse_known_args() return args.config, overrides -def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict, dict]: - """Load training config from a YAML file with sections: model, data, training, eagle/dflash. - - *overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``) - applied on top of the YAML. +def init_distributed_env(training_args: transformers.TrainingArguments) -> None: + """Resolve dp_shard_size from the live env and attach a ParallelismConfig in-place. - Returns: - hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() - eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() - dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert() + Reads ``WORLD_SIZE`` / ``torch.cuda.device_count()`` and (when actually distributed) + builds an ``accelerate.ParallelismConfig`` on ``training_args``. Kept out of the + Pydantic schema so the recipe stays a pure declarative spec. """ - merged = OmegaConf.load(config_path) - if overrides: - merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) - cfg = OmegaConf.to_container(merged, resolve=True) - - # Eagle/DFlash sections map directly to config fields — no field enumeration needed. - eagle_cfg = cfg.get("eagle", {}) - dflash_cfg = cfg.get("dflash", {}) - - hf_cfg = { - **cfg.get("model", {}), - **cfg.get("data", {}), - **cfg.get("training", {}), - } - - if hf_cfg.get("dp_shard_size") is None: - cp_size = hf_cfg.get("cp_size", 1) - # Use WORLD_SIZE (total GPUs across all nodes) when available, else local GPU count. - world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) - hf_cfg["dp_shard_size"] = world_size // cp_size - - return hf_cfg, eagle_cfg, dflash_cfg - - -def train(): - config_path, overrides = _parse_cli() - hf_cfg, eagle_cfg, dflash_cfg = _load_config(config_path, overrides) - - parser = transformers.HfArgumentParser( - ( - ModelArguments, - DataArguments, - TrainingArguments, - MedusaArguments, - ) - ) - model_args, data_args, training_args, medusa_args = parser.parse_dict( - hf_cfg, allow_extra_keys=True - ) - - if not data_args.data_path and not data_args.offline_data_path: + if training_args.cp_size < 1: + raise ValueError(f"cp_size must be >= 1, got {training_args.cp_size}.") + world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count() or 1)) + if training_args.dp_shard_size is None: + training_args.dp_shard_size = world_size // training_args.cp_size + if training_args.dp_shard_size < 1: raise ValueError( - "Either data.data_path or data.offline_data_path must be set in the config." + f"dp_shard_size resolved to {training_args.dp_shard_size}; " + f"WORLD_SIZE ({world_size}) must be >= cp_size ({training_args.cp_size})." ) + if training_args.cp_size > 1 or training_args.dp_shard_size > 1: - # Auto-compute dp_replicate_size so that - # dp_replicate_size * dp_shard_size * cp_size == world_size. - # Note: torch.cuda.device_count() returns per-node GPU count, not world_size. - # WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total. - world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) parallel_size = training_args.dp_shard_size * training_args.cp_size if world_size % parallel_size != 0: raise ValueError( f"world_size ({world_size}) must be divisible by " - f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) " - f"= {parallel_size}" + f"dp_shard_size ({training_args.dp_shard_size}) * " + f"cp_size ({training_args.cp_size}) = {parallel_size}" ) - dp_replicate_size = world_size // parallel_size + try: + from accelerate import ParallelismConfig + except ImportError as e: + raise ImportError( + "cp_size>1 or dp_shard_size>1 requires `accelerate` for ParallelismConfig. " + "Install it via `pip install accelerate`." + ) from e training_args.parallelism_config = ParallelismConfig( cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size, - dp_replicate_size=dp_replicate_size, + dp_replicate_size=world_size // parallel_size, + ) + + +def train(): + config_path, overrides = _parse_cli() + recipe = load_recipe(config_path, overrides=overrides) + if not isinstance(recipe, ModelOptSpeculativeRecipeBase): + raise ValueError( + f"main.py expects a speculative-decoding recipe (eagle / dflash / medusa); " + f"got {type(recipe).__name__} from {config_path!r}." + ) + + # Pydantic-typed sections flow straight through as *_args; only TrainingArguments is + # reconstructed as an HF dataclass so it can be handed to transformers.Trainer. + training_args = HfTrainingArguments(**recipe.training.model_dump()) + init_distributed_env(training_args) + + if not recipe.data.data_path and not recipe.data.offline_data_path: + raise ValueError( + "Either data.data_path or data.offline_data_path must be set in the config." ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 training_args.parallelism_config.sp_backend = None - print_rank_0( - f"arguments: {model_args}, {training_args}, {medusa_args}, " - f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}" - ) + if is_master(): + pprint(recipe) # Detect checkpoint to resume from last_checkpoint = ( @@ -250,80 +174,58 @@ def train(): checkpoint = training_args.resume_from_checkpoint or last_checkpoint - use_offline_training = data_args.offline_data_path is not None + use_offline_training = recipe.data.offline_data_path is not None if checkpoint: with patch_transformers5_params_loading(): model = load_vlm_or_llm( - checkpoint, dtype="auto", trust_remote_code=model_args.trust_remote_code + checkpoint, dtype="auto", trust_remote_code=recipe.model.trust_remote_code ) tokenizer = transformers.AutoTokenizer.from_pretrained( - checkpoint, trust_remote_code=model_args.trust_remote_code + checkpoint, trust_remote_code=recipe.model.trust_remote_code ) else: + model_name_or_path = recipe.model.model_name_or_path + if model_name_or_path is None: + raise ValueError( + "model.model_name_or_path must be set in the recipe YAML or via a dotlist override." + ) # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). - if use_offline_training: - # Load config first to preserve original num_hidden_layers before - # load_vlm_or_llm may reduce layers for offline space savings. - model_config = transformers.AutoConfig.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - ) model = load_vlm_or_llm( - model_args.model_name_or_path, - use_fake_base=model_args.use_fake_base_for_offline, + model_name_or_path, + use_fake_base=recipe.model.use_fake_base_for_offline, use_offline_training=use_offline_training, dtype="auto", device_map="cpu", - trust_remote_code=model_args.trust_remote_code, + trust_remote_code=recipe.model.trust_remote_code, ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings. - # Some models (e.g. Kimi-K2.5) use non-standard config attributes, - # so fall back to the model's own config if the attribute is missing. - model.config.num_orig_hidden_layers = getattr( - model_config, "num_hidden_layers", model.config.num_hidden_layers - ) - if hasattr(model.config, "layer_types"): - del ( - model.config.layer_types - ) # remove layer_types to avoid mismatch with the modified model tokenizer = transformers.AutoTokenizer.from_pretrained( - model_args.model_name_or_path, + model_name_or_path, model_max_length=training_args.training_seq_len, - trust_remote_code=model_args.trust_remote_code, + trust_remote_code=recipe.model.trust_remote_code, ) - if training_args.mode == "medusa": - config = { - "medusa_num_heads": medusa_args.medusa_num_heads, - "medusa_num_layers": medusa_args.medusa_num_layers, - } - mtsp.convert(model, [("medusa", config)]) - elif training_args.mode == "eagle3": - # Validate and rewrite eagle config fields - eagle_cfg = EagleConfig.model_validate( - eagle_cfg, - context={"training_args": training_args, "data_args": data_args}, - ).model_dump() + if isinstance(recipe, ModelOptMedusaRecipe): + medusa_cfg: dict = recipe.medusa.model_dump() + mtsp.convert(model, [("medusa", medusa_cfg)]) + elif isinstance(recipe, ModelOptEagleRecipe): + eagle_cfg: dict = recipe.eagle.model_dump() mtsp.convert(model, [("eagle", eagle_cfg)]) - - # Load draft vocab cache if the draft model uses a compressed vocabulary - if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: - if not os.path.isfile(data_args.draft_vocab_cache): - raise FileNotFoundError( - f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" - ) - model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) - print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") - elif training_args.mode == "dflash": - dflash_cfg = DFlashConfig.model_validate( - dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args} - ).model_dump() + # Load draft vocab cache + mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache) + elif isinstance(recipe, ModelOptDFlashRecipe): + # Fall back to tokenizer.mask_token_id when not set in the recipe; require one of the two. + if recipe.dflash.dflash_mask_token_id is None: + recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None) + if recipe.dflash.dflash_mask_token_id is None: + raise ValueError( + "dflash.dflash_mask_token_id is required: set it in the recipe YAML " + "or use a tokenizer that defines mask_token_id." + ) + dflash_cfg: dict = recipe.dflash.model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) else: - raise Exception(f"{training_args.mode} is not supported!") + raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}") # Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast # them. We iterate named_buffers and reassign via the owning module to @@ -340,19 +242,22 @@ def train(): setattr(mod, parts[-1], buf.to(_target_dev)) print_rank_0("Loading dataset...") - is_dflash = training_args.mode == "dflash" - if training_args.mode in ("eagle3", "medusa", "dflash"): - data_module = make_speculative_data_module( - tokenizer, - data_args, - train_len=training_args.training_seq_len, - answer_only_loss=training_args.answer_only_loss, - shift_labels=not is_dflash, - ) + is_dflash = isinstance(recipe, ModelOptDFlashRecipe) + data_module = make_speculative_data_module( + tokenizer, + recipe.data, + train_len=training_args.training_seq_len, + answer_only_loss=training_args.answer_only_loss, + shift_labels=not is_dflash, + ) callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)] - if eagle_cfg.get("eagle_base_lora") and eagle_cfg.get("eagle_base_lora_warmup_steps", 0) > 0: - callbacks.append(LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"])) + if ( + isinstance(recipe, ModelOptEagleRecipe) + and recipe.eagle.eagle_base_lora + and recipe.eagle.eagle_base_lora_warmup_steps > 0 + ): + callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps)) trainer = EagleTrainerWithAccLog( model=model, diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 54ca93d5388..d6c8c4c1e9a 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"): else: raise ValueError(f"Unsupported precision: {precision}") - quant_cfg_list: list = [ - e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e - ] + quant_cfg_list: list = [e for e in quant_cfg["quant_cfg"] if "quantizer_name" in e] if lm_head_precision == "fp8": quant_cfg_list.append( diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 96f33012afd..749d80a933d 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -17,30 +17,56 @@ from __future__ import annotations +import warnings from enum import Enum -from pydantic import field_validator -from typing_extensions import NotRequired, TypedDict +from pydantic import Field, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField -from modelopt.torch.quantization.config import QuantizeConfig +from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001 +from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig +from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs +from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs +from modelopt.torch.speculative.plugins.hf_training_args import ( + TrainingArguments as SpecTrainingArgs, +) class RecipeType(str, Enum): - """List of recipe types.""" + """List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping.""" PTQ = "ptq" + SPECULATIVE_EAGLE = "speculative_eagle" + SPECULATIVE_DFLASH = "speculative_dflash" + SPECULATIVE_MEDUSA = "speculative_medusa" # QAT = "qat" # Not implemented yet, will be added in the future. -class RecipeMetadataConfig(TypedDict): +_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." + + +class RecipeMetadataConfig(ModeloptBaseConfig): """YAML shape of the recipe metadata section.""" - recipe_type: RecipeType - description: NotRequired[str] + recipe_type: RecipeType = Field( + title="Recipe type", + description="The type of the recipe (e.g. PTQ).", + ) + description: str = ModeloptField( + default=_DEFAULT_RECIPE_DESCRIPTION, + title="Description", + description="Human-readable description of the recipe.", + ) -_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." +def _metadata_field(recipe_type: RecipeType): + """Build the metadata Pydantic field with the recipe_type baked into the default.""" + return ModeloptField( + default={"recipe_type": recipe_type, "description": _DEFAULT_RECIPE_DESCRIPTION}, + title="Metadata", + description="Recipe metadata containing the recipe type and description.", + validate_default=True, + ) class ModelOptRecipeBase(ModeloptBaseConfig): @@ -49,41 +75,131 @@ class ModelOptRecipeBase(ModeloptBaseConfig): If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``. """ - metadata: RecipeMetadataConfig = ModeloptField( - default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION}, + metadata: RecipeMetadataConfig = Field( title="Metadata", - description="Recipe metadata containing the recipe type and description.", - validate_default=True, + description="Recipe metadata containing the recipe type and description. " + "Required: a recipe without a ``metadata`` section is rejected so that a " + "missing section can't silently fall back to a default recipe type.", ) - @field_validator("metadata") - @classmethod - def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig: - """Validate recipe metadata and fill defaults for optional fields.""" - if metadata["recipe_type"] not in RecipeType: - raise ValueError( - f"Unsupported recipe type: {metadata['recipe_type']}. " - f"Only {list(RecipeType)} are currently supported." - ) - return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata} - @property def recipe_type(self) -> RecipeType: """Return the recipe type from metadata.""" - return self.metadata["recipe_type"] + return self.metadata.recipe_type @property def description(self) -> str: """Return the recipe description from metadata.""" - return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION) + return self.metadata.description class ModelOptPTQRecipe(ModelOptRecipeBase): """Our config class for PTQ recipes.""" - quantize: QuantizeConfig = ModeloptField( - default=QuantizeConfig(), + quantize: QuantizeConfig = Field( title="PTQ config", - description="PTQ config containing quant_cfg and algorithm.", + description="PTQ config containing quant_cfg and algorithm. Required: a PTQ " + "recipe without a ``quantize`` section is rejected so that a missing section " + "can't silently fall back to the default INT8 config.", + ) + + +class ModelOptSpeculativeRecipeBase(ModelOptRecipeBase): + """Base class for speculative-decoding recipes. + + Unlike PTQ, speculative-decoding is a training-time optimization: the draft head is trained + with HF Trainer. We therefore bundle ``model`` / ``data`` / ``training`` sections into the + recipe so a single YAML is the full experiment spec. Each section is a typed Pydantic model + (see :mod:`modelopt.torch.speculative.plugins.hf_training_args`) so field typos and bad + values are caught at recipe-load time; HF trainer fields pass through + ``TrainingArguments`` via ``extra='allow'``. + """ + + model: SpecModelArgs = ModeloptField( + default=SpecModelArgs(), + title="HF model args", + description="ModelArguments for the base HF model to train a draft head against.", + validate_default=True, + ) + data: SpecDataArgs = ModeloptField( + default=SpecDataArgs(), + title="HF data args", + description="DataArguments for the training/offline dataset.", + validate_default=True, + ) + training: SpecTrainingArgs = ModeloptField( + default=SpecTrainingArgs(), + title="HF training args", + description="Speculative-decoding extensions; HF trainer fields flow through as extras.", validate_default=True, ) + + +class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase): + """Our config class for EAGLE speculative decoding recipes.""" + + metadata: RecipeMetadataConfig = _metadata_field(RecipeType.SPECULATIVE_EAGLE) + + eagle: EagleConfig = ModeloptField( + default=EagleConfig(), + title="EAGLE config", + description="EAGLE speculative decoding configuration.", + validate_default=True, + ) + + @model_validator(mode="after") + def _derive_eagle_offline(self) -> ModelOptEagleRecipe: + self.eagle.eagle_offline = self.data.offline_data_path is not None + return self + + @model_validator(mode="after") + def _warn_rope_vs_training_seq_len(self) -> ModelOptEagleRecipe: + orig_max_pos = self.eagle.eagle_export_rope_scaling.get("original_max_position_embeddings") + if orig_max_pos is not None and orig_max_pos != self.training.training_seq_len: + warnings.warn( + f"eagle.eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) " + f"differs from training.training_seq_len ({self.training.training_seq_len}). " + f"This may affect long-context inference quality." + ) + return self + + +class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase): + """Our config class for DFlash speculative decoding recipes.""" + + metadata: RecipeMetadataConfig = _metadata_field(RecipeType.SPECULATIVE_DFLASH) + + dflash: DFlashConfig = ModeloptField( + default=DFlashConfig(), + title="DFlash config", + description="DFlash speculative decoding configuration.", + validate_default=True, + ) + + @model_validator(mode="after") + def _derive_dflash_offline(self) -> ModelOptDFlashRecipe: + self.dflash.dflash_offline = self.data.offline_data_path is not None + return self + + +class ModelOptMedusaRecipe(ModelOptSpeculativeRecipeBase): + """Our config class for Medusa speculative decoding recipes.""" + + metadata: RecipeMetadataConfig = _metadata_field(RecipeType.SPECULATIVE_MEDUSA) + + medusa: MedusaConfig = ModeloptField( + default=MedusaConfig(), + title="Medusa config", + description="Medusa speculative decoding configuration.", + validate_default=True, + ) + + +# Single source of truth mapping YAML ``metadata.recipe_type`` to its schema class. The loader +# uses this for typed-list ``$import`` resolution; add a new entry when introducing a recipe. +RECIPE_TYPE_TO_CLASS: dict[RecipeType, type[ModelOptRecipeBase]] = { + RecipeType.PTQ: ModelOptPTQRecipe, + RecipeType.SPECULATIVE_EAGLE: ModelOptEagleRecipe, + RecipeType.SPECULATIVE_DFLASH: ModelOptDFlashRecipe, + RecipeType.SPECULATIVE_MEDUSA: ModelOptMedusaRecipe, +} diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 9c3c40856d2..0a9218ff7d0 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -21,14 +21,32 @@ from importlib.abc import Traversable from pathlib import Path +from omegaconf import OmegaConf + from modelopt.torch.opt.config_loader import BUILTIN_CONFIG_ROOT as BUILTIN_RECIPES_LIB from modelopt.torch.opt.config_loader import load_config from modelopt.torch.quantization.config import QuantizeConfig -from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeMetadataConfig, RecipeType +from .config import ( + RECIPE_TYPE_TO_CLASS, + ModelOptPTQRecipe, + ModelOptRecipeBase, + RecipeMetadataConfig, + RecipeType, +) __all__ = ["load_config", "load_recipe"] +# Each recipe type's mandatory top-level body section. Checked at the loader level (on the +# raw YAML, before pydantic fills in defaults) so the user sees a clear "PTQ recipe file X +# must contain 'quantize'" instead of pydantic's generic missing-field error. +_REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = { + RecipeType.PTQ: "quantize", + RecipeType.SPECULATIVE_EAGLE: "eagle", + RecipeType.SPECULATIVE_DFLASH: "dflash", + RecipeType.SPECULATIVE_MEDUSA: "medusa", +} + def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traversable: """Resolve a recipe path, checking the built-in library first then the filesystem. @@ -52,17 +70,29 @@ def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traver return recipe_path -def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase: - """Load a recipe from a YAML file or directory. +def load_recipe( + recipe_path: str | Path | Traversable, + overrides: list[str] | None = None, +) -> ModelOptRecipeBase: + """Load a recipe from a YAML file or directory, with optional CLI-style overrides. ``recipe_path`` can be: - * A ``.yml`` / ``.yaml`` file with ``metadata`` and ``quantize`` sections. - The suffix may be omitted and will be probed automatically. - * A directory containing ``metadata.yml`` and ``quantize.yml``. + * A ``.yml`` / ``.yaml`` file with ``metadata`` and one of ``quantize`` (PTQ), + ``eagle`` (EAGLE speculative decoding), ``dflash`` (DFlash speculative + decoding) or ``medusa`` (Medusa speculative decoding) sections. The suffix + may be omitted and will be probed automatically. + * A directory containing ``metadata.yml`` and ``quantize.yml`` — + **PTQ recipes only**. Speculative-decoding recipes are always single YAML files. The path may be relative to the built-in recipes library or an absolute / relative filesystem path. + + ``overrides`` is an optional list of ``key.path=value`` dotlist entries applied + on top of the YAML before Pydantic validation. Values are parsed with + ``yaml.safe_load`` so they get proper types (``foo.bar=true`` → bool, ``foo=1`` + → int, ``foo=[1,2]`` → list, etc.). Only supported when *recipe_path* is a + single YAML file. """ resolved = _resolve_recipe_path(recipe_path) @@ -75,44 +105,98 @@ def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase: print(f"[load_recipe] loading: {_display}") if resolved.is_file(): - return _load_recipe_from_file(resolved) + return _load_recipe_from_file(resolved, overrides=overrides) if resolved.is_dir(): + if overrides: + raise ValueError( + "overrides are not supported for directory-format recipes; " + "use the single-YAML-file form instead." + ) return _load_recipe_from_dir(resolved) raise ValueError(f"Recipe path {recipe_path!r} is not a valid YAML file or directory.") -def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBase: - """Load a recipe from a YAML file. +def _apply_dotlist(data: dict, overrides: list[str]) -> dict: + """Merge ``a.b.c=value`` command line overrides on top of ``data`` via OmegaConf.""" + for entry in overrides: + if "=" not in entry: + raise ValueError(f"Invalid override (missing '='): {entry!r}") + merged = OmegaConf.merge( + OmegaConf.create(data), + OmegaConf.from_dotlist(list(overrides)), + ) + return OmegaConf.to_container(merged, resolve=False) - The file must contain a ``metadata`` section with at least ``recipe_type``, - plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. + +def _peek_recipe_type(recipe_file: Path | Traversable) -> RecipeType | None: + """Extract ``metadata.recipe_type`` from a recipe YAML without resolving $imports. + + Needed so :func:`load_config` can be called with the correct ``schema_type`` for + typed-list ``$import`` resolution before the full recipe is constructed. """ - data = load_config(recipe_file, schema_type=ModelOptPTQRecipe) - if not isinstance(data, dict): - raise ValueError( - f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}." - ) + import yaml - metadata = data.get("metadata", {}) - if not isinstance(metadata, dict): - raise ValueError( - f"Recipe file {recipe_file} field 'metadata' must be a mapping, " - f"got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") + try: + raw = yaml.safe_load(recipe_file.read_text()) + return RecipeType(raw["metadata"]["recipe_type"]) + except (TypeError, KeyError, ValueError): + return None - if recipe_type == RecipeType.PTQ: - if "quantize" not in data: - raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.") - return ModelOptPTQRecipe( - metadata=metadata, - quantize=data["quantize"], + +def _load_recipe_from_file( + recipe_file: Path | Traversable, + overrides: list[str] | None = None, +) -> ModelOptRecipeBase: + """Load a recipe from a YAML file, optionally applying dotlist overrides. + + The file must contain a ``metadata`` section with at least ``recipe_type``, + plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``). + """ + rtype = _peek_recipe_type(recipe_file) + if rtype is None: + raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") + schema_class = RECIPE_TYPE_TO_CLASS.get(rtype) + if schema_class is None: + raise ValueError(f"Unsupported recipe type: {rtype!r}") + + # Pre-flight check on the *raw* YAML so the user sees a clear loader-level error + # rather than a generic pydantic missing-field error. Speculative recipes' body + # sections have field-level defaults, so this check is what keeps their loader + # semantics consistent with PTQ. + required_section = _REQUIRED_SECTION_PER_RECIPE_TYPE.get(rtype) + if required_section is not None: + import yaml + + raw = yaml.safe_load(recipe_file.read_text()) or {} + if not isinstance(raw, dict) or required_section not in raw: + kind = ( + rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper() + ) + raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.") + + # Passing ``schema_type=schema_class`` to ``load_config`` enables typed-list + # ``$import`` resolution (e.g. ``$import: disable_all`` spliced into + # ``quantize.quant_cfg`` needs to know the list's element schema is + # :class:`QuantizerCfgEntry`). The return value is already a validated schema + # instance. + if overrides: + # Overrides have to be applied before pydantic validation. Round-trip through + # ``model_dump()`` so $imports are resolved and the dict has the resolved shape; + # then splice the dotlist values and re-validate. + recipe = load_config(recipe_file, schema_type=schema_class) + data = recipe.model_dump() + data = _apply_dotlist(data, overrides) + return schema_class.model_validate(data) + + recipe = load_config(recipe_file, schema_type=schema_class) + if not isinstance(recipe, schema_class): + raise ValueError( + f"Recipe file {recipe_file} must produce a {schema_class.__name__}, " + f"got {type(recipe).__name__}." ) - raise ValueError(f"Unsupported recipe type: {recipe_type!r}") + return recipe def _find_recipe_section_file( @@ -137,25 +221,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: quantize. """ metadata_file = _find_recipe_section_file(recipe_dir, "metadata") - metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig) - if not isinstance(metadata, dict): - raise ValueError( - f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.") - if recipe_type == RecipeType.PTQ: + if metadata.recipe_type == RecipeType.PTQ: quantize_file = _find_recipe_section_file(recipe_dir, "quantize") - quantize_data = load_config(quantize_file, schema_type=QuantizeConfig) - if not isinstance(quantize_data, dict): - raise ValueError( - f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}." - ) - return ModelOptPTQRecipe( - metadata=metadata, - quantize=quantize_data, - ) - raise ValueError(f"Unsupported recipe type: {recipe_type!r}") + quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig) + return ModelOptPTQRecipe(metadata=metadata, quantize=quantize_cfg) + raise ValueError(f"Unsupported recipe type: {metadata.recipe_type!r}") diff --git a/modelopt/torch/export/convert_hf_config.py b/modelopt/torch/export/convert_hf_config.py index 5f8c3f3b55c..06e5923a30f 100644 --- a/modelopt/torch/export/convert_hf_config.py +++ b/modelopt/torch/export/convert_hf_config.py @@ -57,6 +57,11 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None) return { "weights": {"dynamic": False, "num_bits": 4, "type": "int", "group_size": gs}, } + elif quant_algo == "W4A16_NVFP4": + gs = group_size or 16 + return { + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, + } elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"): gs = group_size or 128 return { @@ -183,6 +188,14 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An "targets": ["Linear"], } new_config["config_groups"] = {"group_0": config_group_details} + elif quant_algo_value == "W4A16_NVFP4": + # Weight-only FP4 + group_size = original_quantization_details.get("group_size", 16) + config_group_details = { + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size}, + "targets": ["Linear"], + } + new_config["config_groups"] = {"group_0": config_group_details} elif quant_algo_value == "MIXED_PRECISION": quantized_layers = original_quantization_details.get("quantized_layers", {}) diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index dce39767c76..5f92cc2e5dc 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -38,6 +38,7 @@ QUANTIZATION_MXFP4 = "mxfp4" QUANTIZATION_MXFP8 = "mxfp8" QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8" +QUANTIZATION_W4A16_NVFP4 = "w4a16_nvfp4" QUANTIZATION_NVFP4_AWQ = "nvfp4_awq" QUANTIZATION_FP8_PB_REAL = "fp8_pb_real" QUANTIZATION_FP8_PB_WO = "fp8_pb_wo" diff --git a/modelopt/torch/export/plugins/mcore_deepseek.py b/modelopt/torch/export/plugins/mcore_deepseek.py index d02259e3530..2f9ef40f08e 100644 --- a/modelopt/torch/export/plugins/mcore_deepseek.py +++ b/modelopt/torch/export/plugins/mcore_deepseek.py @@ -43,6 +43,10 @@ "linear_kv_up_proj": NameRemapping("model.layers.{}.self_attn.kv_b_proj."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + # Fused TE spec (mirrors the import side). MLA has no linear_qkv so + # fused_input_layernorm is inert today; fused_pre_mlp_layernorm reaches dense layers. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), # MLP for dense layers "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), @@ -88,6 +92,11 @@ "output_layer": NameRemapping("lm_head.", COL_TP), # Per-layer "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + # MLA has no linear_qkv so fused_input_layernorm is inert for DeepSeek today; included + # for parity in case a future spec fuses the layernorm into a Q/KV projection. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_q_proj": NameRemapping("model.layers.{}.self_attn.q_proj.", COL_TP), "linear_q_down_proj": NameRemapping("model.layers.{}.self_attn.q_a_proj.", REPLICATE), "linear_q_layernorm": NameRemapping("model.layers.{}.self_attn.q_a_layernorm.", REPLICATE), diff --git a/modelopt/torch/export/plugins/mcore_gptoss.py b/modelopt/torch/export/plugins/mcore_gptoss.py index c16347fbf0b..989aa7e67d7 100644 --- a/modelopt/torch/export/plugins/mcore_gptoss.py +++ b/modelopt/torch/export/plugins/mcore_gptoss.py @@ -31,6 +31,8 @@ gptoss_causal_lm_export: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("model.embed_tokens."), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + # MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), "softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks"), @@ -52,6 +54,10 @@ gptoss_causal_lm_import: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("model.embed_tokens.", COL_TP), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + # gpt-oss is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's + # fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), "softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks", COL_TP), diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index 7fb8ec76acf..80a5d9146a9 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -37,11 +37,13 @@ llama_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("model.embed_tokens."), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), # KV cache quant export "core_attention": SelfAttentionScaling("model.layers.{}.self_attn."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), "final_layernorm": NameRemapping("model.norm."), @@ -51,6 +53,8 @@ llama4_causal_lm_export: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("language_model.model.embed_tokens."), "input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm."), + # MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable. + "fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"), # self_attn "linear_qkv": QKVSlicing("language_model.model.layers.{}.self_attn."), "linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj."), @@ -150,9 +154,12 @@ llama_causal_lm_import: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("model.embed_tokens.", COL_TP), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP), "final_layernorm": NameRemapping("model.norm.", REPLICATE), @@ -162,6 +169,10 @@ llama4_causal_lm_import: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("language_model.model.embed_tokens.", COL_TP), "input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + # Llama4 is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's + # fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired. + "fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("language_model.model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj.", ROW_TP), "pre_mlp_layernorm": NameRemapping( diff --git a/modelopt/torch/export/plugins/mcore_qwen.py b/modelopt/torch/export/plugins/mcore_qwen.py index 5c4ae0647d8..4120a9a36d9 100644 --- a/modelopt/torch/export/plugins/mcore_qwen.py +++ b/modelopt/torch/export/plugins/mcore_qwen.py @@ -35,12 +35,17 @@ "output_layer": NameRemapping("lm_head.", COL_TP), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear): the LayerNorm weight lives on + # linear_qkv.layer_norm_weight, loaded directly from the HF norm tensor (no `.weight` suffix + # appended since the value is a Parameter, not a sub-module). + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), "q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm.", REPLICATE), "k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm.", REPLICATE), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP), # MoE @@ -56,12 +61,14 @@ "output_layer": NameRemapping("lm_head."), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), "q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm."), "k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm."), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), # MoE @@ -76,10 +83,12 @@ "output_layer": NameRemapping("lm_head.", COL_TP), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP), } @@ -90,10 +99,12 @@ "output_layer": NameRemapping("lm_head."), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index e485731b3d8..7be98e6416b 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -238,8 +238,9 @@ def _gated_mlp_merging( else: prefix = prefix.replace("model", "mtp") - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) state_dict = {} @@ -273,6 +274,16 @@ def _gated_mlp_merging( else: state_dict["weight"] = tensor.to(self.dtype).to(device=weight.device) + # Preserve the fused LayerNorm weight + TE _extra_state already on the module so + # the strict load_state_dict below doesn't fail for TELayerNormColumnParallelLinear + # (fused under --export-default-te-spec). The actual HF norm tensor is loaded + # separately via the `fused_pre_mlp_layernorm` rule. + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + if "_extra_state" in module_state_dict: + state_dict["_extra_state"] = module_state_dict["_extra_state"] + module.load_state_dict(state_dict) def _grouped_mlp_merging( @@ -433,7 +444,13 @@ def _qkv_merging( layer_norm_weight = module_state_dict.get("layer_norm_weight", None) if layer_norm_weight is not None: state_dict["layer_norm_weight"] = layer_norm_weight - state_dict["_extra_state"] = None # for TE modules require _extra_state key + # Preserve the TE metadata struct (FP8 amax history, recipe version, etc.) — + # `load_state_dict(..., strict=True)` requires the key, but blanking it could + # zero out per-module FP8 bookkeeping on TE versions that populate it. Only + # forward through when the source actually has it, to avoid adding an + # unexpected `_extra_state=None` to TE variants that don't. + if "_extra_state" in module_state_dict: + state_dict["_extra_state"] = module_state_dict["_extra_state"] module.load_state_dict(state_dict) @@ -599,14 +616,32 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = ) # TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear). - # Load the fused layer_norm_weight from the HF norm path. + # Prefer the per-context key (`fused_input_layernorm`); fall back to the legacy + # single-key `fused_norm` for Nemotron-H style (one norm shared across slots). + # Missing both is a plugin misconfig — raise rather than silently random-init. if ( isinstance(layer.input_layernorm, IdentityOp) and hasattr(attention, "linear_qkv") and hasattr(attention.linear_qkv, "layer_norm_weight") - and "fused_norm" in self.rules ): - self.rules["fused_norm"]( + fused_key = ( + "fused_input_layernorm" + if "fused_input_layernorm" in self.rules + else "fused_norm" + ) + if fused_key not in self.rules: + # Branch only fires when model uses fused TELayerNormColumnParallelLinear, + # so missing rule is unambiguously a plugin misconfiguration; raise so it + # doesn't silently ship a chance-accuracy checkpoint. + raise KeyError( + f"{self.arch} uses fused TELayerNormColumnParallelLinear for " + "attention but neither `fused_input_layernorm` nor legacy " + "`fused_norm` is in its import mapping; `linear_qkv.layer_norm_weight` " + "would be left at random init. Add " + '`fused_input_layernorm: NameRemapping("...input_layernorm.weight")` ' + f"to the {self.arch} import mapping." + ) + self.rules[fused_key]( attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp ) @@ -707,14 +742,27 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) # TE spec: pre_mlp_layernorm is fused into linear_fc1 - # (TELayerNormColumnParallelLinear). - # Load the fused layer_norm_weight from the HF norm path. - if ( - isinstance(layer.pre_mlp_layernorm, IdentityOp) - and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") - and "fused_norm" in self.rules + # (TELayerNormColumnParallelLinear). See input_layernorm path above for the + # rule-key fallback rationale. + if isinstance(layer.pre_mlp_layernorm, IdentityOp) and hasattr( + layer.mlp.linear_fc1, "layer_norm_weight" ): - self.rules["fused_norm"]( + fused_key = ( + "fused_pre_mlp_layernorm" + if "fused_pre_mlp_layernorm" in self.rules + else "fused_norm" + ) + if fused_key not in self.rules: + raise KeyError( + f"{self.arch} uses fused TELayerNormColumnParallelLinear for " + "MLP but neither `fused_pre_mlp_layernorm` nor legacy " + "`fused_norm` is in its import mapping; " + "`linear_fc1.layer_norm_weight` would be left at random init. " + "Add `fused_pre_mlp_layernorm: NameRemapping(" + '"...post_attention_layernorm.weight")` ' + f"to the {self.arch} import mapping." + ) + self.rules[fused_key]( layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp ) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 3e488c821f8..b3173706b44 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -69,6 +69,7 @@ QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_W4A8_NVFP4_FP8, + QUANTIZATION_W4A16_NVFP4, ) logger = logging.getLogger(__name__) @@ -375,6 +376,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, ]: # Calibrate weight quantizer if amax is not set @@ -419,6 +421,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, ]: # Calibrate weight quantizer if amax is not set @@ -657,6 +660,9 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames return QUANTIZATION_NVFP4_AWQ if getattr(layer, "fused_with_prequant", False): return QUANTIZATION_NVFP4_AWQ + if input_quantizer is None or not input_quantizer.is_enabled: + if scale_bits == (4, 3): + return QUANTIZATION_W4A16_NVFP4 assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" ) @@ -824,6 +830,11 @@ def process_layer_quant_config(layer_config_dict): "quant_algo": "NVFP4", "group_size": block_size_value, } + elif v == "w4a16_nvfp4": + layer_config = { + "quant_algo": "W4A16_NVFP4", + "group_size": block_size_value, + } elif v == "nvfp4_awq": layer_config = { "quant_algo": "NVFP4_AWQ", @@ -1001,6 +1012,7 @@ def to_quantized_weight( if quantization in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, QUANTIZATION_NVFP4_SVDQUANT, ]: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 73ae63a5a56..0626d0a8fd5 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -83,6 +83,7 @@ QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, + QUANTIZATION_W4A16_NVFP4, ) from .model_utils import get_language_model_from_vl, is_multimodal_model from .moe_utils import _export_fused_experts @@ -521,6 +522,7 @@ def _export_quantized_weight( QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_NVFP4, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ]: @@ -550,6 +552,7 @@ def _export_quantized_weight( QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, ]: # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index f7c227055d0..183dd4cb8bc 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -44,6 +44,7 @@ QUANTIZATION_FP8_PB_WO, QUANTIZATION_NONE, QUANTIZATION_NVFP4, + QUANTIZATION_W4A16_NVFP4, ) from .plugins.hf_checkpoint_utils import copy_hf_ckpt_remote_code, load_multimodal_components from .plugins.mcore_common import all_mcore_hf_export_mapping @@ -289,6 +290,8 @@ def save_pretrained( quantization = "FP8" elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" + elif quantization_format == QUANTIZATION_W4A16_NVFP4: + quantization = "W4A16_NVFP4" # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. @@ -339,7 +342,7 @@ def save_pretrained( "quant_algo": quantization, "exclude_modules": combined_exclude_modules, } - if quantization == "NVFP4": # update block size + if quantization in ("NVFP4", "W4A16_NVFP4"): # update block size quantization_config["group_size"] = 16 if gathered_kv_cache_dtype is not None: @@ -439,25 +442,33 @@ def _get_state_dict(self): if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) - def _get_fused_norm_weight(self, module): - """Return ``module.layer_norm_weight`` when TE fuses the norm into a linear layer. + def _get_fused_norm_weight(self, module, primary_key: str = "fused_norm"): + """Return ``(rule_key, layer_norm_weight)`` when TE fuses the norm into a linear layer. - Returns ``None`` when the ``"fused_norm"`` rule is absent or the module has no - ``layer_norm_weight`` attribute (or its value is ``None``). + Mirrors the importer-side fallback chain: prefer the per-context key + (``fused_input_layernorm`` for attention, ``fused_pre_mlp_layernorm`` for MLP) and + fall back to the legacy ``fused_norm`` rule (Nemotron-H style, one norm shared + across attention/mlp/mamba slots). Returns ``(None, None)`` when no rule is + defined or the module has no ``layer_norm_weight``. """ - if "fused_norm" not in self.rules: - return None - return getattr(module, "layer_norm_weight", None) + fused_key = primary_key if primary_key in self.rules else "fused_norm" + if fused_key not in self.rules: + return None, None + weight = getattr(module, "layer_norm_weight", None) + if weight is None: + return None, None + return fused_key, weight def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id) - elif ( - norm_weight := self._get_fused_norm_weight( - getattr(layer.self_attention, "linear_qkv", None) + else: + fused_key, norm_weight = self._get_fused_norm_weight( + getattr(layer.self_attention, "linear_qkv", None), + primary_key="fused_input_layernorm", ) - ) is not None: - self.rules["fused_norm"](norm_weight, layer_id) + if norm_weight is not None: + self.rules[fused_key](norm_weight, layer_id) if not isinstance(layer.self_attention, IdentityOp): if "MLASelfAttention" in str(type(layer.self_attention)): @@ -496,13 +507,13 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - elif ( - not isinstance(layer.mlp, IdentityOp) - and "MoE" not in str(type(layer.mlp)) - and (norm_weight := self._get_fused_norm_weight(getattr(layer.mlp, "linear_fc1", None))) - is not None - ): - self.rules["fused_norm"](norm_weight, layer_id) + elif not isinstance(layer.mlp, IdentityOp) and "MoE" not in str(type(layer.mlp)): + fused_key, norm_weight = self._get_fused_norm_weight( + getattr(layer.mlp, "linear_fc1", None), + primary_key="fused_pre_mlp_layernorm", + ) + if norm_weight is not None: + self.rules[fused_key](norm_weight, layer_id) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): @@ -610,9 +621,12 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]: def _get_mamba_layer_state_dict(self, layer, layer_id): if not isinstance(layer.norm, IdentityOp): self.rules["norm"](layer.norm, layer_id) - elif (norm_weight := self._get_fused_norm_weight(layer.mixer.in_proj)) is not None: + else: # TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear). - self.rules["fused_norm"](norm_weight, layer_id) + # Mamba uses the legacy single-key `fused_norm` rule (Nemotron-H style). + fused_key, norm_weight = self._get_fused_norm_weight(layer.mixer.in_proj) + if norm_weight is not None: + self.rules[fused_key](norm_weight, layer_id) self.rules["mixer_norm"](layer.mixer.norm, layer_id) self.rules["A_log"](layer.mixer.A_log, layer_id) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 19f836f38dc..b993060c20e 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -79,6 +79,20 @@ except ImportError: HAS_MAMBA = False +# Newer Megatron-LM instantiates Nemotron-H et al. as plain HybridModel (MambaModel split +# out as a subclass). Register HybridModel so the dynamic-space converter sees them. +# DMRegistry._get_registered_nn_class filters by `nn_cls.forward is nn_cls_.forward` and +# returns the first match in insertion order: MambaModel is registered first, so +# MambaModel instances dispatch to MambaModel whether or not MambaModel overrides forward. +try: + from megatron.core.models.hybrid.hybrid_model import HybridModel + + SUPPORTED_MODELS[HybridModel] = "megatron.core.models.hybrid.HybridModel" + + HAS_HYBRID = True +except ImportError: + HAS_HYBRID = False + __all__ = ["get_te_mamba_stack_spec"] @@ -394,6 +408,9 @@ def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: Trace lambda mod, val: (num_attention_heads.active + 2 * mod.config.num_query_groups) * mod.config.kv_channels, ) + # in_features must track input_size so TE's forward-time inp_shape[-1] == in_features + # assertion holds when hidden_size is pruned. + self._register_dynamic_attribute("in_features", lambda mod, val: mod.input_size) self._register_dynamic_attribute("weight", self._get_weight) # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 62f7b7e16a2..fce2eb36f6b 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -17,7 +17,7 @@ import fnmatch import json -from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView +from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView from typing import Any, TypeAlias import torch @@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802 # TODO: expand config classes to searcher -class ModeloptBaseConfig(BaseModel): +class ModeloptBaseConfig(BaseModel, MutableMapping): """Our config base class for mode configuration. The base class extends the capabilities of pydantic's BaseModel to provide additional methods and properties for easier access and manipulation of the configuration. + + Inherits from :class:`collections.abc.MutableMapping` so instances satisfy + ``isinstance(cfg, Mapping)`` / ``isinstance(cfg, MutableMapping)`` checks and pick up the + mixin methods (``pop``, ``popitem``, ``setdefault``, ``clear``). Schema fields are fixed, + so ``__delitem__`` raises :class:`TypeError`; the inherited ``pop`` / ``clear`` / + ``popitem`` therefore also raise on any existing key, while ``pop(key, default)`` for a + missing key still returns the default normally. """ model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) @@ -110,18 +117,49 @@ def __contains__(self, key: str) -> bool: return False def __getitem__(self, key: str) -> Any: - """Get the value for the given key (can be name or alias of field).""" - return getattr(self, self.get_field_name_from_key(key)) + """Get the value for the given key (can be name or alias of field). + + Raises :class:`KeyError` for missing keys so the class behaves like a regular + :class:`Mapping` — required for the inherited ``MutableMapping`` mixin methods + (``pop``, ``setdefault``, ...) to dispatch correctly. + """ + try: + return getattr(self, self.get_field_name_from_key(key)) + except AttributeError: + raise KeyError(key) from None def __setitem__(self, key: str, value: Any) -> None: - """Set the value for the given key (can be name or alias of field).""" - setattr(self, self.get_field_name_from_key(key), value) + """Set the value for the given key (can be name or alias of field). + + Raises :class:`KeyError` (not :class:`AttributeError`) for unknown keys so the + class matches the :class:`MutableMapping` protocol — both for direct + ``cfg["unknown"] = value`` writes and for inherited mixin helpers like + ``setdefault`` that write through ``__setitem__``. + """ + try: + setattr(self, self.get_field_name_from_key(key), value) + except AttributeError: + raise KeyError(key) from None + + def __delitem__(self, key: str) -> None: + """Reject key deletion. + + ``ModeloptBaseConfig`` exposes a fixed pydantic schema, so removing a key is + ill-defined: schema fields can't disappear, and silently resetting them to their + defaults would surprise callers. Raise ``TypeError`` instead. Defined so the + class fully satisfies the ``MutableMapping`` protocol (``__delitem__`` is + required), without committing to actual deletion semantics. + """ + raise TypeError( + f"{type(self).__name__} does not support key deletion; schema fields are " + f"fixed (attempted to delete {key!r})." + ) def get(self, key: str, default: Any = None) -> Any: """Get the value for the given key (can be name or alias) or default if not found.""" try: return self[key] - except AttributeError: + except KeyError: return default def __len__(self) -> int: diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py index 43231c90995..76ed2bb6503 100644 --- a/modelopt/torch/opt/config_loader.py +++ b/modelopt/torch/opt/config_loader.py @@ -33,12 +33,14 @@ import re import sys from pathlib import Path -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints, overload import yaml from pydantic import TypeAdapter from typing_extensions import NotRequired, Required, is_typeddict +from modelopt.torch.opt.config import ModeloptBaseConfig + @dataclass class _ListSnippet: @@ -592,29 +594,74 @@ def _find_import_marker(obj: Any, context: str = "root") -> tuple[Any, str] | No return None +_SchemaT = TypeVar("_SchemaT", bound=ModeloptBaseConfig) + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[_SchemaT], +) -> _SchemaT: ... + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[list[_SchemaT]], +) -> list[_SchemaT]: ... + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: None = None, +) -> Any: ... + + def load_config( config_path: str | Path | Traversable, *, schema_type: Any | None = None, -) -> dict[str, Any] | list[Any]: +) -> Any: """Load a YAML config and resolve all ``$import`` references. This is the primary config loading entry point. It loads the YAML file, - resolves any ``imports`` / ``$import`` directives, and returns the final - config dict or list. - - ``schema_type`` supplies a typing context for import resolution when the - file itself has no ``modelopt-schema`` comment. It is intentionally not a - request to validate the top-level file. Top-level files are validated only - when they declare ``modelopt-schema``; imported snippets are stricter and - must always declare ``modelopt-schema``. + resolves any ``imports`` / ``$import`` directives, and returns either a + validated instance of the schema (when one is known) or the raw resolved + payload. + + The effective schema is selected as follows: + + 1. If ``schema_type`` is provided, it is used. + 2. Otherwise, the schema declared by the file's ``# modelopt-schema:`` + comment (if any) is used. + + When an effective schema is selected, the resolved payload is validated + and returned as an instance of that schema — e.g., a Pydantic model + instance for ``BaseModel`` schemas, or a validated dict / list for + ``TypedDict`` / ``list[TypedDict]`` schemas. If neither source supplies a + schema, the raw resolved dict or list is returned unchanged. + + Imported snippets are stricter and must always declare ``modelopt-schema``; + they are validated during import resolution regardless of the top-level + selection above. """ raw = _load_raw_config_with_schema(config_path) data = raw.data declared_schema_type = _schema_type(raw.schema) if raw.schema else None - resolver_schema_type = declared_schema_type or schema_type + effective_schema_type = schema_type if schema_type is not None else declared_schema_type if isinstance(data, (_ListSnippet, dict)): - data = _resolve_imports(data, schema_type=resolver_schema_type) - _validate_modelopt_schema(raw.schema, data, raw.path, schema_type=declared_schema_type) - return data + data = _resolve_imports(data, schema_type=effective_schema_type) + if effective_schema_type is None: + return data + try: + return TypeAdapter(effective_schema_type).validate_python(data) + except Exception as exc: + raise ValueError( + f"Config file {raw.path} does not match modelopt-schema " + f"{_schema_label(effective_schema_type, raw.schema)!r}: {exc}" + ) from exc diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 6a4e3828750..6fa90f96af4 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -37,7 +37,6 @@ import torch.nn as nn import torch.nn.functional as F from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear -from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_rank, @@ -56,6 +55,7 @@ from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.nas.plugins.megatron import ( + HAS_HYBRID, HAS_MAMBA, SUPPORTED_MODELS, _DynamicMambaLayer, @@ -173,6 +173,20 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i model.config.num_layers = new_num_layers +def _get_hybrid_pattern_key(model: nn.Module) -> str | None: + """Return the attribute name carrying the hybrid block pattern for hybrid models, else None. + + Handles both ``MambaModel`` (which still uses ``hybrid_override_pattern``) and plain + ``HybridModel`` (the parent class introduced in modern Megatron-LM, which carries + ``hybrid_layer_pattern``). Detecting by attribute presence avoids fragile isinstance + checks against a class hierarchy that may shift across MCore versions. + """ + for attr in ("hybrid_override_pattern", "hybrid_layer_pattern"): + if getattr(model, attr, None): + return attr + return None + + def _rprint(*renderables: Any) -> None: """Render rich renderables and print on rank 0 only.""" buf = io.StringIO() @@ -366,14 +380,9 @@ def run_search(self) -> None: # Prune homogeneously self._prune(export_config, prune_depth=True) - # TODO: Rename to hybrid_layer_pattern after MCore 0.17 and nemo:26.04 is released (for M-LM PR #3377) - # Update hybrid_override_pattern if pruning is done on a hybrid model - if isinstance(self.model, MambaModel): - hybrid_key = ( - "hybrid_override_pattern" - if hasattr(self.model, "hybrid_override_pattern") - else "hybrid_layer_pattern" - ) + # Update the hybrid block-type pattern if pruning a hybrid model. + hybrid_key = _get_hybrid_pattern_key(self.model) + if hybrid_key is not None: print_rank_0(f"Original {hybrid_key}: {getattr(self.model, hybrid_key)}") new_num_layers = self.model.config.num_layers assert self.sorted_layers is not None @@ -683,14 +692,9 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di model = self.model active_metric_keys = self.constraints.keys() & _METRIC_CONSTRAINTS - # Get hybrid layer pattern for MambaModel (None for pure GPT) hybrid_layer_pattern: str | None = None - if isinstance(model, MambaModel): - hybrid_key = ( - "hybrid_override_pattern" - if hasattr(self.model, "hybrid_override_pattern") - else "hybrid_layer_pattern" - ) + hybrid_key = _get_hybrid_pattern_key(model) + if hybrid_key is not None: hybrid_layer_pattern = getattr(model, hybrid_key) # If depth pruning on a hybrid model, filter the pattern to only the kept layers. @@ -732,6 +736,14 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di return metrics +_HYBRID_DIVISORS = { + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 512, + "mamba_head_dim_divisor": 8, + "num_moe_experts_divisor": 8, + "num_layers_divisor": 2, +} + MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model( "MCoreMinitronConfig", **get_kwargs_for_create_model_with_rules( @@ -743,19 +755,8 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di "num_moe_experts_divisor": 8, "num_layers_divisor": 2, }, - **( - { - "megatron.core.models.mamba.MambaModel": { - "hidden_size_divisor": 256, - "ffn_hidden_size_divisor": 512, - "mamba_head_dim_divisor": 8, - "num_moe_experts_divisor": 8, - "num_layers_divisor": 2, - } - } - if HAS_MAMBA - else {} - ), + **({"megatron.core.models.mamba.MambaModel": _HYBRID_DIVISORS} if HAS_MAMBA else {}), + **({"megatron.core.models.hybrid.HybridModel": _HYBRID_DIVISORS} if HAS_HYBRID else {}), }, doc='Configuration for the ``"mcore_minitron"`` mode.', ), diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 99cee86dcb2..d96ef4593d2 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -40,7 +40,7 @@ from . import config as mtq_config from . import model_calib -from .config import QuantizeConfig, QuantizerAttributeConfig +from .config import QuantizeConfig, QuantizerAttributeConfig, QuantizerCfgEntry from .conversion import set_quantizer_by_cfg from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import is_quantized_linear @@ -129,7 +129,9 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False}) + self.config.quant_cfg.append( + QuantizerCfgEntry(quantizer_name="*output_quantizer", enable=False) + ) self.compression = estimate_quant_compression(self.config) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c4c20139052..949b4642aff 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -152,23 +152,94 @@ import copy import warnings -from typing import Any, Literal, cast +from collections.abc import Mapping, Sequence +from typing import Any, Literal from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator -from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config from modelopt.torch.utils.network import ConstructorLike -class QuantizerCfgEntry(TypedDict, total=False): +class QuantizerCfgEntry(ModeloptBaseConfig): """A single entry in a ``quant_cfg`` list.""" - quantizer_name: Required[str] # matched against quantizer module names - parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") - cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) - enable: bool | None # toggles matched quantizers on/off; independent of cfg + quantizer_name: str = ModeloptField( + default=..., + title="Quantizer name pattern.", + description="Glob pattern matched against quantizer module names.", + ) + parent_class: str | None = ModeloptField( + default=None, + title="Optional parent-class filter.", + description="If provided, only quantizers whose parent module matches this PyTorch class " + "name (e.g. ``'nn.Linear'``) are affected.", + ) + cfg: "QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None" = ModeloptField( + default=None, + title="Quantizer attribute config.", + description="A :class:`QuantizerAttributeConfig` (or a mapping that validates as one), " + "or a list of such for sequential quantizers. ``None`` leaves the existing attribute " + "config untouched.", + ) + enable: bool = ModeloptField( + default=True, + title="Enable the quantizer.", + description="Toggle matched quantizers on/off; independent of ``cfg``.", + ) + + @model_validator(mode="before") + @classmethod + def _normalize_cfg_shape(cls, values): + """Pre-validation shape rules for ``cfg``. + + Runs against the raw input mapping, before pydantic coerces ``cfg`` into a + :class:`QuantizerAttributeConfig` (which would fill in schema defaults and erase the + distinction between "user typed nothing" and "user typed `{}`"). Two rules: + + 1. ``enable=False`` with an empty ``cfg`` — empty dict, empty list, or list of empty + dicts — is normalized to ``cfg=None``. Downstream applies any non-``None`` ``cfg`` + as a full quantizer-attribute replacement, so without this an entry like + ``{cfg: {}, enable: False}`` would reset attributes to schema defaults and a later + re-enable would bring the quantizer back with defaults instead of its original config. + + 2. ``enable=True`` (explicit or implicit) with an empty ``cfg`` — same shapes — is + rejected. Pydantic would otherwise coerce ``{}`` into ``QuantizerAttributeConfig()`` + with all defaults, silently turning a likely typo (``cfg: {}``) into "quantize with + schema defaults." Callers who really want defaults should drop ``cfg`` entirely and + rely on ``enable=True``; an empty ``cfg`` always indicates missing input. + """ + if not isinstance(values, dict): + return values + cfg = values.get("cfg") + cfg_is_empty = (isinstance(cfg, dict) and len(cfg) == 0) or ( + isinstance(cfg, list) + and (len(cfg) == 0 or all(isinstance(item, dict) and len(item) == 0 for item in cfg)) + ) + if cfg_is_empty: + if values.get("enable") is False: + values = {**values, "cfg": None} + else: + raise ValueError( + f"QuantizerCfgEntry 'cfg' must specify at least one quantizer attribute; " + f"got an empty mapping/list for quantizer " + f"{values.get('quantizer_name')!r}. To keep existing attributes, drop " + f"'cfg' and rely on 'enable=True'; to disable, set 'enable=False'." + ) + return values + + @model_validator(mode="after") + def _validate_instruction(self): + """Reject entries that carry no instruction beyond the path selector.""" + fields_set = self.model_fields_set + if "cfg" not in fields_set and "enable" not in fields_set: + raise ValueError( + f"QuantizerCfgEntry must specify 'cfg', 'enable', or both. An entry with only " + f"'quantizer_name'={self.quantizer_name!r} has no effect (implicit enable=True " + "is not allowed; set it explicitly)." + ) + return self def find_quant_cfg_entry_by_path( @@ -197,7 +268,7 @@ def find_quant_cfg_entry_by_path( """ result = None for entry in quant_cfg_list: - if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: + if entry.get("quantizer_name") == quantizer_name: result = entry if result is None: raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") @@ -930,13 +1001,28 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType +# Pre-normalization input shape: either a sequence of already-validated +# :class:`QuantizerCfgEntry` instances, or a sequence of raw mappings (any of the legacy / +# new dict forms). Splitting the union into two ``Sequence[...]`` arms — rather than +# ``Sequence[QuantizerCfgEntry | Mapping[str, Any]]`` — keeps each arm covariant in its +# element type, so callers can pass ``list[QuantizerCfgEntry]`` or ``list[dict]`` without +# tripping invariance. +RawQuantizeQuantCfgType = Sequence[QuantizerCfgEntry] | Sequence[Mapping[str, Any]] + +# Legacy flat-dict input shape (``{"*": ..., "*weight_quantizer": ...}``). Accepted by +# ``normalize_quant_cfg_list`` for backward compatibility but emits a DeprecationWarning; +# new code should use a list of :class:`QuantizerCfgEntry`-shaped entries instead. +DeprecatedQuantCfgType = Mapping[str, Any] + _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` dicts. +def normalize_quant_cfg_list( + v: RawQuantizeQuantCfgType | DeprecatedQuantCfgType, +) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` instances. Supports the following input forms: @@ -951,35 +1037,19 @@ def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted to a new-format entry with ``parent_class`` set. - **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither - ``cfg`` nor ``enable``. Concretely, the following are invalid: - - - An empty entry ``{}``. - - An entry with only ``quantizer_name`` and no other keys — the only effect would be an - implicit ``enable=True``, which must be stated explicitly. - - An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty - ``dict`` or ``list`` — e.g. ``{"quantizer_name": "*", "cfg": {}}`` or - ``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid - configuration. - - **Normalization** — after conversion and validation every entry is put into canonical form: - - - ``enable`` is set to ``True`` if not explicitly specified. - - ``cfg`` is set to ``None`` if not present in the entry. - - Every returned entry is therefore guaranteed to have the keys ``quantizer_name``, ``enable``, - and ``cfg`` (plus optionally ``parent_class``). + Each normalized dict is then constructed into a :class:`QuantizerCfgEntry`, whose own + validator enforces that every entry specifies ``cfg``, ``enable``, or both, and that any + ``cfg`` for an enabled quantizer is a non-empty dict or non-empty list of non-empty dicts. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. Returns: - A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + A list of validated :class:`QuantizerCfgEntry` instances. Raises: - ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, - if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format - is not recognized. + ValueError: If any entry's shape is not recognized, or if it fails + :class:`QuantizerCfgEntry` validation (missing instruction or invalid ``cfg``). """ def _warn_legacy(): @@ -993,26 +1063,33 @@ def _warn_legacy(): ) # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} → list of single-key dicts. - if isinstance(v, dict): + if isinstance(v, Mapping): _warn_legacy() v = [{k: val} for k, val in v.items()] + elif not isinstance(v, Sequence) or isinstance(v, (str, bytes)): + raise ValueError( + f"quant_cfg must be a sequence of entries (or a legacy flat mapping), got " + f"{type(v).__name__}: {v!r}." + ) - def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: - """Convert a single legacy key-value pair to one or more QuantizerCfgEntry dicts.""" + def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: + """Convert a single legacy key-value pair to one or more entry dicts.""" # Legacy "default" key was a catch-all applied as "*" in the old conversion code. if key == "default": key = "*" if isinstance(key, str) and key.startswith("nn."): - if not isinstance(value, dict): - raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") + if not isinstance(value, Mapping): + raise ValueError( + f"For 'nn.*' scoped format, value must be a mapping, got {value!r}" + ) # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. - entries: list[QuantizerCfgEntry] = [] + entries: list[dict[str, Any]] = [] for q_path, sub_cfg in value.items(): sub_cfg = dict(sub_cfg) enable = sub_cfg.pop("enable", None) cfg = sub_cfg or None - entry: QuantizerCfgEntry = { + entry: dict[str, Any] = { "parent_class": key, "quantizer_name": q_path, "cfg": cfg, @@ -1022,7 +1099,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: entries.append(entry) return entries else: - if isinstance(value, dict): + if isinstance(value, Mapping): cfg = {k: val for k, val in value.items() if k != "enable"} or None enable = value.get("enable") else: @@ -1036,15 +1113,21 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: result: list[QuantizerCfgEntry] = [] _warned_legacy = False for raw in v: - if isinstance(raw, dict) and "quantizer_name" in raw: - entries = [dict(raw)] # copy to avoid mutating caller's data - elif isinstance(raw, dict) and len(raw) == 1: + # Already-validated QuantizerCfgEntry instances (e.g. produced by load_config on a + # snippet schematized with `# modelopt-schema: QuantizerCfgEntry`, then spread into + # a quant_cfg list) are passed through unchanged. + if isinstance(raw, QuantizerCfgEntry): + result.append(raw) + continue + if isinstance(raw, Mapping) and "quantizer_name" in raw: + entries: list[dict[str, Any]] = [dict(raw)] # copy to avoid mutating caller's data + elif isinstance(raw, Mapping) and len(raw) == 1: key, val = next(iter(raw.items())) entries = [dict(e) for e in _dict_to_entry(key, val)] if not _warned_legacy: _warn_legacy() _warned_legacy = True - elif isinstance(raw, dict) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): + elif isinstance(raw, Mapping) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): # Legacy flat dict with nn.*-scoped keys mixed with other keys — expand all pairs. entries = [] for k, val in raw.items(): @@ -1055,42 +1138,10 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: else: raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") - for entry in entries: - # Validate: must carry at least one instruction beyond the path selector. - if "cfg" not in entry and "enable" not in entry: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " - "or both. An entry with only 'quantizer_name' has no effect (implicit " - "enable=True is not allowed; set it explicitly)." - ) - - # Validate: when cfg is present and enable=True, cfg must be a non-empty - # dict or list. An empty cfg would attempt to create a - # QuantizerAttributeConfig with no actual configuration. - cfg = entry.get("cfg") - enable = entry.get("enable", True) - if enable and cfg is not None: - if isinstance(cfg, dict): - is_invalid = len(cfg) == 0 - elif isinstance(cfg, list): - is_invalid = len(cfg) == 0 or any( - not isinstance(item, dict) or len(item) == 0 for item in cfg - ) - else: - is_invalid = True - if is_invalid: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a non-empty dict " - f"or a non-empty list of non-empty dicts when enabling a quantizer " - f"(got {type(cfg).__name__}: {cfg!r}). Either provide quantizer " - "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." - ) - - # Normalize: make enable and cfg always explicit. - entry.setdefault("enable", True) - entry.setdefault("cfg", None) - - result.append(cast("QuantizerCfgEntry", entry)) + # Constructing each QuantizerCfgEntry runs its model_validator, which enforces the + # at-least-one-of('cfg', 'enable') and cfg-shape constraints. Defaults for absent + # 'cfg' / 'enable' are filled by the pydantic field defaults. + result.extend(QuantizerCfgEntry(**entry) for entry in entries) return result @@ -1112,27 +1163,18 @@ class QuantizeConfig(ModeloptBaseConfig): @field_validator("quant_cfg", mode="before") @classmethod - def normalize_quant_cfg(cls, v): - """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" - if not isinstance(v, (list, dict)): - return v + def normalize_quant_cfg( + cls, v: RawQuantizeQuantCfgType | DeprecatedQuantCfgType + ) -> QuantizeQuantCfgType: + """Normalize raw quant_cfg input into a ``list[QuantizerCfgEntry]``. + + Delegates to :func:`normalize_quant_cfg_list`, which accepts every supported input + shape (new-format list, legacy single-key-dict list, legacy flat dict, and lists + containing already-validated ``QuantizerCfgEntry`` instances) and rejects anything + else with a clear ``ValueError`` before pydantic's field-type check would see it. + """ return normalize_quant_cfg_list(v) - @field_validator("quant_cfg", mode="after") - @classmethod - def validate_quant_cfg_entries(cls, v): - """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" - qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) - for entry in v: - cfg = entry.get("cfg") - if cfg is None: - continue - cfgs = cfg if isinstance(cfg, list) else [cfg] - for c in cfgs: - if isinstance(c, dict) and qac_fields & set(c.keys()): - QuantizerAttributeConfig.model_validate(c) - return v - class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" @@ -1157,15 +1199,24 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" -_base_disable_all: list[QuantizerCfgEntry] = [ - cast("QuantizerCfgEntry", load_config("configs/ptq/units/base_disable_all")) +# Shared snippet constants are dumped back to plain dicts before being spliced into +# the public quant config constants below. ``load_config`` returns validated +# ``QuantizerCfgEntry`` instances for schema-tagged files, but the public constants +# (``INT4_AWQ_CFG``, ``NVFP4_DEFAULT_CFG``, etc.) have always been raw dict/list trees; +# splatting schema instances into them would surprise callers that serialise the +# constants or do ``isinstance(entry, dict)`` checks. ``exclude_unset=True`` keeps the +# sparse YAML shape (only the explicitly set fields) so the dumped dicts are +# byte-identical to what authors wrote in the YAML snippets. +_base_disable_all: list[dict[str, Any]] = [ + load_config("configs/ptq/units/base_disable_all").model_dump(exclude_unset=True) ] -_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( - "configs/ptq/units/default_disabled_quantizers" -) +_default_disabled_quantizer_cfg: list[dict[str, Any]] = [ + entry.model_dump(exclude_unset=True) + for entry in load_config("configs/ptq/units/default_disabled_quantizers") +] -_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ +_mamba_moe_disabled_quantizer_cfg: list[dict[str, Any]] = [ {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*q_proj*", "enable": False}, # Skip QKV Linear (HF naming) @@ -1212,7 +1263,9 @@ class _QuantizeExportConfig(ModeloptBaseConfig): "algorithm": "max", } -FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8") +FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8").model_dump( + exclude_unset=True +) MAMBA_MOE_FP8_AGGRESSIVE_CFG = { "quant_cfg": [ @@ -1457,7 +1510,9 @@ class _QuantizeExportConfig(ModeloptBaseConfig): # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8") +FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8").model_dump( + exclude_unset=True +) FP8_AFFINE_KV_CFG = { "quant_cfg": [ @@ -1490,7 +1545,7 @@ def _nvfp4_selective_quant_cfg( algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: list[QuantizerCfgEntry] = [] + quant_cfg: list[dict[str, Any]] = [] quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: # Deep-copy the quantizer dict so each config constant gets its own instance. @@ -1684,6 +1739,7 @@ def _nvfp4_selective_quant_cfg( ], "algorithm": "max", } +W4A16_NVFP4_CFG = _nvfp4_selective_quant_cfg(["*"], weight_only=True) MXFP4_MLP_WEIGHT_ONLY_CFG = { "quant_cfg": [ @@ -1740,6 +1796,7 @@ def _nvfp4_selective_quant_cfg( "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", + "W4A16_NVFP4_CFG", "W4A8_NVFP4_FP8_CFG", "NVFP4_SVDQUANT_DEFAULT_CFG", "W4A8_AWQ_BETA_CFG", @@ -1757,7 +1814,7 @@ def _nvfp4_selective_quant_cfg( } -def need_calibration(config): +def need_calibration(config: QuantizeConfig | Mapping[str, Any]) -> bool: """Check if calibration is needed for the given config.""" if config["algorithm"] is not None and config["algorithm"] != "max": return True @@ -1765,8 +1822,8 @@ def need_calibration(config): def _not_dynamic(cfg): return cfg.get("enable", True) and cfg.get("type", "") != "dynamic" - quant_cfg: list = config.get("quant_cfg") or [] - quant_cfg = normalize_quant_cfg_list(quant_cfg) + raw_quant_cfg: RawQuantizeQuantCfgType | DeprecatedQuantCfgType = config.get("quant_cfg") or [] + quant_cfg: list[QuantizerCfgEntry] = normalize_quant_cfg_list(raw_quant_cfg) for entry in quant_cfg: name = entry["quantizer_name"] raw_cfg = entry.get("cfg") diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 190138f0971..4a79d1a6507 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -31,8 +31,8 @@ from .config import ( QuantizeConfig, - QuantizeQuantCfgType, QuantizerAttributeConfig, + RawQuantizeQuantCfgType, _QuantizeExportConfig, normalize_quant_cfg_list, ) @@ -218,7 +218,7 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: RawQuantizeQuantCfgType): """Apply a quantization config list to the quantizers in ``quant_model``. ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` @@ -480,7 +480,7 @@ def set_quantizer_attributes_partial( @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: RawQuantizeQuantCfgType): """Context manager that temporarily applies a quantization config and restores the original state on exit. Calls :func:`set_quantizer_by_cfg` on entry and reverts every diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 01c6be15308..fa7ca1cedec 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -917,7 +917,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg cal = getattr(module, "_calibrator", None) if cal and not getattr(module, "_dynamic", False): - if method in {"entropy"}: + if method == "entropy": if cal.compute_amax(method) is not None: module.load_calib_amax("entropy", **kwargs) elif cal.compute_amax(**kwargs) is not None: diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 5e65f9cc1d4..3582223c4d3 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None): lines.append(f"{len(lines)} TensorQuantizers found in model") if output_dir: + os.makedirs(output_dir, exist_ok=True) path = os.path.join(output_dir, ".quant_summary.txt") with open(path, "w", encoding="utf-8") as f: f.write("\n".join(lines) + "\n") diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index aadb5ccfdc7..0dff5d79aed 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1132,7 +1132,7 @@ def forward(self, inputs): return outputs - def _short_amax(self, fmt=".4f"): + def _short_amax(self, fmt=".2e"): """Short description of amax. Returns: @@ -1150,7 +1150,7 @@ def _short_amax(self, fmt=".4f"): return "meta" return self._short_tensor(self._amax, fmt) - def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"): + def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"): """Short description of tensor.""" if tensor.numel() == 1: return f"{tensor.item():{fmt}}" diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index f90d2862aef..bec9ab8e081 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -283,11 +283,6 @@ def _configure_attention_for_kv_cache_quant(module: Attention): def _register_extra_state_callbacks(model: torch.nn.Module): for name, module in model.named_modules(): - if name.endswith("output_layer"): - # output_layer is not quantized, - # hence we don't need to register extra state callbacks for it - continue - if type(module) in QuantModuleRegistry: # This module will be replaced as a QuantModule register_modelopt_extra_state_callbacks( @@ -377,8 +372,13 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # output_layer.input_quantizer._amax but TP-only does not. This lead to # state_dict mismatch. if prefix.endswith("output_layer."): - # assert not any("_quantizer" in k for k in self.state_dict()), "quantized output_layer" - return super().sharded_state_dict(prefix, sharded_offsets, metadata) + try: + from megatron.training import get_args as _mlm_get_args + _untied = bool(getattr(_mlm_get_args(), "untie_embeddings_and_output_weights", False)) + except Exception: + _untied = False + if not _untied: + return super().sharded_state_dict(prefix, sharded_offsets, metadata) quantizer_state_dict = {} for k, v in self.state_dict(prefix="", keep_vars=True).items(): diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 15dbd8e2c1a..0f84cdea3c1 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -28,6 +28,22 @@ __all__ = ["NVFP4QTensor"] +def _cast_per_block_scale_to_fp8( + per_block_scale: torch.Tensor, + per_block_scale_max: torch.Tensor | None = None, +) -> torch.Tensor: + """Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN. + + When ``per_block_scale_max`` is provided, first rescales as + ``per_block_scale.float() * 448 / per_block_scale_max`` — the static-export + path needs this because the ``[==0]=1.0`` safety net combined with a small + ``global_amax`` can drive the rescaled value above 448 (see PR #1397). + """ + if per_block_scale_max is not None: + per_block_scale = per_block_scale.float() * 448.0 / per_block_scale_max + return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn) + + class NVFP4QTensor(BaseQuantizedTensor): """Implements the INT4 quantization on tensors for more efficient storage or computation. @@ -132,13 +148,8 @@ def get_weights_scaling_factor_from_quantizer( expected_shape = (*weight.shape[:-1], num_blocks_per_row) per_block_scale = per_block_scale.view(expected_shape) - # Clamp to fp8_e4m3fn range: upper avoids NaN cast, lower avoids 0x00 underflow. if not keep_high_precision: - per_block_scale = ( - (per_block_scale * 448.0 / per_block_scale_max) - .clamp_(min=2**-9, max=448.0) - .to(torch.float8_e4m3fn) - ) + per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale, per_block_scale_max) return per_block_scale, weights_scaling_factor_2 else: # Dynamic path: compute from weight tensor @@ -177,9 +188,8 @@ def get_weights_scaling_factor( ) # Set all zero values in scale to 1.0 per_block_scale[per_block_scale == 0] = 1.0 - # Clamp at fp8_e4m3fn subnormal min so tiny-amax blocks don't underflow to 0. if not keep_high_precision: - per_block_scale = per_block_scale.clamp_(min=2**-9).to(torch.float8_e4m3fn) + per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale) return per_block_scale, weights_scaling_factor_2 @classmethod diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index d82d471d01a..230ed4d9134 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -15,11 +15,9 @@ """Configurations for speculative decoding modes.""" -import warnings from copy import deepcopy -from typing import Any -from pydantic import ValidationInfo, model_validator +from pydantic import model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField @@ -71,7 +69,7 @@ class DFlashConfig(ModeloptBaseConfig): default=False, description=( "Whether to use detached DFlash (offline training from pre-computed hidden states). " - "Auto-derived from data_args.offline_data_path during validation — not user-configurable." + "Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable." ), ) @@ -103,10 +101,12 @@ class DFlashConfig(ModeloptBaseConfig): default=True, description="Whether to report eval accuracy." ) - dflash_mask_token_id: int = ModeloptField( + dflash_mask_token_id: int | None = ModeloptField( default=None, - description="Token ID used for masked (unknown) positions. " - "Set explicitly or auto-detected from tokenizer.mask_token_id in main.py.", + description=( + "Token ID used for masked (unknown) positions. Set explicitly in the recipe YAML, " + "or left unset to fall back to ``tokenizer.mask_token_id`` at training time." + ), ) dflash_architecture_config: dict = ModeloptField( @@ -118,43 +118,6 @@ class DFlashConfig(ModeloptBaseConfig): description="Whether to use torch.compile on DFlash forward/loss methods.", ) - @model_validator(mode="before") - @classmethod - def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any: - """Derive ``dflash_offline`` from ``data_args.offline_data_path``. - - This field is auto-derived, not user-configurable: when context provides - ``data_args``, the derived value overrides any user-supplied value. - """ - ctx = info.context if info.context else {} - data_args = ctx.get("data_args") - if data_args is not None and isinstance(data, dict): - data["dflash_offline"] = getattr(data_args, "offline_data_path", None) is not None - return data - - @model_validator(mode="before") - @classmethod - def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any: - """Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context.""" - if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None: - return data - ctx = info.context if info.context else {} - tokenizer = ctx.get("tokenizer") - if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None: - data["dflash_mask_token_id"] = tokenizer.mask_token_id - return data - - @model_validator(mode="after") - def _check_mask_token_id(self) -> "DFlashConfig": - """Validate that mask_token_id is set after all resolution attempts.""" - if self.dflash_mask_token_id is None: - raise ValueError( - "dflash_mask_token_id is required. Set it in the config YAML " - "(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer " - "has a mask_token_id attribute." - ) - return self - class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" @@ -174,7 +137,11 @@ class EagleConfig(ModeloptBaseConfig): """Eagle config.""" eagle_offline: bool = ModeloptField( - default=False, description=("Whether to use detached Eagle.") + default=False, + description=( + "Whether to use detached Eagle. Derived by ModelOptEagleRecipe from " + "data.offline_data_path; not user-configurable." + ), ) eagle_hidden_state_distillation: bool = ModeloptField( @@ -292,16 +259,6 @@ class EagleConfig(ModeloptBaseConfig): ), ) - @model_validator(mode="before") - @classmethod - def _derive_eagle_offline(cls, data: Any, info: ValidationInfo) -> Any: - """Derive ``eagle_offline`` from ``data_args.offline_data_path`` when provided in context.""" - ctx = info.context if info.context else {} - data_args = ctx.get("data_args") - if data_args is not None and isinstance(data, dict): - data["eagle_offline"] = data_args.offline_data_path is not None - return data - @model_validator(mode="after") def _check_rope_scaling_consistency(self) -> "EagleConfig": if not self.eagle_export_rope_scaling: @@ -315,18 +272,3 @@ def _check_rope_scaling_consistency(self) -> "EagleConfig": f"training rope_type is 'default' (no scaling)." ) return self - - @model_validator(mode="after") - def _warn_rope_vs_training_seq_len(self, info: ValidationInfo) -> "EagleConfig": - ctx = info.context if info.context else {} - training_args = ctx.get("training_args") - if training_args is None: - return self - orig_max_pos = self.eagle_export_rope_scaling.get("original_max_position_embeddings") - if orig_max_pos is not None and orig_max_pos != training_args.training_seq_len: - warnings.warn( - f"eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) " - f"differs from training_seq_len ({training_args.training_seq_len}). " - f"This may affect long-context inference quality." - ) - return self diff --git a/modelopt/torch/speculative/plugins/hf_eagle.py b/modelopt/torch/speculative/plugins/hf_eagle.py index f2040d9d960..d6e76d80c42 100644 --- a/modelopt/torch/speculative/plugins/hf_eagle.py +++ b/modelopt/torch/speculative/plugins/hf_eagle.py @@ -17,6 +17,7 @@ import contextlib import copy +import os from typing import Any import torch @@ -25,6 +26,8 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.utils import ModelOutput +from modelopt.torch.utils import print_rank_0 + from ...export.plugins.hf_spec_export import EagleExporter, SpeculativeDecodingExporter from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel @@ -88,7 +91,7 @@ def _nvtx_range(self, name): return nvtx.range(name) except Exception as e: - print(f"Failed to create NVTX range {name}: {e}") + print_rank_0(f"Failed to create NVTX range {name}: {e}") return contextlib.nullcontext() def _find_base_model_parts(self): @@ -105,7 +108,7 @@ def _find_base_model_parts(self): try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) - print(f"Found {name} at {path}") + print_rank_0(f"Found {name} at {path}") found_submodule = True setattr(self, name, path) break @@ -128,7 +131,7 @@ def _activate_torch_compile(self): try: setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) except Exception: # noqa: PERF203 - print(f"Disabling torch.compile for {name} due to compilation error.") + print_rank_0(f"Disabling torch.compile for {name} due to compilation error.") def get_dummy_inputs(self) -> dict: """Construct dummy inputs for export forward pass.""" @@ -250,6 +253,29 @@ def _preservation_loss( ) return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight + @staticmethod + def load_draft_vocab_cache(model, d2t_path: str | None) -> None: + """Load the draft-to-target token-id mapping; required iff the draft vocab is compressed.""" + if model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size: + return + if d2t_path is None or not os.path.isfile(d2t_path): + raise FileNotFoundError( + f"Draft vocab cache is required when draft_vocab_size " + f"({model.eagle_config.draft_vocab_size}) < vocab_size " + f"({model.eagle_config.vocab_size}); got d2t_path={d2t_path!r}. " + f"Set data.draft_vocab_cache in the recipe YAML." + ) + d2t = model.eagle_module.d2t + loaded = torch.load(d2t_path, map_location=d2t.device, weights_only=True) + if loaded.shape != d2t.shape or loaded.dtype != d2t.dtype: + raise ValueError( + f"Draft vocab cache mismatch at {d2t_path}: " + f"got shape={tuple(loaded.shape)} dtype={loaded.dtype}, " + f"expected shape={tuple(d2t.shape)} dtype={d2t.dtype}." + ) + d2t.copy_(loaded) + print_rank_0(f"Loaded draft vocab cache from {d2t_path}.") + def modify( self, config, diff --git a/modelopt/torch/speculative/plugins/hf_training_args.py b/modelopt/torch/speculative/plugins/hf_training_args.py new file mode 100644 index 00000000000..4801452d980 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_training_args.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic schemas for HF-trainer-based speculative-decoding experiments. + +These are the typed section models used inside speculative-decoding recipes +(:class:`modelopt.recipe.config.ModelOptEagleRecipe` / +:class:`modelopt.recipe.config.ModelOptDFlashRecipe`). They mirror the HF dataclasses used +by :mod:`examples/speculative_decoding/main.py` so that recipe YAMLs are Pydantic-validated +at load time. + +The module is pure Pydantic schema with no runtime dependencies on ``transformers``, +``torch``, or ``accelerate`` — distributed-environment resolution (``WORLD_SIZE`` lookup, +``ParallelismConfig`` construction) is the caller's responsibility, see +``init_distributed_env`` in ``examples/speculative_decoding/main.py``. +""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, field_validator + + +class ModelArguments(BaseModel): + """Arguments for loading the base HF model.""" + + model_config = ConfigDict(extra="forbid", protected_namespaces=()) + + model_name_or_path: str | None = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + use_fake_base_for_offline: bool = False + trust_remote_code: bool = False + + +class DataArguments(BaseModel): + """Arguments for the training dataset.""" + + model_config = ConfigDict(extra="forbid") + + data_path: str | None = None + offline_data_path: str | None = None + lazy_preprocess: bool = True + draft_vocab_cache: str | None = None + chat_template: str | None = None + vlm_img_dir: str | None = None + vlm_processor: str | None = None + sample_size: int = -1 + + @field_validator("sample_size") + @classmethod + def _check_sample_size(cls, v: int) -> int: + if v == 0 or v < -1: + raise ValueError("sample_size must be -1 (use all samples) or a positive integer") + return v + + +class TrainingArguments(BaseModel): + """Speculative-decoding extensions on top of ``transformers.TrainingArguments``. + + HF trainer fields (``learning_rate``, ``num_train_epochs``, ...) flow through as extras + via ``extra='allow'`` — they're re-validated later when the dict is passed to + ``HfTrainingArguments(**recipe.training.model_dump())`` in main.py. + """ + + model_config = ConfigDict(extra="allow") + + training_seq_len: int = 2048 + estimate_ar: bool = False + ar_validate_steps: int = 1000 + answer_only_loss: bool = False + cp_size: int = 1 + dp_shard_size: int | None = None diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 129052e9ae0..80ed8f9abdd 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -32,6 +32,11 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase + +def _join_messages_content(sample: dict) -> str: + return "\n".join(turn["content"] for turn in sample["messages"]) + + # Use dict to store the config for each dataset. # If we want to export more options to user like target languages, we need more standardized approach like dataclass. SUPPORTED_DATASET_CONFIG: dict[str, Any] = { @@ -61,7 +66,7 @@ "path": "nvidia/Nemotron-Post-Training-Dataset-v2", "split": ["stem", "chat", "math", "code"], }, - "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["messages"]), + "preprocess": _join_messages_content, "chat_key": "messages", }, "nemotron-post-training-dataset-v1": { @@ -69,7 +74,93 @@ "path": "nvidia/Nemotron-Post-Training-Dataset-v1", "split": ["stem", "chat", "math", "code", "tool_calling"], }, - "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["messages"]), + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-sft-instruction-following-chat-v2": { + # Skips ``reasoning_on`` split: heterogeneous messages schema fails streaming cast. + "config": { + "path": "nvidia/Nemotron-SFT-Instruction-Following-Chat-v2", + "split": ["reasoning_off"], + }, + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-science-v1": { + "config": { + "path": "nvidia/Nemotron-Science-v1", + "split": ["MCQ", "RQA"], + }, + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-competitive-programming-v1": { + # Skips ``infinibyte_part0[0|1]``: heterogeneous schema fails streaming cast. + "config": { + "path": "nvidia/Nemotron-Competitive-Programming-v1", + "split": [ + "competitive_coding_cpp_part00", + "competitive_coding_cpp_part01", + "competitive_coding_python_part00", + "competitive_coding_python_part01", + ], + }, + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-sft-agentic-v2": { + # Only ``search`` streams cleanly: ``interactive_agent`` has a heterogeneous + # tools schema (string vs list) that breaks pyarrow JSON inference, and + # ``tool_calling`` contains at least one malformed JSON row in a later shard. + "config": { + "path": "nvidia/Nemotron-SFT-Agentic-v2", + "split": ["search"], + }, + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-math-v2": { + "config": { + "path": "nvidia/Nemotron-Math-v2", + "split": ["high_part00", "high_part01", "high_part02", "medium", "low"], + }, + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-sft-swe-v2": { + # Skips ``openhands_swe`` split: heterogeneous schema fails streaming cast. + "config": { + "path": "nvidia/Nemotron-SFT-SWE-v2", + "split": ["agentless"], + }, + "preprocess": _join_messages_content, + "chat_key": "messages", + }, + "nemotron-sft-multilingual-v1": { + "config": { + "path": "nvidia/Nemotron-SFT-Multilingual-v1", + "split": [ + "code_de", + "code_es", + "code_fr", + "code_it", + "code_ja", + "code_zh", + "math_de", + "math_es", + "math_fr", + "math_it", + "math_ja", + "math_zh", + "stem_de", + "stem_es", + "stem_fr", + "stem_it", + "stem_ja", + "stem_zh", + ], + }, + "preprocess": _join_messages_content, "chat_key": "messages", }, "magpie": { @@ -106,6 +197,42 @@ }, } +# Named groups of registered datasets, expanded in ``get_dataset_dataloader``. +# Useful when callers want a single ``--dataset`` token that fans out to several +# entries; per-dataset ``num_samples`` is split evenly across the members. +DATASET_COMBOS: dict[str, list[str]] = { + "cnn_nemotron_v2_mix": ["cnn_dailymail", "nemotron-post-training-dataset-v2"], + "nemotron-post-training-v3": [ + "nemotron-sft-instruction-following-chat-v2", + "nemotron-science-v1", + "nemotron-competitive-programming-v1", + "nemotron-sft-agentic-v2", + "nemotron-math-v2", + "nemotron-sft-swe-v2", + "nemotron-sft-multilingual-v1", + ], +} + + +def _validate_dataset_combos() -> None: + """Validate DATASET_COMBOS at import time: fail loud on typos / collisions.""" + overlap = set(DATASET_COMBOS) & set(SUPPORTED_DATASET_CONFIG) + if overlap: + raise ValueError( + f"DATASET_COMBOS keys collide with SUPPORTED_DATASET_CONFIG: {sorted(overlap)}" + ) + for combo_name, members in DATASET_COMBOS.items(): + if not members: + raise ValueError(f"DATASET_COMBOS['{combo_name}'] must contain at least one dataset.") + unknown = [m for m in members if m not in SUPPORTED_DATASET_CONFIG] + if unknown: + raise ValueError( + f"DATASET_COMBOS['{combo_name}'] references unknown datasets: {unknown}" + ) + + +_validate_dataset_combos() + __all__ = [ "create_forward_loop", "download_hf_dataset_as_jsonl", @@ -260,6 +387,13 @@ def get_dataset_samples( Returns: Samples: The list of samples. """ + if dataset_name in DATASET_COMBOS: + raise ValueError( + f"'{dataset_name}' is a DATASET_COMBOS entry, not a single dataset. " + "Use ``get_dataset_dataloader`` to expand combos, or pass one of " + f"its members: {DATASET_COMBOS[dataset_name]}" + ) + # Local JSONL: load via HF's ``json`` builder and route through the same # auto-preprocess path as unregistered HF datasets so chat / prompt / text # columns are handled consistently with a downloaded HF dataset. Never @@ -471,6 +605,34 @@ def get_dataset_dataloader( "dataset_name and num_samples must be the same length" ) + # Reject inputs that include both a combo and one of its member datasets + # (e.g. ``["cnn_dailymail", "cnn_nemotron_v2_mix"]``), since the combo would sample the + # plain entry a second time with a smaller per-member quota. + plain_inputs = {n for n in dataset_name if n not in DATASET_COMBOS} + for ds_name in dataset_name: + if ds_name in DATASET_COMBOS: + overlap = plain_inputs & set(DATASET_COMBOS[ds_name]) + if overlap: + raise ValueError( + f"--dataset includes both combo '{ds_name}' and its " + f"member(s) {sorted(overlap)}; remove one to avoid " + "double-sampling." + ) + + expanded_names: list[str] = [] + expanded_num_samples: list[int] = [] + for ds_name, n in zip(dataset_name, num_samples): + if ds_name in DATASET_COMBOS: + members = DATASET_COMBOS[ds_name] + base, remainder = divmod(n, len(members)) + for i, member in enumerate(members): + expanded_names.append(member) + expanded_num_samples.append(base + (1 if i < remainder else 0)) + else: + expanded_names.append(ds_name) + expanded_num_samples.append(n) + dataset_name, num_samples = expanded_names, expanded_num_samples + all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): samples = get_dataset_samples( @@ -528,7 +690,7 @@ def get_supported_datasets() -> list[str]: print("Supported datasets:", get_supported_datasets()) """ - return list(SUPPORTED_DATASET_CONFIG.keys()) + return list(SUPPORTED_DATASET_CONFIG.keys()) + list(DATASET_COMBOS.keys()) @contextmanager diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index f5aba0d1a1f..724dda14434 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -111,8 +111,14 @@ def print_rank_0(*args, **kwargs): def warn_rank_0(message, *args, **kwargs): - """Issues a warning only on the master process.""" + """Issues a warning only on the master process. + + Auto-bumps ``stacklevel`` by 1 to skip this wrapper frame, so callers can pass the + same stacklevel they would to ``warnings.warn`` directly and the warning still + points at the user's call site. + """ if dist.is_master(): + kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1 warnings.warn(message, *args, **kwargs) diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index 5625013bb44..1c4c5d54647 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -150,7 +150,9 @@ def megatron_prefill( ) send_to_next_pipeline_rank(output.to(dtype=pp_dtype)) - logits = output[:, :seq_length, :].detach() if pp_last else None + # .contiguous() is required because the slice is a view with the padded stride; the broadcast + # below asserts contiguity when SP pads seq_length up to a multiple of TP. + logits = output[:, :seq_length, :].detach().contiguous() if pp_last else None if model.config.bf16: logits_dtype = torch.bfloat16 diff --git a/modelopt/torch/utils/plugins/megatron_mmlu.py b/modelopt/torch/utils/plugins/megatron_mmlu.py index 4a07405caff..6c70c5aee48 100644 --- a/modelopt/torch/utils/plugins/megatron_mmlu.py +++ b/modelopt/torch/utils/plugins/megatron_mmlu.py @@ -60,6 +60,7 @@ def megatron_mmlu( few_shots: int = 0, fraction: float = 0.05, batch_size: int = 1, + mmlu_dataset: str = "cais/mmlu", ) -> float: """Evaluate the model on MMLU using log-likelihood scoring over batched prefill passes. @@ -73,6 +74,8 @@ def megatron_mmlu( few_shots: The number of few-shot examples to use. fraction: The fraction of the test set to evaluate on. batch_size: Number of examples to process in one forward pass. + mmlu_dataset: HF dataset name or local MMLU dataset path passed to `datasets.load_dataset`. + Defaults to ``cais/mmlu``. """ print_rank_0( f"\nMMLU ({fraction * 100}%, {few_shots}-shot, Batch Size: {batch_size}) evaluation started...\n" @@ -104,8 +107,8 @@ def _generate_prompt(test_example, dev_examples, few_shots=0): # Load all subjects in two dataset calls instead of 2x num_subjects calls. # The "all" config includes a "subject" field for per-subject reporting. - test_dataset = load_dataset("cais/mmlu", "all", split="test") - dev_dataset = load_dataset("cais/mmlu", "all", split="dev") if few_shots > 0 else None + test_dataset = load_dataset(mmlu_dataset, "all", split="test") + dev_dataset = load_dataset(mmlu_dataset, "all", split="dev") if few_shots > 0 else None # Group dev examples by subject for few-shot prompt construction. dev_by_subject: dict = {} diff --git a/modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml b/modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml new file mode 100644 index 00000000000..b4676dbff34 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# W4A16 NVFP4: NVFP4 E2M1 dynamic weight quantizer only; activations remain in BF16. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + nvfp4: configs/numerics/nvfp4 +--- + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 diff --git a/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml new file mode 100644 index 00000000000..416572e0f80 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + w4a16_nvfp4: configs/ptq/units/w4_nvfp4 + +metadata: + recipe_type: ptq + description: NVFP4 W4A16 weight-only, BF16 activations, max calibration. No calibration forward pass required. +quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - $import: w4a16_nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index 3d43e0fe1d4..a38b24d05d6 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -1,4 +1,9 @@ -# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI. +# Base config for DFlash training. A full modelopt recipe; override fields via +# OmegaConf dotlist on the CLI (e.g. `model.model_name_or_path=...`). + +metadata: + recipe_type: speculative_dflash + description: DFlash training recipe (model/data/training/dflash bundled). # maps to ModelArguments (main.py) model: @@ -18,7 +23,6 @@ data: # maps to TrainingArguments (main.py) training: # --- commonly modified --- - mode: dflash output_dir: num_train_epochs: 10 per_device_train_batch_size: 1 diff --git a/modelopt_recipes/general/speculative_decoding/eagle3.yaml b/modelopt_recipes/general/speculative_decoding/eagle3.yaml index a1b7ff77708..78767ad1ebb 100644 --- a/modelopt_recipes/general/speculative_decoding/eagle3.yaml +++ b/modelopt_recipes/general/speculative_decoding/eagle3.yaml @@ -1,4 +1,9 @@ -# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI. +# Base config for EAGLE3 training. A full modelopt recipe; override fields via +# OmegaConf dotlist on the CLI (e.g. `model.model_name_or_path=...`). + +metadata: + recipe_type: speculative_eagle + description: EAGLE3 training recipe (model/data/training/eagle bundled). # maps to ModelArguments (main.py) model: @@ -17,7 +22,6 @@ data: # maps to TrainingArguments (main.py) training: # --- commonly modified --- - mode: eagle3 output_dir: num_train_epochs: 1 per_device_train_batch_size: 1 diff --git a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py index 8bdf3f5e659..6e0c56bfd1d 100644 --- a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py +++ b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py @@ -47,6 +47,7 @@ ("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True, False), ("int8_wo", "tiny_llama-int8-wo", False, False, False, False, False), ("nvfp4_svdquant", "tiny_llama-nvfp4-svdquant", True, False, True, True, True), + ("w4a16_nvfp4", "tiny_llama-w4a16-nvfp4", False, False, False, False, False), # MoE models (fused experts: Qwen3 MoE, GPT-OSS) ("nvfp4", "tiny_qwen3_moe-nvfp4", True, False, True, True, False), ("fp8", "tiny_gpt_oss-fp8", True, False, True, True, False), diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 738dfc268ca..ce241150a3b 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -19,8 +19,13 @@ import pytest -from modelopt.recipe.config import ModelOptPTQRecipe, RecipeType -from modelopt.recipe.loader import load_config, load_recipe +from modelopt.recipe.config import ( + ModelOptDFlashRecipe, + ModelOptEagleRecipe, + ModelOptPTQRecipe, + RecipeType, +) +from modelopt.recipe.loader import _apply_dotlist, load_config, load_recipe # --------------------------------------------------------------------------- # Static YAML fixtures @@ -41,6 +46,10 @@ quantize: {} """ +CFG_RECIPE_MISSING_METADATA = """\ +quantize: {} +""" + CFG_RECIPE_MISSING_quantize = """\ metadata: recipe_type: ptq @@ -49,6 +58,7 @@ CFG_RECIPE_UNSUPPORTED_TYPE = """\ metadata: recipe_type: unknown_type +quantize: {} """ QUANTIZER_ATTRIBUTE_SCHEMA = ( @@ -171,18 +181,27 @@ def test_load_recipe_missing_recipe_type_raises(tmp_path): def test_load_recipe_missing_quantize_raises(tmp_path): - """load_recipe raises ValueError when quantize is absent for a PTQ recipe.""" + """A PTQ recipe missing the ``quantize`` section is rejected (no silent default).""" bad = tmp_path / "bad.yml" bad.write_text(CFG_RECIPE_MISSING_quantize) with pytest.raises(ValueError, match="quantize"): load_recipe(bad) +def test_load_recipe_missing_metadata_raises(tmp_path): + """A recipe missing the ``metadata`` section is rejected (no silent default).""" + bad = tmp_path / "bad.yml" + bad.write_text(CFG_RECIPE_MISSING_METADATA) + with pytest.raises(ValueError, match="metadata"): + load_recipe(bad) + + def test_load_recipe_unsupported_type_raises(tmp_path): """load_recipe raises ValueError for an unknown recipe_type.""" bad = tmp_path / "bad.yml" bad.write_text(CFG_RECIPE_UNSUPPORTED_TYPE) - with pytest.raises(ValueError, match="Unsupported recipe type"): + # Schema-driven validation reports the failure via the metadata schema's enum check. + with pytest.raises(ValueError, match="recipe_type"): load_recipe(bad) @@ -216,6 +235,223 @@ def test_load_recipe_dir_missing_quantize_raises(tmp_path): load_recipe(tmp_path) +# --------------------------------------------------------------------------- +# load_recipe — EAGLE speculative decoding +# --------------------------------------------------------------------------- + + +def test_load_recipe_eagle_builtin(): + """load_recipe loads the built-in EAGLE recipe and returns a ModelOptEagleRecipe.""" + recipe = load_recipe("general/speculative_decoding/eagle3") + assert recipe.recipe_type == RecipeType.SPECULATIVE_EAGLE + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.eagle.eagle_decoder_type == "llama" + assert recipe.eagle.eagle_ttt_steps == 3 + # Full-pipeline recipe also carries typed HF trainer sections. + assert recipe.training.training_seq_len == 2048 + + +def test_load_recipe_eagle_missing_section_raises(tmp_path): + """load_recipe raises ValueError when 'eagle' is absent for a SPECULATIVE_EAGLE recipe.""" + bad = tmp_path / "bad.yml" + bad.write_text("metadata:\n recipe_type: speculative_eagle\n") + with pytest.raises(ValueError, match="eagle"): + load_recipe(bad) + + +def test_load_recipe_eagle_field_validation_raises(tmp_path): + """Invalid EAGLE field values must fail Pydantic validation at load time.""" + bad = tmp_path / "bad.yml" + bad.write_text( + "metadata:\n recipe_type: speculative_eagle\neagle:\n eagle_ttt_steps: not_an_int\n" + ) + with pytest.raises(Exception): # pydantic.ValidationError + load_recipe(bad) + + +# --------------------------------------------------------------------------- +# load_recipe — DFlash speculative decoding +# --------------------------------------------------------------------------- + + +def test_load_recipe_dflash_builtin(): + """load_recipe loads the built-in DFlash recipe and returns a ModelOptDFlashRecipe.""" + recipe = load_recipe("general/speculative_decoding/dflash") + assert recipe.recipe_type == RecipeType.SPECULATIVE_DFLASH + assert isinstance(recipe, ModelOptDFlashRecipe) + assert recipe.dflash.dflash_block_size == 8 + assert recipe.dflash.dflash_num_anchors == 512 + # Full-pipeline recipe also carries typed HF trainer sections. + assert recipe.training.training_seq_len == 4096 + + +def test_load_recipe_dflash_missing_section_raises(tmp_path): + """load_recipe raises ValueError when 'dflash' is absent for a SPECULATIVE_DFLASH recipe.""" + bad = tmp_path / "bad.yml" + bad.write_text("metadata:\n recipe_type: speculative_dflash\n") + with pytest.raises(ValueError, match="dflash"): + load_recipe(bad) + + +def test_load_recipe_eagle_with_training_sections(tmp_path): + """load_recipe populates typed HF trainer sections from all four YAML segments.""" + recipe_path = tmp_path / "eagle.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "model:\n model_name_or_path: TinyLlama/TinyLlama-1.1B-Chat-v1.0\n" + "data:\n data_path: train.jsonl\n" + "training:\n output_dir: ckpts/test\n" + "eagle:\n eagle_decoder_type: llama\n eagle_ttt_steps: 2\n" + ) + recipe = load_recipe(recipe_path) + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.model.model_name_or_path == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + assert recipe.data.data_path == "train.jsonl" + # output_dir is an HF-trainer extra; flows through extras. + assert recipe.training.model_dump()["output_dir"] == "ckpts/test" + assert recipe.eagle.eagle_ttt_steps == 2 + + +def test_typed_model_section_rejects_unknown_field(tmp_path): + """model section has extra='forbid'; unknown keys raise ValidationError at load time.""" + recipe_path = tmp_path / "bad.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "model:\n typo_name: oops\n" + "eagle:\n eagle_decoder_type: llama\n" + ) + with pytest.raises(Exception): # pydantic.ValidationError + load_recipe(recipe_path) + + +def test_typed_training_section_accepts_hf_extras(tmp_path): + """training section has extra='allow'; HF trainer fields flow through without validation.""" + recipe_path = tmp_path / "eagle.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "training:\n" + " num_train_epochs: 3\n" # HF field — accepted as extra + " learning_rate: 1.0e-4\n" # HF field — accepted as extra + " training_seq_len: 4096\n" # our extension field — validated + "eagle:\n eagle_decoder_type: llama\n" + ) + recipe = load_recipe(recipe_path) + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.training.training_seq_len == 4096 + dumped = recipe.training.model_dump() + assert dumped["num_train_epochs"] == 3 + assert dumped["learning_rate"] == 1e-4 + + +# --------------------------------------------------------------------------- +# CLI-style dotlist overrides +# --------------------------------------------------------------------------- + + +def test_apply_dotlist_flat(): + """_apply_dotlist sets a top-level key and parses the value with yaml.safe_load.""" + result = _apply_dotlist({"a": 1}, ["b=2"]) + assert result == {"a": 1, "b": 2} + + +def test_apply_dotlist_nested_overwrite(): + """_apply_dotlist overwrites a nested key without mutating input.""" + original = {"model": {"trust_remote_code": False}} + result = _apply_dotlist(original, ["model.trust_remote_code=true"]) + assert result["model"]["trust_remote_code"] is True + assert original["model"]["trust_remote_code"] is False # input untouched + + +def test_apply_dotlist_creates_missing_path(): + """_apply_dotlist creates intermediate dicts when the path doesn't exist.""" + result = _apply_dotlist({}, ["a.b.c=42"]) + assert result == {"a": {"b": {"c": 42}}} + + +def test_apply_dotlist_parses_typed_values(): + """_apply_dotlist preserves yaml.safe_load's type inference.""" + result = _apply_dotlist( + {}, + [ + "int_v=7", + "float_v=1.5", + "bool_v=true", + "null_v=null", + "list_v=[1, 2, 3]", + "str_v=hello", + ], + ) + assert result == { + "int_v": 7, + "float_v": 1.5, + "bool_v": True, + "null_v": None, + "list_v": [1, 2, 3], + "str_v": "hello", + } + + +def test_apply_dotlist_scientific_notation(): + """OmegaConf parses ``1e-4`` as float natively (unlike yaml.safe_load in YAML 1.1 mode).""" + result = _apply_dotlist({}, ["lr=5e-5", "decay=1e-10", "still_str=hello"]) + assert result["lr"] == 5e-5 and isinstance(result["lr"], float) + assert result["decay"] == 1e-10 and isinstance(result["decay"], float) + assert result["still_str"] == "hello" # non-numeric strings stay as strings + + +def test_apply_dotlist_malformed_raises(): + """_apply_dotlist rejects entries missing the '=' separator.""" + with pytest.raises(ValueError, match="missing '='"): + _apply_dotlist({}, ["foo_no_equals"]) + + +def test_load_recipe_with_overrides(tmp_path): + """load_recipe(path, overrides=...) merges dotlist entries before Pydantic validation.""" + recipe_path = tmp_path / "recipe.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "model:\n trust_remote_code: false\n" + "eagle:\n eagle_ttt_steps: 3\n" + ) + recipe = load_recipe( + recipe_path, + overrides=["model.trust_remote_code=true", "eagle.eagle_ttt_steps=7"], + ) + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.model.trust_remote_code is True + assert recipe.eagle.eagle_ttt_steps == 7 + + +def test_load_recipe_overrides_rejected_for_dir(tmp_path): + """Overrides are not allowed for directory-format recipes.""" + (tmp_path / "recipe.yml").write_text("metadata:\n recipe_type: ptq\n") + (tmp_path / "quantize.yml").write_text("algorithm: max\nquant_cfg: []\n") + with pytest.raises(ValueError, match="directory-format"): + load_recipe(tmp_path, overrides=["quantize.algorithm=gptq"]) + + +def test_typed_data_sample_size_validator(tmp_path): + """DataArguments rejects sample_size=0 via field_validator.""" + recipe_path = tmp_path / "bad.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "data:\n sample_size: 0\n" + "eagle:\n eagle_decoder_type: llama\n" + ) + with pytest.raises(Exception, match="sample_size"): # pydantic.ValidationError + load_recipe(recipe_path) + + +def test_load_recipe_dflash_field_validation_raises(tmp_path): + """Invalid DFlash field values must fail Pydantic validation at load time.""" + bad = tmp_path / "bad.yml" + bad.write_text( + "metadata:\n recipe_type: speculative_dflash\ndflash:\n dflash_block_size: not_an_int\n" + ) + with pytest.raises(Exception): # pydantic.ValidationError + load_recipe(bad) + + # --------------------------------------------------------------------------- # YAML recipe consistency — built-in general/ptq files match config.py dicts # --------------------------------------------------------------------------- @@ -303,7 +539,7 @@ def test_import_resolves_cfg_reference(tmp_path): ) recipe = load_recipe(recipe_file) entry = recipe.quantize["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3), "axis": None} + assert entry["cfg"].model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": None} def test_import_same_name_used_twice(tmp_path): @@ -376,7 +612,10 @@ def test_import_inline_cfg_not_affected(tmp_path): f" axis: 0\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": 8, "axis": 0} + assert recipe.quantize["quant_cfg"][1]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": 8, + "axis": 0, + } def test_import_unknown_reference_raises(tmp_path): @@ -512,7 +751,15 @@ def test_import_entry_element_schema_appends(tmp_path): f" - $import: disable_all\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"] == [{"quantizer_name": "*", "cfg": None, "enable": False}] + # Entry was loaded against the QuantizerCfgEntry pydantic schema, so it is now a + # model instance — compare via model_dump for the dict-shape check. + assert len(recipe.quantize["quant_cfg"]) == 1 + assert recipe.quantize["quant_cfg"][0].model_dump() == { + "quantizer_name": "*", + "parent_class": None, + "cfg": None, + "enable": False, + } def test_import_entry_wrong_schema_raises(tmp_path): @@ -597,7 +844,7 @@ def test_import_cfg_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_inline_overrides_import(tmp_path): @@ -660,6 +907,7 @@ def test_import_in_multiple_dict_values(tmp_path): ) data = load_config(config_file) entry = data["quant_cfg"][0] + # load_config has no schema here — data is a raw dict tree, so entry["cfg"] is a dict. assert entry["cfg"] == {"num_bits": (4, 3)} assert entry["my_field"] == {"fake_quant": False} @@ -684,7 +932,7 @@ def test_import_cfg_multi_import(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_multi_import_later_overrides_earlier(tmp_path): @@ -733,7 +981,11 @@ def test_import_cfg_multi_import_with_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} + assert cfg.model_dump(exclude_unset=True) == { + "num_bits": (4, 3), + "fake_quant": False, + "axis": 0, + } def test_import_dir_format(tmp_path): @@ -750,7 +1002,10 @@ def test_import_dir_format(tmp_path): " $import: fp8\n" ) recipe = load_recipe(tmp_path) - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3), + "axis": None, + } def test_import_dir_format_metadata_imports_do_not_apply_to_quantize(tmp_path): @@ -804,7 +1059,9 @@ def test_import_multi_document_list_snippet(tmp_path): recipe = load_recipe(recipe_file) assert len(recipe.quantize["quant_cfg"]) == 1 assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer" - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3) + } def test_import_builtin_kv_fp8_snippet(): @@ -853,7 +1110,8 @@ def test_import_list_splice_outside_typed_list_raises(tmp_path): """A bare $import in an untyped list is rejected.""" _write_quantizer_cfg_list( tmp_path / "extra_tasks.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( @@ -915,9 +1173,16 @@ def test_import_mixed_tree(tmp_path): ) data = load_config(config_file) # Dict import inside list entry - assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} - # List splice - assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False} + assert data["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": (4, 3)} + # List splice — entries are normalized by QuantizeConfig.quant_cfg's validator, + # which fills in defaults for missing ``enable`` / ``cfg`` keys. Entries are now + # QuantizerCfgEntry pydantic instances, so compare via model_dump. + assert data["quant_cfg"][1].model_dump() == { + "quantizer_name": "*lm_head*", + "parent_class": None, + "enable": False, + "cfg": None, + } # --------------------------------------------------------------------------- @@ -956,7 +1221,7 @@ def test_import_recursive(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3)} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3)} def test_import_circular_raises(tmp_path): @@ -1056,9 +1321,14 @@ def test_import_cross_file_same_name_no_conflict(tmp_path): ) recipe = load_recipe(recipe_file) # Parent's "fmt" resolves to fp8 (e4m3), not child's nvfp4. - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3) + } # Child's "fmt" resolves to nvfp4 (e2m1), not parent's fp8. - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": (2, 1), "axis": 0} + assert recipe.quantize["quant_cfg"][1]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (2, 1), + "axis": 0, + } # --------------------------------------------------------------------------- @@ -1089,8 +1359,10 @@ def test_builtin_config_snippets_with_modelopt_schema(config_path): assert data -def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): - """modelopt-schema validates the resolved payload but load_config still returns a plain dict.""" +def test_modelopt_schema_comment_returns_instance(tmp_path): + """A ``modelopt-schema`` comment makes load_config return an instance of that schema.""" + from modelopt.torch.quantization.config import QuantizerAttributeConfig + config_file = tmp_path / "fp8.yaml" config_file.write_text( "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" @@ -1098,7 +1370,9 @@ def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): "axis:\n" ) data = load_config(config_file) - assert data == {"num_bits": (4, 3), "axis": None} + assert isinstance(data, QuantizerAttributeConfig) + assert data.num_bits == (4, 3) + assert data.axis is None def test_modelopt_schema_comment_validation_error(tmp_path): @@ -1145,7 +1419,13 @@ def test_modelopt_schema_comment_validates_after_import_resolution(tmp_path): f" $import: fp8\n" ) data = load_config(config_file) - assert data == [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3)}}] + # data is a list of QuantizerCfgEntry pydantic instances, not raw dicts. Dump with + # exclude_unset=True so the inner QuantizerAttributeConfig stays sparse (cascades). + assert len(data) == 1 + assert data[0].model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": (4, 3)}, + } # --------------------------------------------------------------------------- @@ -1250,7 +1530,13 @@ def test_load_config_list_valued_yaml(tmp_path): data = load_config(cfg_file) assert isinstance(data, list) assert len(data) == 2 - assert data[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}} + # Entries are QuantizerCfgEntry pydantic instances after schema validation; dump + # with exclude_unset=True so the inner QuantizerAttributeConfig stays in sparse + # form (pydantic cascades exclude_unset to nested models). + assert data[0].model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8}, + } # --------------------------------------------------------------------------- @@ -1262,7 +1548,8 @@ def test_import_dict_value_resolves_to_list_raises(tmp_path): """$import in dict value position raises when snippet is a list.""" _write_quantizer_cfg_list( tmp_path / "entries.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( diff --git a/tests/unit/torch/opt/test_config.py b/tests/unit/torch/opt/test_config.py index b2ffadb1a78..e0c5993a51a 100644 --- a/tests/unit/torch/opt/test_config.py +++ b/tests/unit/torch/opt/test_config.py @@ -72,7 +72,7 @@ def _run_test(is_new_registered): assert config[lin_name] == lin_expected_value assert config[lin_alias] == lin_expected_value assert getattr(config, lin_name) == lin_expected_value - with nullcontext() if is_new_registered else pytest.raises(AttributeError): + with nullcontext() if is_new_registered else pytest.raises(KeyError): config[new_name] # get diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 2f64dab6cea..833ee277211 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -256,27 +256,51 @@ def test_expert_index_recovery(self): # Tests for export # --------------------------------------------------------------------------- class TestExportFusedExperts: + @staticmethod + def _cleanup_registry(mod_type): + if QuantModuleRegistry.get(mod_type) is not None: + QuantModuleRegistry.unregister(mod_type) + def test_export_creates_per_expert_submodules(self): """_export_fused_experts should create per-expert submodules with standard naming.""" + import modelopt.torch.quantization as mtq from modelopt.torch.export.moe_utils import _export_fused_experts - experts = _SyntheticFusedExperts() - expert_type = type(experts) + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) - # Manually register and convert - if QuantModuleRegistry.get(expert_type) is None: - QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})( - _QuantFusedExperts - ) - converted = QuantModuleRegistry.convert(experts) + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + { + "quantizer_name": "*down_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } - # Run a forward pass to calibrate (set amaxes) - seq_len = 16 - hidden_states = torch.randn(seq_len, HIDDEN_DIM) - top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K)) - top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1) - with torch.no_grad(): - converted(hidden_states, top_k_index, top_k_weights) + def forward_loop(m): + torch.manual_seed(0) + for _ in range(2): + x = torch.randn(1, 4, HIDDEN_DIM) + m(x) + + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + converted = model.moe.experts _export_fused_experts(converted, torch.float16) @@ -297,8 +321,7 @@ def test_export_creates_per_expert_submodules(self): assert not hasattr(converted, "down_proj") assert not hasattr(converted, "gate_up_proj_weight_quantizers") - if QuantModuleRegistry.get(expert_type) is not None: - QuantModuleRegistry.unregister(expert_type) + self._cleanup_registry(expert_type) def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch): """gate_proj and up_proj must share weight_scale_2 even when an expert diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 84306dc5116..ce98f989f51 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -81,7 +81,7 @@ def test_new_format_passthrough(self): result = normalize_quant_cfg_list(raw) assert len(result) == 1 assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_new_format_enable_false(self): @@ -103,7 +103,7 @@ def test_legacy_single_key_dict(self): raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_legacy_single_key_dict_with_enable(self): @@ -166,57 +166,101 @@ def test_error_on_multi_key_legacy_dict(self): def test_error_on_empty_cfg_dict_implicit_enable(self): """Entry with cfg={} and implicit enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list([{"quantizer_name": "*weight_quantizer", "cfg": {}}]) def test_error_on_empty_cfg_dict_explicit_enable_true(self): """Entry with cfg={} and explicit enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": {}, "enable": True}] ) def test_error_on_empty_cfg_list_enable_true(self): """Entry with cfg=[] and enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [], "enable": True}] ) def test_error_on_non_dict_non_list_cfg_enable_true(self): - """Entry with cfg of invalid type (e.g. int) and enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + """Entry with cfg of invalid type (e.g. int) and enable=True is rejected. + + Two error paths are acceptable here, and the assertion accepts either: + pydantic's field-type check (``cfg`` must be a dict or list) fires first when + ``cfg`` is the wrong python type, while ``QuantizerCfgEntry``'s model validator + emits the "non-empty dict" message when ``cfg`` is the right type but empty. + Either way the message must implicate the ``cfg`` field, not just any + ``ValueError``. + """ + with pytest.raises(ValueError, match=r"(?s)cfg.*(non-empty|valid dictionary|valid list)"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": 42, "enable": True}] ) def test_error_on_cfg_list_with_empty_dict_enable_true(self): """Entry with cfg=[{}] and enable=True is rejected (empty dict element).""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [{}], "enable": True}] ) def test_error_on_cfg_list_with_non_dict_element_enable_true(self): - """Entry with cfg=[42] and enable=True is rejected (non-dict element).""" - with pytest.raises(ValueError, match="non-empty dict"): + """Entry with cfg=[42] and enable=True is rejected. + + Same dual-path acceptance as :meth:`test_error_on_non_dict_non_list_cfg_enable_true`: + pydantic may report a list-element type error, or the model validator may report + "non-empty dict"; the assertion accepts either as long as the message names the + ``cfg`` field. + """ + with pytest.raises(ValueError, match=r"(?s)cfg.*(non-empty|valid dictionary|valid list)"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [42], "enable": True}] ) - def test_empty_cfg_dict_enable_false_accepted(self): - """Entry with cfg={} and enable=False is allowed (disable-only entry).""" + def test_empty_cfg_dict_enable_false_normalized_to_none(self): + """Entry with cfg={} and enable=False is normalised to cfg=None (disable-only). + + A non-``None`` cfg is applied as a full quantizer-attribute replacement, so an + empty cfg paired with enable=False would silently reset the quantizer's + attributes. Normalisation to ``None`` makes the entry behave like a pure + disable, preserving the existing attribute config. + """ result = normalize_quant_cfg_list( [{"quantizer_name": "*input_quantizer", "cfg": {}, "enable": False}] ) assert result[0]["enable"] is False + assert result[0]["cfg"] is None - def test_empty_cfg_list_enable_false_accepted(self): - """Entry with cfg=[] and enable=False is allowed (disable-only entry).""" + def test_empty_cfg_list_enable_false_normalized_to_none(self): + """Entry with cfg=[] and enable=False is normalised to cfg=None.""" result = normalize_quant_cfg_list( [{"quantizer_name": "*input_quantizer", "cfg": [], "enable": False}] ) assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_cfg_list_of_empty_dicts_enable_false_normalized_to_none(self): + """Entry with cfg=[{}] and enable=False is normalised to cfg=None.""" + result = normalize_quant_cfg_list( + [{"quantizer_name": "*input_quantizer", "cfg": [{}], "enable": False}] + ) + assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_nonempty_cfg_enable_false_preserved(self): + """Entry with a non-empty cfg and enable=False keeps the cfg (disable+replace).""" + result = normalize_quant_cfg_list( + [ + { + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": 4}, + "enable": False, + } + ] + ) + assert result[0]["enable"] is False + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4} def test_new_format_with_list_cfg(self): """cfg can be a list of dicts for SequentialQuantizer.""" @@ -231,7 +275,7 @@ def test_new_format_with_list_cfg(self): ] result = normalize_quant_cfg_list(raw) assert len(result) == 1 - assert result[0]["cfg"] == raw[0]["cfg"] + assert [c.model_dump(exclude_unset=True) for c in result[0]["cfg"]] == raw[0]["cfg"] assert result[0]["enable"] is True def test_legacy_flat_dict_conversion(self): @@ -243,7 +287,7 @@ def test_legacy_flat_dict_conversion(self): assert result[0]["enable"] is False assert result[0]["cfg"] is None assert result[1]["quantizer_name"] == "*weight_quantizer" - assert result[1]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[1]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[1]["enable"] is True def test_legacy_enable_only_produces_cfg_none(self): @@ -274,7 +318,7 @@ def test_legacy_default_key_with_cfg(self): raw = [{"default": {"num_bits": 8, "axis": None}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*" - assert result[0]["cfg"] == {"num_bits": 8, "axis": None} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": None} assert result[0]["enable"] is True def test_legacy_flat_dict_with_default_key(self): @@ -309,7 +353,7 @@ def test_legacy_nn_class_with_cfg(self): assert len(result) == 1 assert result[0]["parent_class"] == "nn.Linear" assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 4, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4, "axis": 0} assert result[0]["enable"] is True def test_legacy_list_valued_cfg(self): @@ -343,7 +387,7 @@ def test_finds_last_match(self): ] ) result = find_quant_cfg_entry_by_path(entries, "*weight_quantizer") - assert result["cfg"] == {"num_bits": 4} + assert result["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4} def test_exact_match_only(self): """Does not do fnmatch — only exact string equality on quantizer_name.""" @@ -400,7 +444,7 @@ def test_wildcard_matches_bare_name(self): [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 8} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 8} assert enable is True def test_star_matches_any_bare_name(self): @@ -420,7 +464,7 @@ def test_path_scoped_pattern_matches_matching_suffix(self): [{"quantizer_name": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 4} def test_path_scoped_pattern_does_not_match_different_suffix(self): """'*mlp*weight_quantizer' does NOT match bare 'input_quantizer'.""" @@ -444,7 +488,7 @@ def test_last_match_wins(self): ] ) matched, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 4} def test_no_match_returns_none(self): """No matching entry returns (None, None).""" diff --git a/tests/unit/torch/quantization/test_nvfp4_tensor.py b/tests/unit/torch/quantization/test_nvfp4_tensor.py new file mode 100644 index 00000000000..8523edc4052 --- /dev/null +++ b/tests/unit/torch/quantization/test_nvfp4_tensor.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow).""" + +from types import SimpleNamespace + +import torch + +from modelopt.torch.quantization.qtensor.nvfp4_tensor import ( + NVFP4QTensor, + _cast_per_block_scale_to_fp8, +) + +_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal +_FP8_E4M3FN_MAX = 448.0 + + +class TestNVFP4ScaleClamping: + """Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN.""" + + def test_no_zero_scales_for_tiny_weights(self): + """Tiny per-block amax (< 0).all(), ( + f"Zero per-block scales found after FP8 cast: {per_block_scale_f32.tolist()}. " + "FP8 scale underflow clamping likely regressed." + ) + assert (per_block_scale_f32 >= _FP8_E4M3FN_MIN).all(), ( + "Per-block scales with zero values found after FP8 cast " + "(below the FP8 E4M3FN subnormal minimum — clamp would have prevented this)." + ) + + def test_normal_weights_unaffected_by_clamp(self): + """Weights with typical magnitudes must not be affected by the underflow clamp.""" + block_size = 16 + torch.manual_seed(42) + normal_weight = torch.randn(8, block_size) + + per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(normal_weight, block_size) + assert (per_block_scale.float() > 0).all(), "Normal weights produced zero scales." + + def test_mixed_weight_no_zeros(self): + """Mixed-magnitude tensor (normal + tiny blocks) must have no zero scales.""" + block_size = 16 + weight = torch.cat( + [ + torch.randn(4, block_size), + torch.full((4, block_size), 1e-12), + ], + dim=0, + ) + + per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, block_size) + assert (per_block_scale.float() > 0).all(), ( + "Zero scales in mixed-magnitude tensor after FP8 cast." + ) + + def test_helper_clamps_overflow_to_max(self): + """Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf).""" + oversized = torch.tensor([100.0, 448.0, 1e3, 1e6]) + out = _cast_per_block_scale_to_fp8(oversized).float() + assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}" + assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}" + + def test_helper_clamps_underflow_to_min(self): + """Values below the FP8 subnormal must clamp up, not collapse to 0.""" + tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2]) + out = _cast_per_block_scale_to_fp8(tiny).float() + assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}" + + def test_static_path_no_nan_when_block_amax_zero(self): + """Static path: zero-amax block + small global_amax must clamp to 448, not cast to NaN.""" + block_size = 16 + # global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448. + global_amax = torch.tensor(0.01) + # One block with amax=0 (triggers safety net), three normal blocks. + per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]]) + weight = torch.randn(1, 4 * block_size) + q = SimpleNamespace( + global_amax=global_amax, + _amax=per_block_amax, + block_sizes={-1: block_size}, + ) + + per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight) + per_block_scale_f32 = per_block_scale.float() + assert torch.isfinite(per_block_scale_f32).all(), ( + f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}" + ) + assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), ( + f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}" + ) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 560c5712125..36975e8c08f 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -113,14 +113,6 @@ def test_convert_sets_mask_token_id(self): assert hasattr(model, "mask_token_id") assert model.mask_token_id == 0 - def test_convert_missing_mask_token_id_errors(self): - """Test that missing mask_token_id raises ValueError.""" - model = get_tiny_llama(num_hidden_layers=4) - config = _get_dflash_config() - del config["dflash_mask_token_id"] - with pytest.raises(ValueError, match="dflash_mask_token_id is required"): - mtsp.convert(model, [("dflash", config)]) - class TestDFlashSaveRestore: """Test DFlash model save and restore.""" diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py b/tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py index 2dbb704344b..d4e1c97b6b4 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py @@ -16,12 +16,12 @@ """CPU unit tests for DFlash offline training support.""" from copy import deepcopy -from types import SimpleNamespace from _test_utils.torch.transformers_models import get_tiny_llama import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG, DFlashConfig +from modelopt.recipe.config import ModelOptDFlashRecipe +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG NUM_BASE_LAYERS = 4 NUM_DRAFT_LAYERS = 2 @@ -73,22 +73,14 @@ def test_convert_offline_target_layer_ids_from_orig(): assert all(0 <= lid < num_orig for lid in model.target_layer_ids) -def test_dflash_config_derives_offline_from_data_args(): - """DFlashConfig._derive_dflash_offline flips the flag when data_args.offline_data_path is set.""" - data = {"dflash_mask_token_id": 0} +def test_dflash_recipe_derives_offline_from_data(): + """ModelOptDFlashRecipe._derive_dflash_offline flips dflash_offline based on data.offline_data_path.""" + dflash_section = {"dflash_mask_token_id": 0} # offline_data_path set → offline=True - cfg = DFlashConfig.model_validate( - data, context={"data_args": SimpleNamespace(offline_data_path="/fake/path")} - ) - assert cfg.dflash_offline is True - - # offline_data_path=None → offline=False - cfg = DFlashConfig.model_validate( - data, context={"data_args": SimpleNamespace(offline_data_path=None)} - ) - assert cfg.dflash_offline is False - - # No data_args in context → default (False) - cfg = DFlashConfig.model_validate(data) - assert cfg.dflash_offline is False + recipe = ModelOptDFlashRecipe(data={"offline_data_path": "/fake/path"}, dflash=dflash_section) + assert recipe.dflash.dflash_offline is True + + # offline_data_path absent → offline=False + recipe = ModelOptDFlashRecipe(dflash=dflash_section) + assert recipe.dflash.dflash_offline is False diff --git a/tests/unit/torch/speculative/test_eagle_config.py b/tests/unit/torch/speculative/test_eagle_config.py index c27074f3e37..2a9fab7a638 100644 --- a/tests/unit/torch/speculative/test_eagle_config.py +++ b/tests/unit/torch/speculative/test_eagle_config.py @@ -15,12 +15,12 @@ """Tests for EagleConfig model validators.""" -import types import warnings import pytest from pydantic import ValidationError +from modelopt.recipe.config import ModelOptEagleRecipe from modelopt.torch.speculative.config import EagleConfig # --- rope scaling consistency validator tests --- @@ -73,57 +73,44 @@ def test_rope_consistency_ok_empty_export_rope(): EagleConfig.model_validate(cfg) -# --- rope vs training_seq_len warning tests --- +# --- rope vs training_seq_len warning tests (on ModelOptEagleRecipe, where the validator lives) --- -def _make_training_args(training_seq_len: int): - return types.SimpleNamespace(training_seq_len=training_seq_len) +_RopeMismatchMsg = "differs from training" -def test_warn_rope_mismatch(): - """Warning should fire when original_max_position_embeddings != training_seq_len.""" - cfg = { - "eagle_export_rope_scaling": { - "rope_type": "yarn", - "factor": 32.0, - "original_max_position_embeddings": 2048, - }, +def _yarn_rope(orig_max_pos: int) -> dict: + return { + "rope_type": "yarn", + "factor": 32.0, + "original_max_position_embeddings": orig_max_pos, } - with pytest.warns(UserWarning, match="differs from training_seq_len"): - EagleConfig.model_validate(cfg, context={"training_args": _make_training_args(4096)}) -def test_no_warn_rope_match(): - """No warning when original_max_position_embeddings == training_seq_len.""" - cfg = { - "eagle_export_rope_scaling": { - "rope_type": "yarn", - "factor": 32.0, - "original_max_position_embeddings": 2048, - }, - } - with warnings.catch_warnings(): - warnings.simplefilter("error") - EagleConfig.model_validate(cfg, context={"training_args": _make_training_args(2048)}) +def test_warn_rope_mismatch(): + """Warning fires when original_max_position_embeddings != training.training_seq_len.""" + with pytest.warns(UserWarning, match=_RopeMismatchMsg): + ModelOptEagleRecipe( + eagle={"eagle_export_rope_scaling": _yarn_rope(2048)}, + training={"training_seq_len": 4096}, + ) -def test_no_warn_without_context(): - """No warning when context is not provided (e.g. inside convert chain).""" +def test_no_warn_rope_match(): + """No warning when original_max_position_embeddings == training.training_seq_len.""" with warnings.catch_warnings(): - warnings.simplefilter("error") - EagleConfig.model_validate({}) + warnings.simplefilter("error", UserWarning) + ModelOptEagleRecipe( + eagle={"eagle_export_rope_scaling": _yarn_rope(2048)}, + training={"training_seq_len": 2048}, + ) def test_no_warn_missing_orig_max_pos(): """No warning when original_max_position_embeddings is absent from rope scaling config.""" - cfg = {"eagle_export_rope_scaling": {}} - with warnings.catch_warnings(): - warnings.simplefilter("error") - EagleConfig.model_validate(cfg, context={"training_args": _make_training_args(4096)}) - - -def test_no_warn_empty_context(): - """No warning when context dict has no training_args key.""" with warnings.catch_warnings(): - warnings.simplefilter("error") - EagleConfig.model_validate({}, context={}) + warnings.simplefilter("error", UserWarning) + ModelOptEagleRecipe( + eagle={"eagle_export_rope_scaling": {}}, + training={"training_seq_len": 4096}, + ) diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index f89663d89b5..812d2cd9c3b 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -17,6 +17,7 @@ import pytest import torch +from huggingface_hub import get_token from torch.utils.data import DataLoader from modelopt.torch.utils.dataset_utils import ( @@ -609,6 +610,89 @@ def test_length_mismatch_raises(self, tmp_path, pad_tokenizer): ) +class TestDatasetCombosExpansion: + """Combo names in ``--dataset`` fan out to their registered members. + + The combo branch in ``get_dataset_dataloader`` is exercised by mocking + ``get_dataset_samples`` so we can assert the post-expansion (name, count) + sequence without hitting HF. + """ + + @staticmethod + def _record_calls(monkeypatch): + calls: list[tuple[str, int]] = [] + + def _fake(name, num_sample, **_kwargs): + calls.append((name, num_sample)) + return [f"{name}-{i}" for i in range(num_sample)] + + from modelopt.torch.utils import dataset_utils + + monkeypatch.setattr(dataset_utils, "get_dataset_samples", _fake) + return calls + + def test_combo_expands_evenly(self, monkeypatch, pad_tokenizer): + from modelopt.torch.utils.dataset_utils import DATASET_COMBOS + + calls = self._record_calls(monkeypatch) + get_dataset_dataloader( + dataset_name="cnn_nemotron_v2_mix", + tokenizer=pad_tokenizer, + num_samples=8, + batch_size=1, + max_sample_length=16, + ) + members = DATASET_COMBOS["cnn_nemotron_v2_mix"] + assert calls == [(members[0], 4), (members[1], 4)] + + def test_combo_remainder_distributed_to_earlier_members(self, monkeypatch, pad_tokenizer): + from modelopt.torch.utils.dataset_utils import DATASET_COMBOS + + calls = self._record_calls(monkeypatch) + get_dataset_dataloader( + dataset_name="nemotron-post-training-v3", + tokenizer=pad_tokenizer, + num_samples=10, + batch_size=1, + max_sample_length=16, + ) + members = DATASET_COMBOS["nemotron-post-training-v3"] + # 10 / 7 = 1 base, remainder 3 -> first 3 get +1 + expected_counts = [2, 2, 2, 1, 1, 1, 1] + assert calls == list(zip(members, expected_counts)) + + def test_plain_and_combo_compose(self, monkeypatch, pad_tokenizer): + from modelopt.torch.utils.dataset_utils import DATASET_COMBOS + + calls = self._record_calls(monkeypatch) + get_dataset_dataloader( + dataset_name=["cnn_dailymail", "nemotron-post-training-v3"], + tokenizer=pad_tokenizer, + num_samples=[3, 7], + batch_size=1, + max_sample_length=16, + ) + members = DATASET_COMBOS["nemotron-post-training-v3"] + assert calls == [("cnn_dailymail", 3)] + [(m, 1) for m in members] + + def test_combo_overlapping_with_member_raises(self, monkeypatch, pad_tokenizer): + self._record_calls(monkeypatch) + with pytest.raises(ValueError, match="combo 'cnn_nemotron_v2_mix'"): + get_dataset_dataloader( + dataset_name=["cnn_dailymail", "cnn_nemotron_v2_mix"], + tokenizer=pad_tokenizer, + num_samples=[2, 4], + batch_size=1, + max_sample_length=16, + ) + + def test_get_dataset_samples_rejects_combo_name(self): + from modelopt.torch.utils.dataset_utils import get_dataset_samples + + with pytest.raises(ValueError, match="DATASET_COMBOS"): + get_dataset_samples("cnn_nemotron_v2_mix", num_samples=1) + + # --------------------------------------------------------------------------- # Live HF dataset round-trips. ``hf-internal-testing/dataset_with_data_files`` # is a 10-row x {train,test} fixture maintained by HF for their own CI — tiny @@ -689,3 +773,57 @@ def test_dataloader_mixing_hf_and_local_jsonl(self, tmp_path, pad_tokenizer): ) batches = list(loader) assert sum(b["input_ids"].shape[0] for b in batches) == 5 + + +_NEW_NEMOTRON_KEYS = [ + "nemotron-sft-instruction-following-chat-v2", + "nemotron-science-v1", + "nemotron-competitive-programming-v1", + "nemotron-sft-agentic-v2", + "nemotron-math-v2", + "nemotron-sft-swe-v2", + "nemotron-sft-multilingual-v1", +] + + +@pytest.mark.parametrize("dataset_key", _NEW_NEMOTRON_KEYS) +def test_new_nemotron_registry_shape(dataset_key): + """Always-on shape check on the 7 newly registered nvidia/Nemotron-* entries. + + Complements the gated smoke test below — catches typos in dataset paths or + split names even when the runner has no HF credentials. + """ + from modelopt.torch.utils.dataset_utils import SUPPORTED_DATASET_CONFIG + + assert dataset_key in SUPPORTED_DATASET_CONFIG + entry = SUPPORTED_DATASET_CONFIG[dataset_key] + config = entry["config"] + assert config["path"].startswith("nvidia/Nemotron-") + splits = config["split"] + assert isinstance(splits, list) and splits + assert all(isinstance(s, str) and s for s in splits) + assert len(set(splits)) == len(splits) + assert callable(entry["preprocess"]) + assert entry["chat_key"] == "messages" + + +@pytest.mark.integration +@pytest.mark.parametrize("dataset_key", _NEW_NEMOTRON_KEYS) +def test_get_dataset_samples_new_nemotron(dataset_key): + """Smoke-test the 7 newly registered nvidia/Nemotron-* calibration datasets. + + Skipped when no HF token is available because these datasets live behind the HF Hub. + ``huggingface_hub.get_token()`` covers both the ``HF_TOKEN`` env var and tokens + cached by ``hf auth login``. + """ + pytest.importorskip("datasets") + if not get_token(): + pytest.skip( + "No HF token (env HF_TOKEN or `hf auth login`); skipping gated Nemotron smoke test" + ) + + samples = get_dataset_samples(dataset_key, num_samples=2) + + assert isinstance(samples, list) + assert len(samples) == 2 + assert all(isinstance(s, str) and len(s) > 0 for s in samples) diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml index ff55a92e39f..6ae64fc1ff4 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml @@ -28,7 +28,7 @@ pipeline: calib_dataset: abisee/cnn_dailymail calib_size: 32 mmlu_dataset: cais/mmlu - mmlu_lower_bound: 0.68 + mmlu_lower_bound: 0.75 hf_local: /hf-local/ slurm_config: _factory_: "slurm_factory" diff --git a/uv.lock b/uv.lock index 0f6d66eb661..3f5ec268f35 100644 --- a/uv.lock +++ b/uv.lock @@ -49,7 +49,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "packaging" }, { name = "psutil" }, { name = "pyyaml" }, @@ -444,14 +444,14 @@ wheels = [ [[package]] name = "click" -version = "8.3.3" +version = "8.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bb/63/f9e1ea081ce35720d8b92acde70daaedace594dc93b693c869e0d5910718/click-8.3.3.tar.gz", hash = "sha256:398329ad4837b2ff7cbe1dd166a4c0f8900c3ca3a218de04466f38f6497f18a2", size = 328061, upload-time = "2026-04-22T15:11:27.506Z" } +sdist = { url = "https://files.pythonhosted.org/packages/23/e4/796662cd90cf80e3a363c99db2b88e0e394b988a575f60a17e16440cd011/click-8.4.0.tar.gz", hash = "sha256:638f1338fe1235c8f4e008e4a8a254fb5c5fbdcbb40ece3c9142ebb78e792973", size = 350843, upload-time = "2026-05-17T00:47:58.425Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl", hash = "sha256:a2bf429bb3033c89fa4936ffb35d5cb471e3719e1f3c8a7c3fff0b8314305613", size = 110502, upload-time = "2026-04-22T15:11:25.044Z" }, + { url = "https://files.pythonhosted.org/packages/ee/ae/8e92f8058baf87f6c7d86ee7e457668690195cc77efedb8d3797a06e3940/click-8.4.0-py3-none-any.whl", hash = "sha256:40c50b7c6c6adac2823d411041ec84f3f103f1b280d5e9ce0d7f998995832f81", size = 116147, upload-time = "2026-05-17T00:47:56.842Z" }, ] [[package]] @@ -632,7 +632,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ef/8f/43961a56021be9e211d359524582b10d3e618d1e821942fc19530addd0a8/cupy_cuda12x-14.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:b42da54c9da0d5a7748e4120f13c47594d3e1fc2741b712591aa915517741096", size = 144959483, upload-time = "2026-02-20T10:22:13.493Z" }, @@ -664,10 +664,10 @@ dependencies = [ { name = "huggingface-hub" }, { name = "multiprocess" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "packaging" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pandas", version = "3.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pyarrow" }, { name = "pyyaml" }, { name = "requests" }, @@ -689,7 +689,7 @@ dependencies = [ { name = "msgpack", marker = "sys_platform != 'win32'" }, { name = "ninja", marker = "sys_platform != 'win32'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'win32'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'win32'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'win32'" }, { name = "packaging", marker = "sys_platform != 'win32'" }, { name = "psutil", marker = "sys_platform != 'win32'" }, { name = "py-cpuinfo", marker = "sys_platform != 'win32'" }, @@ -722,7 +722,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "importlib-metadata" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, { name = "regex" }, { name = "requests" }, @@ -1025,7 +1025,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "1.14.0" +version = "1.15.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1038,9 +1038,9 @@ dependencies = [ { name = "typer" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/39/40/43109e943fd718b0ccd0cd61eb4f1c347df22bf81f5874c6f22adf44bcff/huggingface_hub-1.14.0.tar.gz", hash = "sha256:d6d2c9cd6be1d02ae9ec6672d5587d10a427f377db688e82528f426a041622c2", size = 782365, upload-time = "2026-05-06T14:14:34.278Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/b6/e22bd20a25299c34b8c5922c1545a6320825b13906eb0f7298edfd034a0b/huggingface_hub-1.15.0.tar.gz", hash = "sha256:28abfdddda3927fd4de6a63cf26ab012498a2c24dae52baf150c5c6edf98a1d5", size = 784100, upload-time = "2026-05-15T11:42:52.149Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/a5/33b49ba7bea7c41bb37f74ec0f8beea0831e052330196633fe2c77516ea6/huggingface_hub-1.14.0-py3-none-any.whl", hash = "sha256:efe075535c62e130b30e836b138e13785f6f043d1f0539e0a39aa411a99e90b8", size = 661479, upload-time = "2026-05-06T14:14:32.029Z" }, + { url = "https://files.pythonhosted.org/packages/6e/11/0b64cc9024329b76d7547c19a67604a61d21d3ba678a69d1b220c29d5112/huggingface_hub-1.15.0-py3-none-any.whl", hash = "sha256:a4a59af04cbc41a3fe3fec429b171ef994ef8c971eda10136746f408dd4e3744", size = 663602, upload-time = "2026-05-15T11:42:50.487Z" }, ] [[package]] @@ -1089,11 +1089,11 @@ wheels = [ [[package]] name = "idna" -version = "3.14" +version = "3.15" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/05/b1/efac073e0c297ecf2fb33c346989a529d4e19164f1759102dee5953ee17e/idna-3.14.tar.gz", hash = "sha256:466d810d7a2cc1022bea9b037c39728d51ae7dad40d480fc9b7d7ecf98ba8ee3", size = 198272, upload-time = "2026-05-10T20:32:15.935Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/77/7b3966d0b9d1d31a36ddf1746926a11dface89a83409bf1483f0237aa758/idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc", size = 199245, upload-time = "2026-05-12T22:45:57.011Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6c/3c/3f62dee257eb3d6b2c1ef2a09d36d9793c7111156a73b5654d2c2305e5ce/idna-3.14-py3-none-any.whl", hash = "sha256:e677eaf072e290f7b725f9acf0b3a2bd55f9fd6f7c70abe5f0e34823d0accf69", size = 72184, upload-time = "2026-05-10T20:32:14.295Z" }, + { url = "https://files.pythonhosted.org/packages/d2/23/408243171aa9aaba178d3e2559159c24c1171a641aa83b67bdd3394ead8e/idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8", size = 72340, upload-time = "2026-05-12T22:45:55.733Z" }, ] [[package]] @@ -1434,7 +1434,7 @@ version = "0.5.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } wheels = [ @@ -1909,7 +1909,7 @@ wheels = [ [[package]] name = "numpy" -version = "2.4.4" +version = "2.4.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'win32'", @@ -1941,79 +1941,79 @@ resolution-markers = [ "python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and sys_platform == 'win32'", "python_full_version == '3.11.*' and platform_machine == 's390x' and sys_platform == 'win32'", ] -sdist = { url = "https://files.pythonhosted.org/packages/d7/9f/b8cef5bffa569759033adda9481211426f12f53299629b410340795c2514/numpy-2.4.4.tar.gz", hash = "sha256:2d390634c5182175533585cc89f3608a4682ccb173cc9bb940b2881c8d6f8fa0", size = 20731587, upload-time = "2026-03-29T13:22:01.298Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/c6/4218570d8c8ecc9704b5157a3348e486e84ef4be0ed3e38218ab473c83d2/numpy-2.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f983334aea213c99992053ede6168500e5f086ce74fbc4acc3f2b00f5762e9db", size = 16976799, upload-time = "2026-03-29T13:18:15.438Z" }, - { url = "https://files.pythonhosted.org/packages/dd/92/b4d922c4a5f5dab9ed44e6153908a5c665b71acf183a83b93b690996e39b/numpy-2.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72944b19f2324114e9dc86a159787333b77874143efcf89a5167ef83cfee8af0", size = 14971552, upload-time = "2026-03-29T13:18:18.606Z" }, - { url = "https://files.pythonhosted.org/packages/8a/dc/df98c095978fa6ee7b9a9387d1d58cbb3d232d0e69ad169a4ce784bde4fd/numpy-2.4.4-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:86b6f55f5a352b48d7fbfd2dbc3d5b780b2d79f4d3c121f33eb6efb22e9a2015", size = 5476566, upload-time = "2026-03-29T13:18:21.532Z" }, - { url = "https://files.pythonhosted.org/packages/28/34/b3fdcec6e725409223dd27356bdf5a3c2cc2282e428218ecc9cb7acc9763/numpy-2.4.4-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:ba1f4fc670ed79f876f70082eff4f9583c15fb9a4b89d6188412de4d18ae2f40", size = 6806482, upload-time = "2026-03-29T13:18:23.634Z" }, - { url = "https://files.pythonhosted.org/packages/68/62/63417c13aa35d57bee1337c67446761dc25ea6543130cf868eace6e8157b/numpy-2.4.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8a87ec22c87be071b6bdbd27920b129b94f2fc964358ce38f3822635a3e2e03d", size = 15973376, upload-time = "2026-03-29T13:18:26.677Z" }, - { url = "https://files.pythonhosted.org/packages/cf/c5/9fcb7e0e69cef59cf10c746b84f7d58b08bc66a6b7d459783c5a4f6101a6/numpy-2.4.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:df3775294accfdd75f32c74ae39fcba920c9a378a2fc18a12b6820aa8c1fb502", size = 16925137, upload-time = "2026-03-29T13:18:30.14Z" }, - { url = "https://files.pythonhosted.org/packages/7e/43/80020edacb3f84b9efdd1591120a4296462c23fd8db0dde1666f6ef66f13/numpy-2.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0d4e437e295f18ec29bc79daf55e8a47a9113df44d66f702f02a293d93a2d6dd", size = 17329414, upload-time = "2026-03-29T13:18:33.733Z" }, - { url = "https://files.pythonhosted.org/packages/fd/06/af0658593b18a5f73532d377188b964f239eb0894e664a6c12f484472f97/numpy-2.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6aa3236c78803afbcb255045fbef97a9e25a1f6c9888357d205ddc42f4d6eba5", size = 18658397, upload-time = "2026-03-29T13:18:37.511Z" }, - { url = "https://files.pythonhosted.org/packages/e6/ce/13a09ed65f5d0ce5c7dd0669250374c6e379910f97af2c08c57b0608eee4/numpy-2.4.4-cp311-cp311-win32.whl", hash = "sha256:30caa73029a225b2d40d9fae193e008e24b2026b7ee1a867b7ee8d96ca1a448e", size = 6239499, upload-time = "2026-03-29T13:18:40.372Z" }, - { url = "https://files.pythonhosted.org/packages/bd/63/05d193dbb4b5eec1eca73822d80da98b511f8328ad4ae3ca4caf0f4db91d/numpy-2.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:6bbe4eb67390b0a0265a2c25458f6b90a409d5d069f1041e6aff1e27e3d9a79e", size = 12614257, upload-time = "2026-03-29T13:18:42.95Z" }, - { url = "https://files.pythonhosted.org/packages/87/c5/8168052f080c26fa984c413305012be54741c9d0d74abd7fbeeccae3889f/numpy-2.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:fcfe2045fd2e8f3cb0ce9d4ba6dba6333b8fa05bb8a4939c908cd43322d14c7e", size = 10486775, upload-time = "2026-03-29T13:18:45.835Z" }, - { url = "https://files.pythonhosted.org/packages/28/05/32396bec30fb2263770ee910142f49c1476d08e8ad41abf8403806b520ce/numpy-2.4.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15716cfef24d3a9762e3acdf87e27f58dc823d1348f765bbea6bef8c639bfa1b", size = 16689272, upload-time = "2026-03-29T13:18:49.223Z" }, - { url = "https://files.pythonhosted.org/packages/c5/f3/a983d28637bfcd763a9c7aafdb6d5c0ebf3d487d1e1459ffdb57e2f01117/numpy-2.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23cbfd4c17357c81021f21540da84ee282b9c8fba38a03b7b9d09ba6b951421e", size = 14699573, upload-time = "2026-03-29T13:18:52.629Z" }, - { url = "https://files.pythonhosted.org/packages/9b/fd/e5ecca1e78c05106d98028114f5c00d3eddb41207686b2b7de3e477b0e22/numpy-2.4.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8b3b60bb7cba2c8c81837661c488637eee696f59a877788a396d33150c35d842", size = 5204782, upload-time = "2026-03-29T13:18:55.579Z" }, - { url = "https://files.pythonhosted.org/packages/de/2f/702a4594413c1a8632092beae8aba00f1d67947389369b3777aed783fdca/numpy-2.4.4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:e4a010c27ff6f210ff4c6ef34394cd61470d01014439b192ec22552ee867f2a8", size = 6552038, upload-time = "2026-03-29T13:18:57.769Z" }, - { url = "https://files.pythonhosted.org/packages/7f/37/eed308a8f56cba4d1fdf467a4fc67ef4ff4bf1c888f5fc980481890104b1/numpy-2.4.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f9e75681b59ddaa5e659898085ae0eaea229d054f2ac0c7e563a62205a700121", size = 15670666, upload-time = "2026-03-29T13:19:00.341Z" }, - { url = "https://files.pythonhosted.org/packages/0a/0d/0e3ecece05b7a7e87ab9fb587855548da437a061326fff64a223b6dcb78a/numpy-2.4.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:81f4a14bee47aec54f883e0cad2d73986640c1590eb9bfaaba7ad17394481e6e", size = 16645480, upload-time = "2026-03-29T13:19:03.63Z" }, - { url = "https://files.pythonhosted.org/packages/34/49/f2312c154b82a286758ee2f1743336d50651f8b5195db18cdb63675ff649/numpy-2.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:62d6b0f03b694173f9fcb1fb317f7222fd0b0b103e784c6549f5e53a27718c44", size = 17020036, upload-time = "2026-03-29T13:19:07.428Z" }, - { url = "https://files.pythonhosted.org/packages/7b/e9/736d17bd77f1b0ec4f9901aaec129c00d59f5d84d5e79bba540ef12c2330/numpy-2.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fbc356aae7adf9e6336d336b9c8111d390a05df88f1805573ebb0807bd06fd1d", size = 18368643, upload-time = "2026-03-29T13:19:10.775Z" }, - { url = "https://files.pythonhosted.org/packages/63/f6/d417977c5f519b17c8a5c3bc9e8304b0908b0e21136fe43bf628a1343914/numpy-2.4.4-cp312-cp312-win32.whl", hash = "sha256:0d35aea54ad1d420c812bfa0385c71cd7cc5bcf7c65fed95fc2cd02fe8c79827", size = 5961117, upload-time = "2026-03-29T13:19:13.464Z" }, - { url = "https://files.pythonhosted.org/packages/2d/5b/e1deebf88ff431b01b7406ca3583ab2bbb90972bbe1c568732e49c844f7e/numpy-2.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:b5f0362dc928a6ecd9db58868fca5e48485205e3855957bdedea308f8672ea4a", size = 12320584, upload-time = "2026-03-29T13:19:16.155Z" }, - { url = "https://files.pythonhosted.org/packages/58/89/e4e856ac82a68c3ed64486a544977d0e7bdd18b8da75b78a577ca31c4395/numpy-2.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:846300f379b5b12cc769334464656bc882e0735d27d9726568bc932fdc49d5ec", size = 10221450, upload-time = "2026-03-29T13:19:18.994Z" }, - { url = "https://files.pythonhosted.org/packages/14/1d/d0a583ce4fefcc3308806a749a536c201ed6b5ad6e1322e227ee4848979d/numpy-2.4.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:08f2e31ed5e6f04b118e49821397f12767934cfdd12a1ce86a058f91e004ee50", size = 16684933, upload-time = "2026-03-29T13:19:22.47Z" }, - { url = "https://files.pythonhosted.org/packages/c1/62/2b7a48fbb745d344742c0277f01286dead15f3f68e4f359fbfcf7b48f70f/numpy-2.4.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e823b8b6edc81e747526f70f71a9c0a07ac4e7ad13020aa736bb7c9d67196115", size = 14694532, upload-time = "2026-03-29T13:19:25.581Z" }, - { url = "https://files.pythonhosted.org/packages/e5/87/499737bfba066b4a3bebff24a8f1c5b2dee410b209bc6668c9be692580f0/numpy-2.4.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4a19d9dba1a76618dd86b164d608566f393f8ec6ac7c44f0cc879011c45e65af", size = 5199661, upload-time = "2026-03-29T13:19:28.31Z" }, - { url = "https://files.pythonhosted.org/packages/cd/da/464d551604320d1491bc345efed99b4b7034143a85787aab78d5691d5a0e/numpy-2.4.4-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:d2a8490669bfe99a233298348acc2d824d496dee0e66e31b66a6022c2ad74a5c", size = 6547539, upload-time = "2026-03-29T13:19:30.97Z" }, - { url = "https://files.pythonhosted.org/packages/7d/90/8d23e3b0dafd024bf31bdec225b3bb5c2dbfa6912f8a53b8659f21216cbf/numpy-2.4.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:45dbed2ab436a9e826e302fcdcbe9133f9b0006e5af7168afb8963a6520da103", size = 15668806, upload-time = "2026-03-29T13:19:33.887Z" }, - { url = "https://files.pythonhosted.org/packages/d1/73/a9d864e42a01896bb5974475438f16086be9ba1f0d19d0bb7a07427c4a8b/numpy-2.4.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c901b15172510173f5cb310eae652908340f8dede90fff9e3bf6c0d8dfd92f83", size = 16632682, upload-time = "2026-03-29T13:19:37.336Z" }, - { url = "https://files.pythonhosted.org/packages/34/fb/14570d65c3bde4e202a031210475ae9cde9b7686a2e7dc97ee67d2833b35/numpy-2.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:99d838547ace2c4aace6c4f76e879ddfe02bb58a80c1549928477862b7a6d6ed", size = 17019810, upload-time = "2026-03-29T13:19:40.963Z" }, - { url = "https://files.pythonhosted.org/packages/8a/77/2ba9d87081fd41f6d640c83f26fb7351e536b7ce6dd9061b6af5904e8e46/numpy-2.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0aec54fd785890ecca25a6003fd9a5aed47ad607bbac5cd64f836ad8666f4959", size = 18357394, upload-time = "2026-03-29T13:19:44.859Z" }, - { url = "https://files.pythonhosted.org/packages/a2/23/52666c9a41708b0853fa3b1a12c90da38c507a3074883823126d4e9d5b30/numpy-2.4.4-cp313-cp313-win32.whl", hash = "sha256:07077278157d02f65c43b1b26a3886bce886f95d20aabd11f87932750dfb14ed", size = 5959556, upload-time = "2026-03-29T13:19:47.661Z" }, - { url = "https://files.pythonhosted.org/packages/57/fb/48649b4971cde70d817cf97a2a2fdc0b4d8308569f1dd2f2611959d2e0cf/numpy-2.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:5c70f1cc1c4efbe316a572e2d8b9b9cc44e89b95f79ca3331553fbb63716e2bf", size = 12317311, upload-time = "2026-03-29T13:19:50.67Z" }, - { url = "https://files.pythonhosted.org/packages/ba/d8/11490cddd564eb4de97b4579ef6bfe6a736cc07e94c1598590ae25415e01/numpy-2.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:ef4059d6e5152fa1a39f888e344c73fdc926e1b2dd58c771d67b0acfbf2aa67d", size = 10222060, upload-time = "2026-03-29T13:19:54.229Z" }, - { url = "https://files.pythonhosted.org/packages/99/5d/dab4339177a905aad3e2221c915b35202f1ec30d750dd2e5e9d9a72b804b/numpy-2.4.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4bbc7f303d125971f60ec0aaad5e12c62d0d2c925f0ab1273debd0e4ba37aba5", size = 14822302, upload-time = "2026-03-29T13:19:57.585Z" }, - { url = "https://files.pythonhosted.org/packages/eb/e4/0564a65e7d3d97562ed6f9b0fd0fb0a6f559ee444092f105938b50043876/numpy-2.4.4-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:4d6d57903571f86180eb98f8f0c839fa9ebbfb031356d87f1361be91e433f5b7", size = 5327407, upload-time = "2026-03-29T13:20:00.601Z" }, - { url = "https://files.pythonhosted.org/packages/29/8d/35a3a6ce5ad371afa58b4700f1c820f8f279948cca32524e0a695b0ded83/numpy-2.4.4-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:4636de7fd195197b7535f231b5de9e4b36d2c440b6e566d2e4e4746e6af0ca93", size = 6647631, upload-time = "2026-03-29T13:20:02.855Z" }, - { url = "https://files.pythonhosted.org/packages/f4/da/477731acbd5a58a946c736edfdabb2ac5b34c3d08d1ba1a7b437fa0884df/numpy-2.4.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ad2e2ef14e0b04e544ea2fa0a36463f847f113d314aa02e5b402fdf910ef309e", size = 15727691, upload-time = "2026-03-29T13:20:06.004Z" }, - { url = "https://files.pythonhosted.org/packages/e6/db/338535d9b152beabeb511579598418ba0212ce77cf9718edd70262cc4370/numpy-2.4.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a285b3b96f951841799528cd1f4f01cd70e7e0204b4abebac9463eecfcf2a40", size = 16681241, upload-time = "2026-03-29T13:20:09.417Z" }, - { url = "https://files.pythonhosted.org/packages/e2/a9/ad248e8f58beb7a0219b413c9c7d8151c5d285f7f946c3e26695bdbbe2df/numpy-2.4.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f8474c4241bc18b750be2abea9d7a9ec84f46ef861dbacf86a4f6e043401f79e", size = 17085767, upload-time = "2026-03-29T13:20:13.126Z" }, - { url = "https://files.pythonhosted.org/packages/b5/1a/3b88ccd3694681356f70da841630e4725a7264d6a885c8d442a697e1146b/numpy-2.4.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4e874c976154687c1f71715b034739b45c7711bec81db01914770373d125e392", size = 18403169, upload-time = "2026-03-29T13:20:17.096Z" }, - { url = "https://files.pythonhosted.org/packages/c2/c9/fcfd5d0639222c6eac7f304829b04892ef51c96a75d479214d77e3ce6e33/numpy-2.4.4-cp313-cp313t-win32.whl", hash = "sha256:9c585a1790d5436a5374bac930dad6ed244c046ed91b2b2a3634eb2971d21008", size = 6083477, upload-time = "2026-03-29T13:20:20.195Z" }, - { url = "https://files.pythonhosted.org/packages/d5/e3/3938a61d1c538aaec8ed6fd6323f57b0c2d2d2219512434c5c878db76553/numpy-2.4.4-cp313-cp313t-win_amd64.whl", hash = "sha256:93e15038125dc1e5345d9b5b68aa7f996ec33b98118d18c6ca0d0b7d6198b7e8", size = 12457487, upload-time = "2026-03-29T13:20:22.946Z" }, - { url = "https://files.pythonhosted.org/packages/97/6a/7e345032cc60501721ef94e0e30b60f6b0bd601f9174ebd36389a2b86d40/numpy-2.4.4-cp313-cp313t-win_arm64.whl", hash = "sha256:0dfd3f9d3adbe2920b68b5cd3d51444e13a10792ec7154cd0a2f6e74d4ab3233", size = 10292002, upload-time = "2026-03-29T13:20:25.909Z" }, - { url = "https://files.pythonhosted.org/packages/6e/06/c54062f85f673dd5c04cbe2f14c3acb8c8b95e3384869bb8cc9bff8cb9df/numpy-2.4.4-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:f169b9a863d34f5d11b8698ead99febeaa17a13ca044961aa8e2662a6c7766a0", size = 16684353, upload-time = "2026-03-29T13:20:29.504Z" }, - { url = "https://files.pythonhosted.org/packages/4c/39/8a320264a84404c74cc7e79715de85d6130fa07a0898f67fb5cd5bd79908/numpy-2.4.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2483e4584a1cb3092da4470b38866634bafb223cbcd551ee047633fd2584599a", size = 14704914, upload-time = "2026-03-29T13:20:33.547Z" }, - { url = "https://files.pythonhosted.org/packages/91/fb/287076b2614e1d1044235f50f03748f31fa287e3dbe6abeb35cdfa351eca/numpy-2.4.4-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2d19e6e2095506d1736b7d80595e0f252d76b89f5e715c35e06e937679ea7d7a", size = 5210005, upload-time = "2026-03-29T13:20:36.45Z" }, - { url = "https://files.pythonhosted.org/packages/63/eb/fcc338595309910de6ecabfcef2419a9ce24399680bfb149421fa2df1280/numpy-2.4.4-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:6a246d5914aa1c820c9443ddcee9c02bec3e203b0c080349533fae17727dfd1b", size = 6544974, upload-time = "2026-03-29T13:20:39.014Z" }, - { url = "https://files.pythonhosted.org/packages/44/5d/e7e9044032a716cdfaa3fba27a8e874bf1c5f1912a1ddd4ed071bf8a14a6/numpy-2.4.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:989824e9faf85f96ec9c7761cd8d29c531ad857bfa1daa930cba85baaecf1a9a", size = 15684591, upload-time = "2026-03-29T13:20:42.146Z" }, - { url = "https://files.pythonhosted.org/packages/98/7c/21252050676612625449b4807d6b695b9ce8a7c9e1c197ee6216c8a65c7c/numpy-2.4.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:27a8d92cd10f1382a67d7cf4db7ce18341b66438bdd9f691d7b0e48d104c2a9d", size = 16637700, upload-time = "2026-03-29T13:20:46.204Z" }, - { url = "https://files.pythonhosted.org/packages/b1/29/56d2bbef9465db24ef25393383d761a1af4f446a1df9b8cded4fe3a5a5d7/numpy-2.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:e44319a2953c738205bf3354537979eaa3998ed673395b964c1176083dd46252", size = 17035781, upload-time = "2026-03-29T13:20:50.242Z" }, - { url = "https://files.pythonhosted.org/packages/e3/2b/a35a6d7589d21f44cea7d0a98de5ddcbb3d421b2622a5c96b1edf18707c3/numpy-2.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e892aff75639bbef0d2a2cfd55535510df26ff92f63c92cd84ef8d4ba5a5557f", size = 18362959, upload-time = "2026-03-29T13:20:54.019Z" }, - { url = "https://files.pythonhosted.org/packages/64/c9/d52ec581f2390e0f5f85cbfd80fb83d965fc15e9f0e1aec2195faa142cde/numpy-2.4.4-cp314-cp314-win32.whl", hash = "sha256:1378871da56ca8943c2ba674530924bb8ca40cd228358a3b5f302ad60cf875fc", size = 6008768, upload-time = "2026-03-29T13:20:56.912Z" }, - { url = "https://files.pythonhosted.org/packages/fa/22/4cc31a62a6c7b74a8730e31a4274c5dc80e005751e277a2ce38e675e4923/numpy-2.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:715d1c092715954784bc79e1174fc2a90093dc4dc84ea15eb14dad8abdcdeb74", size = 12449181, upload-time = "2026-03-29T13:20:59.548Z" }, - { url = "https://files.pythonhosted.org/packages/70/2e/14cda6f4d8e396c612d1bf97f22958e92148801d7e4f110cabebdc0eef4b/numpy-2.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:2c194dd721e54ecad9ad387c1d35e63dce5c4450c6dc7dd5611283dda239aabb", size = 10496035, upload-time = "2026-03-29T13:21:02.524Z" }, - { url = "https://files.pythonhosted.org/packages/b1/e8/8fed8c8d848d7ecea092dc3469643f9d10bc3a134a815a3b033da1d2039b/numpy-2.4.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2aa0613a5177c264ff5921051a5719d20095ea586ca88cc802c5c218d1c67d3e", size = 14824958, upload-time = "2026-03-29T13:21:05.671Z" }, - { url = "https://files.pythonhosted.org/packages/05/1a/d8007a5138c179c2bf33ef44503e83d70434d2642877ee8fbb230e7c0548/numpy-2.4.4-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:42c16925aa5a02362f986765f9ebabf20de75cdefdca827d14315c568dcab113", size = 5330020, upload-time = "2026-03-29T13:21:08.635Z" }, - { url = "https://files.pythonhosted.org/packages/99/64/ffb99ac6ae93faf117bcbd5c7ba48a7f45364a33e8e458545d3633615dda/numpy-2.4.4-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:874f200b2a981c647340f841730fc3a2b54c9d940566a3c4149099591e2c4c3d", size = 6650758, upload-time = "2026-03-29T13:21:10.949Z" }, - { url = "https://files.pythonhosted.org/packages/6e/6e/795cc078b78a384052e73b2f6281ff7a700e9bf53bcce2ee579d4f6dd879/numpy-2.4.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c9b39d38a9bd2ae1becd7eac1303d031c5c110ad31f2b319c6e7d98b135c934d", size = 15729948, upload-time = "2026-03-29T13:21:14.047Z" }, - { url = "https://files.pythonhosted.org/packages/5f/86/2acbda8cc2af5f3d7bfc791192863b9e3e19674da7b5e533fded124d1299/numpy-2.4.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b268594bccac7d7cf5844c7732e3f20c50921d94e36d7ec9b79e9857694b1b2f", size = 16679325, upload-time = "2026-03-29T13:21:17.561Z" }, - { url = "https://files.pythonhosted.org/packages/bc/59/cafd83018f4aa55e0ac6fa92aa066c0a1877b77a615ceff1711c260ffae8/numpy-2.4.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ac6b31e35612a26483e20750126d30d0941f949426974cace8e6b5c58a3657b0", size = 17084883, upload-time = "2026-03-29T13:21:21.106Z" }, - { url = "https://files.pythonhosted.org/packages/f0/85/a42548db84e65ece46ab2caea3d3f78b416a47af387fcbb47ec28e660dc2/numpy-2.4.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8e3ed142f2728df44263aaf5fb1f5b0b99f4070c553a0d7f033be65338329150", size = 18403474, upload-time = "2026-03-29T13:21:24.828Z" }, - { url = "https://files.pythonhosted.org/packages/ed/ad/483d9e262f4b831000062e5d8a45e342166ec8aaa1195264982bca267e62/numpy-2.4.4-cp314-cp314t-win32.whl", hash = "sha256:dddbbd259598d7240b18c9d87c56a9d2fb3b02fe266f49a7c101532e78c1d871", size = 6155500, upload-time = "2026-03-29T13:21:28.205Z" }, - { url = "https://files.pythonhosted.org/packages/c7/03/2fc4e14c7bd4ff2964b74ba90ecb8552540b6315f201df70f137faa5c589/numpy-2.4.4-cp314-cp314t-win_amd64.whl", hash = "sha256:a7164afb23be6e37ad90b2f10426149fd75aee07ca55653d2aa41e66c4ef697e", size = 12637755, upload-time = "2026-03-29T13:21:31.107Z" }, - { url = "https://files.pythonhosted.org/packages/58/78/548fb8e07b1a341746bfbecb32f2c268470f45fa028aacdbd10d9bc73aab/numpy-2.4.4-cp314-cp314t-win_arm64.whl", hash = "sha256:ba203255017337d39f89bdd58417f03c4426f12beed0440cfd933cb15f8669c7", size = 10566643, upload-time = "2026-03-29T13:21:34.339Z" }, - { url = "https://files.pythonhosted.org/packages/6b/33/8fae8f964a4f63ed528264ddf25d2b683d0b663e3cba26961eb838a7c1bd/numpy-2.4.4-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:58c8b5929fcb8287cbd6f0a3fae19c6e03a5c48402ae792962ac465224a629a4", size = 16854491, upload-time = "2026-03-29T13:21:38.03Z" }, - { url = "https://files.pythonhosted.org/packages/bc/d0/1aabee441380b981cf8cdda3ae7a46aa827d1b5a8cce84d14598bc94d6d9/numpy-2.4.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:eea7ac5d2dce4189771cedb559c738a71512768210dc4e4753b107a2048b3d0e", size = 14895830, upload-time = "2026-03-29T13:21:41.509Z" }, - { url = "https://files.pythonhosted.org/packages/a5/b8/aafb0d1065416894fccf4df6b49ef22b8db045187949545bced89c034b8e/numpy-2.4.4-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:51fc224f7ca4d92656d5a5eb315f12eb5fe2c97a66249aa7b5f562528a3be38c", size = 5400927, upload-time = "2026-03-29T13:21:44.747Z" }, - { url = "https://files.pythonhosted.org/packages/d6/77/063baa20b08b431038c7f9ff5435540c7b7265c78cf56012a483019ca72d/numpy-2.4.4-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:28a650663f7314afc3e6ec620f44f333c386aad9f6fc472030865dc0ebb26ee3", size = 6715557, upload-time = "2026-03-29T13:21:47.406Z" }, - { url = "https://files.pythonhosted.org/packages/c7/a8/379542d45a14f149444c5c4c4e7714707239ce9cc1de8c2803958889da14/numpy-2.4.4-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:19710a9ca9992d7174e9c52f643d4272dcd1558c5f7af7f6f8190f633bd651a7", size = 15804253, upload-time = "2026-03-29T13:21:50.753Z" }, - { url = "https://files.pythonhosted.org/packages/a2/c8/f0a45426d6d21e7ea3310a15cf90c43a14d9232c31a837702dba437f3373/numpy-2.4.4-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9b2aec6af35c113b05695ebb5749a787acd63cafc83086a05771d1e1cd1e555f", size = 16753552, upload-time = "2026-03-29T13:21:54.344Z" }, - { url = "https://files.pythonhosted.org/packages/04/74/f4c001f4714c3ad9ce037e18cf2b9c64871a84951eaa0baf683a9ca9301c/numpy-2.4.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f2cf083b324a467e1ab358c105f6cad5ea950f50524668a80c486ff1db24e119", size = 12509075, upload-time = "2026-03-29T13:21:57.644Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/50/8e/b8041bc719f056afd864478029d52214789341ac6583437b0ee5031e9530/numpy-2.4.5.tar.gz", hash = "sha256:ca670567a5683b7c1670ec03e0ddd5862e10934e92a70751d68d7b7b74ca7f9f", size = 20735669, upload-time = "2026-05-15T20:25:19.492Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/44/1383ee4d1e916a9e610e46c876b5c83ea023526117d23cd911983929ec34/numpy-2.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3176dc8ff71dbb593606f91a69ad0c3cd3303c7eb546af477370ab9edf760288", size = 16969261, upload-time = "2026-05-15T20:22:23.036Z" }, + { url = "https://files.pythonhosted.org/packages/3d/61/54bacfbec7550bc398e6b6d9a861db35d64f75844e1d7920f5722c3cd5e7/numpy-2.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1811150e5148f5a01a7cc282cb2f489b4a3050a773e173adb480e507bad3a3d7", size = 14964009, upload-time = "2026-05-15T20:22:25.819Z" }, + { url = "https://files.pythonhosted.org/packages/7a/55/fe86c64561761f185339c26001164a2687bd4787af681e961431abd2d534/numpy-2.4.5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0d63a780070871210853ba01e90b88f9b85cf2abf63a7f143d5127189265ddf6", size = 5469106, upload-time = "2026-05-15T20:22:28.13Z" }, + { url = "https://files.pythonhosted.org/packages/2f/74/cf29b8317627f0e3aa2c9fb332d386bd734308cecd9e07da9f407d9ce0c3/numpy-2.4.5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:0c6919cefafb3b76cd46a89dbb203bf1dd95529d2a6d09fef2d325d95d6a79d8", size = 6798945, upload-time = "2026-05-15T20:22:30.061Z" }, + { url = "https://files.pythonhosted.org/packages/80/a9/b61730a17fa87d5abb13ce560a1b4ce3485d37a13e03eb7b414e598e72f8/numpy-2.4.5-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d51efede1e58e8b11877536a5518f60e318d8ff69b89ad7b38ee5e431b24d772", size = 15967025, upload-time = "2026-05-15T20:22:32.328Z" }, + { url = "https://files.pythonhosted.org/packages/03/39/70bcd187eb4d223c21fde02c2bdfbffbffef3288cbb3947c04c74ae39a08/numpy-2.4.5-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:07ce7e74da92d7c71b5df157b9758bcdd53d7fea10602154de3afd2b3ddc34dd", size = 16918685, upload-time = "2026-05-15T20:22:34.759Z" }, + { url = "https://files.pythonhosted.org/packages/ab/31/400fd1315bbe228af3937cf8a74e32023df6217af36077919d00adc382e4/numpy-2.4.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d7828234a13185effb34979e146f9921f2a65dfbbe215e6dbb57d6478fc8e059", size = 17322963, upload-time = "2026-05-15T20:22:37.557Z" }, + { url = "https://files.pythonhosted.org/packages/18/6a/bbbafb657e6f6ee826b4ecdb8722a2e0aae4a981888eaf59eae6a535cc13/numpy-2.4.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f96083adc3dfc1bbf778f2c79654d88115fa07074c97cb724fe9508f12d91c55", size = 18651594, upload-time = "2026-05-15T20:22:40.449Z" }, + { url = "https://files.pythonhosted.org/packages/de/0c/857a515154a2a18b0dfae04089600d166d352d473ec17a0680d879582d06/numpy-2.4.5-cp311-cp311-win32.whl", hash = "sha256:4ed78c904a638b6e5d7cd4db90c06fca5fc6ec2f28d258305368f454a50e79cf", size = 6233849, upload-time = "2026-05-15T20:22:43.139Z" }, + { url = "https://files.pythonhosted.org/packages/f0/66/d215f3fb93541617adb5d58b3b9508e8a6413e499711e0adc0b80bcb445d/numpy-2.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:079b0fad6f2899b23c5da89792b5409d2d83fc83e8bd5c2299cc9c397a264864", size = 12608238, upload-time = "2026-05-15T20:22:45.229Z" }, + { url = "https://files.pythonhosted.org/packages/cb/c4/611d66d3fcfa931954d37a19ce5575f3283d023e89ff0df6ad43b334ae9c/numpy-2.4.5-cp311-cp311-win_arm64.whl", hash = "sha256:d6c78e260b53affe9b395a9d54fc61f101f9521c4d9452c7e9e3718b19e2215b", size = 10479452, upload-time = "2026-05-15T20:22:47.962Z" }, + { url = "https://files.pythonhosted.org/packages/6c/18/3275231e98620002681c922e792db04d72c356e9d8073c387344fc0e4ff1/numpy-2.4.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:654fb8674b61b1c4bd568f944d13a908566fdcb0d797303521d4149d16da05ef", size = 16689166, upload-time = "2026-05-15T20:22:50.761Z" }, + { url = "https://files.pythonhosted.org/packages/db/23/000aab6a16bdec53307f0f72546b57a3ac9266a62d8c257bee97d85fd078/numpy-2.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4cd9f6fa7ce10dc4627f2bb81dd9075dab67e94632e04c2b638e12575ddaa862", size = 14699514, upload-time = "2026-05-15T20:22:53.678Z" }, + { url = "https://files.pythonhosted.org/packages/47/cc/ddaf3af9c46966fef5be879256f213d85a0c56c75d07a3b7defec7cf6b4c/numpy-2.4.5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:4f5bc96d35d94e4ceab8b38a92241b4611e95dc44e63b9f1fa2a331858ee3507", size = 5204601, upload-time = "2026-05-15T20:22:56.257Z" }, + { url = "https://files.pythonhosted.org/packages/07/ea/627fadd11959b3c7759008f34c92a35af8ff942dd8284a66ced648bbe516/numpy-2.4.5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:4bb33e900ee81730ad77a258965134aa8ceac805124f7e5229347beda4b8d0aa", size = 6551360, upload-time = "2026-05-15T20:22:58.334Z" }, + { url = "https://files.pythonhosted.org/packages/a1/47/0728b986b8682d742ff68c16baa5af9d185484abfc635c5cc700f44e62be/numpy-2.4.5-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:32f8f852273ef32b291201ac2a2c97629c4a1ee8632bb670e3443eaa09fc2e72", size = 15671157, upload-time = "2026-05-15T20:23:01.081Z" }, + { url = "https://files.pythonhosted.org/packages/d1/0b/b905ae82d9419dc38123523862db64978ca2954b69609c3ae8fdaca1084c/numpy-2.4.5-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:685681e956fc8dcb75adc6ff26694e1dfd738b24bd8d4696c51ca0110157f912", size = 16645703, upload-time = "2026-05-15T20:23:04.358Z" }, + { url = "https://files.pythonhosted.org/packages/5f/24/e27fc3f5236b4118ed9eed67111675f5c61a07ea333acec87c869c3b359d/numpy-2.4.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6f64dd84b277a737eb59513f6b9bb6195bf41ab11941ef15b2562dbab43fa8ef", size = 17021018, upload-time = "2026-05-15T20:23:07.021Z" }, + { url = "https://files.pythonhosted.org/packages/d3/a7/9041af38d527ab80a06a93570a77e29425b41507ad41f6acf5da78cfb4a4/numpy-2.4.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b42d9496f79e3a728192f05a42d86e36163217b7cdecb3813d0028a0aa6b72d7", size = 18368768, upload-time = "2026-05-15T20:23:09.44Z" }, + { url = "https://files.pythonhosted.org/packages/49/82/326a014442f32c2663434fd424d9298791f47f8a0f17585ad60519a5606e/numpy-2.4.5-cp312-cp312-win32.whl", hash = "sha256:86d980970f5110595ca14855768073b08585fc1acc36895de303e039e7dee4a5", size = 5962819, upload-time = "2026-05-15T20:23:11.631Z" }, + { url = "https://files.pythonhosted.org/packages/3c/f0/cbf5d391b0b3a5e8cad264603e2fae256b0bde8ce43566b13b78faedc659/numpy-2.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:3333dba6a4e611d666f69e177ba8fe4140366ff681a5feb2374d3fd4fff3acb6", size = 12321621, upload-time = "2026-05-15T20:23:14.305Z" }, + { url = "https://files.pythonhosted.org/packages/3c/d0/0f18909d9bc37a5f3f969fc737d2bb5df9f2ff295f71b467e6f52a0d6c4e/numpy-2.4.5-cp312-cp312-win_arm64.whl", hash = "sha256:4593d197270b894efeb538dcbe227e4bcf1c77f88c4c6bf933ead812cfaa4453", size = 10221430, upload-time = "2026-05-15T20:23:16.887Z" }, + { url = "https://files.pythonhosted.org/packages/e3/a4/fb50657c7cab297bf34edcd60a074cb0647f61771430d6363575274160fe/numpy-2.4.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1ef248460b645c102026b82337cc4e88231909c66dd77b59ec6d6cac7e44f277", size = 16684760, upload-time = "2026-05-15T20:23:19.436Z" }, + { url = "https://files.pythonhosted.org/packages/3e/43/87e731299b9408eda705b3b9cb31c7bceb9347d2af9cbb16b2b1e4b5bc0f/numpy-2.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4603622bdcdbf8dccb1d9d5b21d16a7aa4e473ae6c8e14048d846fd4ca2907a0", size = 14694117, upload-time = "2026-05-15T20:23:21.832Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c7/0b2bb8acea222e9dd6e582afc2bc553b89b8833cbdccc68e68f050fb31f8/numpy-2.4.5-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6c18d49c67689c562854b53fdc433b93e47c12952aa6fa6d59f185e1a5992419", size = 5199141, upload-time = "2026-05-15T20:23:24.066Z" }, + { url = "https://files.pythonhosted.org/packages/39/60/b6972b5d47033d90000f0097c81a98b9486589a2d7003bf725bff275cb0d/numpy-2.4.5-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b1c663ddc641f4192e90511bec61a09bc231e3bbdb996cdc6edbcaa0e528d685", size = 6546954, upload-time = "2026-05-15T20:23:26.099Z" }, + { url = "https://files.pythonhosted.org/packages/c1/e9/ed667cb12c11ca0adde431f685d3a5dd78e6f78b27228c581c8415198e9e/numpy-2.4.5-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:93793222b524f692f12b2f8752ce8b1d9d9125b2bfd5dbf0fb69c92c5e1ce86c", size = 15669430, upload-time = "2026-05-15T20:23:28.147Z" }, + { url = "https://files.pythonhosted.org/packages/44/e5/679f6ffeb01294b0008e5ada4a113cb47617bc0e1819a529fd7973c6d7f4/numpy-2.4.5-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1616bde34b2bcba2fa9bde06217ce00da4f3d1bdfb264d54525a99e8fe170d83", size = 16633390, upload-time = "2026-05-15T20:23:31.622Z" }, + { url = "https://files.pythonhosted.org/packages/36/46/42bfffc9a780ec902ccd7470d3219192ee82b7b442710307dd85b4d121b0/numpy-2.4.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:09d7d97da1c2c62f4818b3e150a57572ff8dcf1cf5ac501aac832ffd4ebd9566", size = 17020709, upload-time = "2026-05-15T20:23:34.08Z" }, + { url = "https://files.pythonhosted.org/packages/44/00/3e840bfee0cc6cec22209f2c97057f26eeb30de031e4933b4dfc0395416c/numpy-2.4.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d68d0b355ab2e39fe0de59001d7151dfdbbb880ef67baeed806661e03df5097", size = 18357818, upload-time = "2026-05-15T20:23:36.965Z" }, + { url = "https://files.pythonhosted.org/packages/72/cb/3447b400b9da84134575486f0f656541559b00d4b262477bce9b678bbca8/numpy-2.4.5-cp313-cp313-win32.whl", hash = "sha256:fe28b64777ddfa0eca9b5f51474034ebe3dcb8324f48f27b28f479085673ae33", size = 5961114, upload-time = "2026-05-15T20:23:39.586Z" }, + { url = "https://files.pythonhosted.org/packages/28/f9/a90d2220ffcdc0798f5d55bb5d5463cd6254ec9ef43f384dae80217d7a2f/numpy-2.4.5-cp313-cp313-win_amd64.whl", hash = "sha256:fb4a6c9c537d6ccec9cc4aeae4261bd3cc79b070c67ddc0646f5b1c07fddde42", size = 12318553, upload-time = "2026-05-15T20:23:41.436Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c9/96f531fb3234545315152d34efdf3de7daee81254448447eb619e8d16967/numpy-2.4.5-cp313-cp313-win_arm64.whl", hash = "sha256:6d7df2da2e7ea0624a43aa368104b3a3ce14aae98ad4bb2c9a93fecef76f1c97", size = 10222200, upload-time = "2026-05-15T20:23:43.681Z" }, + { url = "https://files.pythonhosted.org/packages/e1/f4/a291caab5a3c520babf93ff77c54fd5fdb1ebbc3296cee2eb2146ce773b1/numpy-2.4.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2a235607a18df941760a695927051af4b1cd5d3ee85840d0e2af816785771feb", size = 14821438, upload-time = "2026-05-15T20:23:45.911Z" }, + { url = "https://files.pythonhosted.org/packages/85/26/13dbb1159b864370568e7309063fd72667984df89db74e9caeb175d067c7/numpy-2.4.5-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:58dcf64969d870f36bc7fbd557d2617e997db7dc06261b6e3327148ea460d0a4", size = 5326663, upload-time = "2026-05-15T20:23:48.18Z" }, + { url = "https://files.pythonhosted.org/packages/7c/99/d233408072a0e019e2288e27edd23f7d572ccd4a73d1539baa3270ede85d/numpy-2.4.5-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:235f54b0156274d8fa3155db3ed6d2f401c7e8f3367c90db0a12f02a58fde6ed", size = 6646874, upload-time = "2026-05-15T20:23:49.856Z" }, + { url = "https://files.pythonhosted.org/packages/c5/00/eeb6f193dfe767725e952e0464f3e51f44145c5dd261cd7389aa36ac0713/numpy-2.4.5-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef3b5bb65437a3555c648e706475db01c645559ca80dc8b03e4f202ea757e0d6", size = 15728147, upload-time = "2026-05-15T20:23:51.655Z" }, + { url = "https://files.pythonhosted.org/packages/e5/c9/b8ed039f1fde1b13a8807c893e7e2f9432a379f4d6401edecf0028da5b2c/numpy-2.4.5-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7f09a7e5f017d7098c66522097c96257411c9620c0926212200d66bc8cee3976", size = 16681770, upload-time = "2026-05-15T20:23:53.933Z" }, + { url = "https://files.pythonhosted.org/packages/11/5b/0198ef6cb7016eca6d895d392106012138127fab23f46637e76d5e25c9f5/numpy-2.4.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:993a88d8fdd8554466a8765cd8bacd97ba56b70ca6b0a04bcdca77f5afed4222", size = 17086218, upload-time = "2026-05-15T20:23:56.646Z" }, + { url = "https://files.pythonhosted.org/packages/f0/fe/8821f3cfc660ae84c92ee158505941874b62c56a42e035a41425228cd8cf/numpy-2.4.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:84f58bed609b5669f5ad3d597901a4f1f86ee5b3c3708aaa55f05b4fe6e0f656", size = 18403542, upload-time = "2026-05-15T20:23:59.173Z" }, + { url = "https://files.pythonhosted.org/packages/0e/00/e64ecaf498865e7b091f57658b2c522503e5d1b70e43b807f5f8247e1d88/numpy-2.4.5-cp313-cp313t-win32.whl", hash = "sha256:7200c58f3f933ca61e66346667dcc8510bb111995e9ce15398a731e6a4afa4bb", size = 6084903, upload-time = "2026-05-15T20:24:01.506Z" }, + { url = "https://files.pythonhosted.org/packages/20/c0/354997dedaf74e8311c2cf9a6027b476fd8d424cb92189cc0ae2b25f501c/numpy-2.4.5-cp313-cp313t-win_amd64.whl", hash = "sha256:c26c71080d35db5002102f5d9ff614d45de02aa1f7802943e691e063e5ee93bc", size = 12458420, upload-time = "2026-05-15T20:24:03.735Z" }, + { url = "https://files.pythonhosted.org/packages/66/dc/917ee5ea4a31ca1a6e4c9a85386477efa318dcc60db257c5ef4adda096c1/numpy-2.4.5-cp313-cp313t-win_arm64.whl", hash = "sha256:2caa576d1707b275cba1aeb60a5c50daa6fa2a3f28ecb08123bc05fd439005db", size = 10291826, upload-time = "2026-05-15T20:24:06.535Z" }, + { url = "https://files.pythonhosted.org/packages/ca/c1/3be0bf102fc17cff5bd142e3be0bfffabec6fa46da0a462396c76b0765d0/numpy-2.4.5-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:889ca2c072315de638a5194a772aa1fa2df92bdd6175f6a222d4784040424b61", size = 16683455, upload-time = "2026-05-15T20:24:08.988Z" }, + { url = "https://files.pythonhosted.org/packages/e8/3e/0742d724901fa36bc54b338c6e62e463a7601180da896aa44978f0adf004/numpy-2.4.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:89e89304fb1f8c3f0ecfa4a7d48f311dd79771336a940e920159d643d1307e77", size = 14704577, upload-time = "2026-05-15T20:24:11.542Z" }, + { url = "https://files.pythonhosted.org/packages/25/1c/196c610ff4c6782d697ba780ebdc1616be143213701bf22c1a270f3bf7dd/numpy-2.4.5-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:144fcc5a3a17679b2b82543b4a2d8dd29937230a7af13232b5f753872feb6361", size = 5209756, upload-time = "2026-05-15T20:24:14.091Z" }, + { url = "https://files.pythonhosted.org/packages/52/c0/23fb1bc506f774e03db66219a2830e720f4d3dbcaaddf855a7ff7bb6d96f/numpy-2.4.5-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:398bb16772b265b9fa5c07b07072646ea97137c10ffb62a9a087b277fc825c29", size = 6543937, upload-time = "2026-05-15T20:24:16.223Z" }, + { url = "https://files.pythonhosted.org/packages/9f/49/db4662c26e68520afcc84d672a6f9f5294063dee0e57a46d61afdaa7f9ed/numpy-2.4.5-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb352e7b8876da1249e72254736d6c58c505fa4e58a3d7e30efca241ca9ca9ce", size = 15685292, upload-time = "2026-05-15T20:24:17.978Z" }, + { url = "https://files.pythonhosted.org/packages/43/80/1315439acedd8398319bac177d6de3d48ab39c62cc0c810f74f0a9a73996/numpy-2.4.5-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7341b08ff8124d7353939778e2707b8732d03c78c1c30e0815aba2dacbe1245a", size = 16638528, upload-time = "2026-05-15T20:24:20.478Z" }, + { url = "https://files.pythonhosted.org/packages/56/81/364388600932618fe735d97fdd2437cb8dd87a23377ac11d8b9d5db098b7/numpy-2.4.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:deb01226f012539f3945261ffe1c10aec081a0fa0a5c925419933c70f3ae2d23", size = 17036709, upload-time = "2026-05-15T20:24:22.949Z" }, + { url = "https://files.pythonhosted.org/packages/32/4a/a1185b18a94a6d9587e54b437e7d0ba36ecf6e614f1bea03f5249912c64e/numpy-2.4.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d888bdf7335f76878c3c7b264ac1ff089863e211ec81249f9fb5795c2183dc25", size = 18363254, upload-time = "2026-05-15T20:24:25.402Z" }, + { url = "https://files.pythonhosted.org/packages/b9/8e/95c1d2ed15ae97750ede8c8a0ac487c9c01207afff430f47078b1d9d7dc5/numpy-2.4.5-cp314-cp314-win32.whl", hash = "sha256:15f90d1256e9b2320aff24fde44815b787ab6d7c49a1a11bfd8138b321c5f080", size = 6010184, upload-time = "2026-05-15T20:24:27.852Z" }, + { url = "https://files.pythonhosted.org/packages/aa/92/d063df4d63d988b20d881856c74df76c0c1786229bb870f3a52af0981d4d/numpy-2.4.5-cp314-cp314-win_amd64.whl", hash = "sha256:4bd2cd4ef9c0afa87de73723c0a33c0edff62143e1432917458e26d3d195d87f", size = 12450344, upload-time = "2026-05-15T20:24:29.856Z" }, + { url = "https://files.pythonhosted.org/packages/3d/64/c0ae481f7c3b2f85869bcd8fc5d30aa7c96b394162eef9c9315957f115c5/numpy-2.4.5-cp314-cp314-win_arm64.whl", hash = "sha256:db304568c650e9d7039744d3575d0d287754debb2057d7c7b8cdfdc2c487a957", size = 10495674, upload-time = "2026-05-15T20:24:32.352Z" }, + { url = "https://files.pythonhosted.org/packages/57/89/c5a4c677acf17aa50ba09a15e61812f90baac42bb6ca38d112e005858351/numpy-2.4.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6de2883e0d2c63eae1bab1a84b390dca74aabb3d20ea1f5d58f360853c83abf3", size = 14824078, upload-time = "2026-05-15T20:24:34.669Z" }, + { url = "https://files.pythonhosted.org/packages/e7/52/57e7144284f6b51ba93523e495ff239260b1ecd5257e3700a436332e5688/numpy-2.4.5-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:06760fe73ae5005008748d182de612c733542af3cde063d532cd2127561b27be", size = 5329246, upload-time = "2026-05-15T20:24:36.957Z" }, + { url = "https://files.pythonhosted.org/packages/b9/b3/09dbce80fd4a7db4318f2fc01eec0ae76f29306442b5a32d4b811d082cdf/numpy-2.4.5-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:4b51a01745cb04cc19278482207444b4d30728ce91c28d27a3bfae5fc6ff24c7", size = 6649877, upload-time = "2026-05-15T20:24:38.861Z" }, + { url = "https://files.pythonhosted.org/packages/30/c2/dbdb23e82d540b757690ef13f011c386fca6a63848eec6136baf8ce7cbed/numpy-2.4.5-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9a05636d7937d0936f271e5ba957fa8d746b5be3c2025caa1a2508f4fe521d40", size = 15730534, upload-time = "2026-05-15T20:24:41.168Z" }, + { url = "https://files.pythonhosted.org/packages/c4/bd/68f6e9b3c20decf40ac06708a7b506757e3a8588efed32988d1b747316be/numpy-2.4.5-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14b86f56048ed09c3bbe48962a7dff077c2fd3274f8cf981800f3b38eac49cc3", size = 16679741, upload-time = "2026-05-15T20:24:44.874Z" }, + { url = "https://files.pythonhosted.org/packages/39/1d/0fcac0b6b4ea1b50ca8fca05a34bed5c8d56e34c1cb5ffb04cf76109ac3c/numpy-2.4.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:130d58151c4db23e9fa860b84784e219a3aa3e030acc88a493ea37006c4dfd4c", size = 17085598, upload-time = "2026-05-15T20:24:47.603Z" }, + { url = "https://files.pythonhosted.org/packages/0b/e8/a472b2564cf6cc498ad7aa9741d9832648221b8ab8cc0dbef41faa248ede/numpy-2.4.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d475afc8cbe935ff5944f753d863bba774d7f4e1feaaa4102901e3e053ca5963", size = 18403855, upload-time = "2026-05-15T20:24:50.474Z" }, + { url = "https://files.pythonhosted.org/packages/b9/a4/da82196f8cc4bd28ecf17bd57008c84f3d4696caf06753d9bad45e4ad749/numpy-2.4.5-cp314-cp314t-win32.whl", hash = "sha256:27f4a6dc26353a860b348961b9aa9e009835688b435cfa105e873b8dc2c726f5", size = 6156900, upload-time = "2026-05-15T20:24:53.134Z" }, + { url = "https://files.pythonhosted.org/packages/98/31/860959b91a73d9a085006554fa3850da51a7ffab64599bac5097243438ab/numpy-2.4.5-cp314-cp314t-win_amd64.whl", hash = "sha256:76ac6e90f5e226011c88f9b7040a4bcae612518bc7e9adc127e697a13b28ad1a", size = 12638906, upload-time = "2026-05-15T20:24:55.009Z" }, + { url = "https://files.pythonhosted.org/packages/9e/2a/bbd3097913083ad07c0f28fc9629666221fc18923e17ce97ae22a5dccdd6/numpy-2.4.5-cp314-cp314t-win_arm64.whl", hash = "sha256:7c392e2c1bf596701d3c6832be7567eab5d5b0a13865036c33365ee097d37f8b", size = 10565875, upload-time = "2026-05-15T20:24:57.425Z" }, + { url = "https://files.pythonhosted.org/packages/fc/5d/9a644cfb841bc76b584afc3af1708b3bf6c5cb51fc84a7008246cd93b7b7/numpy-2.4.5-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:6bf0bfc1c2e1db972e30b6cd3d4861f477f3af908b27799b239dc3cbe3eb4b95", size = 16847544, upload-time = "2026-05-15T20:24:59.746Z" }, + { url = "https://files.pythonhosted.org/packages/56/8f/4fe5e3ba76d858dae1fe79078818c0520447335be0082c0dedf82719cc08/numpy-2.4.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:73d664413fb97229149c4711ef56531a6fe8c15c1c2626b0bbe497b84c287e70", size = 14889039, upload-time = "2026-05-15T20:25:03.179Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6f/79f195abf922ecc43e7d0eb6cc969462a71b524a35bcd1fa26b4a1d7406a/numpy-2.4.5-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:b35bee5ef99e8d227a07829bee2e864fcb65f7c157646fcd8ec8b4b45dd8b88f", size = 5394106, upload-time = "2026-05-15T20:25:05.659Z" }, + { url = "https://files.pythonhosted.org/packages/58/6f/79cd6247205802bcbd10b40ea087e20ded526e10e9be224d34de832b216e/numpy-2.4.5-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:02981d0fc9f9ce147643d552966d47f329a02f7ecb3b113e84207242f20dfa83", size = 6708718, upload-time = "2026-05-15T20:25:08.071Z" }, + { url = "https://files.pythonhosted.org/packages/d7/22/5f378a9d4633c98f28c4709d4144b1a4630c5c09e109d2e781e2d26c8fe1/numpy-2.4.5-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0e63caf31a1df06338ae63d999f7a33a675ced62eea9c9b02db4b1c1f45cff38", size = 15798292, upload-time = "2026-05-15T20:25:10.689Z" }, + { url = "https://files.pythonhosted.org/packages/63/1c/cec582febef798c99888892d92dc1d28dfe29cb427c41f44d13d0dec208f/numpy-2.4.5-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8fc52b85a7b45e474be53eddf08e006d22e381a4e41bcde8e4aa08da0e7d198", size = 16747406, upload-time = "2026-05-15T20:25:13.879Z" }, + { url = "https://files.pythonhosted.org/packages/b1/dc/d358a16a6fec86cf736b8fbe67386044b3fa2aded1a80cff90e836799301/numpy-2.4.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:40c71d50a4da1a7c317af419461052d3911a5770bfc5fd55baf52cc45e7a2c20", size = 12504085, upload-time = "2026-05-15T20:25:16.667Z" }, ] [[package]] @@ -2031,7 +2031,7 @@ source = { editable = "." } dependencies = [ { name = "ninja" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "nvidia-ml-py" }, { name = "omegaconf" }, { name = "packaging" }, @@ -2074,7 +2074,7 @@ all = [ { name = "onnxscript" }, { name = "onnxslim" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pandas", version = "3.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "peft" }, { name = "polygraphy" }, { name = "sentencepiece" }, @@ -2113,7 +2113,7 @@ dev = [ { name = "onnxscript" }, { name = "onnxslim" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pandas", version = "3.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "peft" }, { name = "polygraphy" }, { name = "pre-commit" }, @@ -2206,7 +2206,7 @@ puzzletron = [ { name = "immutabledict" }, { name = "lru-dict" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "pandas", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pandas", version = "3.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "typeguard" }, ] @@ -2307,7 +2307,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "protobuf" }, { name = "typing-extensions" }, ] @@ -2349,7 +2349,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "onnx" }, ] wheels = [ @@ -2363,7 +2363,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "onnx" }, { name = "sympy" }, { name = "typing-extensions" }, @@ -2379,7 +2379,7 @@ version = "1.16.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "onnx" }, { name = "packaging" }, { name = "protobuf" }, @@ -2437,7 +2437,7 @@ resolution-markers = [ ] dependencies = [ { name = "flatbuffers", marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, { name = "packaging", marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, { name = "protobuf", marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, { name = "sympy", marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, @@ -2486,7 +2486,7 @@ dependencies = [ { name = "coloredlogs", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, { name = "flatbuffers", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin') or (python_full_version < '3.11' and sys_platform == 'win32')" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'win32'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'win32'" }, { name = "packaging", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, { name = "protobuf", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, { name = "sympy", marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, @@ -2519,7 +2519,7 @@ resolution-markers = [ ] dependencies = [ { name = "flatbuffers", marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "packaging", marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "protobuf", marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, { name = "sympy", marker = "python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, @@ -2540,7 +2540,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "onnx" }, { name = "onnx-ir" }, { name = "packaging" }, @@ -2648,7 +2648,7 @@ wheels = [ [[package]] name = "pandas" -version = "3.0.2" +version = "3.0.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'win32'", @@ -2681,59 +2681,59 @@ resolution-markers = [ "python_full_version == '3.11.*' and platform_machine == 's390x' and sys_platform == 'win32'", ] dependencies = [ - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "python-dateutil", marker = "python_full_version >= '3.11'" }, { name = "tzdata", marker = "(python_full_version >= '3.11' and sys_platform == 'emscripten') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/99/b342345300f13440fe9fe385c3c481e2d9a595ee3bab4d3219247ac94e9a/pandas-3.0.2.tar.gz", hash = "sha256:f4753e73e34c8d83221ba58f232433fca2748be8b18dbca02d242ed153945043", size = 4645855, upload-time = "2026-03-31T06:48:30.816Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/97/35/6411db530c618e0e0005187e35aa02ce60ae4c4c4d206964a2f978217c27/pandas-3.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a727a73cbdba2f7458dc82449e2315899d5140b449015d822f515749a46cbbe0", size = 10326926, upload-time = "2026-03-31T06:46:08.29Z" }, - { url = "https://files.pythonhosted.org/packages/c4/d3/b7da1d5d7dbdc5ef52ed7debd2b484313b832982266905315dad5a0bf0b1/pandas-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dbbd4aa20ca51e63b53bbde6a0fa4254b1aaabb74d2f542df7a7959feb1d760c", size = 9926987, upload-time = "2026-03-31T06:46:11.724Z" }, - { url = "https://files.pythonhosted.org/packages/52/77/9b1c2d6070b5dbe239a7bc889e21bfa58720793fb902d1e070695d87c6d0/pandas-3.0.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:339dda302bd8369dedeae979cb750e484d549b563c3f54f3922cb8ff4978c5eb", size = 10757067, upload-time = "2026-03-31T06:46:14.903Z" }, - { url = "https://files.pythonhosted.org/packages/20/17/ec40d981705654853726e7ac9aea9ddbb4a5d9cf54d8472222f4f3de06c2/pandas-3.0.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:61c2fd96d72b983a9891b2598f286befd4ad262161a609c92dc1652544b46b76", size = 11258787, upload-time = "2026-03-31T06:46:17.683Z" }, - { url = "https://files.pythonhosted.org/packages/90/e3/3f1126d43d3702ca8773871a81c9f15122a1f412342cc56284ffda5b1f70/pandas-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c934008c733b8bbea273ea308b73b3156f0181e5b72960790b09c18a2794fe1e", size = 11771616, upload-time = "2026-03-31T06:46:20.532Z" }, - { url = "https://files.pythonhosted.org/packages/2e/cf/0f4e268e1f5062e44a6bda9f925806721cd4c95c2b808a4c82ebe914f96b/pandas-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:60a80bb4feacbef5e1447a3f82c33209c8b7e07f28d805cfd1fb951e5cb443aa", size = 12337623, upload-time = "2026-03-31T06:46:23.754Z" }, - { url = "https://files.pythonhosted.org/packages/44/a0/97a6339859d4acb2536efb24feb6708e82f7d33b2ed7e036f2983fcced82/pandas-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:ed72cb3f45190874eb579c64fa92d9df74e98fd63e2be7f62bce5ace0ade61df", size = 9897372, upload-time = "2026-03-31T06:46:26.703Z" }, - { url = "https://files.pythonhosted.org/packages/8f/eb/781516b808a99ddf288143cec46b342b3016c3414d137da1fdc3290d8860/pandas-3.0.2-cp311-cp311-win_arm64.whl", hash = "sha256:f12b1a9e332c01e09510586f8ca9b108fd631fd656af82e452d7315ef6df5f9f", size = 9154922, upload-time = "2026-03-31T06:46:30.284Z" }, - { url = "https://files.pythonhosted.org/packages/f3/b0/c20bd4d6d3f736e6bd6b55794e9cd0a617b858eaad27c8f410ea05d953b7/pandas-3.0.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:232a70ebb568c0c4d2db4584f338c1577d81e3af63292208d615907b698a0f18", size = 10347921, upload-time = "2026-03-31T06:46:33.36Z" }, - { url = "https://files.pythonhosted.org/packages/35/d0/4831af68ce30cc2d03c697bea8450e3225a835ef497d0d70f31b8cdde965/pandas-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:970762605cff1ca0d3f71ed4f3a769ea8f85fc8e6348f6e110b8fea7e6eb5a14", size = 9888127, upload-time = "2026-03-31T06:46:36.253Z" }, - { url = "https://files.pythonhosted.org/packages/61/a9/16ea9346e1fc4a96e2896242d9bc674764fb9049b0044c0132502f7a771e/pandas-3.0.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aff4e6f4d722e0652707d7bcb190c445fe58428500c6d16005b02401764b1b3d", size = 10399577, upload-time = "2026-03-31T06:46:39.224Z" }, - { url = "https://files.pythonhosted.org/packages/c4/a8/3a61a721472959ab0ce865ef05d10b0d6bfe27ce8801c99f33d4fa996e65/pandas-3.0.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ef8b27695c3d3dc78403c9a7d5e59a62d5464a7e1123b4e0042763f7104dc74f", size = 10880030, upload-time = "2026-03-31T06:46:42.412Z" }, - { url = "https://files.pythonhosted.org/packages/da/65/7225c0ea4d6ce9cb2160a7fb7f39804871049f016e74782e5dade4d14109/pandas-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f8d68083e49e16b84734eb1a4dcae4259a75c90fb6e2251ab9a00b61120c06ab", size = 11409468, upload-time = "2026-03-31T06:46:45.2Z" }, - { url = "https://files.pythonhosted.org/packages/fa/5b/46e7c76032639f2132359b5cf4c785dd8cf9aea5ea64699eac752f02b9db/pandas-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:32cc41f310ebd4a296d93515fcac312216adfedb1894e879303987b8f1e2b97d", size = 11936381, upload-time = "2026-03-31T06:46:48.293Z" }, - { url = "https://files.pythonhosted.org/packages/7b/8b/721a9cff6fa6a91b162eb51019c6243b82b3226c71bb6c8ef4a9bd65cbc6/pandas-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:a4785e1d6547d8427c5208b748ae2efb64659a21bd82bf440d4262d02bfa02a4", size = 9744993, upload-time = "2026-03-31T06:46:51.488Z" }, - { url = "https://files.pythonhosted.org/packages/d5/18/7f0bd34ae27b28159aa80f2a6799f47fda34f7fb938a76e20c7b7fe3b200/pandas-3.0.2-cp312-cp312-win_arm64.whl", hash = "sha256:08504503f7101300107ecdc8df73658e4347586db5cfdadabc1592e9d7e7a0fd", size = 9056118, upload-time = "2026-03-31T06:46:54.548Z" }, - { url = "https://files.pythonhosted.org/packages/bf/ca/3e639a1ea6fcd0617ca4e8ca45f62a74de33a56ae6cd552735470b22c8d3/pandas-3.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5918ba197c951dec132b0c5929a00c0bf05d5942f590d3c10a807f6e15a57d3", size = 10321105, upload-time = "2026-03-31T06:46:57.327Z" }, - { url = "https://files.pythonhosted.org/packages/0b/77/dbc82ff2fb0e63c6564356682bf201edff0ba16c98630d21a1fb312a8182/pandas-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d606a041c89c0a474a4702d532ab7e73a14fe35c8d427b972a625c8e46373668", size = 9864088, upload-time = "2026-03-31T06:46:59.935Z" }, - { url = "https://files.pythonhosted.org/packages/5c/2b/341f1b04bbca2e17e13cd3f08c215b70ef2c60c5356ef1e8c6857449edc7/pandas-3.0.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:710246ba0616e86891b58ab95f2495143bb2bc83ab6b06747c74216f583a6ac9", size = 10369066, upload-time = "2026-03-31T06:47:02.792Z" }, - { url = "https://files.pythonhosted.org/packages/12/c5/cbb1ffefb20a93d3f0e1fdcda699fb84976210d411b008f97f48bf6ce27e/pandas-3.0.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5d3cfe227c725b1f3dff4278b43d8c784656a42a9325b63af6b1492a8232209e", size = 10876780, upload-time = "2026-03-31T06:47:06.205Z" }, - { url = "https://files.pythonhosted.org/packages/98/fe/2249ae5e0a69bd0ddf17353d0a5d26611d70970111f5b3600cdc8be883e7/pandas-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c3b723df9087a9a9a840e263ebd9f88b64a12075d1bf2ea401a5a42f254f084d", size = 11375181, upload-time = "2026-03-31T06:47:09.383Z" }, - { url = "https://files.pythonhosted.org/packages/de/64/77a38b09e70b6464883b8d7584ab543e748e42c1b5d337a2ee088e0df741/pandas-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a3096110bf9eac0070b7208465f2740e2d8a670d5cb6530b5bb884eca495fd39", size = 11928899, upload-time = "2026-03-31T06:47:12.686Z" }, - { url = "https://files.pythonhosted.org/packages/5e/52/42855bf626868413f761addd574acc6195880ae247a5346477a4361c3acb/pandas-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:07a10f5c36512eead51bc578eb3354ad17578b22c013d89a796ab5eee90cd991", size = 9746574, upload-time = "2026-03-31T06:47:15.64Z" }, - { url = "https://files.pythonhosted.org/packages/88/39/21304ae06a25e8bf9fc820d69b29b2c495b2ae580d1e143146c309941760/pandas-3.0.2-cp313-cp313-win_arm64.whl", hash = "sha256:5fdbfa05931071aba28b408e59226186b01eb5e92bea2ab78b65863ca3228d84", size = 9047156, upload-time = "2026-03-31T06:47:18.595Z" }, - { url = "https://files.pythonhosted.org/packages/72/20/7defa8b27d4f330a903bb68eea33be07d839c5ea6bdda54174efcec0e1d2/pandas-3.0.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:dbc20dea3b9e27d0e66d74c42b2d0c1bed9c2ffe92adea33633e3bedeb5ac235", size = 10756238, upload-time = "2026-03-31T06:47:22.012Z" }, - { url = "https://files.pythonhosted.org/packages/e9/95/49433c14862c636afc0e9b2db83ff16b3ad92959364e52b2955e44c8e94c/pandas-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b75c347eff42497452116ce05ef461822d97ce5b9ff8df6edacb8076092c855d", size = 10408520, upload-time = "2026-03-31T06:47:25.197Z" }, - { url = "https://files.pythonhosted.org/packages/3b/f8/462ad2b5881d6b8ec8e5f7ed2ea1893faa02290d13870a1600fe72ad8efc/pandas-3.0.2-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1478075142e83a5571782ad007fb201ed074bdeac7ebcc8890c71442e96adf7", size = 10324154, upload-time = "2026-03-31T06:47:28.097Z" }, - { url = "https://files.pythonhosted.org/packages/0a/65/d1e69b649cbcddda23ad6e4c40ef935340f6f652a006e5cbc3555ac8adb3/pandas-3.0.2-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5880314e69e763d4c8b27937090de570f1fb8d027059a7ada3f7f8e98bdcb677", size = 10714449, upload-time = "2026-03-31T06:47:30.85Z" }, - { url = "https://files.pythonhosted.org/packages/47/a4/85b59bc65b8190ea3689882db6cdf32a5003c0ccd5a586c30fdcc3ffc4fc/pandas-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b5329e26898896f06035241a626d7c335daa479b9bbc82be7c2742d048e41172", size = 11338475, upload-time = "2026-03-31T06:47:34.026Z" }, - { url = "https://files.pythonhosted.org/packages/1e/c4/bc6966c6e38e5d9478b935272d124d80a589511ed1612a5d21d36f664c68/pandas-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:81526c4afd31971f8b62671442a4b2b51e0aa9acc3819c9f0f12a28b6fcf85f1", size = 11786568, upload-time = "2026-03-31T06:47:36.941Z" }, - { url = "https://files.pythonhosted.org/packages/e8/74/09298ca9740beed1d3504e073d67e128aa07e5ca5ca2824b0c674c0b8676/pandas-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:7cadd7e9a44ec13b621aec60f9150e744cfc7a3dd32924a7e2f45edff31823b0", size = 10488652, upload-time = "2026-03-31T06:47:40.612Z" }, - { url = "https://files.pythonhosted.org/packages/bb/40/c6ea527147c73b24fc15c891c3fcffe9c019793119c5742b8784a062c7db/pandas-3.0.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:db0dbfd2a6cdf3770aa60464d50333d8f3d9165b2f2671bcc299b72de5a6677b", size = 10326084, upload-time = "2026-03-31T06:47:43.834Z" }, - { url = "https://files.pythonhosted.org/packages/95/25/bdb9326c3b5455f8d4d3549fce7abcf967259de146fe2cf7a82368141948/pandas-3.0.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0555c5882688a39317179ab4a0ed41d3ebc8812ab14c69364bbee8fb7a3f6288", size = 9914146, upload-time = "2026-03-31T06:47:46.67Z" }, - { url = "https://files.pythonhosted.org/packages/8d/77/3a227ff3337aa376c60d288e1d61c5d097131d0ac71f954d90a8f369e422/pandas-3.0.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:01f31a546acd5574ef77fe199bc90b55527c225c20ccda6601cf6b0fd5ed597c", size = 10444081, upload-time = "2026-03-31T06:47:49.681Z" }, - { url = "https://files.pythonhosted.org/packages/15/88/3cdd54fa279341afa10acf8d2b503556b1375245dccc9315659f795dd2e9/pandas-3.0.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:deeca1b5a931fdf0c2212c8a659ade6d3b1edc21f0914ce71ef24456ca7a6535", size = 10897535, upload-time = "2026-03-31T06:47:53.033Z" }, - { url = "https://files.pythonhosted.org/packages/06/9d/98cc7a7624f7932e40f434299260e2917b090a579d75937cb8a57b9d2de3/pandas-3.0.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:0f48afd9bb13300ffb5a3316973324c787054ba6665cda0da3fbd67f451995db", size = 11446992, upload-time = "2026-03-31T06:47:56.193Z" }, - { url = "https://files.pythonhosted.org/packages/9a/cd/19ff605cc3760e80602e6826ddef2824d8e7050ed80f2e11c4b079741dc3/pandas-3.0.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:6c4d8458b97a35717b62469a4ea0e85abd5ed8687277f5ccfc67f8a5126f8c53", size = 11968257, upload-time = "2026-03-31T06:47:59.137Z" }, - { url = "https://files.pythonhosted.org/packages/db/60/aba6a38de456e7341285102bede27514795c1eaa353bc0e7638b6b785356/pandas-3.0.2-cp314-cp314-win_amd64.whl", hash = "sha256:b35d14bb5d8285d9494fe93815a9e9307c0876e10f1e8e89ac5b88f728ec8dcf", size = 9865893, upload-time = "2026-03-31T06:48:02.038Z" }, - { url = "https://files.pythonhosted.org/packages/08/71/e5ec979dd2e8a093dacb8864598c0ff59a0cee0bbcdc0bfec16a51684d4f/pandas-3.0.2-cp314-cp314-win_arm64.whl", hash = "sha256:63d141b56ef686f7f0d714cfb8de4e320475b86bf4b620aa0b7da89af8cbdbbb", size = 9188644, upload-time = "2026-03-31T06:48:05.045Z" }, - { url = "https://files.pythonhosted.org/packages/f1/6c/7b45d85db19cae1eb524f2418ceaa9d85965dcf7b764ed151386b7c540f0/pandas-3.0.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:140f0cffb1fa2524e874dde5b477d9defe10780d8e9e220d259b2c0874c89d9d", size = 10776246, upload-time = "2026-03-31T06:48:07.789Z" }, - { url = "https://files.pythonhosted.org/packages/a8/3e/7b00648b086c106e81766f25322b48aa8dfa95b55e621dbdf2fdd413a117/pandas-3.0.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ae37e833ff4fed0ba352f6bdd8b73ba3ab3256a85e54edfd1ab51ae40cca0af8", size = 10424801, upload-time = "2026-03-31T06:48:10.897Z" }, - { url = "https://files.pythonhosted.org/packages/da/6e/558dd09a71b53b4008e7fc8a98ec6d447e9bfb63cdaeea10e5eb9b2dabe8/pandas-3.0.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4d888a5c678a419a5bb41a2a93818e8ed9fd3172246555c0b37b7cc27027effd", size = 10345643, upload-time = "2026-03-31T06:48:13.7Z" }, - { url = "https://files.pythonhosted.org/packages/be/e3/921c93b4d9a280409451dc8d07b062b503bbec0531d2627e73a756e99a82/pandas-3.0.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b444dc64c079e84df91baa8bf613d58405645461cabca929d9178f2cd392398d", size = 10743641, upload-time = "2026-03-31T06:48:16.659Z" }, - { url = "https://files.pythonhosted.org/packages/56/ca/fd17286f24fa3b4d067965d8d5d7e14fe557dd4f979a0b068ac0deaf8228/pandas-3.0.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:4544c7a54920de8eeacaa1466a6b7268ecfbc9bc64ab4dbb89c6bbe94d5e0660", size = 11361993, upload-time = "2026-03-31T06:48:19.475Z" }, - { url = "https://files.pythonhosted.org/packages/e4/a5/2f6ed612056819de445a433ca1f2821ac3dab7f150d569a59e9cc105de1d/pandas-3.0.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:734be7551687c00fbd760dc0522ed974f82ad230d4a10f54bf51b80d44a08702", size = 11815274, upload-time = "2026-03-31T06:48:22.695Z" }, - { url = "https://files.pythonhosted.org/packages/00/2f/b622683e99ec3ce00b0854bac9e80868592c5b051733f2cf3a868e5fea26/pandas-3.0.2-cp314-cp314t-win_amd64.whl", hash = "sha256:57a07209bebcbcf768d2d13c9b78b852f9a15978dac41b9e6421a81ad4cdd276", size = 10888530, upload-time = "2026-03-31T06:48:25.806Z" }, - { url = "https://files.pythonhosted.org/packages/cb/2b/f8434233fab2bd66a02ec014febe4e5adced20e2693e0e90a07d118ed30e/pandas-3.0.2-cp314-cp314t-win_arm64.whl", hash = "sha256:5371b72c2d4d415d08765f32d689217a43227484e81b2305b52076e328f6f482", size = 9455341, upload-time = "2026-03-31T06:48:28.418Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/f8/87/4341c6252d1c47b08768c3d25ac487362bf403f0313ddae4a2a26c9b1b4c/pandas-3.0.3.tar.gz", hash = "sha256:696a4a00a2a2a35d4e5deb3fc946641b96c944f02230e4f76137fe35d806c4fc", size = 4651414, upload-time = "2026-05-11T18:54:29.21Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/16/b5c76b838fd9bf6ce84d3a53346b8874ec05c5f0040d75ef2c320100cd2a/pandas-3.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:455f6f8139d4282188f526868dbc3c828470e88a3d9d59a891bd46a455f21b98", size = 10338495, upload-time = "2026-05-11T18:52:11.558Z" }, + { url = "https://files.pythonhosted.org/packages/5a/b0/a4ffc4ae74d2d822200dcc46898987d8eb6032d1e2b219cae39da6f5cbcc/pandas-3.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4e15135e2ee5df1063313e2425ceef8ac0f4ae775893815b0923651b806a5639", size = 9938250, upload-time = "2026-05-11T18:52:17.005Z" }, + { url = "https://files.pythonhosted.org/packages/2e/b2/3323601a52caee42c019e370090ca4544b241437240ca04f786cce82b0cf/pandas-3.0.3-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:05f1f1752b8533ea03f7f39a9c15b1a058d067bb48f4748948e7a8691e0510f2", size = 10770558, upload-time = "2026-05-11T18:52:19.865Z" }, + { url = "https://files.pythonhosted.org/packages/32/f1/bbecd2f867b97abebe0f9b53d750f862251b40337e061b36676ded3d920f/pandas-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a1e45c80cceb3b4a21bc5939d52e8cbd8d9b7305309219d59e9754d9ce09e27", size = 11274611, upload-time = "2026-05-11T18:52:22.622Z" }, + { url = "https://files.pythonhosted.org/packages/7f/4f/eafabf2d5fae5adf143b4d18d3706c5efdc368a7c4eb1ee8a3eddabbd0f6/pandas-3.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:14da8316da4d0c5a77618425996bfb1248ca87fc2c1486e6fde4652bd18b5824", size = 11784670, upload-time = "2026-05-11T18:52:25.4Z" }, + { url = "https://files.pythonhosted.org/packages/49/44/1eb20389301b57b19cc099a1c2f662501f72f08a65f912d05822613c1532/pandas-3.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a55066a0505dae0ba2b50a46637db34b46f9094c65c5d4800794ef6335010938", size = 12353708, upload-time = "2026-05-11T18:52:28.139Z" }, + { url = "https://files.pythonhosted.org/packages/eb/62/c321f13b5ba1819fc8dca456c7fce578da2dcfecff1abbf0eaddf8406c0f/pandas-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6674ab18ad8c57802867264b00e15e7bb904700cdd9046e3b2fa1fce237439ea", size = 9907609, upload-time = "2026-05-11T18:52:30.982Z" }, + { url = "https://files.pythonhosted.org/packages/53/85/1b7f563ebc6357c27233a02a96b589bcce1fa9c6eb89fb4f0e56421d277e/pandas-3.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:5cc09a68b3120e0f54870dede8287a7bb1fa463907e4fcec1ea77cab6179bf7a", size = 9165596, upload-time = "2026-05-11T18:52:33.334Z" }, + { url = "https://files.pythonhosted.org/packages/24/f1/392f8c5bfc16f66a0d2d41561c01627c228fe7ed2a0d056ef11315042570/pandas-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fed2ff7fd9779120e388e285fc029bd5cf9490cdd2e4166a9ee22c0e49a9ab09", size = 10357846, upload-time = "2026-05-11T18:52:36.143Z" }, + { url = "https://files.pythonhosted.org/packages/cf/3d/b16412745651e855f357e5e66930248688378853a6e2698a214e331fba1f/pandas-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b168fc218fd80a6cbdbdbc1a97ddc7889ed057d7eb45f50d866ceab5f39904c4", size = 9899550, upload-time = "2026-05-11T18:52:38.976Z" }, + { url = "https://files.pythonhosted.org/packages/31/a8/fa2535168fffcedf67f4f6de28d2dd903a747ca7c8ea6989451aaeb3a92f/pandas-3.0.3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0383c72c75cdcca61a9e116e611143902dbfd08bff356829c2f6d1cf40a9ca8c", size = 10412965, upload-time = "2026-05-11T18:52:41.915Z" }, + { url = "https://files.pythonhosted.org/packages/65/b6/09b01cdbc15224e2850365192d17b7bdebb8bdbd8780ed221fcdf0d9a515/pandas-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6dc0b3fd2169c9157deed50b4d519553a3655c8c6a96027136d654592be973a9", size = 10894600, upload-time = "2026-05-11T18:52:45.02Z" }, + { url = "https://files.pythonhosted.org/packages/c9/a4/2eb28f2fccb4ced4a2c79ab2a5dee9ade1ebf44922ebad6fea158c9f95d4/pandas-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7e65d5407dc0b394f509699650e4a2ec01c0514f21850f453fa60f3be79a5dbf", size = 11422824, upload-time = "2026-05-11T18:52:48.058Z" }, + { url = "https://files.pythonhosted.org/packages/f8/45/830bb57f533a4604b355e07edcb8ea18cf88b5f94e5fca92f27052d7c597/pandas-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f8894dc474d648fe7b6ff0ca9b0bd73950d19952bc1a6534540762c5d79d305c", size = 11950889, upload-time = "2026-05-11T18:52:50.905Z" }, + { url = "https://files.pythonhosted.org/packages/b9/c5/fc1b368f303087d20e8c9bf3d6ceb186263cfac0ade735cd938538bea839/pandas-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:c7be265b62cef88e253a941e4698604973736dcfe242fdb5198f0f7bc473cdcc", size = 9755463, upload-time = "2026-05-11T18:52:53.386Z" }, + { url = "https://files.pythonhosted.org/packages/86/bd/fda8f9705b1b09c6ebe14bfc0fa0e4ec8584d54ea673628f157ff55131af/pandas-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:557409bc4178e70ee8d9ddb494798e51ebf6ea59330f6be22c51bab2a7db6c49", size = 9066158, upload-time = "2026-05-11T18:52:56.038Z" }, + { url = "https://files.pythonhosted.org/packages/c5/90/62d8302883c44308c477e222c3daf7c813a34c8e96985882fbd53d964352/pandas-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:67b3b64c11910cfa29f4e94a14d3bff9ee693b6fc76055e7cad549cee0aec5fa", size = 10331071, upload-time = "2026-05-11T18:52:58.838Z" }, + { url = "https://files.pythonhosted.org/packages/7f/ae/6a6493c783a101f165e4356953ba3c74d6f77f0042fa7d753da9dfbb640c/pandas-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:39436b377d56d2a2e52d0395bdbee171f01068e99af5250509aceeb929f765c7", size = 9875690, upload-time = "2026-05-11T18:53:01.431Z" }, + { url = "https://files.pythonhosted.org/packages/62/7c/5df8e9f56c69a2769fbe9382a5ef8f2658c007e376434e1e2cbb57ad895f/pandas-3.0.3-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4be06d68f9ddcfc645b87534911da79a8fbffc7573c80e0edcf42a5020624d8", size = 10381634, upload-time = "2026-05-11T18:53:04.393Z" }, + { url = "https://files.pythonhosted.org/packages/99/68/1237369725aa617bb358263d535803e3053fdbc593513ec5ed9c9896b5b6/pandas-3.0.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a4eeb6830daf35a71cc09649bd823e2b542dac246cdee9614c6e4bd65028cd6a", size = 10891243, upload-time = "2026-05-11T18:53:07.643Z" }, + { url = "https://files.pythonhosted.org/packages/25/93/77d108e8af7222b4a503ebde0e30215b1c2e4f8e53a526431890f22d5586/pandas-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1928e07221f82db493cd4af1e23c1bfca524a19a4699887975bff68f49a72bfb", size = 11388659, upload-time = "2026-05-11T18:53:10.634Z" }, + { url = "https://files.pythonhosted.org/packages/d0/bd/eff5b4399f332ac386c853f6cd2bd3fa2ca0061b9f36ecd9c4d7c4265649/pandas-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:51b1fe551acb77dac643c6fda86084d8d446c10fe64b06a9cc29c4cc8540e7f2", size = 11942880, upload-time = "2026-05-11T18:53:13.536Z" }, + { url = "https://files.pythonhosted.org/packages/2c/20/559ace4200982c3887d0b86bfd0d856a2143ef8ddab63cc07934951a964c/pandas-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:a82d532a3351d435432cd913edbccaf8b8e01d4dd0e5ced5a8d2e8ecd94c7e44", size = 9757091, upload-time = "2026-05-11T18:53:16.306Z" }, + { url = "https://files.pythonhosted.org/packages/3a/66/69055a09fe200f29f922a3eeec4804611900b95f52d932ece3393c3c0c19/pandas-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:275c14e0fce14a2ec20eee474aecd305478ea3c1e6f6a9d8fe219a165542717e", size = 9057282, upload-time = "2026-05-11T18:53:18.768Z" }, + { url = "https://files.pythonhosted.org/packages/57/0e/efe801b0e6811e8e650cd21b7f2608e30f08a7067e2bf6e8752b0d56ee3c/pandas-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:46997386d528eb40376ecd6b033cf4a8a1e5282580f68f43de875b78cba2199d", size = 10767016, upload-time = "2026-05-11T18:53:21.227Z" }, + { url = "https://files.pythonhosted.org/packages/ea/dc/eb55135a1d5f0f0519f28da1f609a206d2cad1f9c35c32d51e38dd7261ae/pandas-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:261e308dfb22448384b7580cf719d2f998fe2966c92893c3e77d14008af1f066", size = 10420210, upload-time = "2026-05-11T18:53:23.982Z" }, + { url = "https://files.pythonhosted.org/packages/c6/3e/b1d5d955ce33ffecb407465a60bc32769d74fcf68224b7ae67ae11d4dea4/pandas-3.0.3-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dd1a5d1def6a46002e964510bdc67c368aa0951df5d1d9f8365336f5a1f490cd", size = 10336126, upload-time = "2026-05-11T18:53:26.731Z" }, + { url = "https://files.pythonhosted.org/packages/f5/76/a01261711ab60a22d71b862f0de20e4c504bf80457270ad8cb42110f6abc/pandas-3.0.3-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d72828c20c6d6e83e1e22a6a3b47b326b71664112fa9705dcbccfd7a39b62085", size = 10728051, upload-time = "2026-05-11T18:53:29.125Z" }, + { url = "https://files.pythonhosted.org/packages/e9/21/ea191195e587b18cf682e97f433f81b2d0fbe341380e80a3e0d6e4403c8e/pandas-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:d26cbe1fcfc12e8fd900e2454163e466b2d3af84f7c75481df7683ffc073d870", size = 11350796, upload-time = "2026-05-11T18:53:32.056Z" }, + { url = "https://files.pythonhosted.org/packages/64/69/f0eaaf54939f0e8c6768fd06be9af2cef9b36048b96dfb9e1b2c685a807e/pandas-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:3e91cec1879ada0624fc3dc9953c5cbd60208e59c0db28f540c5d6d47502422f", size = 11799741, upload-time = "2026-05-11T18:53:34.985Z" }, + { url = "https://files.pythonhosted.org/packages/45/a4/865e0e510cae5fc2194de4db28be638952de942571ba9125934fd9c01d47/pandas-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:08d789b41f87e0905880e293cedf6197ce71fe67cc081358b1e148a491b9bd13", size = 10499958, upload-time = "2026-05-11T18:53:37.857Z" }, + { url = "https://files.pythonhosted.org/packages/86/54/effdcc3c0ff7a08037889200e148ebe94c16c4f653be078c7b3675955df1/pandas-3.0.3-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:3650109c0f22879df8bd6179ab9ee3d7f1d1d4e7e0094a3f0032d9f51e2e64ac", size = 10336065, upload-time = "2026-05-11T18:53:41.099Z" }, + { url = "https://files.pythonhosted.org/packages/68/10/bf2d6738d72748b961a3751ab89522d58c54efc36a8e1a12161216cd45cf/pandas-3.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:bab900348131a7db1f69a7309ef141fd5680f1487094193bcbbb61791573bf8f", size = 9926101, upload-time = "2026-05-11T18:53:43.515Z" }, + { url = "https://files.pythonhosted.org/packages/ae/e9/e35cf11c8a136e757b956f5f0efdcaa50aecde85ea055f1898dfc68262f3/pandas-3.0.3-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ba7e08b9ac1d54569cd1e256e3668975ed624d6826f7b68df0342b012007bddb", size = 10457553, upload-time = "2026-05-11T18:53:46.394Z" }, + { url = "https://files.pythonhosted.org/packages/58/3b/1cdec6772bdbaf7b25dab360c59f03cadf05492dd724c6540af905389b07/pandas-3.0.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d71c63ae4ebdbf70209742096f1fc46a83a0613c99d4b23766cced9ff8cd62a", size = 10914065, upload-time = "2026-05-11T18:53:49.134Z" }, + { url = "https://files.pythonhosted.org/packages/c4/c2/1ef644445fcd72e3627bceec77e3560636f87ddce4ed841afe76b83b5bf9/pandas-3.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:e3a2ec42c98ffa2565a67e08e218d06d72576d758d90facb7c00805194d8f360", size = 11459188, upload-time = "2026-05-11T18:53:52.527Z" }, + { url = "https://files.pythonhosted.org/packages/7e/49/4d8d4f42cbc9c4adc7a1870f269c02cbd6cd40d059622c06fb298addcbad/pandas-3.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:335f62418ed562cfc3c49e9e196375c28b729dcef8543abf4f9438e381bf3c76", size = 11982966, upload-time = "2026-05-11T18:53:55.043Z" }, + { url = "https://files.pythonhosted.org/packages/38/55/792619469bab9882d8bbd5865d45a72f6478762d04a9af4bf0d08c503e95/pandas-3.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:3c20a521bbb85902f79f7270c80a59e1b5452d96d170c034f207181870f97ac5", size = 9876755, upload-time = "2026-05-11T18:53:58.067Z" }, + { url = "https://files.pythonhosted.org/packages/2a/af/33c469653b0ba03b50c3a98192d4c07f0c75c66b263ceb097fce0ee97d31/pandas-3.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:a2d2dff8a04f3917b55ab3910c32990f8ddf7eceba114947838cefa976a68977", size = 9198658, upload-time = "2026-05-11T18:54:00.733Z" }, + { url = "https://files.pythonhosted.org/packages/a2/fa/b8c257bd76b8bd060c3a9151c1fca05e9b9c5e3af5d0f549c0356f6d143d/pandas-3.0.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:0d589105b3c14645af1738ff279b2995102d8f7a03b0a66dc8d95550eb513e04", size = 10787242, upload-time = "2026-05-11T18:54:03.564Z" }, + { url = "https://files.pythonhosted.org/packages/54/eb/f19206ffb0bf1919002969aa448b4702c6594845156a6f8050674855aac3/pandas-3.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:13fc1e853d9e04743d11ba75a985ccbc2a317fe07d8af61e445a6fd24dacd6a6", size = 10436369, upload-time = "2026-05-11T18:54:06.311Z" }, + { url = "https://files.pythonhosted.org/packages/fd/24/c7c39fb4fe22b71a0c2d78bf0c585c600092d85f94f086d2b3b2f6ca27e2/pandas-3.0.3-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:819959dab7bbd0049c15623fbac4e29a191b9528160a61fb1032242d8ced2d9c", size = 10358306, upload-time = "2026-05-11T18:54:09.085Z" }, + { url = "https://files.pythonhosted.org/packages/16/ec/dd2a9eb7fa1204df88c0864164e35b228ac581062ac612ba0a67fd812e4c/pandas-3.0.3-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:60ae316d3fd75d1858d450d0db0103ea2be3e7d4a95ec2f064f7e2ae63f7b028", size = 10758394, upload-time = "2026-05-11T18:54:11.956Z" }, + { url = "https://files.pythonhosted.org/packages/95/6e/00c61ea8e85b4f6d8d35e11852a1a4998fc7fafc91c6a602d1cc9c972d64/pandas-3.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bd3a518890b400d32f9023722dc9a9a5c969f00b415419a3c06c043f09bb5d7d", size = 11375717, upload-time = "2026-05-11T18:54:14.539Z" }, + { url = "https://files.pythonhosted.org/packages/31/89/8fc1c268969fac43688d65fd92e67df24bd128d53cb4d2eee534cd307399/pandas-3.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c39be2d709d01fa972a0cabc522389fceca4f3969332ba25a7d6c5802cf976a", size = 11828897, upload-time = "2026-05-11T18:54:17.146Z" }, + { url = "https://files.pythonhosted.org/packages/56/3b/e7d20dea247a3e6dc0bd8a6953854afbedc03951def4e7371e05e7263e25/pandas-3.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4db8c527972a821cf5286b40ccc57642a39bc62e62022b42f99f8a67fca8c3a1", size = 10900855, upload-time = "2026-05-11T18:54:19.72Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/68a0978d1ef8502b8492099beaa6e7a0c1b32e3b5d4f677f5810cb08711c/pandas-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:b2c95f8bfc1ee412bf482605d7bfd30c12d1d26bd59fdd91efeef1d4718decb1", size = 9466464, upload-time = "2026-05-11T18:54:22.754Z" }, ] [[package]] @@ -2753,7 +2753,7 @@ dependencies = [ { name = "accelerate" }, { name = "huggingface-hub" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "packaging" }, { name = "psutil" }, { name = "pyyaml" }, @@ -3327,11 +3327,11 @@ wheels = [ [[package]] name = "pyreadline3" -version = "3.5.4" +version = "3.5.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0f/49/4cea918a08f02817aabae639e3d0ac046fef9f9180518a3ad394e22da148/pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7", size = 99839, upload-time = "2024-09-19T02:40:10.062Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/6d/f94028646d7bbe6d9d873c47ee7c246f2d29129d253f0d96cb6fcab70733/pyreadline3-3.5.6.tar.gz", hash = "sha256:61e53218b99656091ddb077df9e71f25850e72e030b6183b39c9b7e6e4f4a9bf", size = 100368, upload-time = "2026-05-14T17:55:04.471Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, + { url = "https://files.pythonhosted.org/packages/f7/5e/35c856e186b74678c24927847ad9895a51f1bc02a0c6126477a6c6040064/pyreadline3-3.5.6-py3-none-any.whl", hash = "sha256:8449b734232e42a5dcd74048e39b60db2839a4c38cf3ae2bf7707d58b5389c0d", size = 85243, upload-time = "2026-05-14T17:55:03.262Z" }, ] [[package]] @@ -3404,15 +3404,15 @@ wheels = [ [[package]] name = "python-discovery" -version = "1.3.0" +version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "platformdirs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ae/e0/cc5a8653e9a24f6cf84768f05064aa8ed5a83dcefd5e2a043db14a1c5f44/python_discovery-1.3.0.tar.gz", hash = "sha256:d098f1e86be5d45fe4d14bf1029294aabbd332f4321179dec85e76cddce834b0", size = 63925, upload-time = "2026-05-05T14:38:39.769Z" } +sdist = { url = "https://files.pythonhosted.org/packages/48/60/e88788207d81e46362cfbef0d4aaf4c0f49efc3c12d4c3fa3f542c34ebec/python_discovery-1.3.1.tar.gz", hash = "sha256:62f6db28064c9613e7ca76cb3f00c38c839a07c31c00dfe7ed0986493d2150a6", size = 68011, upload-time = "2026-05-12T20:53:36.336Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/d4/24d543ab8b8158b7f5a97113c831205f5c900c92c8762b1e7f44b7ea0405/python_discovery-1.3.0-py3-none-any.whl", hash = "sha256:441d9ced3dfce36e113beb35ca302c71c7ef06f3c0f9c227a0b9bb3bd49b9e9f", size = 33124, upload-time = "2026-05-05T14:38:38.539Z" }, + { url = "https://files.pythonhosted.org/packages/b7/6f/a05a317a66fee0aad270011461f1a63a453ed12471249f172f7d2e2bc7b4/python_discovery-1.3.1-py3-none-any.whl", hash = "sha256:ed188687ebb3b82c01a17cd5ac62fc94d9f6487a7f1a0f9dfe89753fec91039c", size = 33185, upload-time = "2026-05-12T20:53:34.969Z" }, ] [[package]] @@ -3620,7 +3620,7 @@ wheels = [ [[package]] name = "requests" -version = "2.33.1" +version = "2.34.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -3628,9 +3628,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5f/a4/98b9c7c6428a668bf7e42ebb7c79d576a1c3c1e3ae2d47e674b468388871/requests-2.33.1.tar.gz", hash = "sha256:18817f8c57c6263968bc123d237e3b8b08ac046f5456bd1e307ee8f4250d3517", size = 134120, upload-time = "2026-03-30T16:09:15.531Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/c3/e2a2b89f2d3e2179abd6d00ebd70bff6273f37fb3e0cc209f48b39d00cbf/requests-2.34.2.tar.gz", hash = "sha256:f288924cae4e29463698d6d60bc6a4da69c89185ad1e0bcc4104f584e960b9ed", size = 142856, upload-time = "2026-05-14T19:25:27.735Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" }, + { url = "https://files.pythonhosted.org/packages/a0/f4/c67b0b3f1b9245e8d266f0f112c500d50e5b4e83cb6f3b71b6528104182a/requests-2.34.2-py3-none-any.whl", hash = "sha256:2a0d60c172f83ac6ab31e4554906c0f3b3588d37b5cb939b1c061f4907e278e0", size = 73075, upload-time = "2026-05-14T19:25:26.443Z" }, ] [[package]] @@ -3798,7 +3798,7 @@ resolution-markers = [ "python_full_version == '3.11.*' and platform_machine == 's390x' and sys_platform == 'win32'", ] dependencies = [ - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" } wheels = [ @@ -4212,11 +4212,11 @@ wheels = [ [[package]] name = "stevedore" -version = "5.7.0" +version = "5.8.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/6d/90764092216fa560f6587f83bb70113a8ba510ba436c6476a2b47359057c/stevedore-5.7.0.tar.gz", hash = "sha256:31dd6fe6b3cbe921e21dcefabc9a5f1cf848cf538a1f27543721b8ca09948aa3", size = 516200, upload-time = "2026-02-20T13:27:06.765Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/88/35e4d27d9177d7df76d060e0a18f69c6c5794c96960c94042e20a12c8ba2/stevedore-5.8.0.tar.gz", hash = "sha256:b49867b32ca3016e94100e68dbf26e72aa7b8708d0a3f73c08aeb220370ac715", size = 514710, upload-time = "2026-05-18T09:15:27.731Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/06/36d260a695f383345ab5bbc3fd447249594ae2fa8dfd19c533d5ae23f46b/stevedore-5.7.0-py3-none-any.whl", hash = "sha256:fd25efbb32f1abb4c9e502f385f0018632baac11f9ee5d1b70f88cc5e22ad4ed", size = 54483, upload-time = "2026-02-20T13:27:05.561Z" }, + { url = "https://files.pythonhosted.org/packages/f5/ac/19f9941c74add59d17694930ec8105d5eddeee4ce56dd8632b765ca16d6c/stevedore-5.8.0-py3-none-any.whl", hash = "sha256:88eede9e66ca80e34085b9174e2327da2c61ac91f24f70e41c3ad76e4bb4872b", size = 54553, upload-time = "2026-05-18T09:15:25.82Z" }, ] [[package]] @@ -4242,63 +4242,63 @@ wheels = [ [[package]] name = "tiktoken" -version = "0.12.0" +version = "0.13.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "regex" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/89/b3/2cb7c17b6c4cf8ca983204255d3f1d95eda7213e247e6947a0ee2c747a2c/tiktoken-0.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3de02f5a491cfd179aec916eddb70331814bd6bf764075d39e21d5862e533970", size = 1051991, upload-time = "2025-10-06T20:21:34.098Z" }, - { url = "https://files.pythonhosted.org/packages/27/0f/df139f1df5f6167194ee5ab24634582ba9a1b62c6b996472b0277ec80f66/tiktoken-0.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6cfb6d9b7b54d20af21a912bfe63a2727d9cfa8fbda642fd8322c70340aad16", size = 995798, upload-time = "2025-10-06T20:21:35.579Z" }, - { url = "https://files.pythonhosted.org/packages/ef/5d/26a691f28ab220d5edc09b9b787399b130f24327ef824de15e5d85ef21aa/tiktoken-0.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:cde24cdb1b8a08368f709124f15b36ab5524aac5fa830cc3fdce9c03d4fb8030", size = 1129865, upload-time = "2025-10-06T20:21:36.675Z" }, - { url = "https://files.pythonhosted.org/packages/b2/94/443fab3d4e5ebecac895712abd3849b8da93b7b7dec61c7db5c9c7ebe40c/tiktoken-0.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6de0da39f605992649b9cfa6f84071e3f9ef2cec458d08c5feb1b6f0ff62e134", size = 1152856, upload-time = "2025-10-06T20:21:37.873Z" }, - { url = "https://files.pythonhosted.org/packages/54/35/388f941251b2521c70dd4c5958e598ea6d2c88e28445d2fb8189eecc1dfc/tiktoken-0.12.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6faa0534e0eefbcafaccb75927a4a380463a2eaa7e26000f0173b920e98b720a", size = 1195308, upload-time = "2025-10-06T20:21:39.577Z" }, - { url = "https://files.pythonhosted.org/packages/f8/00/c6681c7f833dd410576183715a530437a9873fa910265817081f65f9105f/tiktoken-0.12.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:82991e04fc860afb933efb63957affc7ad54f83e2216fe7d319007dab1ba5892", size = 1255697, upload-time = "2025-10-06T20:21:41.154Z" }, - { url = "https://files.pythonhosted.org/packages/5f/d2/82e795a6a9bafa034bf26a58e68fe9a89eeaaa610d51dbeb22106ba04f0a/tiktoken-0.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:6fb2995b487c2e31acf0a9e17647e3b242235a20832642bb7a9d1a181c0c1bb1", size = 879375, upload-time = "2025-10-06T20:21:43.201Z" }, - { url = "https://files.pythonhosted.org/packages/de/46/21ea696b21f1d6d1efec8639c204bdf20fde8bafb351e1355c72c5d7de52/tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb", size = 1051565, upload-time = "2025-10-06T20:21:44.566Z" }, - { url = "https://files.pythonhosted.org/packages/c9/d9/35c5d2d9e22bb2a5f74ba48266fb56c63d76ae6f66e02feb628671c0283e/tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa", size = 995284, upload-time = "2025-10-06T20:21:45.622Z" }, - { url = "https://files.pythonhosted.org/packages/01/84/961106c37b8e49b9fdcf33fe007bb3a8fdcc380c528b20cc7fbba80578b8/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc", size = 1129201, upload-time = "2025-10-06T20:21:47.074Z" }, - { url = "https://files.pythonhosted.org/packages/6a/d0/3d9275198e067f8b65076a68894bb52fd253875f3644f0a321a720277b8a/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded", size = 1152444, upload-time = "2025-10-06T20:21:48.139Z" }, - { url = "https://files.pythonhosted.org/packages/78/db/a58e09687c1698a7c592e1038e01c206569b86a0377828d51635561f8ebf/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd", size = 1195080, upload-time = "2025-10-06T20:21:49.246Z" }, - { url = "https://files.pythonhosted.org/packages/9e/1b/a9e4d2bf91d515c0f74afc526fd773a812232dd6cda33ebea7f531202325/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967", size = 1255240, upload-time = "2025-10-06T20:21:50.274Z" }, - { url = "https://files.pythonhosted.org/packages/9d/15/963819345f1b1fb0809070a79e9dd96938d4ca41297367d471733e79c76c/tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def", size = 879422, upload-time = "2025-10-06T20:21:51.734Z" }, - { url = "https://files.pythonhosted.org/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8", size = 1050728, upload-time = "2025-10-06T20:21:52.756Z" }, - { url = "https://files.pythonhosted.org/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b", size = 994049, upload-time = "2025-10-06T20:21:53.782Z" }, - { url = "https://files.pythonhosted.org/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" }, - { url = "https://files.pythonhosted.org/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" }, - { url = "https://files.pythonhosted.org/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" }, - { url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, - { url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, - { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, - { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, - { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" }, - { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, - { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" }, - { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, - { url = "https://files.pythonhosted.org/packages/8e/32/45d02e2e0ea2be3a9ed22afc47d93741247e75018aac967b713b2941f8ea/tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697", size = 879117, upload-time = "2025-10-06T20:22:08.418Z" }, - { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, - { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, - { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" }, - { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, - { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" }, - { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, - { url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, - { url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188, upload-time = "2025-10-06T20:22:19.563Z" }, - { url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978, upload-time = "2025-10-06T20:22:20.702Z" }, - { url = "https://files.pythonhosted.org/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271, upload-time = "2025-10-06T20:22:22.06Z" }, - { url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216, upload-time = "2025-10-06T20:22:23.085Z" }, - { url = "https://files.pythonhosted.org/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860, upload-time = "2025-10-06T20:22:24.602Z" }, - { url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567, upload-time = "2025-10-06T20:22:25.671Z" }, - { url = "https://files.pythonhosted.org/packages/80/57/ce64fd16ac390fafde001268c364d559447ba09b509181b2808622420eec/tiktoken-0.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:399c3dd672a6406719d84442299a490420b458c44d3ae65516302a99675888f3", size = 921067, upload-time = "2025-10-06T20:22:26.753Z" }, - { url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473, upload-time = "2025-10-06T20:22:27.775Z" }, - { url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855, upload-time = "2025-10-06T20:22:28.799Z" }, - { url = "https://files.pythonhosted.org/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022, upload-time = "2025-10-06T20:22:29.981Z" }, - { url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736, upload-time = "2025-10-06T20:22:30.996Z" }, - { url = "https://files.pythonhosted.org/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908, upload-time = "2025-10-06T20:22:32.073Z" }, - { url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706, upload-time = "2025-10-06T20:22:33.385Z" }, - { url = "https://files.pythonhosted.org/packages/af/df/c7891ef9d2712ad774777271d39fdef63941ffba0a9d59b7ad1fd2765e57/tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71", size = 920667, upload-time = "2025-10-06T20:22:34.444Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/e4/e5/5f3cb2159769d0f4324c0e9e87f9de3c4b1cd45848a96b2eb3566ad5ca77/tiktoken-0.13.0.tar.gz", hash = "sha256:c9435714c3a84c2319499de9a300c0e604449dd0799ff246458b3bb6a7f433c1", size = 38986, upload-time = "2026-05-15T04:51:27.153Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/e3/03c90dadcf5b3f82b83cee9adee60ef666b329c654f58c066af44eae0287/tiktoken-0.13.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:47b1df8d73390a24f94980c75158cdd5c56d256f16d55f30cb49c230caba9ba4", size = 1036627, upload-time = "2026-05-15T04:50:11.229Z" }, + { url = "https://files.pythonhosted.org/packages/5e/30/760463e5b2e8ad2bc229ae0a17ecb06727b6cbc094f08d8f65844315632e/tiktoken-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7d40c6c5aab171dcd6eb8455bc567bde404bb9def60cdb8c1299cc782b242bb9", size = 984699, upload-time = "2026-05-15T04:50:12.874Z" }, + { url = "https://files.pythonhosted.org/packages/de/8a/8895f342a6b6aabd1a358e672f6f077b3ae51d0c63ca605d142db3bcd8ab/tiktoken-0.13.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:9b842981fa91accdffd48ff6408a977b7a91c3fbda55d353c3c68114d5c9d69e", size = 1118690, upload-time = "2026-05-15T04:50:14.234Z" }, + { url = "https://files.pythonhosted.org/packages/51/e0/92557768fb0801f0d9dd9243cb9b6d342900b05e4b1006d4771f49ce233e/tiktoken-0.13.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ed5a30027cb4d8c7ca8b273d4766f3db3cf58fad9e9f3b1a68a351ffb54873d5", size = 1138423, upload-time = "2026-05-15T04:50:15.668Z" }, + { url = "https://files.pythonhosted.org/packages/8f/b9/a3d99feeedb032ffd09cd6652077f86bdee9a70dd0b990b2b272b445d4c3/tiktoken-0.13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7ab10f4a21c2999846940113f6dbd72e0fa06a24119feddd74cc47e85818e06d", size = 1185077, upload-time = "2026-05-15T04:50:17.19Z" }, + { url = "https://files.pythonhosted.org/packages/cc/93/bab868277d475dc6d2aaacd34cdd239c282f4908dcc8702e0a3311a8e032/tiktoken-0.13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:a2937ad042d49d50eac6e1ba07c5661d4bd3942a5b1e0c0d08475c4df83676e1", size = 1241702, upload-time = "2026-05-15T04:50:18.772Z" }, + { url = "https://files.pythonhosted.org/packages/c3/16/27e9f7e0ed76e501cfefc9fb2112df4c7bf70ca96945b15ecb7615aac860/tiktoken-0.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:44733b99bfd72b590cd0936b1c01b3b4dd73122db2d544bc1ceeb18a7678c910", size = 876565, upload-time = "2026-05-15T04:50:20.268Z" }, + { url = "https://files.pythonhosted.org/packages/1a/4c/1bc81f4cd53e827c4ee67ca951b5935724716049452d8dfa09b8b82372bb/tiktoken-0.13.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:7bfe1849caa65d1e1d9871817170ec497bbb7984e182012e1bdce72f66608cdb", size = 1036353, upload-time = "2026-05-15T04:50:21.757Z" }, + { url = "https://files.pythonhosted.org/packages/75/91/10b9c7076bc02c246c853201fdbbe300a4b8c5ed7b84c25f7403f4e32655/tiktoken-0.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:91c180fe255bd5a86d8316210d2833a1d4d33d026cd86a67812f4773743c8d26", size = 984644, upload-time = "2026-05-15T04:50:23.256Z" }, + { url = "https://files.pythonhosted.org/packages/4e/e4/fceae98015fab47fcd49b8bd7f46145bcd187a47e0add1e5378ed67ef980/tiktoken-0.13.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:059c8ecf554eb5b41e6e054ba467b871b03277d267dee7244380aca4359747d4", size = 1119261, upload-time = "2026-05-15T04:50:24.348Z" }, + { url = "https://files.pythonhosted.org/packages/f9/39/fe42ad00de01a8c4a49ad8649a2c8a316835a9cad5961b11d21eac0020a5/tiktoken-0.13.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:36217497eaffc158607a3b26f065300db2aefd43b115263f3b9688ce38146173", size = 1138253, upload-time = "2026-05-15T04:50:25.505Z" }, + { url = "https://files.pythonhosted.org/packages/03/c4/ccee1ecccca107e9a16efcecdeeb964c325305038554d466ece65b42338f/tiktoken-0.13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:303f7d91b4fce3baddbcde05c139091d4caa5026ac7214c1dc7ff7a71ee429ff", size = 1185747, upload-time = "2026-05-15T04:50:27.02Z" }, + { url = "https://files.pythonhosted.org/packages/9d/03/cd0cba295522b91eb55c6b2704f1df895f8226cfe60ab10d4d51d0cc9e69/tiktoken-0.13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5d48843bee149630eb735a99e1f4a85b47308d21868ea63163f6e87768d3cfed", size = 1241265, upload-time = "2026-05-15T04:50:28.815Z" }, + { url = "https://files.pythonhosted.org/packages/7e/25/a10efd564402d82c2ff50d12057353ace447aa8007deceaa48641f63d35c/tiktoken-0.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:fc1c44cd37b43fc46bae593129164f4f281e82ea116b57a85aa81bda57eafc94", size = 876509, upload-time = "2026-05-15T04:50:30.026Z" }, + { url = "https://files.pythonhosted.org/packages/85/8e/144bde4e01df66b34bb865557c7cd754ed08b036217ebd79c9db5e9048a9/tiktoken-0.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:32ac870a806cfb260a02d0cb70426aef02e038297f8ad50df5040bb5af360791", size = 1034888, upload-time = "2026-05-15T04:50:31.579Z" }, + { url = "https://files.pythonhosted.org/packages/36/18/d4ac9d20956cdebca04841316660ed584c2fecdc2b81722a28bc7ad3b1e4/tiktoken-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d9980f11429ed2d737c463bb1fb78cf330caa026adf002f714aced7849a687b", size = 982970, upload-time = "2026-05-15T04:50:32.961Z" }, + { url = "https://files.pythonhosted.org/packages/74/ed/6bb8d05b9f731f749fee5c6f5ca63e981143c826a5985877330507bd13b7/tiktoken-0.13.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3f277ebea5edd7b8bf03c6f9431e1d67d517530115572b2dc1d465326e8f88c7", size = 1115741, upload-time = "2026-05-15T04:50:34.475Z" }, + { url = "https://files.pythonhosted.org/packages/34/de/2ca96b07a82d972b74fe4b46de055b79c904e45c7eab699354a0bfa697dc/tiktoken-0.13.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:a116178fa7e1b4065bff05214360373a65cac22f965be7b3f73d00a0dbfe7649", size = 1136523, upload-time = "2026-05-15T04:50:35.782Z" }, + { url = "https://files.pythonhosted.org/packages/ee/dc/9dafec002c2d4424378563cf4cf5c7fb93631d2a55013c8b87554ee4012c/tiktoken-0.13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2c397ddda233208345b01bd30f2fca79ff730e55731d0108a603f9bc57f6af3b", size = 1181954, upload-time = "2026-05-15T04:50:36.99Z" }, + { url = "https://files.pythonhosted.org/packages/a1/d0/1f8578c45b2f24759b46f0b50d31878c63c73e6bf0f2227e10ec5c5408dc/tiktoken-0.13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:95097e4f89b06403976e498abf61a0ee73a7497e73fb599cb211d8197a054d91", size = 1240069, upload-time = "2026-05-15T04:50:38.221Z" }, + { url = "https://files.pythonhosted.org/packages/aa/90/28d7f154888610aa9237e541986beb62b479df29d193a5a0617dbb1514d0/tiktoken-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8f2d16e7a7c783ad81f36e457d046d1f1c8af70b22aec8a13238efe531977c41", size = 874748, upload-time = "2026-05-15T04:50:39.587Z" }, + { url = "https://files.pythonhosted.org/packages/9c/83/b096c859c2a47c11731bf2f5885f4028b809dfe2396582883eed9cae372f/tiktoken-0.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5df5d1507bd245f1ccad4a074698240021239e455eb0bb4ced4e3d7181872154", size = 1034228, upload-time = "2026-05-15T04:50:40.988Z" }, + { url = "https://files.pythonhosted.org/packages/53/61/c68e123b6d753e3fc2751e9b18e732c9d8bf1e1926762e736eee935d931c/tiktoken-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8fe806a50664e83a6ffd56cbd1e4f5dcc6cd32a3e7538f70dc38b1a271384545", size = 982978, upload-time = "2026-05-15T04:50:42.195Z" }, + { url = "https://files.pythonhosted.org/packages/ef/8b/96cc178cc584e65d363134500f297790b06cd48cdeb1e8fcf7bbe60f4715/tiktoken-0.13.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:125bc05005e747f993a83dc67934249932d6e4209854452cd4c0b1d53fba3ba2", size = 1116355, upload-time = "2026-05-15T04:50:43.564Z" }, + { url = "https://files.pythonhosted.org/packages/86/f5/bab735d2c72ea55404b295d02d092644eb5f7cc6205e34d35eb9abfb9ab2/tiktoken-0.13.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:5e6358911cab4adee6712da27d65573496a4f68cf8a2b5fca6a4ad10fc5748cf", size = 1135772, upload-time = "2026-05-15T04:50:44.782Z" }, + { url = "https://files.pythonhosted.org/packages/4e/b9/6de04ebdf904edfaad87788011b3735087a0c9ea671b9027e1e4e965e8c8/tiktoken-0.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:975cbd78d085d75d26b59660e262736dcaed1e35f8f142cd6291025c01d25486", size = 1182415, upload-time = "2026-05-15T04:50:46.422Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9c/470a05f3b1caf038f44880e334d47ab674e0c80d514c66b375d14d5afa10/tiktoken-0.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:75ab9bc99fa020a4c283424590ecd7f3afd70c1c281cb3fa3192a6c3af9f9615", size = 1239879, upload-time = "2026-05-15T04:50:48.052Z" }, + { url = "https://files.pythonhosted.org/packages/42/a6/c1936d16055436cb32e6c6128d68629622e00f4768562f55653752d34768/tiktoken-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:6b1615f0ff71953d19729ceb18865429c185b0a23c5353f1bbca34a394bf60f7", size = 874829, upload-time = "2026-05-15T04:50:49.202Z" }, + { url = "https://files.pythonhosted.org/packages/d6/07/acb5992c3772b5a36284f742cfb7a5895aa4471d1848ac31464ad50d7fdf/tiktoken-0.13.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6eb4a5bfbc6426938026b1a334e898ac53541360d62d8c689870160cc80abd67", size = 1033600, upload-time = "2026-05-15T04:50:50.4Z" }, + { url = "https://files.pythonhosted.org/packages/14/e9/742e9aec30f59b9f161f7ff7cd072e02ea836c9e1c0854a8076dfcd40d5c/tiktoken-0.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:43cee3e5400573b2046fbf092cc7a5bc30164f9e4c95ce20714da929df48737a", size = 982516, upload-time = "2026-05-15T04:50:52.03Z" }, + { url = "https://files.pythonhosted.org/packages/72/74/ca1541b053e7648254d2e4b42a253e1bb4359f2c91a0a8d49228c794e1a0/tiktoken-0.13.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:7de52e3f566d19b3b11bd37eea552c6c305ad74081f736882bd44d148ed4c48d", size = 1115518, upload-time = "2026-05-15T04:50:53.543Z" }, + { url = "https://files.pythonhosted.org/packages/46/e3/93825eaf5a4a504795b787e5d5dea07fbeb3dabf97aa7b450be8bde59c89/tiktoken-0.13.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:51384448aa508e4df84c0f7c1dc3211c7f7b8096325660ee5fc82f3e11b381ce", size = 1136867, upload-time = "2026-05-15T04:50:55.191Z" }, + { url = "https://files.pythonhosted.org/packages/8c/46/002b68de6827091d5ae90b048f326e8aad8d953520950e5ce1508879414f/tiktoken-0.13.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e28157350f7ebf35008dd8e9e0fdb621f976e4230c881099c85e8cf07eaa50e2", size = 1181826, upload-time = "2026-05-15T04:50:56.296Z" }, + { url = "https://files.pythonhosted.org/packages/db/c6/d393e3185a276505182f7abd93fe714f3c444a2be9180798fa052347504e/tiktoken-0.13.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:165cf1820ea4a354985c2490a5205d4cc74661c934aca79dd0368232fff94e0f", size = 1239489, upload-time = "2026-05-15T04:50:57.918Z" }, + { url = "https://files.pythonhosted.org/packages/b7/4d/bc07d1f1635d4897a202acc0ae11c2886eaa7325c359ba4741b47bf8e225/tiktoken-0.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6c43a675ca14f6f2749ba7f12075d37456015a24b859f2517b9beb4ef30807ec", size = 873820, upload-time = "2026-05-15T04:50:59.528Z" }, + { url = "https://files.pythonhosted.org/packages/8c/93/0dd6adca026a616c3a92974566b43381eea4b475ce1f36c062b8271a9ac5/tiktoken-0.13.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:eaaaef47c2406277181d2086484c317bf7fc433e2d5d03ff94f56b0dcec87471", size = 1034977, upload-time = "2026-05-15T04:51:00.957Z" }, + { url = "https://files.pythonhosted.org/packages/d9/77/5ec6e6bc5b30bed6d93f7f2162d8f6b32437b3ba27cb527cfe004f6109c9/tiktoken-0.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ca8b310bd93b3772cb1b7922d915446864860f562bdfe4825c63a0aed3fb28cd", size = 983635, upload-time = "2026-05-15T04:51:02.629Z" }, + { url = "https://files.pythonhosted.org/packages/94/b0/c8ae9aff00d625c50659b4513e707a0462c4bf5d4d6cc1b802103225c02e/tiktoken-0.13.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:32e0c12305105002c047b3bb1070b0dd9a73b0cb3b2856a8972b810e7a4f5881", size = 1116036, upload-time = "2026-05-15T04:51:04.082Z" }, + { url = "https://files.pythonhosted.org/packages/1b/ac/6a5dddd1d0a6018ecb389bd0353e6b4a515eb4d2286611bd0ace1937b9e1/tiktoken-0.13.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:5ba5fd62507a932d1241346179e3b39bc7bf7408f03c272652d93b3bedf5db24", size = 1135544, upload-time = "2026-05-15T04:51:05.229Z" }, + { url = "https://files.pythonhosted.org/packages/f4/b8/585032b4384b2f7dcdaddcb52865c83a701a420d09e3c2b4a2be1c450c57/tiktoken-0.13.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d108bc2d470fc53c8ecd24f2c0fd2b5f98c33e87cdb6aa2e9b8c5dced703d273", size = 1182217, upload-time = "2026-05-15T04:51:06.517Z" }, + { url = "https://files.pythonhosted.org/packages/cd/b6/993ff1ded3958215fd341a847b8e5ffeb5de473f435296870d314fc91ac4/tiktoken-0.13.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:cb99cb5127449f58d0a2d5f5ccfb390d8dbdfd919c221246caaee29d8725ed51", size = 1239404, upload-time = "2026-05-15T04:51:07.843Z" }, + { url = "https://files.pythonhosted.org/packages/dd/3d/fef7e06e3b33e7538db0ced734cf9fe23b6832d2ac4990c119c377aec55e/tiktoken-0.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:115c4f26ffa11caac8b54eea35c2ad38c612c20a48d35dd15d70a02ac6f51f58", size = 918686, upload-time = "2026-05-15T04:51:08.925Z" }, + { url = "https://files.pythonhosted.org/packages/c1/82/a7fc44582bc32ab00de988a2299bf77c077f59068b233109e34b7d6ca7e6/tiktoken-0.13.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:472527e9132952f2fbf77cd290658bacf003d4d5a3fabc18e5fbd407cbae4d9b", size = 1034454, upload-time = "2026-05-15T04:51:10.035Z" }, + { url = "https://files.pythonhosted.org/packages/37/d0/24d8a890c14f432a05cea669c17bebeaa99f96a7c79523b590f564246411/tiktoken-0.13.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4e2f67d27c9626cdd25fe33d9313c5cdb3d8d82da646b68d6eb8e7e9c20e6448", size = 982976, upload-time = "2026-05-15T04:51:11.23Z" }, + { url = "https://files.pythonhosted.org/packages/49/b7/2ab43f62788a9266187a9bfc1d3af99ad83e5eaa25fbef168a69cd5ad14f/tiktoken-0.13.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2b920b35805cd64585a37c3dc7ce65fba4d2d36016be01e1d7942482ca29093a", size = 1115526, upload-time = "2026-05-15T04:51:12.608Z" }, + { url = "https://files.pythonhosted.org/packages/64/39/1494321ed323ce7a14d88e3cd6cb9058625977df1c6961ddc492bd10a9f3/tiktoken-0.13.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:493af3aa28a4aaf2e3d2600a2ee717252c9bf5ab38fff94eb5a02db5ab77e5ad", size = 1136466, upload-time = "2026-05-15T04:51:13.926Z" }, + { url = "https://files.pythonhosted.org/packages/96/d9/dfd086aa2d918c563a140720e0ce296cada1634efd2783d5cf51e05f984e/tiktoken-0.13.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6644c9c2b5cf3916f5a3641d7d12fdb3f006a7b3d9ff6acdaec44e29ab1ff91e", size = 1181863, upload-time = "2026-05-15T04:51:15.025Z" }, + { url = "https://files.pythonhosted.org/packages/2f/68/a18b4f307086954fdae32714cb4f85562e34f9d34ab206e61f1816aa6018/tiktoken-0.13.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5cb65b60b9408563676d874a3a4ee573370066f0dc4e29d84e82e989c6517424", size = 1239218, upload-time = "2026-05-15T04:51:16.103Z" }, + { url = "https://files.pythonhosted.org/packages/16/5b/f2aa703a4fc5d2dff73460a7d46cc2f3f44aa0f3dd8eeb20d2a0ecf68862/tiktoken-0.13.0-cp314-cp314t-win_amd64.whl", hash = "sha256:85b78cc3a2c3d48723ca751fa981f1fedccd54194ca0471b957364353a898b07", size = 918110, upload-time = "2026-05-15T04:51:17.237Z" }, ] [[package]] @@ -4425,7 +4425,7 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "psutil" }, { name = "pyparsing" }, { name = "requests" }, @@ -4443,7 +4443,7 @@ version = "0.0.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "torch", marker = "sys_platform == 'never'" }, { name = "torchvision" }, ] @@ -4454,43 +4454,43 @@ wheels = [ [[package]] name = "torchvision" -version = "0.26.0" +version = "0.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, { name = "torch", marker = "sys_platform == 'never'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/74/b4/cdfee31e0402ea035135462cb0ab496e974d56fab6b4e7a1f0cbccb8cd28/torchvision-0.26.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a06d4772a8e13e772906ed736cc53ec6639e5e60554f8e5fa6ca165aabebc464", size = 1863503, upload-time = "2026-03-23T18:13:01.384Z" }, - { url = "https://files.pythonhosted.org/packages/e4/74/11fee109841e80ad14e5ca2d80bff6b10eb11b7838ff06f35bfeaa9f7251/torchvision-0.26.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:2adfbe438473236191ff077a4a9a0c767436879c89628aa97137e959b0c11a94", size = 7766423, upload-time = "2026-03-23T18:12:56.049Z" }, - { url = "https://files.pythonhosted.org/packages/5e/00/24d8c7845c3f270153fb81395a5135b2778e2538e81d14c6aea5106c689c/torchvision-0.26.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b6f9ad1ecc0eab52647298b379ee9426845f8903703e6127973f8f3d049a798b", size = 7518249, upload-time = "2026-03-23T18:12:51.743Z" }, - { url = "https://files.pythonhosted.org/packages/d7/ed/e53cd7c0da7ae002e5e929c1796ebbe7ec0c700c29f7a0a6696497fb3d8b/torchvision-0.26.0-cp310-cp310-win_amd64.whl", hash = "sha256:f13f12b3791a266de2d599cb8162925261622a037d87fc03132848343cf68f75", size = 3669784, upload-time = "2026-03-23T18:12:49.949Z" }, - { url = "https://files.pythonhosted.org/packages/b4/bd/d552a2521bade3295b2c6e7a4a0d1022261cab7ca7011f4e2a330dbb3caa/torchvision-0.26.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:55bd6ad4ae77be01ba67a410b05b51f53b0d0ee45f146eb6a0dfb9007e70ab3c", size = 1863499, upload-time = "2026-03-23T18:12:58.696Z" }, - { url = "https://files.pythonhosted.org/packages/33/bf/21b899792b08cae7a298551c68398a79e333697479ed311b3b067aab4bdc/torchvision-0.26.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:1c55dc8affbcc0eb2060fbabbe996ae9e5839b24bb6419777f17848945a411b1", size = 7767527, upload-time = "2026-03-23T18:12:44.348Z" }, - { url = "https://files.pythonhosted.org/packages/9a/45/57bbf9e216850d065e66dd31a50f57424b607f1d878ab8956e56a1f4e36b/torchvision-0.26.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd10b5f994c210f4f6d6761cf686f82d748554adf486cb0979770c3252868c8f", size = 7519925, upload-time = "2026-03-23T18:12:53.283Z" }, - { url = "https://files.pythonhosted.org/packages/10/58/ed8f7754299f3e91d6414b6dc09f62b3fa7c6e5d63dfe48d69ab81498a37/torchvision-0.26.0-cp311-cp311-win_amd64.whl", hash = "sha256:de6424b12887ad884f39a0ee446994ae3cd3b6a00a9cafe1bead85a031132af0", size = 3983834, upload-time = "2026-03-23T18:13:00.224Z" }, - { url = "https://files.pythonhosted.org/packages/ae/e7/56b47cc3b132aea90ccce22bcb8975dec688b002150012acc842846039d0/torchvision-0.26.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c409e1c3fdebec7a3834465086dbda8bf7680eff79abf7fd2f10c6b59520a7a4", size = 1863502, upload-time = "2026-03-23T18:12:57.326Z" }, - { url = "https://files.pythonhosted.org/packages/f4/ec/5c31c92c08b65662fe9604a4067ae8232582805949f11ddc042cebe818ed/torchvision-0.26.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:406557718e62fdf10f5706e88d8a5ec000f872da913bf629aab9297622585547", size = 7767944, upload-time = "2026-03-23T18:12:42.805Z" }, - { url = "https://files.pythonhosted.org/packages/f5/d8/cb6ccda1a1f35a6597645818641701207b3e8e13553e75fce5d86bac74b2/torchvision-0.26.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d61a5abb6b42a0c0c311996c2ac4b83a94418a97182c83b055a2a4ae985e05aa", size = 7522205, upload-time = "2026-03-23T18:12:54.654Z" }, - { url = "https://files.pythonhosted.org/packages/1c/a9/c272623a0f735c35f0f6cd6dc74784d4f970e800cf063bb76687895a2ab9/torchvision-0.26.0-cp312-cp312-win_amd64.whl", hash = "sha256:7993c01648e7c61d191b018e84d38fe0825c8fcb2720cd0f37caf7ba14404aa1", size = 4255155, upload-time = "2026-03-23T18:12:32.652Z" }, - { url = "https://files.pythonhosted.org/packages/da/80/0762f77f53605d10c9477be39bb47722cc8e383bbbc2531471ce0e396c07/torchvision-0.26.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:5d63dd43162691258b1b3529b9041bac7d54caa37eae0925f997108268cbf7c4", size = 1860809, upload-time = "2026-03-23T18:12:47.629Z" }, - { url = "https://files.pythonhosted.org/packages/e6/81/0b3e58d1478c660a5af4268713486b2df7203f35abd9195fea87348a5178/torchvision-0.26.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a39c7a26538c41fda453f9a9692b5ff9b35a5437db1d94f3027f6f509c160eac", size = 7727494, upload-time = "2026-03-23T18:12:46.062Z" }, - { url = "https://files.pythonhosted.org/packages/b6/dc/d9ab5d29115aa05e12e30f1397a3eeae1d88a511241dc3bce48dc4342675/torchvision-0.26.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:b7e6213620bbf97742e5f79832f9e9d769e6cf0f744c5b53dad80b76db633691", size = 7521747, upload-time = "2026-03-23T18:12:36.815Z" }, - { url = "https://files.pythonhosted.org/packages/a9/1b/f1bc86a918c5f6feab1eeff11982e2060f4704332e96185463d27855bdf5/torchvision-0.26.0-cp313-cp313-win_amd64.whl", hash = "sha256:4280c35ec8cba1fcc8294fb87e136924708726864c379e4c54494797d86bc474", size = 4319880, upload-time = "2026-03-23T18:12:38.168Z" }, - { url = "https://files.pythonhosted.org/packages/66/28/b4ad0a723ed95b003454caffcc41894b34bd8379df340848cae2c33871de/torchvision-0.26.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:358fc4726d0c08615b6d83b3149854f11efb2a564ed1acb6fce882e151412d23", size = 1951973, upload-time = "2026-03-23T18:12:48.781Z" }, - { url = "https://files.pythonhosted.org/packages/71/e2/7a89096e6cf2f3336353b5338ba925e0addf9d8601920340e6bdf47e8eb3/torchvision-0.26.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:3daf9cc149cf3cdcbd4df9c59dae69ffca86c6823250442c3bbfd63fc2e26c61", size = 7728679, upload-time = "2026-03-23T18:12:26.196Z" }, - { url = "https://files.pythonhosted.org/packages/69/1d/4e1eebc17d18ce080a11dcf3df3f8f717f0efdfa00983f06e8ba79259f61/torchvision-0.26.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:82c3965eca27e86a316e31e4c3e5a16d353e0bcbe0ef8efa2e66502c54493c4b", size = 7609138, upload-time = "2026-03-23T18:12:35.327Z" }, - { url = "https://files.pythonhosted.org/packages/f3/a4/f1155e943ae5b32400d7000adc81c79bb0392b16ceb33bcf13e02e48cced/torchvision-0.26.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ebc043cc5a4f0bf22e7680806dbba37ffb19e70f6953bbb44ed1a90aeb5c9bea", size = 4248202, upload-time = "2026-03-23T18:12:41.423Z" }, - { url = "https://files.pythonhosted.org/packages/7f/c8/9bffa9c7f7bdf95b2a0a2dc535c290b9f1cc580c3fb3033ab1246ffffdeb/torchvision-0.26.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:eb61804eb9dbe88c5a2a6c4da8dec1d80d2d0a6f18c999c524e32266cb1ebcd3", size = 1860813, upload-time = "2026-03-23T18:12:39.636Z" }, - { url = "https://files.pythonhosted.org/packages/7b/ac/48f28ffd227991f2e14f4392dde7e8dc14352bb9428c1ef4a4bbf5f7ed85/torchvision-0.26.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:9a904f2131cbfadab4df828088a9f66291ad33f49ff853872aed1f86848ef776", size = 7727777, upload-time = "2026-03-23T18:12:22.549Z" }, - { url = "https://files.pythonhosted.org/packages/a4/21/a2266f7f1b0e58e624ff15fd6f01041f59182c49551ece0db9a183071329/torchvision-0.26.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:0f3e572efe62ad645017ea847e0b5e4f2f638d4e39f05bc011d1eb9ac68d4806", size = 7522174, upload-time = "2026-03-23T18:12:29.565Z" }, - { url = "https://files.pythonhosted.org/packages/fc/ba/1666f90bc0bdd77aaa11dcc42bb9f621a9c3668819c32430452e3d404730/torchvision-0.26.0-cp314-cp314-win_amd64.whl", hash = "sha256:114bec0c0e98aa4ba446f63e2fe7a2cbca37b39ac933987ee4804f65de121800", size = 4348469, upload-time = "2026-03-23T18:12:24.44Z" }, - { url = "https://files.pythonhosted.org/packages/45/8f/1f0402ac55c2ae15651ff831957d083fe70b2d12282e72612a30ba601512/torchvision-0.26.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:b7d3e295624a28b3b1769228ce1345d94cf4d390dd31136766f76f2d20f718da", size = 1860826, upload-time = "2026-03-23T18:12:34.1Z" }, - { url = "https://files.pythonhosted.org/packages/d2/6a/18a582fe3c5ee26f49b5c9fb21ad8016b4d1c06d10178894a58653946fda/torchvision-0.26.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:7058c5878262937e876f20c25867b33724586aa4499e2853b2d52b99a5e51953", size = 7729089, upload-time = "2026-03-23T18:12:31.394Z" }, - { url = "https://files.pythonhosted.org/packages/c5/9b/f7e119b59499edc00c55c03adc9ec3bd96144d9b81c46852c431f9c64a9a/torchvision-0.26.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:8008474855623c6ba52876589dc52df0aa66e518c25eca841445348e5f79844c", size = 7522704, upload-time = "2026-03-23T18:12:20.301Z" }, - { url = "https://files.pythonhosted.org/packages/d0/6a/09f3844c10643f6c0de5d95abc863420cfaf194c88c7dffd0ac523e2015f/torchvision-0.26.0-cp314-cp314t-win_amd64.whl", hash = "sha256:e9d0e022c19a78552fb055d0414d47fecb4a649309b9968573daea160ba6869c", size = 4454275, upload-time = "2026-03-23T18:12:27.487Z" }, + { url = "https://files.pythonhosted.org/packages/13/15/2df874db140bbfe42f377e05e2dd38f2b9dc88414a6607eecc42073b2baa/torchvision-0.27.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:0822b58d2c5d325cd0c7152b744acbd15f898c07572e2cfb70b075a865a4f6f9", size = 1758817, upload-time = "2026-05-13T14:57:20.113Z" }, + { url = "https://files.pythonhosted.org/packages/f7/32/10b1ff4087d35b7af7bd85ccb85fbc2573c6f1c2008cf8abfcaf605a10fc/torchvision-0.27.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c9f44e35e6ec01caedacce9e941a5bf21fe424403321efac2507a201273653c5", size = 7830083, upload-time = "2026-05-13T14:57:18.336Z" }, + { url = "https://files.pythonhosted.org/packages/57/20/97dca91770235028ba7e9c598ca1fc48c297f1843af8102430f2adcd4335/torchvision-0.27.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:419c98a9275b27660cdce6d09080fd5974d1ec1d4a225f71439ebacb3b0c4e64", size = 7573816, upload-time = "2026-05-13T14:57:12.327Z" }, + { url = "https://files.pythonhosted.org/packages/37/a5/66fbf7f21f292d095a153ee142806646813e2055a69efe5854c28e7c3fb9/torchvision-0.27.0-cp310-cp310-win_amd64.whl", hash = "sha256:2664d06acd64d328aa7689b0d0c81ee31e240e9977d8768816b4be7c66c03211", size = 3435489, upload-time = "2026-05-13T14:57:13.716Z" }, + { url = "https://files.pythonhosted.org/packages/cf/d6/a7e71e981042d5c573e2e61891b9023b190c88adb75b18bed8594371250c/torchvision-0.27.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:df0c166b6bdf7c47f88e81e8b43bc085451d5c50d0c5d1691bc474c1227d6fed", size = 1758812, upload-time = "2026-05-13T14:57:16.662Z" }, + { url = "https://files.pythonhosted.org/packages/93/f9/f542fb7e4476603fb237ebdc64369a7d11f18eb5a129aa2559cbdb710aee/torchvision-0.27.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9bb9251f64b854124efed95d02953a89f7e2726c3ca662d7ea0151129157297f", size = 7831148, upload-time = "2026-05-13T14:57:08.37Z" }, + { url = "https://files.pythonhosted.org/packages/f6/61/7aa7cc2c9e8750027f6fb9ae3a7393ef43860bcdfe3966e2f71fee800e31/torchvision-0.27.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f44453f107c296d5446a79f7ac59733ad8bf5ddfa04c53805dfbae298a42a798", size = 7575519, upload-time = "2026-05-13T14:56:50.552Z" }, + { url = "https://files.pythonhosted.org/packages/19/aa/929b358b1a643849b81ec95569938044cc37dc65ab10c84eb6d82fe1bfbb/torchvision-0.27.0-cp311-cp311-win_amd64.whl", hash = "sha256:b4aacff70ea4b7377f996f9048989c850d221fef33658ddbcae42aa5bd4ca11a", size = 3749475, upload-time = "2026-05-13T14:57:11.007Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c8/5cd91932f7f3671b0743dc4ae1a4c16b1d0b45bf4087976277d325bda718/torchvision-0.27.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:1a6dd742a150645126df9e0b2e449874c1d635897c773b322c2e067e98382dfe", size = 1758824, upload-time = "2026-05-13T14:57:15.227Z" }, + { url = "https://files.pythonhosted.org/packages/d9/36/7fb7d19477b3d93283b52fea11fa8ee30ab9064a08c97b4a6b91445e26cb/torchvision-0.27.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65772ff3ec4f4f5d680e30019835555dd239e7fefee4b0a846375fe1cb1592ef", size = 7831034, upload-time = "2026-05-13T14:57:06.483Z" }, + { url = "https://files.pythonhosted.org/packages/62/43/dfd894c3f8b01b5b33fde990f0159c1926ebc7b6e2c4193e2efb7da3c4cb/torchvision-0.27.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7a9966a088d06b4cf6c610e03be62de469efa6f2cd2e7c7eed8e925ed6af59ac", size = 7579774, upload-time = "2026-05-13T14:56:59.337Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0c/722e989f9cf026e97ef7cb24a9bb1859e099f72d247ae35388fb89729f73/torchvision-0.27.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c037709072ca9b19750c0cbe9e8bb6f91c9a1be1befa26df33e281deccbd8c7", size = 4021073, upload-time = "2026-05-13T14:57:00.848Z" }, + { url = "https://files.pythonhosted.org/packages/d8/ae/36547812e6e047c1d80bcacd1b17a340612b08a6e876e0aabf3d0b9228b0/torchvision-0.27.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:41d6dae73e1af09fa82ded597ae57f2a2314285acde54b25890a8f8e51b999d7", size = 1758826, upload-time = "2026-05-13T14:57:05.262Z" }, + { url = "https://files.pythonhosted.org/packages/ae/30/32c4ea842738728a14e3df8c576c62dedcf5ae5cb6a5c984c6429ebe7524/torchvision-0.27.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:70f071c6f74b60d5fe8851636d8d4cd5f4fa29d57fd9348a87a6f17b990b95ba", size = 7789501, upload-time = "2026-05-13T14:56:57.786Z" }, + { url = "https://files.pythonhosted.org/packages/f6/24/4d0d48684251bd0673f87d633d5d88ab00227983b00591156eed2f86c8d5/torchvision-0.27.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:aaafa6962c9d91f42503de1957d6fa349907d028c06f335bd95da7a5bc57147d", size = 7579868, upload-time = "2026-05-13T14:56:41.618Z" }, + { url = "https://files.pythonhosted.org/packages/ba/da/e6edd051d2ba25adf23b120fa97f458dff888d098c51e84724f17d2d1470/torchvision-0.27.0-cp313-cp313-win_amd64.whl", hash = "sha256:aee384a2782c89517c4ab9061d2720ba59fd2ffe5ef89d0a149cc2d43abdf521", size = 4092700, upload-time = "2026-05-13T14:57:09.729Z" }, + { url = "https://files.pythonhosted.org/packages/fa/23/95dfa40431360f42ca949bf861434bed51164adfa8fb9801e05bf3194f50/torchvision-0.27.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:c5121f1b9ab09a7f73e837871deb8321551f7eaeb19d87aa00de9191968eae44", size = 1845008, upload-time = "2026-05-13T14:57:03.768Z" }, + { url = "https://files.pythonhosted.org/packages/23/b9/9dbdf76b2b49a75ba8088df6f7c755bdb520afb6c6dbac0102b46cde5e99/torchvision-0.27.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:1c01f0d1091ae22b9dfc082b0a0fe5faaf053686a29b4fb082ba7691375c73cf", size = 7791430, upload-time = "2026-05-13T14:56:56.206Z" }, + { url = "https://files.pythonhosted.org/packages/5c/6a/e4a16cf2f3310c2ea7760dc5d9054496844391e0f4c1fae87fefac2f3d9e/torchvision-0.27.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:dadea3c5ecfd05bbb2a3312ab0374f213c58bf6459cb059122e2f4dfe13d10ed", size = 7668441, upload-time = "2026-05-13T14:57:02.127Z" }, + { url = "https://files.pythonhosted.org/packages/00/70/01b6461117a6a94b5af3f8ee166bb0f045056f3cf187750c110dabfdfffa/torchvision-0.27.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a49e55055a39a8506fe7e59850522cab004efb2c3839f6057658889c1d69c815", size = 4141602, upload-time = "2026-05-13T14:56:53.449Z" }, + { url = "https://files.pythonhosted.org/packages/92/22/c0633677b3b3f3e69554a21ac087bf705f829c40cd5e3783507b8c006681/torchvision-0.27.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:c1fac0fc2a7adf29481fc1938a0e7845c57ba1147a986784109c4d98f434ea8c", size = 1758818, upload-time = "2026-05-13T14:56:54.988Z" }, + { url = "https://files.pythonhosted.org/packages/48/e8/55f9d9667b56dae470e69e31beac9b00d458ea393feec1aae95cc4f3f1c9/torchvision-0.27.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:cbf89764fc76f3f17fbf80c12d5a89c691e91cb9d82c38412aaf0568655ffb19", size = 7789667, upload-time = "2026-05-13T14:56:48.858Z" }, + { url = "https://files.pythonhosted.org/packages/00/bc/6f8681daf3bbc4c315bb0005110f99d28e3ecd675bf9c8f2c0d393fbac7a/torchvision-0.27.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:91f61b9865423037c327eb56afa207cc72de874e458c361840db9dcf5ce0c0eb", size = 7579848, upload-time = "2026-05-13T14:56:38.209Z" }, + { url = "https://files.pythonhosted.org/packages/19/6c/8d8020e6bd1e46c53e487c9c4e9457a07f2ee28931028fb5d71e2da40adc/torchvision-0.27.0-cp314-cp314-win_amd64.whl", hash = "sha256:5bb82fc3c55daf1788621e504310b0a286f1069627a8742f692aebb075ef25a7", size = 4119284, upload-time = "2026-05-13T14:56:46.625Z" }, + { url = "https://files.pythonhosted.org/packages/8d/7e/e78c48662a8d551606efdbe11c6b9c1d6d2391b92cd0e4591b9e6a2412b8/torchvision-0.27.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:2c4099a15150143b9b034730b404a56d572efe0b79489b4c765d929cb4eac7f3", size = 1758828, upload-time = "2026-05-13T14:56:52.293Z" }, + { url = "https://files.pythonhosted.org/packages/21/dd/d03ee9f9ee7bf11a8c7c776fb8e7fd6102f59c013791a2a4e5175bd6cba7/torchvision-0.27.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b4c6bb0a670dcba017b3643e21902c9b8a1cc1c127d602f1488fa29ec3c6e865", size = 7790618, upload-time = "2026-05-13T14:56:44.721Z" }, + { url = "https://files.pythonhosted.org/packages/39/08/4002336a74742be70728603ec1769feb2b55e0d19c532c9ec9f92008de76/torchvision-0.27.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1c2db4bde82bc48ebff73436a6adf34d4f809448268a70d9a1285f5c8f92313d", size = 7580217, upload-time = "2026-05-13T14:56:43.274Z" }, + { url = "https://files.pythonhosted.org/packages/ed/cb/4dd4783eb3565f526ba6e64b6f6ca26c00eacc924cdfe60455db9d91b84b/torchvision-0.27.0-cp314-cp314t-win_amd64.whl", hash = "sha256:72bf547e58ddb948689734eed6f4b6a2031f979dba4fb08e3690688b392e929f", size = 4226392, upload-time = "2026-05-13T14:56:40.235Z" }, ] [[package]] @@ -4512,7 +4512,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "regex" }, @@ -4528,14 +4528,14 @@ wheels = [ [[package]] name = "typeguard" -version = "4.5.1" +version = "4.5.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2b/e8/66e25efcc18542d58706ce4e50415710593721aae26e794ab1dec34fb66f/typeguard-4.5.1.tar.gz", hash = "sha256:f6f8ecbbc819c9bc749983cc67c02391e16a9b43b8b27f15dc70ed7c4a007274", size = 80121, upload-time = "2026-02-19T16:09:03.392Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/1c/dfba5c4633cafc4c701f237d2ba63b416805047fd6d96aab4cfc40969f98/typeguard-4.5.2.tar.gz", hash = "sha256:5a16dcac23502039299c97c8941651bc33d7ea8cc4b2f7d6bbb1b528f6eea423", size = 80240, upload-time = "2026-05-14T12:59:40.857Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/91/88/b55b3117287a8540b76dbdd87733808d4d01c8067a3b339408c250bb3600/typeguard-4.5.1-py3-none-any.whl", hash = "sha256:44d2bf329d49a244110a090b55f5f91aa82d9a9834ebfd30bcc73651e4a8cc40", size = 36745, upload-time = "2026-02-19T16:09:01.6Z" }, + { url = "https://files.pythonhosted.org/packages/5b/29/74eeb4d3f3ae61ca096b018ad486b3b3c74b17bec09ab4edab721cbefec3/typeguard-4.5.2-py3-none-any.whl", hash = "sha256:fcf9de18bd945cdb4c7b996e12b4c51ce83f92f191314a6d7cf1739586ec98cf", size = 36748, upload-time = "2026-05-14T12:59:39.473Z" }, ] [[package]] @@ -4594,47 +4594,47 @@ wheels = [ [[package]] name = "uv" -version = "0.11.13" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4d/3c/463dc85baffc8dda4183b31ba2546204740c0cbac5c01d3671c4eb52819c/uv-0.11.13.tar.gz", hash = "sha256:c30889b6a4417f94a0315371ec5bf8af151f062406ad3fb4b2cbf13d645d825c", size = 4124451, upload-time = "2026-05-11T01:37:54.367Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/82/e6/78a0092e303dd8edf5a3ea74442b17b2ed8c1e9f82e97c7359045cefccdc/uv-0.11.13-py3-none-linux_armv6l.whl", hash = "sha256:4e56623a9ff6d7372290963cd21777bcb52aacbff6619d58a2659ee8240f8fed", size = 23545030, upload-time = "2026-05-11T01:38:23.367Z" }, - { url = "https://files.pythonhosted.org/packages/60/7e/e48c24814e5a2cbf2bb9ccf55d9327813fe3074ada9526851914663dc380/uv-0.11.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:72ad50ae5ce446f6887be842adffd1770b8e138caccc972f333915e524b323ac", size = 23076867, upload-time = "2026-05-11T01:38:02.308Z" }, - { url = "https://files.pythonhosted.org/packages/66/f6/0dcbc43f83e90626981a10b179769b25c0a218717a4331f928c26b6e13a2/uv-0.11.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e5913805ee60b4e331dd7322ae95e18ceb110f6a5baae608d71a532ed1115e75", size = 21710719, upload-time = "2026-05-11T01:37:47.115Z" }, - { url = "https://files.pythonhosted.org/packages/12/c7/348575ae1ea6f312860915a60c1c7c4cf591339164ee321824ba9143a2c4/uv-0.11.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:eb4fe81624bc92894c59aaf88a57cb1fcaf7da95dc3cf2ef1ed86847f0a7e9f6", size = 23300489, upload-time = "2026-05-11T01:37:57.718Z" }, - { url = "https://files.pythonhosted.org/packages/31/3c/78a8afbb98a50db65f4096025bbeff7aac67af8a4d3329f4f9bd8b5acc42/uv-0.11.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:64dd1c36893d0da363d4c8e91c5d554d01a30061c83302eb93c75ca91b0f7eb3", size = 23077624, upload-time = "2026-05-11T01:37:43.197Z" }, - { url = "https://files.pythonhosted.org/packages/aa/30/d68cfdcaa88ad5a2bd1b149818ef51d970518ddd39001dd62ff5e4709d11/uv-0.11.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a3551cb18aaa75f50153877a4988e718ba365ba998563c390a99e207aeeadd0d", size = 23107411, upload-time = "2026-05-11T01:38:06.838Z" }, - { url = "https://files.pythonhosted.org/packages/02/2c/2311a29f32e1d404dc2fbc516e5febdf4567fcc3cfdd94e398bf5566b515/uv-0.11.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd454d4f40e232355fa96937bae41e91b16279e2526034050576da5a2d8a7f40", size = 24551248, upload-time = "2026-05-11T01:37:23.403Z" }, - { url = "https://files.pythonhosted.org/packages/15/cc/ecb7174b11f64079ab9ec8ec0443aeaf69b86c6e6ad213b094d61ac71205/uv-0.11.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4356e97a0e3e4d3ab53fd15415af12764a979759e37a3124372e3e6755e9a0c", size = 25455493, upload-time = "2026-05-11T01:38:10.814Z" }, - { url = "https://files.pythonhosted.org/packages/eb/0d/44031030724a5128efb06be62a701fe36a1664f91aee346ffaf6f0432d39/uv-0.11.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a714e866853c72cb2b7a18187cf3db4a1475a2032f3bd00e1c98ccf214c31d0", size = 24562712, upload-time = "2026-05-11T01:38:14.979Z" }, - { url = "https://files.pythonhosted.org/packages/e3/ad/dfb82224e73031c71dc70eb4513a6f4f6af66da35c3c955e28d75fe03d1c/uv-0.11.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23ddb47f7a317979cf1945cf9ed89d2639f60f7d06164f9ff1ad292c4cc5b3c", size = 24662925, upload-time = "2026-05-11T01:37:31.646Z" }, - { url = "https://files.pythonhosted.org/packages/41/3d/6cd9b920dcc83f0866e842caad5575cc3d5ca6604facbf5582950bbfc68c/uv-0.11.13-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:00efc945dd0392d7ac571cde402936e13ffba855121de79f42b3de9ee2f6a69a", size = 23398601, upload-time = "2026-05-11T01:37:18.832Z" }, - { url = "https://files.pythonhosted.org/packages/de/a9/291ff99b1dce9ec14b5a0358ad7d384485471d6ee4ecd7d98e05ef570da5/uv-0.11.13-py3-none-manylinux_2_31_riscv64.musllinux_1_1_riscv64.whl", hash = "sha256:b0807d1e9bc84c902cba9bb0b23627f6c980c54167c999e502571974fcfe2d6e", size = 24138999, upload-time = "2026-05-11T01:38:19.294Z" }, - { url = "https://files.pythonhosted.org/packages/9c/88/8eabfbe745371696d09d08e47e637d567413071eb02c7e2324a919ba4f87/uv-0.11.13-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:bea50519b30c1bc4e4a331dcc1d55253cd8d886d243d3506ec00f34cdf030eff", size = 24196974, upload-time = "2026-05-11T01:37:35.742Z" }, - { url = "https://files.pythonhosted.org/packages/ea/bf/9ab0db9d7f8d7b52382d70eb26bbd9e84dbe6cfce709ec7bf31895991a0a/uv-0.11.13-py3-none-musllinux_1_1_i686.whl", hash = "sha256:d714e4a09e28198664758576542c7cedb054677ab3cdec60207a75ed74f82235", size = 23822126, upload-time = "2026-05-11T01:38:31.829Z" }, - { url = "https://files.pythonhosted.org/packages/72/d9/dc2d1eb6b4181e5485cd36ecdb1c2f4fbec9b4078bb2b7266ef5481d2433/uv-0.11.13-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:bf067fc357e1f75783c343e731c3bf4f8ca531917eafd6d9f18cd477ddaee158", size = 24868862, upload-time = "2026-05-11T01:38:27.595Z" }, - { url = "https://files.pythonhosted.org/packages/94/94/de37ee6b07459780de695e6c57e158ce1307de075f40718740a981132d9e/uv-0.11.13-py3-none-win32.whl", hash = "sha256:79c3f501bbf849bc566e108545891abfbc15e4e85c22d8875bfe405c1e2efc42", size = 22581531, upload-time = "2026-05-11T01:37:39.382Z" }, - { url = "https://files.pythonhosted.org/packages/d9/89/01f90839cd1204e7a328cc36da27a09bc8b1a9692d3f9b79cee0a0945e1b/uv-0.11.13-py3-none-win_amd64.whl", hash = "sha256:974ec55646a7e680f91cdf4f77fbc6e2a71157240cd0efa387d458709b63ab04", size = 25194788, upload-time = "2026-05-11T01:37:51.372Z" }, - { url = "https://files.pythonhosted.org/packages/63/99/4d75ad86221363a277c3be4e36e928e84f0dff256413e83e58d8af8c0e2c/uv-0.11.13-py3-none-win_arm64.whl", hash = "sha256:35aaca82115b8dc747f22b8c76b1026e707f4c9a59fe39ab3c21be111a65fa44", size = 23589361, upload-time = "2026-05-11T01:37:27.755Z" }, +version = "0.11.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/a3/be4a946c7c2fc4094c020c8f7d8bd0a739bad55ebe4e2817d6e2b1bc6bff/uv-0.11.14.tar.gz", hash = "sha256:0ea006a117b586b2681b6dfd9703a540d2ad2a136ec0f48d272767e599cc3dfb", size = 4130699, upload-time = "2026-05-12T18:00:37.321Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/15/9b2138b16eb1fa8c2cd84b1037ad10c38b3acc36ce96c6d27000bfb7e716/uv-0.11.14-py3-none-linux_armv6l.whl", hash = "sha256:78411a883f230a710af19f2ac6e6f0ba8eae90f0e5af4605f923fd367539fff4", size = 23545199, upload-time = "2026-05-12T18:01:34.526Z" }, + { url = "https://files.pythonhosted.org/packages/75/81/c678e8b9a8e624f9c338c66cd57dd9cfc6b5a0501ad3c87fd0cc0bf8850a/uv-0.11.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:078f2e63da89c8fcf6d578f02156045c5990c57d76464aab3f3f798d3fff95cd", size = 22957064, upload-time = "2026-05-12T18:00:54.225Z" }, + { url = "https://files.pythonhosted.org/packages/f7/ad/95fbd15b23f26f36d0cfb0ddf159b9602a1b1c0feced60a7f98385e919f1/uv-0.11.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:dcdad43d52c130e3159e84ab1844e04d819d2c4a2495a687d27f80d560a3650e", size = 21678307, upload-time = "2026-05-12T18:00:57.132Z" }, + { url = "https://files.pythonhosted.org/packages/8b/cb/b3da1c4d95d6dd507896bca16dbd643118013b2b151f5f35a08d3391728c/uv-0.11.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:9923da7c63d70de9fe71829503d7e7ebfd6304e804d7232aad5f716e190db25b", size = 23353409, upload-time = "2026-05-12T18:01:27.512Z" }, + { url = "https://files.pythonhosted.org/packages/51/ad/78c6b8d6bcc04c5043b50631e9b413422a03a0bd7c4a997748f8e9cbac25/uv-0.11.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:3b0759ca504e48dcd4fafb1a61ef69aeb24c5a60fbf5f504a7873c8db1b24718", size = 23103964, upload-time = "2026-05-12T18:01:31.094Z" }, + { url = "https://files.pythonhosted.org/packages/0f/7d/acb66e09bc54a74e4288e996d841af04d88588fd6bdbfbab2468ab7169a7/uv-0.11.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:78b51b117549ee4db7197ea5ece0848cecd443e464fb9dff9f254cdc1e4ed96f", size = 23104638, upload-time = "2026-05-12T18:01:10.093Z" }, + { url = "https://files.pythonhosted.org/packages/31/0a/8497be61accdb8e56d02e11edd3ac471466259420e0bd9c05c1966df134a/uv-0.11.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1ddbe8a2ab160affc179e9c3a40913b23a08cdf55254e1f3829cc22a51a0d8d", size = 24625888, upload-time = "2026-05-12T18:01:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/95/91/f730799fd20a45777b255e20cf9f648a4e4e0979bf65e87a8633197cf7d9/uv-0.11.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3005a2db1e8d72e125630d4f22ac4ceddb2c033e1f9b94b7f3ea38ebac46dd6", size = 25445231, upload-time = "2026-05-12T18:00:40.012Z" }, + { url = "https://files.pythonhosted.org/packages/f5/4d/106463fc27e63e402aec2e791774dac2db5bd5e1c36cdcf38125aa97ab1c/uv-0.11.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5c8f9ea36274ef2f9d24f0522085e280844172e901d9213f66a21b212266706", size = 24571961, upload-time = "2026-05-12T18:00:43.713Z" }, + { url = "https://files.pythonhosted.org/packages/12/4d/163fe746b97bd1129627e8b1f943e17583ddc143eaab532d56a799a9ba5a/uv-0.11.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:379e64b236cf55f762a8308d7efe4365d5296ba29f3a4868761bc45b4e915a71", size = 24718523, upload-time = "2026-05-12T18:01:06.587Z" }, + { url = "https://files.pythonhosted.org/packages/19/fb/7a3673494a0cf70267559166398f9c50c4925ff20122f99a28d6c5a80d83/uv-0.11.14-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:29c12a562441fc2d604e6920c558cacce74a55f889468708683a79b35a6e18a1", size = 23454821, upload-time = "2026-05-12T18:00:51.166Z" }, + { url = "https://files.pythonhosted.org/packages/bb/43/6358394a567d865f3a5ce27b1e0d939549911e36d9b59f0c545a167f92f7/uv-0.11.14-py3-none-manylinux_2_31_riscv64.musllinux_1_1_riscv64.whl", hash = "sha256:e84069681c0334e07cbc7f114eb09d7fe1335e1db0297a66dbca80a1b393fe6d", size = 24087843, upload-time = "2026-05-12T18:00:47.272Z" }, + { url = "https://files.pythonhosted.org/packages/ef/f6/7d0ae1e1f52b85057ca24d8876d6a4cc87b541ea6aca627fe36594c06099/uv-0.11.14-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:b15bf7c146e38d7c938d3a207115d5fdd8ef764fe1f866c225b1bed27e88da1e", size = 24147611, upload-time = "2026-05-12T18:01:20.499Z" }, + { url = "https://files.pythonhosted.org/packages/5a/a2/511ad0c5da5697fd990b99569425b62b81cbc3458c35acc845211b55d6b5/uv-0.11.14-py3-none-musllinux_1_1_i686.whl", hash = "sha256:ddda5c5e41097814adac535c74851bae55e8097b9afc79aeae7fcffd8d86c06d", size = 23920348, upload-time = "2026-05-12T18:01:24.033Z" }, + { url = "https://files.pythonhosted.org/packages/6b/b6/7084e3401b1f1020f215a125136eec1ed2bd541e10a5fea1625515579599/uv-0.11.14-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:e54326703f1eca83a6fd73275e0f398b16b7d3f81531bf58899c2869bc403f6c", size = 24928981, upload-time = "2026-05-12T18:01:13.961Z" }, + { url = "https://files.pythonhosted.org/packages/4d/6a/7e81729fe729889c8cc63bbf64291734359bd7f6ba84852dc0504453511d/uv-0.11.14-py3-none-win32.whl", hash = "sha256:b384d873d0d18552c7524226125efd3965d921b7134c2f476c333771beb733e1", size = 22573503, upload-time = "2026-05-12T18:00:34.36Z" }, + { url = "https://files.pythonhosted.org/packages/94/5d/f8905f9af5cd46af2a688b2246dbb5a4d95b8557eeffd7f241e037659d9e/uv-0.11.14-py3-none-win_amd64.whl", hash = "sha256:f0a8b58b38e984241bca5d7a5a47bf9ffe1ca2ab392a640887db8a04c4a9ec95", size = 25175590, upload-time = "2026-05-12T18:01:00.38Z" }, + { url = "https://files.pythonhosted.org/packages/04/cb/7333d08d944f3018eb89242cd5e646e7b37faa1b567faeaf9254a8b59d53/uv-0.11.14-py3-none-win_arm64.whl", hash = "sha256:6a13e7e064563050c6606b3fd77091d427cdbdc5938b6f134baf8d8ec79bfdb7", size = 23594775, upload-time = "2026-05-12T18:01:03.55Z" }, ] [[package]] name = "uvicorn" -version = "0.46.0" +version = "0.47.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "h11" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/93/041fca8274050e40e6791f267d82e0e2e27dd165627bd640d3e0e378d877/uvicorn-0.46.0.tar.gz", hash = "sha256:fb9da0926999cc6cb22dc7cd71a94a632f078e6ae47ff683c5c420750fb7413d", size = 88758, upload-time = "2026-04-23T07:16:00.151Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/b1/8e7077a8641086aea449e1b5752a570f1b5906c64e0a33cd6d93b63a066b/uvicorn-0.47.0.tar.gz", hash = "sha256:7c9a0ea1a9414106bbab7324609c162d8fa0cdcdcb703060987269d77c7bb533", size = 90582, upload-time = "2026-05-14T18:16:54.455Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/31/a3/5b1562db76a5a488274b2332a97199b32d0442aca0ed193697fd47786316/uvicorn-0.46.0-py3-none-any.whl", hash = "sha256:bbebbcbed972d162afca128605223022bedd345b7bc7855ce66deb31487a9048", size = 70926, upload-time = "2026-04-23T07:15:58.355Z" }, + { url = "https://files.pythonhosted.org/packages/15/41/ac2dfdbc1f60c7af4f994c7a335cfa7040c01642b605d65f611cecc2a1e4/uvicorn-0.47.0-py3-none-any.whl", hash = "sha256:2c5715bc12d1892d84752049f400cd1c3cb018514967fdfeb97640443a6a9432", size = 71301, upload-time = "2026-05-14T18:16:51.762Z" }, ] [[package]] name = "virtualenv" -version = "21.3.1" +version = "21.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -4643,112 +4643,115 @@ dependencies = [ { name = "python-discovery" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/0d/915c02c94d207b85580eb09bffab54438a709e7288524094fe781da526c2/virtualenv-21.3.1.tar.gz", hash = "sha256:c2305bc1fddeec40699b8370d13f8d431b0701f00ce895061ce493aeded4426b", size = 7613791, upload-time = "2026-05-05T01:34:31.402Z" } +sdist = { url = "https://files.pythonhosted.org/packages/15/ba/1f6e8c957e4932be060dcdc482d339c12e0216351478add3645cdaa53c05/virtualenv-21.3.3.tar.gz", hash = "sha256:f5bda277e553b1c2b3c1a8debfc30496e1288cc93ce6b7b71b3280047e317328", size = 7613784, upload-time = "2026-05-13T18:01:30.19Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/4f/f71e641e504111a5a74e3a20bc52d01bd86788b22699dd3fee1c63253cf6/virtualenv-21.3.1-py3-none-any.whl", hash = "sha256:d1a71cf58f2f9228fff23a1f6ec15d39785c6b32e03658d104974247145edd35", size = 7594539, upload-time = "2026-05-05T01:34:28.98Z" }, + { url = "https://files.pythonhosted.org/packages/f4/34/a9dbe051de88a63eb7408ea66630bac38e72f7f6077d4be58737106860d9/virtualenv-21.3.3-py3-none-any.whl", hash = "sha256:7d5987d8369e098e41406efb780a3d4ca79280097293899e351a6407ee153ab3", size = 7594554, upload-time = "2026-05-13T18:01:27.815Z" }, ] [[package]] name = "watchfiles" -version = "1.1.1" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c2/c9/8869df9b2a2d6c59d79220a4db37679e74f807c559ffe5265e08b227a210/watchfiles-1.1.1.tar.gz", hash = "sha256:a173cb5c16c4f40ab19cecf48a534c409f7ea983ab8fed0741304a1c0a31b3f2", size = 94440, upload-time = "2025-10-14T15:06:21.08Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/1a/206e8cf2dd86fddf939165a57b4df61607a1e0add2785f170a3f616b7d9f/watchfiles-1.1.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:eef58232d32daf2ac67f42dea51a2c80f0d03379075d44a587051e63cc2e368c", size = 407318, upload-time = "2025-10-14T15:04:18.753Z" }, - { url = "https://files.pythonhosted.org/packages/b3/0f/abaf5262b9c496b5dad4ed3c0e799cbecb1f8ea512ecb6ddd46646a9fca3/watchfiles-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:03fa0f5237118a0c5e496185cafa92878568b652a2e9a9382a5151b1a0380a43", size = 394478, upload-time = "2025-10-14T15:04:20.297Z" }, - { url = "https://files.pythonhosted.org/packages/b1/04/9cc0ba88697b34b755371f5ace8d3a4d9a15719c07bdc7bd13d7d8c6a341/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca65483439f9c791897f7db49202301deb6e15fe9f8fe2fed555bf986d10c31", size = 449894, upload-time = "2025-10-14T15:04:21.527Z" }, - { url = "https://files.pythonhosted.org/packages/d2/9c/eda4615863cd8621e89aed4df680d8c3ec3da6a4cf1da113c17decd87c7f/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f0ab1c1af0cb38e3f598244c17919fb1a84d1629cc08355b0074b6d7f53138ac", size = 459065, upload-time = "2025-10-14T15:04:22.795Z" }, - { url = "https://files.pythonhosted.org/packages/84/13/f28b3f340157d03cbc8197629bc109d1098764abe1e60874622a0be5c112/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bc570d6c01c206c46deb6e935a260be44f186a2f05179f52f7fcd2be086a94d", size = 488377, upload-time = "2025-10-14T15:04:24.138Z" }, - { url = "https://files.pythonhosted.org/packages/86/93/cfa597fa9389e122488f7ffdbd6db505b3b915ca7435ecd7542e855898c2/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e84087b432b6ac94778de547e08611266f1f8ffad28c0ee4c82e028b0fc5966d", size = 595837, upload-time = "2025-10-14T15:04:25.057Z" }, - { url = "https://files.pythonhosted.org/packages/57/1e/68c1ed5652b48d89fc24d6af905d88ee4f82fa8bc491e2666004e307ded1/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:620bae625f4cb18427b1bb1a2d9426dc0dd5a5ba74c7c2cdb9de405f7b129863", size = 473456, upload-time = "2025-10-14T15:04:26.497Z" }, - { url = "https://files.pythonhosted.org/packages/d5/dc/1a680b7458ffa3b14bb64878112aefc8f2e4f73c5af763cbf0bd43100658/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:544364b2b51a9b0c7000a4b4b02f90e9423d97fbbf7e06689236443ebcad81ab", size = 455614, upload-time = "2025-10-14T15:04:27.539Z" }, - { url = "https://files.pythonhosted.org/packages/61/a5/3d782a666512e01eaa6541a72ebac1d3aae191ff4a31274a66b8dd85760c/watchfiles-1.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bbe1ef33d45bc71cf21364df962af171f96ecaeca06bd9e3d0b583efb12aec82", size = 630690, upload-time = "2025-10-14T15:04:28.495Z" }, - { url = "https://files.pythonhosted.org/packages/9b/73/bb5f38590e34687b2a9c47a244aa4dd50c56a825969c92c9c5fc7387cea1/watchfiles-1.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1a0bb430adb19ef49389e1ad368450193a90038b5b752f4ac089ec6942c4dff4", size = 622459, upload-time = "2025-10-14T15:04:29.491Z" }, - { url = "https://files.pythonhosted.org/packages/f1/ac/c9bb0ec696e07a20bd58af5399aeadaef195fb2c73d26baf55180fe4a942/watchfiles-1.1.1-cp310-cp310-win32.whl", hash = "sha256:3f6d37644155fb5beca5378feb8c1708d5783145f2a0f1c4d5a061a210254844", size = 272663, upload-time = "2025-10-14T15:04:30.435Z" }, - { url = "https://files.pythonhosted.org/packages/11/a0/a60c5a7c2ec59fa062d9a9c61d02e3b6abd94d32aac2d8344c4bdd033326/watchfiles-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:a36d8efe0f290835fd0f33da35042a1bb5dc0e83cbc092dcf69bce442579e88e", size = 287453, upload-time = "2025-10-14T15:04:31.53Z" }, - { url = "https://files.pythonhosted.org/packages/1f/f8/2c5f479fb531ce2f0564eda479faecf253d886b1ab3630a39b7bf7362d46/watchfiles-1.1.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f57b396167a2565a4e8b5e56a5a1c537571733992b226f4f1197d79e94cf0ae5", size = 406529, upload-time = "2025-10-14T15:04:32.899Z" }, - { url = "https://files.pythonhosted.org/packages/fe/cd/f515660b1f32f65df671ddf6f85bfaca621aee177712874dc30a97397977/watchfiles-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:421e29339983e1bebc281fab40d812742268ad057db4aee8c4d2bce0af43b741", size = 394384, upload-time = "2025-10-14T15:04:33.761Z" }, - { url = "https://files.pythonhosted.org/packages/7b/c3/28b7dc99733eab43fca2d10f55c86e03bd6ab11ca31b802abac26b23d161/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e43d39a741e972bab5d8100b5cdacf69db64e34eb19b6e9af162bccf63c5cc6", size = 448789, upload-time = "2025-10-14T15:04:34.679Z" }, - { url = "https://files.pythonhosted.org/packages/4a/24/33e71113b320030011c8e4316ccca04194bf0cbbaeee207f00cbc7d6b9f5/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f537afb3276d12814082a2e9b242bdcf416c2e8fd9f799a737990a1dbe906e5b", size = 460521, upload-time = "2025-10-14T15:04:35.963Z" }, - { url = "https://files.pythonhosted.org/packages/f4/c3/3c9a55f255aa57b91579ae9e98c88704955fa9dac3e5614fb378291155df/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2cd9e04277e756a2e2d2543d65d1e2166d6fd4c9b183f8808634fda23f17b14", size = 488722, upload-time = "2025-10-14T15:04:37.091Z" }, - { url = "https://files.pythonhosted.org/packages/49/36/506447b73eb46c120169dc1717fe2eff07c234bb3232a7200b5f5bd816e9/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f3f58818dc0b07f7d9aa7fe9eb1037aecb9700e63e1f6acfed13e9fef648f5d", size = 596088, upload-time = "2025-10-14T15:04:38.39Z" }, - { url = "https://files.pythonhosted.org/packages/82/ab/5f39e752a9838ec4d52e9b87c1e80f1ee3ccdbe92e183c15b6577ab9de16/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb9f66367023ae783551042d31b1d7fd422e8289eedd91f26754a66f44d5cff", size = 472923, upload-time = "2025-10-14T15:04:39.666Z" }, - { url = "https://files.pythonhosted.org/packages/af/b9/a419292f05e302dea372fa7e6fda5178a92998411f8581b9830d28fb9edb/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aebfd0861a83e6c3d1110b78ad54704486555246e542be3e2bb94195eabb2606", size = 456080, upload-time = "2025-10-14T15:04:40.643Z" }, - { url = "https://files.pythonhosted.org/packages/b0/c3/d5932fd62bde1a30c36e10c409dc5d54506726f08cb3e1d8d0ba5e2bc8db/watchfiles-1.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5fac835b4ab3c6487b5dbad78c4b3724e26bcc468e886f8ba8cc4306f68f6701", size = 629432, upload-time = "2025-10-14T15:04:41.789Z" }, - { url = "https://files.pythonhosted.org/packages/f7/77/16bddd9779fafb795f1a94319dc965209c5641db5bf1edbbccace6d1b3c0/watchfiles-1.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:399600947b170270e80134ac854e21b3ccdefa11a9529a3decc1327088180f10", size = 623046, upload-time = "2025-10-14T15:04:42.718Z" }, - { url = "https://files.pythonhosted.org/packages/46/ef/f2ecb9a0f342b4bfad13a2787155c6ee7ce792140eac63a34676a2feeef2/watchfiles-1.1.1-cp311-cp311-win32.whl", hash = "sha256:de6da501c883f58ad50db3a32ad397b09ad29865b5f26f64c24d3e3281685849", size = 271473, upload-time = "2025-10-14T15:04:43.624Z" }, - { url = "https://files.pythonhosted.org/packages/94/bc/f42d71125f19731ea435c3948cad148d31a64fccde3867e5ba4edee901f9/watchfiles-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:35c53bd62a0b885bf653ebf6b700d1bf05debb78ad9292cf2a942b23513dc4c4", size = 287598, upload-time = "2025-10-14T15:04:44.516Z" }, - { url = "https://files.pythonhosted.org/packages/57/c9/a30f897351f95bbbfb6abcadafbaca711ce1162f4db95fc908c98a9165f3/watchfiles-1.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:57ca5281a8b5e27593cb7d82c2ac927ad88a96ed406aa446f6344e4328208e9e", size = 277210, upload-time = "2025-10-14T15:04:45.883Z" }, - { url = "https://files.pythonhosted.org/packages/74/d5/f039e7e3c639d9b1d09b07ea412a6806d38123f0508e5f9b48a87b0a76cc/watchfiles-1.1.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:8c89f9f2f740a6b7dcc753140dd5e1ab9215966f7a3530d0c0705c83b401bd7d", size = 404745, upload-time = "2025-10-14T15:04:46.731Z" }, - { url = "https://files.pythonhosted.org/packages/a5/96/a881a13aa1349827490dab2d363c8039527060cfcc2c92cc6d13d1b1049e/watchfiles-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd404be08018c37350f0d6e34676bd1e2889990117a2b90070b3007f172d0610", size = 391769, upload-time = "2025-10-14T15:04:48.003Z" }, - { url = "https://files.pythonhosted.org/packages/4b/5b/d3b460364aeb8da471c1989238ea0e56bec24b6042a68046adf3d9ddb01c/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8526e8f916bb5b9a0a777c8317c23ce65de259422bba5b31325a6fa6029d33af", size = 449374, upload-time = "2025-10-14T15:04:49.179Z" }, - { url = "https://files.pythonhosted.org/packages/b9/44/5769cb62d4ed055cb17417c0a109a92f007114a4e07f30812a73a4efdb11/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2edc3553362b1c38d9f06242416a5d8e9fe235c204a4072e988ce2e5bb1f69f6", size = 459485, upload-time = "2025-10-14T15:04:50.155Z" }, - { url = "https://files.pythonhosted.org/packages/19/0c/286b6301ded2eccd4ffd0041a1b726afda999926cf720aab63adb68a1e36/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30f7da3fb3f2844259cba4720c3fc7138eb0f7b659c38f3bfa65084c7fc7abce", size = 488813, upload-time = "2025-10-14T15:04:51.059Z" }, - { url = "https://files.pythonhosted.org/packages/c7/2b/8530ed41112dd4a22f4dcfdb5ccf6a1baad1ff6eed8dc5a5f09e7e8c41c7/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8979280bdafff686ba5e4d8f97840f929a87ed9cdf133cbbd42f7766774d2aa", size = 594816, upload-time = "2025-10-14T15:04:52.031Z" }, - { url = "https://files.pythonhosted.org/packages/ce/d2/f5f9fb49489f184f18470d4f99f4e862a4b3e9ac2865688eb2099e3d837a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcc5c24523771db3a294c77d94771abcfcb82a0e0ee8efd910c37c59ec1b31bb", size = 475186, upload-time = "2025-10-14T15:04:53.064Z" }, - { url = "https://files.pythonhosted.org/packages/cf/68/5707da262a119fb06fbe214d82dd1fe4a6f4af32d2d14de368d0349eb52a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db5d7ae38ff20153d542460752ff397fcf5c96090c1230803713cf3147a6803", size = 456812, upload-time = "2025-10-14T15:04:55.174Z" }, - { url = "https://files.pythonhosted.org/packages/66/ab/3cbb8756323e8f9b6f9acb9ef4ec26d42b2109bce830cc1f3468df20511d/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:28475ddbde92df1874b6c5c8aaeb24ad5be47a11f87cde5a28ef3835932e3e94", size = 630196, upload-time = "2025-10-14T15:04:56.22Z" }, - { url = "https://files.pythonhosted.org/packages/78/46/7152ec29b8335f80167928944a94955015a345440f524d2dfe63fc2f437b/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:36193ed342f5b9842edd3532729a2ad55c4160ffcfa3700e0d54be496b70dd43", size = 622657, upload-time = "2025-10-14T15:04:57.521Z" }, - { url = "https://files.pythonhosted.org/packages/0a/bf/95895e78dd75efe9a7f31733607f384b42eb5feb54bd2eb6ed57cc2e94f4/watchfiles-1.1.1-cp312-cp312-win32.whl", hash = "sha256:859e43a1951717cc8de7f4c77674a6d389b106361585951d9e69572823f311d9", size = 272042, upload-time = "2025-10-14T15:04:59.046Z" }, - { url = "https://files.pythonhosted.org/packages/87/0a/90eb755f568de2688cb220171c4191df932232c20946966c27a59c400850/watchfiles-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:91d4c9a823a8c987cce8fa2690923b069966dabb196dd8d137ea2cede885fde9", size = 288410, upload-time = "2025-10-14T15:05:00.081Z" }, - { url = "https://files.pythonhosted.org/packages/36/76/f322701530586922fbd6723c4f91ace21364924822a8772c549483abed13/watchfiles-1.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:a625815d4a2bdca61953dbba5a39d60164451ef34c88d751f6c368c3ea73d404", size = 278209, upload-time = "2025-10-14T15:05:01.168Z" }, - { url = "https://files.pythonhosted.org/packages/bb/f4/f750b29225fe77139f7ae5de89d4949f5a99f934c65a1f1c0b248f26f747/watchfiles-1.1.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:130e4876309e8686a5e37dba7d5e9bc77e6ed908266996ca26572437a5271e18", size = 404321, upload-time = "2025-10-14T15:05:02.063Z" }, - { url = "https://files.pythonhosted.org/packages/2b/f9/f07a295cde762644aa4c4bb0f88921d2d141af45e735b965fb2e87858328/watchfiles-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5f3bde70f157f84ece3765b42b4a52c6ac1a50334903c6eaf765362f6ccca88a", size = 391783, upload-time = "2025-10-14T15:05:03.052Z" }, - { url = "https://files.pythonhosted.org/packages/bc/11/fc2502457e0bea39a5c958d86d2cb69e407a4d00b85735ca724bfa6e0d1a/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14e0b1fe858430fc0251737ef3824c54027bedb8c37c38114488b8e131cf8219", size = 449279, upload-time = "2025-10-14T15:05:04.004Z" }, - { url = "https://files.pythonhosted.org/packages/e3/1f/d66bc15ea0b728df3ed96a539c777acfcad0eb78555ad9efcaa1274688f0/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f27db948078f3823a6bb3b465180db8ebecf26dd5dae6f6180bd87383b6b4428", size = 459405, upload-time = "2025-10-14T15:05:04.942Z" }, - { url = "https://files.pythonhosted.org/packages/be/90/9f4a65c0aec3ccf032703e6db02d89a157462fbb2cf20dd415128251cac0/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059098c3a429f62fc98e8ec62b982230ef2c8df68c79e826e37b895bc359a9c0", size = 488976, upload-time = "2025-10-14T15:05:05.905Z" }, - { url = "https://files.pythonhosted.org/packages/37/57/ee347af605d867f712be7029bb94c8c071732a4b44792e3176fa3c612d39/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfb5862016acc9b869bb57284e6cb35fdf8e22fe59f7548858e2f971d045f150", size = 595506, upload-time = "2025-10-14T15:05:06.906Z" }, - { url = "https://files.pythonhosted.org/packages/a8/78/cc5ab0b86c122047f75e8fc471c67a04dee395daf847d3e59381996c8707/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:319b27255aacd9923b8a276bb14d21a5f7ff82564c744235fc5eae58d95422ae", size = 474936, upload-time = "2025-10-14T15:05:07.906Z" }, - { url = "https://files.pythonhosted.org/packages/62/da/def65b170a3815af7bd40a3e7010bf6ab53089ef1b75d05dd5385b87cf08/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c755367e51db90e75b19454b680903631d41f9e3607fbd941d296a020c2d752d", size = 456147, upload-time = "2025-10-14T15:05:09.138Z" }, - { url = "https://files.pythonhosted.org/packages/57/99/da6573ba71166e82d288d4df0839128004c67d2778d3b566c138695f5c0b/watchfiles-1.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c22c776292a23bfc7237a98f791b9ad3144b02116ff10d820829ce62dff46d0b", size = 630007, upload-time = "2025-10-14T15:05:10.117Z" }, - { url = "https://files.pythonhosted.org/packages/a8/51/7439c4dd39511368849eb1e53279cd3454b4a4dbace80bab88feeb83c6b5/watchfiles-1.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3a476189be23c3686bc2f4321dd501cb329c0a0469e77b7b534ee10129ae6374", size = 622280, upload-time = "2025-10-14T15:05:11.146Z" }, - { url = "https://files.pythonhosted.org/packages/95/9c/8ed97d4bba5db6fdcdb2b298d3898f2dd5c20f6b73aee04eabe56c59677e/watchfiles-1.1.1-cp313-cp313-win32.whl", hash = "sha256:bf0a91bfb5574a2f7fc223cf95eeea79abfefa404bf1ea5e339c0c1560ae99a0", size = 272056, upload-time = "2025-10-14T15:05:12.156Z" }, - { url = "https://files.pythonhosted.org/packages/1f/f3/c14e28429f744a260d8ceae18bf58c1d5fa56b50d006a7a9f80e1882cb0d/watchfiles-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:52e06553899e11e8074503c8e716d574adeeb7e68913115c4b3653c53f9bae42", size = 288162, upload-time = "2025-10-14T15:05:13.208Z" }, - { url = "https://files.pythonhosted.org/packages/dc/61/fe0e56c40d5cd29523e398d31153218718c5786b5e636d9ae8ae79453d27/watchfiles-1.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac3cc5759570cd02662b15fbcd9d917f7ecd47efe0d6b40474eafd246f91ea18", size = 277909, upload-time = "2025-10-14T15:05:14.49Z" }, - { url = "https://files.pythonhosted.org/packages/79/42/e0a7d749626f1e28c7108a99fb9bf524b501bbbeb9b261ceecde644d5a07/watchfiles-1.1.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:563b116874a9a7ce6f96f87cd0b94f7faf92d08d0021e837796f0a14318ef8da", size = 403389, upload-time = "2025-10-14T15:05:15.777Z" }, - { url = "https://files.pythonhosted.org/packages/15/49/08732f90ce0fbbc13913f9f215c689cfc9ced345fb1bcd8829a50007cc8d/watchfiles-1.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ad9fe1dae4ab4212d8c91e80b832425e24f421703b5a42ef2e4a1e215aff051", size = 389964, upload-time = "2025-10-14T15:05:16.85Z" }, - { url = "https://files.pythonhosted.org/packages/27/0d/7c315d4bd5f2538910491a0393c56bf70d333d51bc5b34bee8e68e8cea19/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce70f96a46b894b36eba678f153f052967a0d06d5b5a19b336ab0dbbd029f73e", size = 448114, upload-time = "2025-10-14T15:05:17.876Z" }, - { url = "https://files.pythonhosted.org/packages/c3/24/9e096de47a4d11bc4df41e9d1e61776393eac4cb6eb11b3e23315b78b2cc/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cb467c999c2eff23a6417e58d75e5828716f42ed8289fe6b77a7e5a91036ca70", size = 460264, upload-time = "2025-10-14T15:05:18.962Z" }, - { url = "https://files.pythonhosted.org/packages/cc/0f/e8dea6375f1d3ba5fcb0b3583e2b493e77379834c74fd5a22d66d85d6540/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:836398932192dae4146c8f6f737d74baeac8b70ce14831a239bdb1ca882fc261", size = 487877, upload-time = "2025-10-14T15:05:20.094Z" }, - { url = "https://files.pythonhosted.org/packages/ac/5b/df24cfc6424a12deb41503b64d42fbea6b8cb357ec62ca84a5a3476f654a/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:743185e7372b7bc7c389e1badcc606931a827112fbbd37f14c537320fca08620", size = 595176, upload-time = "2025-10-14T15:05:21.134Z" }, - { url = "https://files.pythonhosted.org/packages/8f/b5/853b6757f7347de4e9b37e8cc3289283fb983cba1ab4d2d7144694871d9c/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afaeff7696e0ad9f02cbb8f56365ff4686ab205fcf9c4c5b6fdfaaa16549dd04", size = 473577, upload-time = "2025-10-14T15:05:22.306Z" }, - { url = "https://files.pythonhosted.org/packages/e1/f7/0a4467be0a56e80447c8529c9fce5b38eab4f513cb3d9bf82e7392a5696b/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f7eb7da0eb23aa2ba036d4f616d46906013a68caf61b7fdbe42fc8b25132e77", size = 455425, upload-time = "2025-10-14T15:05:23.348Z" }, - { url = "https://files.pythonhosted.org/packages/8e/e0/82583485ea00137ddf69bc84a2db88bd92ab4a6e3c405e5fb878ead8d0e7/watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:831a62658609f0e5c64178211c942ace999517f5770fe9436be4c2faeba0c0ef", size = 628826, upload-time = "2025-10-14T15:05:24.398Z" }, - { url = "https://files.pythonhosted.org/packages/28/9a/a785356fccf9fae84c0cc90570f11702ae9571036fb25932f1242c82191c/watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:f9a2ae5c91cecc9edd47e041a930490c31c3afb1f5e6d71de3dc671bfaca02bf", size = 622208, upload-time = "2025-10-14T15:05:25.45Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f4/0872229324ef69b2c3edec35e84bd57a1289e7d3fe74588048ed8947a323/watchfiles-1.1.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:d1715143123baeeaeadec0528bb7441103979a1d5f6fd0e1f915383fea7ea6d5", size = 404315, upload-time = "2025-10-14T15:05:26.501Z" }, - { url = "https://files.pythonhosted.org/packages/7b/22/16d5331eaed1cb107b873f6ae1b69e9ced582fcf0c59a50cd84f403b1c32/watchfiles-1.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:39574d6370c4579d7f5d0ad940ce5b20db0e4117444e39b6d8f99db5676c52fd", size = 390869, upload-time = "2025-10-14T15:05:27.649Z" }, - { url = "https://files.pythonhosted.org/packages/b2/7e/5643bfff5acb6539b18483128fdc0ef2cccc94a5b8fbda130c823e8ed636/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7365b92c2e69ee952902e8f70f3ba6360d0d596d9299d55d7d386df84b6941fb", size = 449919, upload-time = "2025-10-14T15:05:28.701Z" }, - { url = "https://files.pythonhosted.org/packages/51/2e/c410993ba5025a9f9357c376f48976ef0e1b1aefb73b97a5ae01a5972755/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfff9740c69c0e4ed32416f013f3c45e2ae42ccedd1167ef2d805c000b6c71a5", size = 460845, upload-time = "2025-10-14T15:05:30.064Z" }, - { url = "https://files.pythonhosted.org/packages/8e/a4/2df3b404469122e8680f0fcd06079317e48db58a2da2950fb45020947734/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b27cf2eb1dda37b2089e3907d8ea92922b673c0c427886d4edc6b94d8dfe5db3", size = 489027, upload-time = "2025-10-14T15:05:31.064Z" }, - { url = "https://files.pythonhosted.org/packages/ea/84/4587ba5b1f267167ee715b7f66e6382cca6938e0a4b870adad93e44747e6/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:526e86aced14a65a5b0ec50827c745597c782ff46b571dbfe46192ab9e0b3c33", size = 595615, upload-time = "2025-10-14T15:05:32.074Z" }, - { url = "https://files.pythonhosted.org/packages/6a/0f/c6988c91d06e93cd0bb3d4a808bcf32375ca1904609835c3031799e3ecae/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04e78dd0b6352db95507fd8cb46f39d185cf8c74e4cf1e4fbad1d3df96faf510", size = 474836, upload-time = "2025-10-14T15:05:33.209Z" }, - { url = "https://files.pythonhosted.org/packages/b4/36/ded8aebea91919485b7bbabbd14f5f359326cb5ec218cd67074d1e426d74/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c85794a4cfa094714fb9c08d4a218375b2b95b8ed1666e8677c349906246c05", size = 455099, upload-time = "2025-10-14T15:05:34.189Z" }, - { url = "https://files.pythonhosted.org/packages/98/e0/8c9bdba88af756a2fce230dd365fab2baf927ba42cd47521ee7498fd5211/watchfiles-1.1.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:74d5012b7630714b66be7b7b7a78855ef7ad58e8650c73afc4c076a1f480a8d6", size = 630626, upload-time = "2025-10-14T15:05:35.216Z" }, - { url = "https://files.pythonhosted.org/packages/2a/84/a95db05354bf2d19e438520d92a8ca475e578c647f78f53197f5a2f17aaf/watchfiles-1.1.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:8fbe85cb3201c7d380d3d0b90e63d520f15d6afe217165d7f98c9c649654db81", size = 622519, upload-time = "2025-10-14T15:05:36.259Z" }, - { url = "https://files.pythonhosted.org/packages/1d/ce/d8acdc8de545de995c339be67711e474c77d643555a9bb74a9334252bd55/watchfiles-1.1.1-cp314-cp314-win32.whl", hash = "sha256:3fa0b59c92278b5a7800d3ee7733da9d096d4aabcfabb9a928918bd276ef9b9b", size = 272078, upload-time = "2025-10-14T15:05:37.63Z" }, - { url = "https://files.pythonhosted.org/packages/c4/c9/a74487f72d0451524be827e8edec251da0cc1fcf111646a511ae752e1a3d/watchfiles-1.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:c2047d0b6cea13b3316bdbafbfa0c4228ae593d995030fda39089d36e64fc03a", size = 287664, upload-time = "2025-10-14T15:05:38.95Z" }, - { url = "https://files.pythonhosted.org/packages/df/b8/8ac000702cdd496cdce998c6f4ee0ca1f15977bba51bdf07d872ebdfc34c/watchfiles-1.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:842178b126593addc05acf6fce960d28bc5fae7afbaa2c6c1b3a7b9460e5be02", size = 277154, upload-time = "2025-10-14T15:05:39.954Z" }, - { url = "https://files.pythonhosted.org/packages/47/a8/e3af2184707c29f0f14b1963c0aace6529f9d1b8582d5b99f31bbf42f59e/watchfiles-1.1.1-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:88863fbbc1a7312972f1c511f202eb30866370ebb8493aef2812b9ff28156a21", size = 403820, upload-time = "2025-10-14T15:05:40.932Z" }, - { url = "https://files.pythonhosted.org/packages/c0/ec/e47e307c2f4bd75f9f9e8afbe3876679b18e1bcec449beca132a1c5ffb2d/watchfiles-1.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:55c7475190662e202c08c6c0f4d9e345a29367438cf8e8037f3155e10a88d5a5", size = 390510, upload-time = "2025-10-14T15:05:41.945Z" }, - { url = "https://files.pythonhosted.org/packages/d5/a0/ad235642118090f66e7b2f18fd5c42082418404a79205cdfca50b6309c13/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f53fa183d53a1d7a8852277c92b967ae99c2d4dcee2bfacff8868e6e30b15f7", size = 448408, upload-time = "2025-10-14T15:05:43.385Z" }, - { url = "https://files.pythonhosted.org/packages/df/85/97fa10fd5ff3332ae17e7e40e20784e419e28521549780869f1413742e9d/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6aae418a8b323732fa89721d86f39ec8f092fc2af67f4217a2b07fd3e93c6101", size = 458968, upload-time = "2025-10-14T15:05:44.404Z" }, - { url = "https://files.pythonhosted.org/packages/47/c2/9059c2e8966ea5ce678166617a7f75ecba6164375f3b288e50a40dc6d489/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f096076119da54a6080e8920cbdaac3dbee667eb91dcc5e5b78840b87415bd44", size = 488096, upload-time = "2025-10-14T15:05:45.398Z" }, - { url = "https://files.pythonhosted.org/packages/94/44/d90a9ec8ac309bc26db808a13e7bfc0e4e78b6fc051078a554e132e80160/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00485f441d183717038ed2e887a7c868154f216877653121068107b227a2f64c", size = 596040, upload-time = "2025-10-14T15:05:46.502Z" }, - { url = "https://files.pythonhosted.org/packages/95/68/4e3479b20ca305cfc561db3ed207a8a1c745ee32bf24f2026a129d0ddb6e/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a55f3e9e493158d7bfdb60a1165035f1cf7d320914e7b7ea83fe22c6023b58fc", size = 473847, upload-time = "2025-10-14T15:05:47.484Z" }, - { url = "https://files.pythonhosted.org/packages/4f/55/2af26693fd15165c4ff7857e38330e1b61ab8c37d15dc79118cdba115b7a/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c91ed27800188c2ae96d16e3149f199d62f86c7af5f5f4d2c61a3ed8cd3666c", size = 455072, upload-time = "2025-10-14T15:05:48.928Z" }, - { url = "https://files.pythonhosted.org/packages/66/1d/d0d200b10c9311ec25d2273f8aad8c3ef7cc7ea11808022501811208a750/watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:311ff15a0bae3714ffb603e6ba6dbfba4065ab60865d15a6ec544133bdb21099", size = 629104, upload-time = "2025-10-14T15:05:49.908Z" }, - { url = "https://files.pythonhosted.org/packages/e3/bd/fa9bb053192491b3867ba07d2343d9f2252e00811567d30ae8d0f78136fe/watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a916a2932da8f8ab582f242c065f5c81bed3462849ca79ee357dd9551b0e9b01", size = 622112, upload-time = "2025-10-14T15:05:50.941Z" }, - { url = "https://files.pythonhosted.org/packages/ba/4c/a888c91e2e326872fa4705095d64acd8aa2fb9c1f7b9bd0588f33850516c/watchfiles-1.1.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:17ef139237dfced9da49fb7f2232c86ca9421f666d78c264c7ffca6601d154c3", size = 409611, upload-time = "2025-10-14T15:06:05.809Z" }, - { url = "https://files.pythonhosted.org/packages/1e/c7/5420d1943c8e3ce1a21c0a9330bcf7edafb6aa65d26b21dbb3267c9e8112/watchfiles-1.1.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:672b8adf25b1a0d35c96b5888b7b18699d27d4194bac8beeae75be4b7a3fc9b2", size = 396889, upload-time = "2025-10-14T15:06:07.035Z" }, - { url = "https://files.pythonhosted.org/packages/0c/e5/0072cef3804ce8d3aaddbfe7788aadff6b3d3f98a286fdbee9fd74ca59a7/watchfiles-1.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77a13aea58bc2b90173bc69f2a90de8e282648939a00a602e1dc4ee23e26b66d", size = 451616, upload-time = "2025-10-14T15:06:08.072Z" }, - { url = "https://files.pythonhosted.org/packages/83/4e/b87b71cbdfad81ad7e83358b3e447fedd281b880a03d64a760fe0a11fc2e/watchfiles-1.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b495de0bb386df6a12b18335a0285dda90260f51bdb505503c02bcd1ce27a8b", size = 458413, upload-time = "2025-10-14T15:06:09.209Z" }, - { url = "https://files.pythonhosted.org/packages/d3/8e/e500f8b0b77be4ff753ac94dc06b33d8f0d839377fee1b78e8c8d8f031bf/watchfiles-1.1.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:db476ab59b6765134de1d4fe96a1a9c96ddf091683599be0f26147ea1b2e4b88", size = 408250, upload-time = "2025-10-14T15:06:10.264Z" }, - { url = "https://files.pythonhosted.org/packages/bd/95/615e72cd27b85b61eec764a5ca51bd94d40b5adea5ff47567d9ebc4d275a/watchfiles-1.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:89eef07eee5e9d1fda06e38822ad167a044153457e6fd997f8a858ab7564a336", size = 396117, upload-time = "2025-10-14T15:06:11.28Z" }, - { url = "https://files.pythonhosted.org/packages/c9/81/e7fe958ce8a7fb5c73cc9fb07f5aeaf755e6aa72498c57d760af760c91f8/watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce19e06cbda693e9e7686358af9cd6f5d61312ab8b00488bc36f5aabbaf77e24", size = 450493, upload-time = "2025-10-14T15:06:12.321Z" }, - { url = "https://files.pythonhosted.org/packages/6e/d4/ed38dd3b1767193de971e694aa544356e63353c33a85d948166b5ff58b9e/watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6f39af2eab0118338902798b5aa6664f46ff66bc0280de76fca67a7f262a49", size = 457546, upload-time = "2025-10-14T15:06:13.372Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/cd/41/5e1a4bb12aac5f1493fa1bdc11154eca3b258ca4eba65d39c473fe19d8e9/watchfiles-1.2.0.tar.gz", hash = "sha256:c995fba777f1ea992f090f9236e9284cf7a5d1a0130dd5a3d82c598cacd76838", size = 108252, upload-time = "2026-05-18T04:32:04.251Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/5a/2bf22ecb24916983bf1cc0095e7dea2741d14d6553b0d6a2ac8bc96eca93/watchfiles-1.2.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bb68bf4df85abebe5efddc53cf2075520f243a59868d9b3973278b23e76962a9", size = 400471, upload-time = "2026-05-18T04:31:08.908Z" }, + { url = "https://files.pythonhosted.org/packages/55/70/dea1f6a0e76607841a60fb51af150e70124864673f61704abb62b90cdcc7/watchfiles-1.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c16cb06dd17d43b9d185094268459eac92c9538356f050e55b54e82cf700e1d4", size = 394599, upload-time = "2026-05-18T04:30:19.845Z" }, + { url = "https://files.pythonhosted.org/packages/18/52/752dcc7dc817baef5e89518732925795ce52e36a683a9a3c9fb68b21504e/watchfiles-1.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77a0feab9af4c021c581f695258c642b3d10c5fd4c676e33a0d8606425d82631", size = 455458, upload-time = "2026-05-18T04:30:29.126Z" }, + { url = "https://files.pythonhosted.org/packages/12/48/366ebbb22fcc504c2f72b45f0b7e72f40a18795cc01752c16066d597b67a/watchfiles-1.2.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a16ffe19bf5cf9f5edaa1ad1dd830c5a816e8feec430c522302ab55483a4b994", size = 460513, upload-time = "2026-05-18T04:31:40.85Z" }, + { url = "https://files.pythonhosted.org/packages/ad/44/1f9e1b15e7a729062e0d0c3d0d7225ea4ab98b2267ef87287153be2495fc/watchfiles-1.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:204f299afcbd65918ab78dbc52626b0ae45e9d8cef403fdbf33ecf9e40eac66e", size = 493616, upload-time = "2026-05-18T04:30:58.47Z" }, + { url = "https://files.pythonhosted.org/packages/7e/55/8b1086dcc8a1d6a697a62767bd7ea368e74c61c6fd171683cfe24a3fe5d2/watchfiles-1.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11743adfa510bfffebe97659fb280182b5c9b238708f667e866f308c3430dc19", size = 573154, upload-time = "2026-05-18T04:30:37.903Z" }, + { url = "https://files.pythonhosted.org/packages/14/7a/242f400cc77fafa7b18d53d19d9cb64fc6a6f61f28c55913bae7c674d92a/watchfiles-1.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb72919d93e3a16fc451d3aa3d4b1698423daca1b382d3d959c9ac51297c12a8", size = 467046, upload-time = "2026-05-18T04:30:41.869Z" }, + { url = "https://files.pythonhosted.org/packages/02/c8/79eee650c62d2c186598489814468e389b5def0ebe755399ff645b35b1b2/watchfiles-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62f042afde2dde21ec1d2c1a74361e804673df86f51e418a999c9acfe671b07", size = 457100, upload-time = "2026-05-18T04:31:13.064Z" }, + { url = "https://files.pythonhosted.org/packages/81/36/519f6dbb7a95e4fe7c1513ed25b1520295ef9905a27f1f2226a73892bfb7/watchfiles-1.2.0-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:027ae72bfdfd254862065d8b3e2a815c6ab9b1853ce41e6648ece84afd34a551", size = 467038, upload-time = "2026-05-18T04:30:32.915Z" }, + { url = "https://files.pythonhosted.org/packages/2f/12/951af6b9f89097e02511122258402cb3578443021930b70cf968d6310dc0/watchfiles-1.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e1cfd51e97e13ff3bd047c140764d277fc9b95b7cb5da59e46a47d167adab310", size = 632563, upload-time = "2026-05-18T04:30:11.539Z" }, + { url = "https://files.pythonhosted.org/packages/28/cc/0cba1f0a6117b7ec117271bdc3cb3a5a252005959755a2c09a745e0942cc/watchfiles-1.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:24b2405c0a46738dd9e1cf7135aa5dbdb9d42d024628651b3b13d5117e99f8df", size = 660851, upload-time = "2026-05-18T04:31:53.186Z" }, + { url = "https://files.pythonhosted.org/packages/d0/f2/26347558cc8bf6877845e66b315f644d03c173906aa09e233a3f4fd23928/watchfiles-1.2.0-cp310-cp310-win32.whl", hash = "sha256:8c520725602756229f045b032a1ff33d7ef0f7404189d62f6c2438cb6d8ef6a1", size = 277023, upload-time = "2026-05-18T04:30:18.825Z" }, + { url = "https://files.pythonhosted.org/packages/6d/68/a5e67b6b68e94f4c1511d61c46c55eba0737583620b6febf194c7b9cc23f/watchfiles-1.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:03b14855c6f35539e2d95c442ae9530a75762f1e26567152b9ed05f96534a74d", size = 290107, upload-time = "2026-05-18T04:32:09.677Z" }, + { url = "https://files.pythonhosted.org/packages/fc/3d/8024c801df84d1587740d0359e7fdd80afeae3d159011f3d5376dd82f18e/watchfiles-1.2.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:704fd259e332e01f9b9c178f4bce9e49027e5587cc2600eeeaf8e76e1c846201", size = 400242, upload-time = "2026-05-18T04:31:19.014Z" }, + { url = "https://files.pythonhosted.org/packages/87/5b/f4dfd45323e949984a3a7f9dc31d1cbb049921e7d98253488dda72ccdaa9/watchfiles-1.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6543cf55d170003296d185c0af981f3e1311564907e1f4e08671fc7693a890a5", size = 394562, upload-time = "2026-05-18T04:30:08.46Z" }, + { url = "https://files.pythonhosted.org/packages/98/d8/19483ef075d601c409bce8bcbb5c0f81a10876fff870400568f08ce484a1/watchfiles-1.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89d8c2394a065ca86f5d2910ff263ae67c127e1376ccc4f9fc35c71db879f80a", size = 456611, upload-time = "2026-05-18T04:30:45.723Z" }, + { url = "https://files.pythonhosted.org/packages/b1/6a/cc81fbe7ee42f2f22e661a6e12def7807e01b14b2f39e0ff83fd373fd307/watchfiles-1.2.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:772b80df316480d894a0e3165fdd19cf77f5d17f9a787f94029465ad0e3529d1", size = 461379, upload-time = "2026-05-18T04:31:29.292Z" }, + { url = "https://files.pythonhosted.org/packages/b1/57/7e669002082c0a0f4fb5113bb70125f7110124b846b0a11bc5ae8e90eac1/watchfiles-1.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d158cd89df6053823533e06fb1d73c549133bff5f0396170c0e53d9559340717", size = 493556, upload-time = "2026-05-18T04:30:05.44Z" }, + { url = "https://files.pythonhosted.org/packages/45/7d/f60a2b19807b21fe8281f3a8da4f59eef0d5f96825ac4680ba2d4f2ebf91/watchfiles-1.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d516b3283a758e087841aedb8031549fb41ced08f3db10aa6d2bf32dc042525b", size = 575255, upload-time = "2026-05-18T04:30:40.568Z" }, + { url = "https://files.pythonhosted.org/packages/bd/49/77f5b5e6efbcd57482f74948ebb1b97e5c0046d6b61475042d830c84b3ff/watchfiles-1.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:53b2290c92e0506d102cd448fbc610d87079553f86caa39d67440856a8b8bba5", size = 467052, upload-time = "2026-05-18T04:31:17.942Z" }, + { url = "https://files.pythonhosted.org/packages/ee/5a/73e2959af1b97fd5d556f9a8bdba017be23ceeef731869d5eaa0a753d5a3/watchfiles-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a711b51aec4370d0dcda5b6c09463206f133a5759341d7744b953a7b62e1100e", size = 456858, upload-time = "2026-05-18T04:30:30.182Z" }, + { url = "https://files.pythonhosted.org/packages/50/57/1bc8c27fad7e6c19bddee15d276dbb6ab72480ec01c127afff1673aee417/watchfiles-1.2.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:e2ca07fa7d89195ec0865d3d285666286740bfa83d83e5cee204043a31ecc165", size = 467579, upload-time = "2026-05-18T04:32:15.897Z" }, + { url = "https://files.pythonhosted.org/packages/09/6c/3c2e44edba3553c5e3c3b8c8a2a6dee6b9e12ae2cf4bd2378bebf9dc3038/watchfiles-1.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e0618518f282c4ebff60f5e5b1247b6d91bb8b9f4476947563a1e74acc66f3c6", size = 633253, upload-time = "2026-05-18T04:31:37.123Z" }, + { url = "https://files.pythonhosted.org/packages/30/c2/d8c84a882ab39bbefcc4915ab3e91830b7a7e990c5570b0b69075aba3faf/watchfiles-1.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0d191c054d0715c3c95c99df9b8dbf6fd096d8c1e021e8f212e1bd8bc444ccb5", size = 660713, upload-time = "2026-05-18T04:31:24.62Z" }, + { url = "https://files.pythonhosted.org/packages/a9/07/f97736a5fc605364fe67b25e9fa4a6965dfd4840d50c406ada507e9d735f/watchfiles-1.2.0-cp311-cp311-win32.whl", hash = "sha256:9342472aff9b093c5acd4f6d8f70ae0937964ab56542502bcf5579782da69ae8", size = 277222, upload-time = "2026-05-18T04:31:21.131Z" }, + { url = "https://files.pythonhosted.org/packages/cf/99/2b04981977fc2608afd60360d928c6aecf6b950292ca221d98f4005f6694/watchfiles-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:dbd6c97045dad81227c8d040173da044c1de08de64a5ea8b555da4aee1d5fa22", size = 290274, upload-time = "2026-05-18T04:31:45.966Z" }, + { url = "https://files.pythonhosted.org/packages/3c/74/f7f58a7075ee9cf612b0cfcddb78b8cd8234f0742d6f0075cf0da2dde1c6/watchfiles-1.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:57a2d9fa4fb4c2ecae57b13dfff2c7ab53e21a2ba674fe9f05506680fcdcc0d7", size = 283460, upload-time = "2026-05-18T04:31:39.126Z" }, + { url = "https://files.pythonhosted.org/packages/b8/2f/e42c992d2afda3108ea1c02acecc991b9f31d05c14adc2a7cee9ee211fc4/watchfiles-1.2.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:bc13eb17538be00c874699dc0abe4ee2bc8d50bb1166a6b9e175ef3fd7eb8f26", size = 400115, upload-time = "2026-05-18T04:32:02.06Z" }, + { url = "https://files.pythonhosted.org/packages/5f/8f/6af2ea19065c91d8b0ea3516fdfc8c0d349f407e8e9fbf4e5a17360de8ad/watchfiles-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d95ddc1eb6914154253d239089900813f6a767e174b8e6a50e7fdacb7e4236c", size = 393659, upload-time = "2026-05-18T04:30:50.951Z" }, + { url = "https://files.pythonhosted.org/packages/13/01/b32a967c56fb3e3e5be3db52c3d3b87fa4513aa367d8ed1ad96d42952e5f/watchfiles-1.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f70d8b291ef6e88d19b1f297a6905ddb978888d9272b0d05e6f53309856bcfc", size = 453207, upload-time = "2026-05-18T04:31:04.231Z" }, + { url = "https://files.pythonhosted.org/packages/04/98/97557a812180338cb1abd32e1cffcc4588f59b5f23e0cb006b2ba95ba64a/watchfiles-1.2.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:56d8641cf834c2836922899105bd3ce3d0dfc69291d52edf0b4d0436829b34c0", size = 459273, upload-time = "2026-05-18T04:31:50.377Z" }, + { url = "https://files.pythonhosted.org/packages/e8/a8/b4b08dcb7653b8087c6586f7ce649505900e866bbcfe40dc9587af02e686/watchfiles-1.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2581a94056e55d7d0a31a823ea92bf73749c489ca2285bfdc0fbe6b2bb49d50c", size = 489927, upload-time = "2026-05-18T04:31:42.485Z" }, + { url = "https://files.pythonhosted.org/packages/50/94/3dceea03545d2e5ddfd839f0ddd5e1cecbf1697b5a428d5ba11cef6af95d/watchfiles-1.2.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41bc1199f7523b3f82843c88cbb979180c949caef0342cf90968f178e5d49b01", size = 570476, upload-time = "2026-05-18T04:31:03.071Z" }, + { url = "https://files.pythonhosted.org/packages/cc/f2/d39a5450c3532092b91f81d274360e613c2371bc874a89c7a1a3c5e8d138/watchfiles-1.2.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7571e4464cb6e434958f867f7f730b8ab0b75e3f8e5eac0499168486ab3c33a8", size = 465650, upload-time = "2026-05-18T04:30:12.701Z" }, + { url = "https://files.pythonhosted.org/packages/22/24/ed72f68cbc1333ca9b9f2200aa048bb6658ae41709bc1caad4310f4bdffd/watchfiles-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e53a384f76b631c3ae5334ce6a52f0baa3a911eb94a4eac7f160079868b716d5", size = 456398, upload-time = "2026-05-18T04:30:13.784Z" }, + { url = "https://files.pythonhosted.org/packages/0d/64/982ef4a4e5bab5b6e5b6becc8cd5e732f6130a78b855f0abec6439a9a135/watchfiles-1.2.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:d20029a60a71a052a24c4db7673bc4de39ab89adbaccbfb5d67987c5d73f424d", size = 465140, upload-time = "2026-05-18T04:31:52.111Z" }, + { url = "https://files.pythonhosted.org/packages/a0/0c/95282abf4ed680b6096010bcfc30c5fa7a041fc5aa5a2ad17a2cc6c75bba/watchfiles-1.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2cb93af48550faf1cea04c303107c8b75833de7013e57ce27d3b8d21d8d0f58c", size = 630259, upload-time = "2026-05-18T04:31:25.676Z" }, + { url = "https://files.pythonhosted.org/packages/30/45/607c1de1530c4bdcf2cf1d1ecc2505ddba5d96bd43ba9f2b0e79876f850f/watchfiles-1.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2995c176de7692b86a2e4c58d9ec718f753150a979cb4a754e2b4ffa38e70906", size = 659859, upload-time = "2026-05-18T04:30:24.333Z" }, + { url = "https://files.pythonhosted.org/packages/fa/08/d9e2e0f9e8e6791d33aefc694ad7eefa7f901f63caff84a81ded38692f9c/watchfiles-1.2.0-cp312-cp312-win32.whl", hash = "sha256:7a2cffd17d27d2ecbb310c2b1d8174f222a5495b1a721894afa88ec11e25b898", size = 275480, upload-time = "2026-05-18T04:30:31.307Z" }, + { url = "https://files.pythonhosted.org/packages/1c/e6/9d42569c0102645cc8cea5d8c7d8a1e9d4ada2cb7f05f75e554b8aa2202a/watchfiles-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:f155b3a1b2a5fc89cdc70d47ee5d54e3b75e88efa34982028a35daef9ba00379", size = 288718, upload-time = "2026-05-18T04:32:10.745Z" }, + { url = "https://files.pythonhosted.org/packages/0a/26/88e0dc6ee3898169d7fa22bb6a69cabf2502d2ee25cb8c876d1262d204f8/watchfiles-1.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:8fa585ede612ee9f9e91b18bebf9ba11b9ae29a4e3a0d0cf6fca3e382133f0d5", size = 281026, upload-time = "2026-05-18T04:30:22.23Z" }, + { url = "https://files.pythonhosted.org/packages/d1/4d/70a7feced9f87e2ff26dba42667290f41694fc64646c67261fbb8cab5d5c/watchfiles-1.2.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:01ea8d66f0693b9b60a6541c8d10263091ca9a9060d242f3c1f3143f9aad2c98", size = 399730, upload-time = "2026-05-18T04:31:38.162Z" }, + { url = "https://files.pythonhosted.org/packages/31/3a/0da302f2307aee316922806ebd5726c542cbd787c938271cf14a074c7daf/watchfiles-1.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ba0480b9a74af058f43b337e937a451e109295c420916d68ad24e3dc02f5e44", size = 392842, upload-time = "2026-05-18T04:30:27.051Z" }, + { url = "https://files.pythonhosted.org/packages/db/ef/d5bdb705c224dbc256aa0c1ec47bf4e61ec52558f2afb44a71a1fe4d7015/watchfiles-1.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f34e26a19f91f710c08e0183429f0d1d15df734e6bc78c31e77b9ea9c433658", size = 452989, upload-time = "2026-05-18T04:31:11.945Z" }, + { url = "https://files.pythonhosted.org/packages/71/29/5495f2c1661949ef7a35e4d71111d129cfe7606414a26887a919d0a55406/watchfiles-1.2.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b4e77f6a55f858504069abd35d336a637555c09bca453dde1ee1e5ada8a6a1fb", size = 458978, upload-time = "2026-05-18T04:30:52.606Z" }, + { url = "https://files.pythonhosted.org/packages/d5/8c/7f9c07c433811c2fffd93e13fdfb7135de9aab5f2ae41be08960fa0047dc/watchfiles-1.2.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0cb4d80e212f116474a545c21c912b445f16bb0cef9e6a73a498164223e14e2f", size = 490248, upload-time = "2026-05-18T04:31:36.003Z" }, + { url = "https://files.pythonhosted.org/packages/3c/11/d93632febc52fbc21be90231bb7c17fd5387f46c9076fd40a5f9c2ae6910/watchfiles-1.2.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b974946a10af379d425e2eef5b62f5c6ebeaccf91d45eaad6f5b27ecd4f91aa0", size = 571847, upload-time = "2026-05-18T04:31:10.862Z" }, + { url = "https://files.pythonhosted.org/packages/55/b4/383173e73aabb07ad1d9c7aa859d95437ac46a6d6a1e11005facda0c9d19/watchfiles-1.2.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86bc13c25a8d1fcd70b51d0ce7c9b65e90de5666fcbfd3e34957cc73ee19aeb5", size = 465974, upload-time = "2026-05-18T04:30:17.006Z" }, + { url = "https://files.pythonhosted.org/packages/a7/6c/89b1a230a78f57c52dd8893adb1f92f94411721b6ec12596c56d98c74356/watchfiles-1.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca148d73dea36c9763aaa351e4d7a51780ec1584217c45276f4fe8239c768b71", size = 454782, upload-time = "2026-05-18T04:30:35.656Z" }, + { url = "https://files.pythonhosted.org/packages/24/62/1732118367cfff0a9fce3bf62ff4bfded09ef5df21d9d446b858b3f70a96/watchfiles-1.2.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:c525543d91961c6955b2636b308569e84a1d1c5f5f2932041ab9ef46422f43e3", size = 465182, upload-time = "2026-05-18T04:30:20.846Z" }, + { url = "https://files.pythonhosted.org/packages/28/96/716f7e5f51339bf22963f3345f9f27d7f3b30e2eadc597e257c881dd3c53/watchfiles-1.2.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:a204794696ffb8f9b10fba6f7cb5216d42f3b2b71860ccac6b6e42f5f10973b0", size = 629841, upload-time = "2026-05-18T04:31:05.397Z" }, + { url = "https://files.pythonhosted.org/packages/4c/fe/c40783950fd771ccf66ab3ec2722d188a9af1c7f96c6e811f36e40c6e03f/watchfiles-1.2.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:10d86db20695afe7997ac9e1717637d6714a8d0220458c33f3d2061f54cec427", size = 658028, upload-time = "2026-05-18T04:31:48.22Z" }, + { url = "https://files.pythonhosted.org/packages/71/72/4508db1856d1d87fcbb3b63f4839bab1b5682cb0e8d224d122263c09654a/watchfiles-1.2.0-cp313-cp313-win32.whl", hash = "sha256:eb283ee99e21ad6443c8cdb06ac5b34b1308c329cbdf03fa02b445363714c799", size = 275183, upload-time = "2026-05-18T04:30:59.57Z" }, + { url = "https://files.pythonhosted.org/packages/f9/36/14b76ca57652e5cc5fd1c11f32a261292c08a0d19a00351013c2549cbfb2/watchfiles-1.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:a0f27f01bee51861392bb6b7c4fdb290b27d1eb194e9e28788d68102a0e898d9", size = 288059, upload-time = "2026-05-18T04:32:07.937Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8d/0a85e395398d8d20fadfe5c5d32c726eee17a519e78fb356f2cf7531bffe/watchfiles-1.2.0-cp313-cp313-win_arm64.whl", hash = "sha256:3651aa7058595e9cfb75d35dd5ada2bf9f48a5b8a0f3562821d3e210c507e077", size = 280186, upload-time = "2026-05-18T04:31:54.484Z" }, + { url = "https://files.pythonhosted.org/packages/37/68/36db056f1fdcc5f07302f56e631774d6835bcd6fa3ace402304621d5f9e5/watchfiles-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:faea288b6f0ab1902ef08f4ca6de005dccf856c4e0c4f21b8c5fce02d90a1b08", size = 399031, upload-time = "2026-05-18T04:30:44.576Z" }, + { url = "https://files.pythonhosted.org/packages/c1/64/01a9d6f66a82a5c101ce939274106cc72759d62427e153f01edd2b9f87c2/watchfiles-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:01859b11fd9fbca670f4d5da00fbac282cfea9bd67a2125d8b2833a3b5617ea9", size = 391205, upload-time = "2026-05-18T04:30:25.413Z" }, + { url = "https://files.pythonhosted.org/packages/84/2c/0a44fe058cb4bb7b8ede6b6670698bbb7c0400740e378d00022189b7b31d/watchfiles-1.2.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fff610d7bb2256a317bb1e96f0d7862c7aa8076733ee5df0fd41bbe76a24a4f4", size = 451892, upload-time = "2026-05-18T04:32:14.005Z" }, + { url = "https://files.pythonhosted.org/packages/67/a1/351e0d56cd35e6488b5c8b4fb11a809a5bc923e8fe8fed9faf8920be0c89/watchfiles-1.2.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b141a4891c995a039cd89e9a49e62df1dc8a559a5d1a6e4c7106d16c12777a55", size = 458867, upload-time = "2026-05-18T04:31:22.279Z" }, + { url = "https://files.pythonhosted.org/packages/d5/7d/9d09605187f1b838998624049fcf8bf47b73c1a3b76901fcac1782f62277/watchfiles-1.2.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f22943b7770483f6ea0721c6b11d022947a98eb0acae14694de034f4d0d38925", size = 490217, upload-time = "2026-05-18T04:31:43.657Z" }, + { url = "https://files.pythonhosted.org/packages/60/5d/a17a16eccb182f04188cd308ec24b1a71a9b5c4e7098269cf35d9fa56d02/watchfiles-1.2.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1bc6195825b7dcd217968bb1f801a60fd4c16e8eeab5bedc7fe917d7d5995ab4", size = 571458, upload-time = "2026-05-18T04:32:11.875Z" }, + { url = "https://files.pythonhosted.org/packages/d3/3d/4dd457062083ab1938e5dfd45032eb425cee2ac817287ca8ff4356183e5d/watchfiles-1.2.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4a4b147f5dca2a5d325a06a832fb43f345751adfbc63204aec30e0d9ca965a2", size = 464707, upload-time = "2026-05-18T04:30:43.492Z" }, + { url = "https://files.pythonhosted.org/packages/c6/71/ea8c57b128f5383de74d0c7d2d9c57ad7c9a65a930c451bd25d524b295b7/watchfiles-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4543579a9bdb0c9560039b4ffddbdb39545707659fbc430ce4c10f3f68d557f9", size = 454663, upload-time = "2026-05-18T04:30:16.061Z" }, + { url = "https://files.pythonhosted.org/packages/53/fd/2e812bf938406d7db351f0703ddd3fc6c061cf30d96153a77bc79a943a44/watchfiles-1.2.0-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:20aa0e708b920bde876a4aa82dc7dd6ebea228a63a67cda6632c2fc87b787efa", size = 463537, upload-time = "2026-05-18T04:31:44.9Z" }, + { url = "https://files.pythonhosted.org/packages/86/56/d17a7f1dd1bc3035f1072694a551301272f1739c2d8e319c927cb9e29b38/watchfiles-1.2.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:d413349d565dab74297f2a63e84a097936be69bf8f3b3801f27f380e32040f44", size = 629194, upload-time = "2026-05-18T04:31:14.141Z" }, + { url = "https://files.pythonhosted.org/packages/be/06/f1ff66bf5cae50aa4062779a0ecd0bbaf15e466195719074078947d9a17d/watchfiles-1.2.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:f28b2725eb8cce327b9b3ab02415c853011dc55c95832fe90de6bc56f5315f72", size = 656194, upload-time = "2026-05-18T04:31:47.14Z" }, + { url = "https://files.pythonhosted.org/packages/e7/54/a9c7ea9a82a4ac65e7004c0a03920b5cdd2f9c3b678757d9cd425aa51d53/watchfiles-1.2.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:b8c8358484d5fa12ef34f05b7f4168eaf1932f408725ff6d023c33ec17bd79d4", size = 400205, upload-time = "2026-05-18T04:32:05.153Z" }, + { url = "https://files.pythonhosted.org/packages/aa/5d/c9ab3534374a4a67450696905d6ef16a04405448b8dc52bd752ae50423d4/watchfiles-1.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9f04b092229ad2c50126dd3c922c8822e51e605993764a33058d4a791ab42281", size = 392508, upload-time = "2026-05-18T04:30:54.849Z" }, + { url = "https://files.pythonhosted.org/packages/26/ca/1ad30103535cf0cecd7b993e8d50edc5351b1820e38f2d22e3df58962feb/watchfiles-1.2.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a7ce236284f002a156f70add88efe5c70879cccbb658be0822c54b1306fc09d", size = 452448, upload-time = "2026-05-18T04:30:53.727Z" }, + { url = "https://files.pythonhosted.org/packages/37/a1/ceee2cdf2afbd715fa07758d39c9859513eae411b23196f7fd039e5feedd/watchfiles-1.2.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b9909cc2b48468b575eefa944919e1fe8a36c5849d5c7c168f80a8c1db69398e", size = 459605, upload-time = "2026-05-18T04:30:23.312Z" }, + { url = "https://files.pythonhosted.org/packages/e8/f6/421e30fd1cb3907a84ed92ab3f1983e37ba2dca015e9a894a048418417a2/watchfiles-1.2.0-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a37faaed405c67e28e6be45a1fa4f206ef5a2860f27c237db9fa30704c38242", size = 490757, upload-time = "2026-05-18T04:30:47.358Z" }, + { url = "https://files.pythonhosted.org/packages/41/b0/55ed1b97ed08be7bba6f9a541cac15f2a858e1d74d2b07b6da70a82aab00/watchfiles-1.2.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9649193aa27bd9ff2e80ff29bfaa93085496c7a3a377592823cc58b77ee88add", size = 568672, upload-time = "2026-05-18T04:30:38.915Z" }, + { url = "https://files.pythonhosted.org/packages/d1/cf/d8ae8a80dd7bafab395ea7681c10237311bbf34d37704a8c744e7cf31fc7/watchfiles-1.2.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e4ff8e37f99cf1da89e255e07c9c4b37c214038c4283707bdec308cb1b0ea1f", size = 464197, upload-time = "2026-05-18T04:30:09.914Z" }, + { url = "https://files.pythonhosted.org/packages/7c/8a/3076c496ca8dafe0e8cd03fcebdfc47be4b1174b4e5b24ff6e396e6b3af2/watchfiles-1.2.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:054dc20fd2e3132b4c3883b4a00d72fd6e1f56fdaf89fccd12e8057d74cd74d7", size = 453181, upload-time = "2026-05-18T04:30:14.829Z" }, + { url = "https://files.pythonhosted.org/packages/e5/10/9745e17c98e7b8a86454df0a3c7b5686bd650383f1e9f26e4ebcbd6cc0c0/watchfiles-1.2.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:e140ed30ebde76796b686e67c182cff10ea2fbab186fafd1560f74bb5a473a6e", size = 465109, upload-time = "2026-05-18T04:30:28.123Z" }, + { url = "https://files.pythonhosted.org/packages/8f/95/8ef4a95481d3e0cb52d62a06fa6e972e81424be2d9698b91a2fecca9904c/watchfiles-1.2.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:bb7e52ecf68ba46d22df23467b87cffeb2146908aa523ebfe803019618cfda06", size = 630653, upload-time = "2026-05-18T04:31:49.304Z" }, + { url = "https://files.pythonhosted.org/packages/fd/e4/3b3bf36b0f829b50c6ebcb8d031583863c59f923d6a6af3d485e470d0fac/watchfiles-1.2.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:23282a321c8baf9b3a3c4afff673f9fe65eb7fdc2338d765ccad9d3d1916a5ba", size = 657838, upload-time = "2026-05-18T04:31:06.497Z" }, + { url = "https://files.pythonhosted.org/packages/21/b1/6cbbb50c1f3002ab568777d44aa21206dfb8807a840990c4037523b51812/watchfiles-1.2.0-cp314-cp314-win32.whl", hash = "sha256:c0db965c5f79aa49fe672d297cf1febc5ad149b658594944f49a54a2b96270a7", size = 275108, upload-time = "2026-05-18T04:30:06.891Z" }, + { url = "https://files.pythonhosted.org/packages/92/45/190ce6db8dcb4536682cf75d3889ff1a27182a58cb519d343cb6d9ea63d8/watchfiles-1.2.0-cp314-cp314-win_amd64.whl", hash = "sha256:71283b39fd17e5408eb123bd37aeecfd9d54c81fc184421943208aadb879d103", size = 288441, upload-time = "2026-05-18T04:32:12.901Z" }, + { url = "https://files.pythonhosted.org/packages/74/0d/3eae1c2313ab08378431d907c3f8095ecca00f3eda33111cf4f0f2591799/watchfiles-1.2.0-cp314-cp314-win_arm64.whl", hash = "sha256:c5c19526f4e54a00f2666a6c0e9e40d582c09e865055ea7378bf0009aab857b3", size = 280684, upload-time = "2026-05-18T04:31:26.902Z" }, + { url = "https://files.pythonhosted.org/packages/b1/75/fb64e6c25d6b5ca636d03df34ffb1c6e9873303e76d27967e045f8df088f/watchfiles-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:d73a585accffa5ae39c17264c36ec3166d2fad7000c780f5ef83b2722afb9dd2", size = 398857, upload-time = "2026-05-18T04:32:17.108Z" }, + { url = "https://files.pythonhosted.org/packages/73/4e/9f7adf01754cbf81843722ccfec169d8f26c69778281a302855cecd2ee08/watchfiles-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ae99b14c5f21e026e0e9d96f40e07d8570ebee6cafd9d8fc318354606daa7a28", size = 392413, upload-time = "2026-05-18T04:31:07.911Z" }, + { url = "https://files.pythonhosted.org/packages/47/c8/bec626bcc2d69f44b9acb24ce7d60ed7b16b73628eea747fcbd169d8edda/watchfiles-1.2.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4429f3b105524a10b72c3a819b091c495d2811d419c1e1e8df773a5a5974f831", size = 452409, upload-time = "2026-05-18T04:31:20.142Z" }, + { url = "https://files.pythonhosted.org/packages/00/b7/b6362068e81e7c556d155a34c35d40ac3ef42d747b06d7f6e5bf58e359c2/watchfiles-1.2.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:43d818978d06062d9b22c4fab2ebe44cf5213d42dc8e62bda8c2760cfa2eeb33", size = 458827, upload-time = "2026-05-18T04:32:06.219Z" }, + { url = "https://files.pythonhosted.org/packages/67/f8/9a813fa42afb1e0b4625e75f0479826644d3ee8dc287e093799bc01f390c/watchfiles-1.2.0-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b9f732dc58b2dbe69e464ccf8fff7a03b0dd0be439da4c0720d3558527d3d6b4", size = 490104, upload-time = "2026-05-18T04:31:56.034Z" }, + { url = "https://files.pythonhosted.org/packages/2f/bf/27dfb6094ca4c9aad21298b5525b6c53cb36121ee454331d05161e58d130/watchfiles-1.2.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f200104103feb097de4cab8fe4f5dd18a2026934c7dea98c55a2f5fd6d5a33b", size = 571360, upload-time = "2026-05-18T04:31:57.133Z" }, + { url = "https://files.pythonhosted.org/packages/fb/39/44a096d67270ea93df91d33877dbe91fbda3aa4f8ec2edf799d93eda8736/watchfiles-1.2.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:63ac26eefbf4af1741247d6fb68b11c49a25b2f7413fbd318a83a12aaa9cf666", size = 464644, upload-time = "2026-05-18T04:30:57.33Z" }, + { url = "https://files.pythonhosted.org/packages/0e/80/c7472203bad6268e3ef1ad260739704847898938ad7ea8b63a5131f46b50/watchfiles-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c4997d4e4a55f0d02b6cde327322daf3a0400e5df6c6b15948994bf72497925", size = 454771, upload-time = "2026-05-18T04:30:48.736Z" }, + { url = "https://files.pythonhosted.org/packages/51/cf/3b10b268b4b7f0fc26e9debb5eef1998b515887840f444cd3ec80c688755/watchfiles-1.2.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:4c887eba18b7945ac73067a8b4a66f21cd46c2539b2bc68588f7be6c7eb6d26b", size = 463494, upload-time = "2026-05-18T04:31:33.826Z" }, + { url = "https://files.pythonhosted.org/packages/3d/3e/a4302545cd589262a0dc7d140e86f7688eba3f9c72776c27f7e23b8864c4/watchfiles-1.2.0-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:3416ff151bb6b5a8d8d11664974fbef4d9305b9b2957839ab5a270468fd8df30", size = 629383, upload-time = "2026-05-18T04:31:15.596Z" }, + { url = "https://files.pythonhosted.org/packages/db/99/d5649df0a9a410d45b7c882304d0b790903ac9b6e8f2cfd12114e0c6b9f2/watchfiles-1.2.0-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:0e831a271c035d89789cffc386b6aa1375f39f1cd25eb7ca0997e4970d152fc5", size = 656093, upload-time = "2026-05-18T04:31:58.707Z" }, + { url = "https://files.pythonhosted.org/packages/23/f4/7513ef1e85fc4c6331b59479d6d72661fc391fbe543678052ac72c8b6c19/watchfiles-1.2.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:4674d49eb94706dfe666c069fc0a1b646ffcf920473492e209f6d5f60d3f0cc2", size = 403050, upload-time = "2026-05-18T04:30:36.753Z" }, + { url = "https://files.pythonhosted.org/packages/27/0b/a54103cfd732bb703c7a749222011a0483ef3705948dae3b203158601119/watchfiles-1.2.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:094b9b70103d4e963499bdea001ee3c2697b144cd9ae6218a62c0f89ec9e31db", size = 396629, upload-time = "2026-05-18T04:32:03.268Z" }, + { url = "https://files.pythonhosted.org/packages/5e/2c/73f31a3b893886206c3f54d73e8ad8dee58cdb2f69ad2622e0a8a9e07f4e/watchfiles-1.2.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0ef001f8c25ad0fa9529f914c1600647ecd0f542d11c19b7894768c67b6acb7", size = 457318, upload-time = "2026-05-18T04:31:01.932Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f9/45d021e4a5cc7b9dd567f7cbb06d3b75f751a690063fb6cc7ec60f4e46b7/watchfiles-1.2.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a88fc94e647bc4eec523f1caa540258eb71d14278b9daf72fa1e2658a98df0f0", size = 457771, upload-time = "2026-05-18T04:30:56.331Z" }, ] [[package]] From 6b1965a4555686f5efb18a6317d9a0bda5413fc5 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 20 May 2026 17:46:15 -0700 Subject: [PATCH 7/7] dataset util for mcore quantize Signed-off-by: Jennifer Chen --- modelopt/torch/utils/dataset_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 80ed8f9abdd..800140acd99 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -557,6 +557,22 @@ def __len__(self): return len(next(iter(self.encodings.values()))) +def get_dataloader_from_dataset( + dataset, + batch_size: int = 1, + distributed: bool = False, + sampler_kwargs: dict | None = None, + shuffle: bool = False, +) -> DataLoader: + """Wrap a pre-tokenized torch Dataset in a DataLoader, with optional DistributedSampler.""" + if distributed: + from torch.utils.data.distributed import DistributedSampler + + sampler = DistributedSampler(dataset, **(sampler_kwargs or {})) + return DataLoader(dataset, batch_size=batch_size, sampler=sampler) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + + def get_dataset_dataloader( dataset_name: str | list[str] = "cnn_dailymail", tokenizer: "PreTrainedTokenizerBase | None" = None,