From 016e475c8cbce5b9235988829d213f5d244dbb22 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Mon, 4 May 2026 02:51:32 -0700 Subject: [PATCH 01/11] [OpenVINO] Add DFlash model support - Introduced `--dflash-target-model` argument for exporting DFlash draft models. - Implemented `update_config_for_dflash` to handle DFlash-specific configurations. - Enhanced model conversion and metadata handling for DFlash models. - Added `DFlashDummyInputGenerator` for generating dummy inputs specific to DFlash. - Updated tests to include DFlash model loading and export functionality. This update enables the export and inference of models utilizing DFlash architecture, enhancing the OpenVINO integration. --- optimum/commands/export/openvino.py | 10 + optimum/exporters/openvino/__main__.py | 38 +++ optimum/exporters/openvino/convert.py | 20 ++ optimum/exporters/openvino/model_configs.py | 86 +++++ optimum/exporters/openvino/model_patcher.py | 312 +++++++++++++++++++ optimum/intel/openvino/modeling_decoder.py | 9 + tests/openvino/test_decoder.py | 42 +++ tests/openvino/test_exporters_cli.py | 17 + tests/openvino/utils_tests.py | 5 + tests/scripts/extract_dflash_debug_bundle.py | 113 +++++++ 10 files changed, 652 insertions(+) create mode 100644 tests/scripts/extract_dflash_debug_bundle.py diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index cd3280189e..a534c5d4f9 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -294,6 +294,15 @@ def parse_args_openvino(parser: "ArgumentParser"): type=json.loads, help=("Any kwargs passed to the model forward, or used to customize the export for a given model."), ) + optional_group.add_argument( + "--dflash-target-model", + type=str, + default=None, + help=( + "Target model ID or local path used when exporting DFlash draft models. Only the target token " + "embedding and lm_head weights are loaded." + ), + ) def no_compression_parameter_provided(args): @@ -479,6 +488,7 @@ def run(self): library_name=library_name, variant=self.args.variant, model_kwargs=self.args.model_kwargs, + dflash_target_model=self.args.dflash_target_model, # **input_shapes, ) if apply_main_quantize: diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index f8637d2d1e..a8815bc57f 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -152,6 +152,34 @@ def update_config_for_eagle3(config): return config +def update_config_for_dflash( + config, + dflash_target_model: Optional[str], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + revision: str = "main", + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, +): + if dflash_target_model is None: + raise ValueError("Exporting DFlash draft models requires `dflash_target_model` / `--dflash-target-model`.") + + moduler_name = "optimum.exporters.openvino.model_patcher" + spec = importlib.util.find_spec(moduler_name) + if spec and spec.origin: + moduler_path = os.path.dirname(spec.origin) + config.auto_map = { + "AutoModel": moduler_path + "--model_patcher.Qwen3DFlashDraftModel", + "AutoModelForCausalLM": moduler_path + "--model_patcher.Qwen3DFlashForCausalLM", + } + config.dflash_target_model = dflash_target_model + config.dflash_target_cache_dir = cache_dir + config.dflash_target_revision = revision + config.dflash_target_token = token + config.dflash_target_local_files_only = local_files_only + config.tie_word_embeddings = False + return config + + def infer_library_name( model_name_or_path: str, subfolder: str = "", @@ -198,6 +226,7 @@ def main_export( library_name: Optional[str] = None, model_loading_kwargs: Optional[Dict[str, Any]] = None, variant: Optional[str] = None, + dflash_target_model: Optional[str] = None, **kwargs_shapes, ): """ @@ -320,6 +349,15 @@ def main_export( archs = getattr(config, "architectures", None) if isinstance(archs, list) and len(archs) > 0 and archs[0] == "LlamaForCausalLMEagle3": loading_kwargs["config"] = update_config_for_eagle3(config) + elif isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel": + loading_kwargs["config"] = update_config_for_dflash( + config, + dflash_target_model=dflash_target_model, + cache_dir=cache_dir, + revision=revision, + token=token, + local_files_only=local_files_only, + ) # mxfp4 quantized model will be dequantized to bf16 if quant_method == "mxfp4" and is_transformers_version(">=", "4.55"): diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index fddd840b7d..a344b44f74 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -136,6 +136,8 @@ def _save_model( if getattr(config, "eagle3", False): model = _add_eagle3_mode_to_rt_info(model) + if getattr(config, "dflash", False): + model = _add_dflash_mode_to_rt_info(model, config._config) save_model(model, path, compress_to_fp16) del model @@ -869,6 +871,24 @@ def _add_eagle3_mode_to_rt_info(model: Model): return model +def _add_dflash_mode_to_rt_info(model: Model, hf_config): + """ + Add DFlash metadata. + """ + try: + model.set_rt_info("True", ["dflash_mode"]) + model.set_rt_info(str(getattr(hf_config, "block_size", "")), ["dflash", "block_size"]) + dflash_config = getattr(hf_config, "dflash_config", {}) + if "mask_token_id" in dflash_config: + model.set_rt_info(str(dflash_config["mask_token_id"]), ["dflash", "mask_token_id"]) + if "target_layer_ids" in dflash_config: + model.set_rt_info(",".join(map(str, dflash_config["target_layer_ids"])), ["dflash", "target_layer_ids"]) + except Exception: + pass + + return model + + def _add_version_info_to_model(model: Model, library_name: Optional[str] = None): """ Add dependency versions to OpenVINO model diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 36a83adfa7..f3328fd3b0 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -382,6 +382,55 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): _MODEL_PATCHER = Qwen2MoEPatcher +class DFlashDummyInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("input_ids", "target_hidden", "position_ids") + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + self.batch_size = batch_size + self.context_length = sequence_length + self.hidden_size = normalized_config.hidden_size + config = normalized_config.config + self.block_size = getattr(config, "block_size", sequence_length) + dflash_config = getattr(config, "dflash_config", {}) + self.num_target_layers = len(dflash_config.get("target_layer_ids", [])) + self.vocab_size = getattr(config, "vocab_size", normalized_config.vocab_size) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "input_ids": + return self.random_int_tensor( + [self.batch_size, self.block_size], + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ) + if input_name == "target_hidden": + return self.random_float_tensor( + [self.batch_size, self.context_length, self.hidden_size * self.num_target_layers], + framework=framework, + dtype=float_dtype, + ) + if input_name == "position_ids": + position_length = self.context_length + self.block_size + if framework == "pt": + return torch.arange(position_length, dtype=DTYPE_MAPPER.pt(int_dtype)).unsqueeze(0).repeat( + self.batch_size, 1 + ) + return self.random_int_tensor( + [self.batch_size, position_length], + max_value=position_length, + framework=framework, + dtype=int_dtype, + ) + raise ValueError(f"Unsupported DFlash input name: {input_name}") + + @register_in_tasks_manager( "qwen3", *[ @@ -400,8 +449,39 @@ class Qwen3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig _MODEL_PATCHER = OVDecoderModelPatcher + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + preprocessors: list[Any] | None = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + preprocessors=preprocessors, + ) + archs = getattr(config, "architectures", None) + self.dflash = isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel" + if self.dflash: + self.DUMMY_INPUT_GENERATOR_CLASSES = (DFlashDummyInputGenerator,) + self.MIN_TRANSFORMERS_VERSION = "4.57.0" + @property def inputs(self) -> Dict[str, Dict[int, str]]: + if self.dflash: + return { + "input_ids": {0: "batch_size", 1: "block_size"}, + "target_hidden": {0: "batch_size", 1: "context_length", 2: "target_hidden_size"}, + "position_ids": {0: "batch_size", 1: "position_sequence_length"}, + } if self.task in ["feature-extraction"]: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, @@ -411,6 +491,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs return common_inputs + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self.dflash: + return {"logits": {0: "batch_size", 1: "draft_sequence_length"}} + return super().outputs + class DummyQwen3VLLMInputGenerator(DummyTextInputGenerator): SUPPORTED_INPUT_NAMES = ( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 59e40f0e8a..d36c391568 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -14,11 +14,13 @@ import functools import inspect +import json import logging import logging as log import math import types from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -31,6 +33,7 @@ BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, + CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.models.llama.configuration_llama import LlamaConfig @@ -71,6 +74,34 @@ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock if is_transformers_version(">=", "4.54"): from transformers.masking_utils import create_causal_mask +if is_transformers_version(">=", "4.57"): + from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Config, + Qwen3MLP, + Qwen3PreTrainedModel, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + eager_attention_forward as qwen3_eager_attention_forward, + rotate_half as qwen3_rotate_half, + ) +else: + Qwen3Config = PretrainedConfig + Qwen3PreTrainedModel = PreTrainedModel + + class _UnavailableQwen3Module(nn.Module): + def __init__(self, *args, **kwargs): + raise ImportError("DFlash export requires transformers >= 4.57.") + + Qwen3MLP = _UnavailableQwen3Module + Qwen3RMSNorm = _UnavailableQwen3Module + Qwen3RotaryEmbedding = _UnavailableQwen3Module + + def qwen3_eager_attention_forward(*args, **kwargs): + raise ImportError("DFlash export requires transformers >= 4.57.") + + def qwen3_rotate_half(*args, **kwargs): + raise ImportError("DFlash export requires transformers >= 4.57.") + if is_transformers_version(">=", "4.56"): import transformers.masking_utils if is_transformers_version(">=", "5"): @@ -8461,6 +8492,287 @@ def forward( ) +def _dflash_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (qwen3_rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (qwen3_rotate_half(k) * sin) + return q_embed, k_embed + + +def _dflash_resolve_tensor_file(config: PretrainedConfig, tensor_name: str) -> str: + target_model = getattr(config, "dflash_target_model", None) + if target_model is None: + raise ValueError("DFlash logits export requires `dflash_target_model` to load target weights.") + + target_path = Path(target_model) + if target_path.is_dir(): + index_path = target_path / "model.safetensors.index.json" + if index_path.exists(): + with index_path.open() as f: + weight_map = json.load(f)["weight_map"] + return str(target_path / weight_map[tensor_name]) + single_file = target_path / "model.safetensors" + if single_file.exists(): + return str(single_file) + raise FileNotFoundError(f"Could not find safetensors weights in {target_path}.") + + from huggingface_hub import hf_hub_download + + cache_dir = getattr(config, "dflash_target_cache_dir", None) + revision = getattr(config, "dflash_target_revision", "main") + token = getattr(config, "dflash_target_token", None) + local_files_only = getattr(config, "dflash_target_local_files_only", False) + + try: + index_path = hf_hub_download( + target_model, + "model.safetensors.index.json", + cache_dir=cache_dir, + revision=revision, + token=token, + local_files_only=local_files_only, + ) + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + filename = weight_map[tensor_name] + except Exception: + filename = "model.safetensors" + + return hf_hub_download( + target_model, + filename, + cache_dir=cache_dir, + revision=revision, + token=token, + local_files_only=local_files_only, + ) + + +def _dflash_load_target_tensor(config: PretrainedConfig, tensor_names: Tuple[str, ...]) -> torch.Tensor: + from safetensors import safe_open + + last_error = None + for tensor_name in tensor_names: + try: + tensor_file = _dflash_resolve_tensor_file(config, tensor_name) + with safe_open(tensor_file, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + return f.get_tensor(tensor_name) + except Exception as error: + last_error = error + raise ValueError(f"Could not load any of {tensor_names} from DFlash target model.") from last_error + + +class Qwen3DFlashAttention(nn.Module): + """Qwen3 attention variant used by DFlash, where draft tokens attend over target context and noise tokens.""" + + def __init__(self, config: "Qwen3Config", layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) + query_states = self.q_norm(query_states).transpose(1, 2) + + key_context = self.k_proj(target_hidden) + key_noise = self.k_proj(hidden_states) + value_context = self.v_proj(target_hidden) + value_noise = self.v_proj(hidden_states) + + key_states = torch.cat([key_context, key_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + value_states = torch.cat([value_context, value_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + key_states = self.k_norm(key_states).transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = _dflash_apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_output, attn_weights = qwen3_eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlashDecoderLayer(nn.Module): + def __init__(self, config: "Qwen3Config", layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + target_hidden: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3DFlashDraftModel(Qwen3PreTrainedModel): + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlashDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.layers = nn.ModuleList( + [Qwen3DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + dflash_config = getattr(config, "dflash_config", {}) + self.target_layer_ids = dflash_config.get("target_layer_ids", []) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.fc = nn.Linear(len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.mask_token_id = dflash_config.get("mask_token_id", None) + self.post_init() + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + target_hidden: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + **kwargs, + ) -> torch.FloatTensor: + hidden_states = noise_embedding + target_hidden = target_hidden.to(hidden_states.dtype) + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + return self.norm(hidden_states) + + +class Qwen3DFlashForCausalLM(Qwen3DFlashDraftModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _keys_to_ignore_on_load_missing = [r"embed_tokens.weight", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self._load_target_weights(config) + + def _load_target_weights(self, config): + embed_weight = _dflash_load_target_tensor(config, ("model.embed_tokens.weight", "embed_tokens.weight")) + try: + lm_head_weight = _dflash_load_target_tensor(config, ("lm_head.weight",)) + except ValueError: + lm_head_weight = embed_weight + + self.embed_tokens.weight = nn.Parameter(embed_weight, requires_grad=False) + self.lm_head.weight = nn.Parameter(lm_head_weight, requires_grad=False) + + def forward( + self, + input_ids: torch.LongTensor, + target_hidden: torch.Tensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + logits_to_keep: Optional[int] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + noise_embedding = self.embed_tokens(input_ids) + hidden_states = super().forward( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + if logits_to_keep is None: + logits_to_keep = self.block_size - 1 + logits = self.lm_head(hidden_states[:, -logits_to_keep:, :]) + return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + + # Patched implementation of the gated delta rule in recurrent form. # Adapted from: # https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L522 diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index a74582b2c1..bf5e6c17a1 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -342,6 +342,7 @@ def _export( model_loading_kwargs["torch_dtype"] = torch_dtype variant = kwargs.pop("variant", None) + dflash_target_model = kwargs.pop("dflash_target_model", None) main_export( model_name_or_path=model_id, @@ -359,6 +360,7 @@ def _export( model_loading_kwargs=model_loading_kwargs, library_name=cls._library_name, variant=variant, + dflash_target_model=dflash_target_model, ) if config.model_type == "phi3" and config.max_position_embeddings != getattr( @@ -567,6 +569,13 @@ def prepare_inputs( hidden_states = torch.zeros(hs_shape, device=self.device, dtype=torch.float32) inputs["hidden_states"] = hidden_states + # DFlash draft models consume target-model context features explicitly. + if "target_hidden" in self.input_names: + target_hidden = kwargs.get("target_hidden", None) + if target_hidden is None: + raise ValueError("DFlash draft models require `target_hidden` to be passed to the forward call.") + inputs["target_hidden"] = target_hidden + return inputs def forward( diff --git a/tests/openvino/test_decoder.py b/tests/openvino/test_decoder.py index ba45d393f5..222e42862e 100644 --- a/tests/openvino/test_decoder.py +++ b/tests/openvino/test_decoder.py @@ -11,6 +11,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES from transformers.testing_utils import slow from utils_tests import ( + DFLASH_MODELS, EAGLE3_MODELS, F32_CONFIG, MODEL_NAMES, @@ -939,6 +940,47 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair): del ov_model gc.collect() + @parameterized.expand(DFLASH_MODELS.items()) + @pytest.mark.skipif(is_transformers_version("<", "4.57"), reason="DFlash requires transformers >= 4.57") + @pytest.mark.skipif("DFLASH_DEBUG_BUNDLE" not in os.environ, reason="Set DFLASH_DEBUG_BUNDLE to run DFlash tests") + @pytest.mark.run_slow + @slow + def test_load_and_infer_with_dflash_debug_bundle(self, model_arch, model_pair): + draft_model_id, target_model_id = model_pair + bundle = torch.load(os.environ["DFLASH_DEBUG_BUNDLE"], map_location="cpu") + + ov_model_path = os.environ.get("DFLASH_OV_MODEL_DIR") + if ov_model_path: + ov_model = OVModelForCausalLM.from_pretrained(ov_model_path, use_cache=False, device=OPENVINO_DEVICE) + else: + ov_model = OVModelForCausalLM.from_pretrained( + draft_model_id, + export=True, + trust_remote_code=True, + dflash_target_model=target_model_id, + use_cache=False, + stateful=False, + device=OPENVINO_DEVICE, + ) + + ov_outputs = ov_model( + input_ids=bundle["input_ids"], + target_hidden=bundle["target_hidden"], + position_ids=bundle["position_ids"], + ) + expected_logits = bundle["expected_logits"] + + self.assertEqual(tuple(ov_outputs.logits.shape), tuple(expected_logits.shape)) + torch.testing.assert_close( + ov_outputs.logits.float(), + expected_logits.float(), + rtol=float(os.environ.get("DFLASH_RTOL", "5e-2")), + atol=float(os.environ.get("DFLASH_ATOL", "5e-2")), + ) + + del ov_model + gc.collect() + HYBRID_ARCHITECTURES = [] if is_transformers_version(">=", "4.53"): HYBRID_ARCHITECTURES.append("granitemoehybrid") diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 397e6cc787..9ffaba1e15 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import os import subprocess import unittest from pathlib import Path @@ -28,6 +29,7 @@ ) from utils_tests import ( _ARCHITECTURES_TO_EXPECTED_INT8, + DFLASH_MODELS, MODEL_NAMES, OPENVINO_DEVICE, REMOTE_CODE_MODELS, @@ -855,6 +857,21 @@ def _openvino_export( model_name_or_path=model_name, output=tmpdir, task=task, model_kwargs=model_kwargs, **loading_kwargs ) + @unittest.skipUnless(os.environ.get("RUN_DFLASH_EXPORT_TEST"), "Set RUN_DFLASH_EXPORT_TEST=1 to run DFlash export") + def test_dflash_export_smoke(self): + draft_model_id, target_model_id = DFLASH_MODELS["qwen3_coder_dflash"] + with TemporaryDirectory() as tmpdir: + main_export( + model_name_or_path=draft_model_id, + output=tmpdir, + task="text-generation", + trust_remote_code=True, + dflash_target_model=target_model_id, + convert_tokenizer=False, + ) + self.assertTrue((Path(tmpdir) / "openvino_model.xml").exists()) + self.assertTrue((Path(tmpdir) / "openvino_model.bin").exists()) + def test_filtered_architectures(cls): if is_transformers_version("<", "4.49"): expected = {"qwen3_vl", "llama4", "qwen2_5_vl", "phi4mm"} diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 3ee3a2035b..7034c2ddd8 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -231,9 +231,13 @@ "ltx-video": "optimum-intel-internal-testing/tiny-random-ltx-video", "zamba2": "optimum-intel-internal-testing/tiny-random-zamba2", "qwen3_eagle3": "AngelSlim/Qwen3-1.7B_eagle3", + "qwen3_coder_dflash": "z-lab/Qwen3-Coder-30B-A3B-DFlash", } EAGLE3_MODELS = {"qwen3_eagle3": ("AngelSlim/Qwen3-1.7B_eagle3", "Qwen/Qwen3-1.7B")} +DFLASH_MODELS = { + "qwen3_coder_dflash": ("z-lab/Qwen3-Coder-30B-A3B-DFlash", "Qwen/Qwen3-Coder-30B-A3B-Instruct") +} _ARCHITECTURES_TO_EXPECTED_INT8 = { "afmoe": {"model": 16}, @@ -414,6 +418,7 @@ "minicpm3", "deepseek", "qwen3_eagle3", + "qwen3_coder_dflash", ) if is_transformers_version("<", "5"): diff --git a/tests/scripts/extract_dflash_debug_bundle.py b/tests/scripts/extract_dflash_debug_bundle.py new file mode 100644 index 0000000000..28bbce3a71 --- /dev/null +++ b/tests/scripts/extract_dflash_debug_bundle.py @@ -0,0 +1,113 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature == 0: + return logits.argmax(dim=-1) + return torch.distributions.Categorical(logits=logits / temperature).sample() + + +def extract_context_feature(hidden_states: list[torch.Tensor], layer_ids: list[int]) -> torch.Tensor: + # hidden_states[0] is the embedding output, so model layer ids are offset by one. + return torch.cat([hidden_states[layer_id + 1] for layer_id in layer_ids], dim=-1) + + +def main(): + parser = argparse.ArgumentParser(description="Create a lightweight DFlash correctness fixture.") + parser.add_argument("--draft-model", default="z-lab/Qwen3-Coder-30B-A3B-DFlash") + parser.add_argument("--target-model", default="Qwen/Qwen3-Coder-30B-A3B-Instruct") + parser.add_argument("--prompt", default="Write a quicksort in Python.") + parser.add_argument("--output", required=True) + parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"]) + parser.add_argument("--device-map", default="auto") + parser.add_argument("--temperature", type=float, default=0.0) + args = parser.parse_args() + + dtype = getattr(torch, args.dtype) + tokenizer = AutoTokenizer.from_pretrained(args.target_model, trust_remote_code=True) + target = AutoModelForCausalLM.from_pretrained( + args.target_model, + torch_dtype=dtype, + device_map=args.device_map, + trust_remote_code=True, + ).eval() + draft = AutoModel.from_pretrained( + args.draft_model, + torch_dtype=dtype, + device_map=args.device_map, + trust_remote_code=True, + ).eval() + + device = next(target.parameters()).device + input_ids = tokenizer(args.prompt, return_tensors="pt").input_ids.to(device) + block_size = draft.config.block_size + mask_token_id = draft.config.dflash_config["mask_token_id"] + target_layer_ids = draft.config.dflash_config["target_layer_ids"] + + position_ids = torch.arange(input_ids.shape[1] + block_size, device=device).unsqueeze(0) + with torch.inference_mode(): + target_output = target( + input_ids, + position_ids=position_ids[:, : input_ids.shape[1]], + use_cache=False, + logits_to_keep=1, + output_hidden_states=True, + ) + first_draft_token = sample(target_output.logits[:, -1, :], args.temperature) + + block_input_ids = torch.full((1, block_size), mask_token_id, dtype=torch.long, device=device) + block_input_ids[:, 0] = first_draft_token + + target_hidden = extract_context_feature(target_output.hidden_states, target_layer_ids).to(next(draft.parameters()).device) + noise_embedding = target.model.embed_tokens(block_input_ids.to(device)).to(next(draft.parameters()).device) + draft_position_ids = position_ids.to(next(draft.parameters()).device) + + draft_hidden = draft( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=draft_position_ids, + use_cache=False, + is_causal=False, + ) + expected_logits = target.lm_head(draft_hidden[:, -block_size + 1 :, :].to(device)) + sampled_tokens = sample(expected_logits, args.temperature) + + bundle = { + "input_ids": block_input_ids.cpu(), + "target_hidden": target_hidden.cpu(), + "position_ids": position_ids.cpu(), + "expected_logits": expected_logits.cpu(), + "sampled_tokens": sampled_tokens.cpu(), + "metadata": { + "draft_model": args.draft_model, + "target_model": args.target_model, + "prompt": args.prompt, + "block_size": block_size, + "mask_token_id": mask_token_id, + "target_layer_ids": target_layer_ids, + "dtype": args.dtype, + "temperature": args.temperature, + }, + } + torch.save(bundle, args.output) + + +if __name__ == "__main__": + main() From 41050f29cfccd4386579f60fbb9cdbe4d9c23c3e Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Mon, 4 May 2026 05:36:19 -0700 Subject: [PATCH 02/11] [OpenVINO] Enhance Qwen3DFlash model loading and weight handling - Removed the direct call to `_load_target_weights` in the constructor of `Qwen3DFlashForCausalLM`. - Added a class method `from_pretrained` to handle loading weights and configurations more effectively. - Updated weight handling to ensure compatibility with the target data type. - Modified the `extract_dflash_debug_bundle.py` script to use `dtype` instead of `torch_dtype` and added `attn_implementation` parameter for draft model loading. These changes improve the model's initialization process and enhance the flexibility of loading configurations. --- optimum/exporters/openvino/model_patcher.py | 18 +++++++++++++++++- tests/scripts/extract_dflash_debug_bundle.py | 5 +++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index d36c391568..ac1820288a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -8734,7 +8734,6 @@ def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self._load_target_weights(config) def _load_target_weights(self, config): embed_weight = _dflash_load_target_tensor(config, ("model.embed_tokens.weight", "embed_tokens.weight")) @@ -8743,9 +8742,26 @@ def _load_target_weights(self, config): except ValueError: lm_head_weight = embed_weight + target_dtype = self.fc.weight.dtype + embed_weight = embed_weight.to(target_dtype) + lm_head_weight = lm_head_weight.to(target_dtype) + self.embed_tokens.weight = nn.Parameter(embed_weight, requires_grad=False) self.lm_head.weight = nn.Parameter(lm_head_weight, requires_grad=False) + @classmethod + def from_pretrained(cls, *model_args, **kwargs): + output_loading_info = kwargs.get("output_loading_info", False) + result = super().from_pretrained(*model_args, **kwargs) + + if output_loading_info: + model, loading_info = result + model._load_target_weights(model.config) + return model, loading_info + + result._load_target_weights(result.config) + return result + def forward( self, input_ids: torch.LongTensor, diff --git a/tests/scripts/extract_dflash_debug_bundle.py b/tests/scripts/extract_dflash_debug_bundle.py index 28bbce3a71..0b0494914c 100644 --- a/tests/scripts/extract_dflash_debug_bundle.py +++ b/tests/scripts/extract_dflash_debug_bundle.py @@ -44,15 +44,16 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.target_model, trust_remote_code=True) target = AutoModelForCausalLM.from_pretrained( args.target_model, - torch_dtype=dtype, + dtype=dtype, device_map=args.device_map, trust_remote_code=True, ).eval() draft = AutoModel.from_pretrained( args.draft_model, - torch_dtype=dtype, + dtype=dtype, device_map=args.device_map, trust_remote_code=True, + attn_implementation="eager", ).eval() device = next(target.parameters()).device From 236a81a4893988818fb1b66ea045e6186430a8e7 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 5 May 2026 05:07:39 -0700 Subject: [PATCH 03/11] [OpenVINO] Add hidden state annotation support for text generation models - Introduced functions to check and annotate hidden states in models during export. - Enhanced configuration to include hidden state outputs for models with multiple hidden layers. - Implemented a test suite to validate hidden state annotations in exported OpenVINO models. These changes improve the model export process by allowing the inclusion of hidden states, which is essential for certain text generation tasks. --- optimum/exporters/openvino/convert.py | 107 +++++++++++++++++- .../openvino/test_hidden_state_annotations.py | 59 ++++++++++ 2 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 tests/openvino/test_hidden_state_annotations.py diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index a344b44f74..fc167402c3 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -16,6 +16,7 @@ import functools import gc import inspect +import json import logging import os from pathlib import Path @@ -81,6 +82,9 @@ logger = logging.getLogger(__name__) +HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" +HIDDEN_STATE_TENSOR_NAME_TEMPLATE = "ov.hidden_states.decoder_layer_{}" + if is_torch_available(): import torch.nn as nn from transformers.modeling_utils import PreTrainedModel @@ -144,6 +148,86 @@ def _save_model( gc.collect() +def _can_annotate_hidden_states(model, config: "OnnxConfig") -> bool: + if "text-generation" not in getattr(config, "task", ""): + return False + + model_config = getattr(model, "config", None) + if model_config is None or not hasattr(model_config, "num_hidden_layers"): + return False + if not hasattr(model_config, "output_hidden_states"): + return False + + can_record_outputs = getattr(model, "_can_record_outputs", None) + if isinstance(can_record_outputs, dict) and not can_record_outputs.get("hidden_states", False): + return False + + return True + + +def _annotate_hidden_state_outputs(model: Model, public_output_names: List[str], num_hidden_layers: int): + public_output_names = set(public_output_names) + temporary_results = [] + for result in model.get_results(): + result_names = set(result.output(0).get_names()) + if result_names & public_output_names: + continue + temporary_results.append(result) + + if not temporary_results: + return + + # HF hidden_states usually contains embeddings first, then one tensor after each decoder block. + offset = 1 if len(temporary_results) >= num_hidden_layers + 1 else 0 + hidden_layer_results = temporary_results[offset : offset + num_hidden_layers] + if len(hidden_layer_results) == num_hidden_layers: + layers = {} + for layer_idx, result in enumerate(hidden_layer_results): + tensor_name = HIDDEN_STATE_TENSOR_NAME_TEMPLATE.format(layer_idx) + result.input_value(0).get_tensor().add_names({tensor_name}) + layers[str(layer_idx)] = tensor_name + + model.set_rt_info( + json.dumps({"version": 1, "layers": layers}), + HIDDEN_STATES_RT_INFO_KEY, + ) + else: + logger.debug( + "Skipping hidden-state annotation: expected at least %s hidden-state outputs, got %s.", + num_hidden_layers, + len(temporary_results), + ) + + for result in temporary_results: + model.remove_result(result) + model.validate_nodes_and_infer_types() + + +def _enable_hidden_states_in_config_outputs(config: "OnnxConfig", num_hidden_layers: int): + original_class = config.__class__ + + class ConfigWithHiddenStateOutputs(original_class): + @property + def outputs(self): + outputs = original_class.outputs.fget(self).copy() + for idx in range(num_hidden_layers + 1): + outputs[f"hidden_states.{idx}"] = {0: "batch_size", 1: "sequence_length"} + return outputs + + config.__class__ = ConfigWithHiddenStateOutputs + return original_class + + +def _flatten_hidden_states_outputs(outputs: Dict): + hidden_states = outputs.pop("hidden_states", None) + if hidden_states is None: + return outputs + + for idx, hidden_state in enumerate(hidden_states): + outputs[f"hidden_states.{idx}"] = hidden_state + return outputs + + def export( model: Union["PreTrainedModel", "ModelMixin", "DiffusionPipeline"], config: "OnnxConfig", @@ -382,6 +466,17 @@ def export_pytorch( dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs) dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs) + output_names = list(config.outputs.keys()) + annotate_hidden_states = _can_annotate_hidden_states(model, config) + original_output_hidden_states = None + original_config_class = None + if annotate_hidden_states: + original_output_hidden_states = model.config.output_hidden_states + model.config.output_hidden_states = True + if hasattr(config, "_config") and hasattr(config._config, "output_hidden_states"): + config._config.output_hidden_states = True + original_config_class = _enable_hidden_states_in_config_outputs(config, model.config.num_hidden_layers) + # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching, # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output # To handle it, additional wrapper on patcher forward applied. @@ -406,6 +501,8 @@ def ts_patched_forward(*args, **kwargs): input_dict = dict(zip(keys, tuple_input)) kwargs[input_name] = input_dict outputs = patched_forward(**kwargs) + if annotate_hidden_states: + outputs = _flatten_hidden_states_outputs(outputs) return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()]) patcher.patched_forward = ts_patched_forward @@ -449,9 +546,14 @@ def ts_patched_forward(*args, **kwargs): extension=conversion_extensions, ) + if annotate_hidden_states: + config.__class__ = original_config_class + model.config.output_hidden_states = original_output_hidden_states + if hasattr(config, "_config") and hasattr(config._config, "output_hidden_states"): + config._config.output_hidden_states = original_output_hidden_states + ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation? - output_names = list(config.outputs.keys()) for idx, out_tensor in enumerate(ov_model.outputs): if idx < len(output_names): out_tensor.get_tensor().set_names({output_names[idx]}) @@ -464,6 +566,9 @@ def ts_patched_forward(*args, **kwargs): if stateful: patch_stateful(model.config, ov_model) + if annotate_hidden_states: + _annotate_hidden_state_outputs(ov_model, output_names, model.config.num_hidden_layers) + library_name = _infer_library_from_model_or_model_class(model=model, library_name=library_name) _save_model( diff --git a/tests/openvino/test_hidden_state_annotations.py b/tests/openvino/test_hidden_state_annotations.py new file mode 100644 index 0000000000..b1588b3d4c --- /dev/null +++ b/tests/openvino/test_hidden_state_annotations.py @@ -0,0 +1,59 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +import openvino as ov +from transformers import AutoModelForCausalLM +from utils_tests import MODEL_NAMES + +from optimum.exporters.openvino import export_from_model + + +HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" + + +class HiddenStateAnnotationExportTest(unittest.TestCase): + def test_export_hidden_state_annotations_without_extra_outputs(self): + for task in ("text-generation", "text-generation-with-past"): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) + with self.subTest(task=task), TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task=task, + preprocessors=None, + stateful=False, + ) + + ov_model = ov.Core().read_model(Path(tmpdirname) / "openvino_model.xml") + output_names = set().union(*(output.get_names() for output in ov_model.outputs)) + self.assertNotIn("last_hidden_state", output_names) + self.assertFalse(any(name.startswith("ov.hidden_states.") for name in output_names)) + + rt_info = ov_model.get_rt_info() + self.assertIn(HIDDEN_STATES_RT_INFO_KEY, rt_info) + annotation = json.loads(rt_info[HIDDEN_STATES_RT_INFO_KEY].value) + self.assertEqual(annotation["version"], 1) + self.assertEqual(len(annotation["layers"]), model.config.num_hidden_layers) + + graph_tensor_names = set() + for op in ov_model.get_ops(): + for output in op.outputs(): + graph_tensor_names.update(output.get_names()) + for tensor_name in annotation["layers"].values(): + self.assertIn(tensor_name, graph_tensor_names) From 0d4d48987d834a4df0f43b2532dde87cc4b0300f Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 5 May 2026 07:44:33 -0700 Subject: [PATCH 04/11] [OpenVINO] Add tests for hidden state output annotation - Implemented helper functions to find and add model outputs based on tensor names. - Added a new test case to validate that annotated hidden state outputs match those from PyTorch for the GPT-2 model. - Enhanced the export process to include hidden state outputs, ensuring compatibility with text generation tasks. These changes improve the testing framework for OpenVINO model exports, specifically focusing on hidden state annotations. --- .../openvino/test_hidden_state_annotations.py | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/openvino/test_hidden_state_annotations.py b/tests/openvino/test_hidden_state_annotations.py index b1588b3d4c..69764a523c 100644 --- a/tests/openvino/test_hidden_state_annotations.py +++ b/tests/openvino/test_hidden_state_annotations.py @@ -17,7 +17,9 @@ from pathlib import Path from tempfile import TemporaryDirectory +import numpy as np import openvino as ov +import torch from transformers import AutoModelForCausalLM from utils_tests import MODEL_NAMES @@ -27,6 +29,22 @@ HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" +def _find_output_by_tensor_name(model, tensor_name): + for op in model.get_ops(): + for output in op.outputs(): + if tensor_name in output.get_names(): + return output + raise AssertionError(f"Tensor {tensor_name} was not found in the OpenVINO graph") + + +def _add_model_output(model, output, output_name): + output.get_tensor().add_names({output_name}) + if hasattr(model, "add_output"): + model.add_output(output) + else: + model.add_outputs([output]) + + class HiddenStateAnnotationExportTest(unittest.TestCase): def test_export_hidden_state_annotations_without_extra_outputs(self): for task in ("text-generation", "text-generation-with-past"): @@ -57,3 +75,59 @@ def test_export_hidden_state_annotations_without_extra_outputs(self): graph_tensor_names.update(output.get_names()) for tensor_name in annotation["layers"].values(): self.assertIn(tensor_name, graph_tensor_names) + + def test_annotated_hidden_state_output_matches_pytorch(self): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) + model.eval() + + with TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task="text-generation", + preprocessors=None, + stateful=False, + ) + + core = ov.Core() + ov_model = core.read_model(Path(tmpdirname) / "openvino_model.xml") + annotation = json.loads(ov_model.get_rt_info()[HIDDEN_STATES_RT_INFO_KEY].value) + layer_idx = 0 + output_name = "decoder_layer_0_hidden_state" + hidden_state_output = _find_output_by_tensor_name(ov_model, annotation["layers"][str(layer_idx)]) + _add_model_output(ov_model, hidden_state_output, output_name) + + input_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + with torch.no_grad(): + torch_outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + return_dict=True, + ) + + compiled_model = core.compile_model(ov_model, "CPU") + ov_inputs = {} + for input_port in compiled_model.inputs: + input_name = input_port.get_any_name() + if input_name == "input_ids": + ov_inputs[input_name] = input_ids.numpy() + elif input_name == "attention_mask": + ov_inputs[input_name] = attention_mask.numpy() + elif input_name == "position_ids": + ov_inputs[input_name] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1) + elif input_name == "token_type_ids": + ov_inputs[input_name] = np.zeros(input_ids.shape, dtype=np.int64) + else: + self.fail(f"Unexpected OpenVINO model input: {input_name}") + + infer_result = compiled_model(ov_inputs) + ov_output_port = next(output for output in compiled_model.outputs if output_name in output.get_names()) + np.testing.assert_allclose( + infer_result[ov_output_port], + torch_outputs.hidden_states[layer_idx + 1].detach().numpy(), + rtol=1e-4, + atol=1e-4, + ) From 1824a2f60763787e4f570be966dfee80ca1b2ce2 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Sun, 17 May 2026 06:28:33 -0700 Subject: [PATCH 05/11] [OpenVINO] Implement DFlash block size override in configuration - Added support for overriding the DFlash block size via the environment variable `DFLASH_BLOCK_SIZE_OVERRIDE`. - Included error handling to ensure the block size is an integer greater than 1. - This enhancement allows for more flexible configuration of DFlash model exports, improving usability and performance. These changes contribute to the ongoing improvements in the OpenVINO export process for DFlash models. --- optimum/exporters/openvino/__main__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index a8815bc57f..b28e53ff39 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -171,6 +171,17 @@ def update_config_for_dflash( "AutoModel": moduler_path + "--model_patcher.Qwen3DFlashDraftModel", "AutoModelForCausalLM": moduler_path + "--model_patcher.Qwen3DFlashForCausalLM", } + # TODO ofir: remove this after implementing load time override in openvino.genai + dflash_block_size_override = os.environ.get("DFLASH_BLOCK_SIZE_OVERRIDE") + if dflash_block_size_override: + try: + block_size = int(dflash_block_size_override) + except ValueError as exc: + raise ValueError("DFLASH_BLOCK_SIZE_OVERRIDE must be an integer.") from exc + if block_size <= 1: + raise ValueError("DFLASH_BLOCK_SIZE_OVERRIDE must be greater than 1.") + config.block_size = block_size + config.dflash_target_model = dflash_target_model config.dflash_target_cache_dir = cache_dir config.dflash_target_revision = revision From dfb074c981f5b1c7965875a2cb8b3ce337da5c49 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 20 May 2026 01:53:33 -0700 Subject: [PATCH 06/11] [OpenVINO] Enhance DFlash model integration and testing - Added support for committed prefix cache policy in DFlash models by updating runtime information. - Modified `DFlashDummyInputGenerator` to use "hidden_states" instead of "target_hidden" for input names. - Updated Qwen3DFlash model to handle hidden states and past key values more effectively during inference. - Introduced a new script to compare DFlash cache semantics between original and patched models. - Enhanced tests to validate the integration of hidden states and ensure consistency in outputs. These changes improve the functionality and testing of DFlash models within the OpenVINO framework, ensuring better performance and reliability. --- optimum/exporters/openvino/convert.py | 1 + optimum/exporters/openvino/model_configs.py | 29 ++- optimum/exporters/openvino/model_patcher.py | 78 ++++-- optimum/intel/openvino/modeling_decoder.py | 19 +- tests/openvino/test_decoder.py | 95 +++++++- .../scripts/compare_dflash_cache_semantics.py | 161 +++++++++++++ tests/scripts/extract_dflash_debug_bundle.py | 222 +++++++++++++++--- 7 files changed, 519 insertions(+), 86 deletions(-) create mode 100644 tests/scripts/compare_dflash_cache_semantics.py diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index fc167402c3..d1023d5759 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -983,6 +983,7 @@ def _add_dflash_mode_to_rt_info(model: Model, hf_config): try: model.set_rt_info("True", ["dflash_mode"]) model.set_rt_info(str(getattr(hf_config, "block_size", "")), ["dflash", "block_size"]) + model.set_rt_info("committed_prefix", ["dflash", "cache_policy"]) dflash_config = getattr(hf_config, "dflash_config", {}) if "mask_token_id" in dflash_config: model.set_rt_info(str(dflash_config["mask_token_id"]), ["dflash", "mask_token_id"]) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index f3328fd3b0..33a8cf323a 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -383,7 +383,7 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): class DFlashDummyInputGenerator(DummyInputGenerator): - SUPPORTED_INPUT_NAMES = ("input_ids", "target_hidden", "position_ids") + SUPPORTED_INPUT_NAMES = ("input_ids", "hidden_states", "position_ids") def __init__( self, @@ -410,7 +410,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int framework=framework, dtype=int_dtype, ) - if input_name == "target_hidden": + if input_name == "hidden_states": return self.random_float_tensor( [self.batch_size, self.context_length, self.hidden_size * self.num_target_layers], framework=framework, @@ -471,17 +471,18 @@ def __init__( archs = getattr(config, "architectures", None) self.dflash = isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel" if self.dflash: - self.DUMMY_INPUT_GENERATOR_CLASSES = (DFlashDummyInputGenerator,) + self.DUMMY_INPUT_GENERATOR_CLASSES = (DFlashDummyInputGenerator, GemmaDummyPastKeyValuesGenerator) self.MIN_TRANSFORMERS_VERSION = "4.57.0" @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.dflash: - return { - "input_ids": {0: "batch_size", 1: "block_size"}, - "target_hidden": {0: "batch_size", 1: "context_length", 2: "target_hidden_size"}, - "position_ids": {0: "batch_size", 1: "position_sequence_length"}, - } + common_inputs = super().inputs + common_inputs.pop("attention_mask", None) + common_inputs["input_ids"] = {0: "batch_size", 1: "block_size"} + common_inputs["hidden_states"] = {0: "batch_size", 1: "context_length", 2: "target_hidden_size"} + common_inputs["position_ids"] = {0: "batch_size", 1: "position_sequence_length"} + return common_inputs if self.task in ["feature-extraction"]: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, @@ -494,9 +495,19 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: if self.dflash: - return {"logits": {0: "batch_size", 1: "draft_sequence_length"}} + common_outputs = super().outputs + common_outputs["logits"] = {0: "batch_size", 1: "draft_sequence_length"} + return common_outputs return super().outputs + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + super().add_past_key_values(inputs_or_outputs, direction) + if self.dflash and direction == "outputs": + for axes in inputs_or_outputs.values(): + for axis, name in axes.items(): + if name == "past_sequence_length + sequence_length": + axes[axis] = "past_sequence_length + context_length" + class DummyQwen3VLLMInputGenerator(DummyTextInputGenerator): SUPPORTED_INPUT_NAMES = ( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index ac1820288a..83e815c4d3 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -8601,22 +8601,31 @@ def forward( query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) query_states = self.q_norm(query_states).transpose(1, 2) - key_context = self.k_proj(target_hidden) - key_noise = self.k_proj(hidden_states) - value_context = self.v_proj(target_hidden) - value_noise = self.v_proj(hidden_states) - - key_states = torch.cat([key_context, key_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) - value_states = torch.cat([value_context, value_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + kv_hidden_states = torch.cat([target_hidden, hidden_states], dim=1) + key_states = self.k_proj(kv_hidden_states).view(bsz, ctx_len + q_len, -1, self.head_dim) + value_states = self.v_proj(kv_hidden_states).view(bsz, ctx_len + q_len, -1, self.head_dim) key_states = self.k_norm(key_states).transpose(1, 2) value_states = value_states.transpose(1, 2) cos, sin = position_embeddings query_states, key_states = _dflash_apply_rotary_pos_emb(query_states, key_states, cos, sin) + target_key_states, block_key_states = key_states.split([ctx_len, q_len], dim=2) + target_value_states, block_value_states = value_states.split([ctx_len, q_len], dim=2) if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Persist only committed target-prefix K/V. The speculative block + # stays local to this inference, so rejection never requires cache trim. + target_cache_position = cache_position[:ctx_len] if cache_position is not None else None + cache_kwargs = {"sin": sin[:, :ctx_len], "cos": cos[:, :ctx_len], "cache_position": target_cache_position} + target_key_states, target_value_states = past_key_values.update( + target_key_states, + target_value_states, + self.layer_idx, + cache_kwargs, + ) + + key_states = torch.cat([target_key_states, block_key_states], dim=2) + value_states = torch.cat([target_value_states, block_value_states], dim=2) attn_output, attn_weights = qwen3_eager_attention_forward( self, @@ -8703,27 +8712,48 @@ def forward( position_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, noise_embedding: Optional[torch.Tensor] = None, - target_hidden: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, - use_cache: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> torch.FloatTensor: - hidden_states = noise_embedding - target_hidden = target_hidden.to(hidden_states.dtype) + ) -> BaseModelOutputWithPast: + noise_states = noise_embedding + target_hidden = hidden_states.to(noise_states.dtype) target_hidden = self.hidden_norm(self.fc(target_hidden)) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + elif not isinstance(past_key_values, Cache): + if is_transformers_version("<", "5"): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCache(past_key_values) + if use_cache and cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + target_hidden.shape[1] + noise_states.shape[1], + device=noise_states.device, + ) + position_embeddings = self.rotary_emb(noise_states, position_ids) for layer in self.layers: - hidden_states = layer( - hidden_states=hidden_states, + noise_states = layer( + hidden_states=noise_states, target_hidden=target_hidden, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, + cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) - return self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=self.norm(noise_states), + past_key_values=past_key_values if use_cache else None, + ) class Qwen3DFlashForCausalLM(Qwen3DFlashDraftModel, GenerationMixin): @@ -8765,17 +8795,17 @@ def from_pretrained(cls, *model_args, **kwargs): def forward( self, input_ids: torch.LongTensor, - target_hidden: torch.Tensor, + hidden_states: torch.Tensor, position_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, - use_cache: bool = False, + use_cache: Optional[bool] = None, logits_to_keep: Optional[int] = None, **kwargs, ) -> CausalLMOutputWithPast: noise_embedding = self.embed_tokens(input_ids) - hidden_states = super().forward( - target_hidden=target_hidden, + outputs = super().forward( + hidden_states=hidden_states, noise_embedding=noise_embedding, position_ids=position_ids, attention_mask=attention_mask, @@ -8785,8 +8815,8 @@ def forward( ) if logits_to_keep is None: logits_to_keep = self.block_size - 1 - logits = self.lm_head(hidden_states[:, -logits_to_keep:, :]) - return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + logits = self.lm_head(outputs.last_hidden_state[:, -logits_to_keep:, :]) + return CausalLMOutputWithPast(logits=logits, past_key_values=outputs.past_key_values) # Patched implementation of the gated delta rule in recurrent form. diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index bf5e6c17a1..9abaa67d3f 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -550,7 +550,10 @@ def prepare_inputs( else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 - if past_key_values: + dflash_hidden_states = kwargs.get("hidden_states", None) + if past_key_values and dflash_hidden_states is not None and "hidden_states" in self.input_names: + position_ids = position_ids[:, -(input_ids.shape[1] + dflash_hidden_states.shape[1]) :] + elif past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids @@ -565,17 +568,15 @@ def prepare_inputs( if "hidden_states" in self.input_names: hidden_states = kwargs.get("hidden_states", None) if hidden_states is None: - hs_shape = (batch_size, input_ids.shape[1], self.config.hidden_size * 3) + dflash_config = getattr(self.config, "dflash_config", None) + if dflash_config is not None: + raise ValueError("DFlash draft models require `hidden_states` to be passed to the forward call.") + else: + hidden_size = self.config.hidden_size * 3 + hs_shape = (batch_size, input_ids.shape[1], hidden_size) hidden_states = torch.zeros(hs_shape, device=self.device, dtype=torch.float32) inputs["hidden_states"] = hidden_states - # DFlash draft models consume target-model context features explicitly. - if "target_hidden" in self.input_names: - target_hidden = kwargs.get("target_hidden", None) - if target_hidden is None: - raise ValueError("DFlash draft models require `target_hidden` to be passed to the forward call.") - inputs["target_hidden"] = target_hidden - return inputs def forward( diff --git a/tests/openvino/test_decoder.py b/tests/openvino/test_decoder.py index 222e42862e..2f809cf40e 100644 --- a/tests/openvino/test_decoder.py +++ b/tests/openvino/test_decoder.py @@ -945,35 +945,104 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair): @pytest.mark.skipif("DFLASH_DEBUG_BUNDLE" not in os.environ, reason="Set DFLASH_DEBUG_BUNDLE to run DFlash tests") @pytest.mark.run_slow @slow - def test_load_and_infer_with_dflash_debug_bundle(self, model_arch, model_pair): + def test_load_and_infer_stateful_with_dflash_debug_bundle(self, model_arch, model_pair): draft_model_id, target_model_id = model_pair bundle = torch.load(os.environ["DFLASH_DEBUG_BUNDLE"], map_location="cpu") ov_model_path = os.environ.get("DFLASH_OV_MODEL_DIR") if ov_model_path: - ov_model = OVModelForCausalLM.from_pretrained(ov_model_path, use_cache=False, device=OPENVINO_DEVICE) + ov_model = OVModelForCausalLM.from_pretrained( + ov_model_path, + use_cache=True, + stateful=True, + device=OPENVINO_DEVICE, + ) else: ov_model = OVModelForCausalLM.from_pretrained( draft_model_id, export=True, trust_remote_code=True, dflash_target_model=target_model_id, - use_cache=False, - stateful=False, + use_cache=True, + stateful=True, device=OPENVINO_DEVICE, ) - - ov_outputs = ov_model( - input_ids=bundle["input_ids"], - target_hidden=bundle["target_hidden"], - position_ids=bundle["position_ids"], + self.assertTrue(ov_model.stateful) + if hasattr(ov_model, "model"): + self.assertEqual(ov_model.model.get_rt_info(["dflash", "cache_policy"]).value, "committed_prefix") + input_names = {name for model_input in ov_model.model.inputs for name in model_input.get_names()} + output_names = {name for model_output in ov_model.model.outputs for name in model_output.get_names()} + self.assertFalse(any("past_key_values" in name for name in input_names)) + self.assertFalse(any("present" in name for name in output_names)) + + steps = bundle.get("steps") + if steps is None: + steps = [ + { + "input_ids": bundle["input_ids"], + "hidden_states": bundle["hidden_states"], + "position_ids": bundle["position_ids"], + "expected_logits": bundle["expected_logits"], + } + ] + if "next_input_ids" in bundle: + steps.append( + { + "input_ids": bundle["next_input_ids"], + "hidden_states": bundle["next_hidden_states"], + "position_ids": bundle["next_position_ids"], + "expected_logits": bundle["next_expected_logits"], + } + ) + min_debug_steps = int(os.environ.get("DFLASH_MIN_DEBUG_STEPS", "3")) + self.assertGreaterEqual( + len(steps), + min_debug_steps, + "Regenerate DFLASH_DEBUG_BUNDLE with the updated multi-step extractor.", ) - expected_logits = bundle["expected_logits"] - self.assertEqual(tuple(ov_outputs.logits.shape), tuple(expected_logits.shape)) + past_key_values = None + for idx, step in enumerate(steps): + with self.subTest(step=idx): + ov_outputs = ov_model( + input_ids=step["input_ids"], + hidden_states=step["hidden_states"], + position_ids=step["position_ids"], + past_key_values=past_key_values, + ) + expected_logits = step.get( + "original_cached_logits", step.get("expected_logits_cached", step["expected_logits"]) + ) + if "expected_logits_full_prefix" in step: + torch.testing.assert_close( + step["expected_logits_full_prefix"].float(), + expected_logits.float(), + rtol=float(os.environ.get("DFLASH_PYTORCH_CACHE_RTOL", os.environ.get("DFLASH_RTOL", "5e-2"))), + atol=float(os.environ.get("DFLASH_PYTORCH_CACHE_ATOL", os.environ.get("DFLASH_ATOL", "5e-2"))), + ) + self.assertEqual(tuple(ov_outputs.logits.shape), tuple(expected_logits.shape)) + torch.testing.assert_close( + ov_outputs.logits.float(), + expected_logits.float(), + rtol=float(os.environ.get("DFLASH_RTOL", "5e-2")), + atol=float(os.environ.get("DFLASH_ATOL", "5e-2")), + ) + self.assertTrue("past_key_values" in ov_outputs) + self.assertIsInstance(ov_outputs.past_key_values, tuple) + self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + past_key_values = ov_outputs.past_key_values + + first_step = steps[0] + reset_outputs = ov_model( + input_ids=first_step["input_ids"], + hidden_states=first_step["hidden_states"], + position_ids=first_step["position_ids"], + ) torch.testing.assert_close( - ov_outputs.logits.float(), - expected_logits.float(), + reset_outputs.logits.float(), + first_step.get( + "original_cached_logits", first_step.get("expected_logits_cached", first_step["expected_logits"]) + ).float(), rtol=float(os.environ.get("DFLASH_RTOL", "5e-2")), atol=float(os.environ.get("DFLASH_ATOL", "5e-2")), ) diff --git a/tests/scripts/compare_dflash_cache_semantics.py b/tests/scripts/compare_dflash_cache_semantics.py new file mode 100644 index 0000000000..a561f365b7 --- /dev/null +++ b/tests/scripts/compare_dflash_cache_semantics.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compare original DFlash bundle references with the Optimum patched export model. + +The bundle is produced by extract_dflash_debug_bundle.py on a machine that can +load the full target model. It contains original DFlash logits and original +DFlash K/V caches cropped to the committed target-backed prefix. This helper +loads only the patched Optimum DFlash model, replays the same steps, and checks +that both logits and committed-prefix caches match. +""" + +import argparse + +import torch +from transformers import AutoConfig + + +def cache_to_layers(cache) -> tuple: + if hasattr(cache, "to_legacy_cache"): + return tuple(cache.to_legacy_cache()) + if hasattr(cache, "layers"): + layers = [] + for layer in cache.layers: + key = getattr(layer, "keys", None) + value = getattr(layer, "values", None) + if key is None: + key = getattr(layer, "key_cache", None) + if value is None: + value = getattr(layer, "value_cache", None) + layers.append((key, value)) + return tuple(layers) + return tuple(cache) + + +def cache_seq_length(cache) -> int: + layers = cache_to_layers(cache) + if not layers: + return 0 + return layers[0][0].shape[-2] + + +def cache_to_cpu(cache) -> tuple: + return tuple( + (key.detach().float().cpu().contiguous(), value.detach().float().cpu().contiguous()) + for key, value in cache_to_layers(cache) + ) + + +def assert_cache_close(actual_cache, expected_cache, *, step_idx: int, rtol: float, atol: float): + actual_layers = cache_to_cpu(actual_cache) + if len(actual_layers) != len(expected_cache): + raise AssertionError(f"Step {step_idx}: layer count mismatch {len(actual_layers)} != {len(expected_cache)}") + for layer_idx, ((actual_key, actual_value), (expected_key, expected_value)) in enumerate( + zip(actual_layers, expected_cache) + ): + torch.testing.assert_close( + actual_key, + expected_key.float(), + rtol=rtol, + atol=atol, + msg=f"Step {step_idx}, layer {layer_idx}: key cache mismatch", + ) + torch.testing.assert_close( + actual_value, + expected_value.float(), + rtol=rtol, + atol=atol, + msg=f"Step {step_idx}, layer {layer_idx}: value cache mismatch", + ) + + +def load_patched_dflash(draft_model: str, target_model: str, dtype: torch.dtype, device_map: str): + from optimum.exporters.openvino.__main__ import update_config_for_dflash + from optimum.exporters.openvino.model_patcher import Qwen3DFlashForCausalLM + + config = AutoConfig.from_pretrained(draft_model, trust_remote_code=True) + config = update_config_for_dflash(config, dflash_target_model=target_model) + return Qwen3DFlashForCausalLM.from_pretrained( + draft_model, + config=config, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + attn_implementation="eager", + ).eval() + + +def main(): + parser = argparse.ArgumentParser(description="Compare patched DFlash against original committed-cache bundle data.") + parser.add_argument("--bundle", required=True, help="Path to a bundle produced by extract_dflash_debug_bundle.py") + parser.add_argument("--draft-model", default=None, help="Override draft model ID/path from the bundle metadata.") + parser.add_argument("--target-model", default=None, help="Override target model ID/path from the bundle metadata.") + parser.add_argument("--dtype", default=None, choices=["float16", "bfloat16", "float32"]) + parser.add_argument("--device-map", default="auto") + parser.add_argument("--rtol", type=float, default=5e-2) + parser.add_argument("--atol", type=float, default=5e-2) + args = parser.parse_args() + + bundle = torch.load(args.bundle, map_location="cpu") + metadata = bundle.get("metadata", {}) + draft_model = args.draft_model or metadata.get("draft_model", "z-lab/Qwen3-Coder-30B-A3B-DFlash") + target_model = args.target_model or metadata.get("target_model", "Qwen/Qwen3-Coder-30B-A3B-Instruct") + dtype = getattr(torch, args.dtype or metadata.get("dtype", "bfloat16")) + steps = bundle["steps"] + + patched_draft = load_patched_dflash(draft_model, target_model, dtype, args.device_map) + patched_device = next(patched_draft.parameters()).device + patched_past_key_values = None + with torch.inference_mode(): + for step_idx, step in enumerate(steps): + if "original_committed_cache" not in step: + raise ValueError("Bundle is missing original_committed_cache. Regenerate it with the updated extractor.") + original_logits = step.get("original_cached_logits", step["expected_logits"]) + patched_outputs = patched_draft( + input_ids=step["input_ids"].to(patched_device), + hidden_states=step["hidden_states"].to(patched_device), + position_ids=step["position_ids"].to(patched_device), + past_key_values=patched_past_key_values, + use_cache=True, + ) + patched_past_key_values = patched_outputs.past_key_values + torch.testing.assert_close( + patched_outputs.logits.detach().float().cpu(), + original_logits.float(), + rtol=args.rtol, + atol=args.atol, + msg=f"Step {step_idx}: logits mismatch", + ) + assert_cache_close( + patched_past_key_values, + step["original_committed_cache"], + step_idx=step_idx, + rtol=args.rtol, + atol=args.atol, + ) + patched_length = cache_seq_length(patched_past_key_values) + if patched_length != step["expected_present_length"]: + raise AssertionError( + f"Step {step_idx}: patched cache length {patched_length} != {step['expected_present_length']}" + ) + print(f"Step {step_idx}: patched logits and committed-prefix cache match original bundle reference") + + print("DFlash committed-prefix cache semantics match the original implementation.") + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/extract_dflash_debug_bundle.py b/tests/scripts/extract_dflash_debug_bundle.py index 0b0494914c..fe90cea6db 100644 --- a/tests/scripts/extract_dflash_debug_bundle.py +++ b/tests/scripts/extract_dflash_debug_bundle.py @@ -12,9 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Create a portable DFlash correctness fixture. + +This script is standalone: run it on any machine that can load the PyTorch +target and DFlash draft models, then copy the resulting `.pt` bundle to the +machine that runs the OpenVINO export tests. + +Required packages: + - torch + - transformers >= 4.57 + - accelerate, if using --device-map auto + - safetensors + - huggingface_hub + +Example: + python tests/scripts/extract_dflash_debug_bundle.py \ + --draft-model z-lab/Qwen3-Coder-30B-A3B-DFlash \ + --target-model Qwen/Qwen3-Coder-30B-A3B-Instruct \ + --dtype float16 \ + --num-steps 4 \ + --output dflash_debug_bundle_kv.pt +""" + import argparse import torch +from transformers.cache_utils import DynamicCache from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer @@ -29,6 +53,100 @@ def extract_context_feature(hidden_states: list[torch.Tensor], layer_ids: list[i return torch.cat([hidden_states[layer_id + 1] for layer_id in layer_ids], dim=-1) +def cache_to_layers(cache) -> tuple: + if hasattr(cache, "to_legacy_cache"): + return tuple(cache.to_legacy_cache()) + if hasattr(cache, "layers"): + layers = [] + for layer in cache.layers: + key = getattr(layer, "keys", None) + value = getattr(layer, "values", None) + if key is None: + key = getattr(layer, "key_cache", None) + if value is None: + value = getattr(layer, "value_cache", None) + layers.append((key, value)) + return tuple(layers) + return tuple(cache) + + +def crop_tensor(tensor: torch.Tensor, length: int) -> torch.Tensor: + return tensor[..., :length, :].contiguous() + + +def crop_cache(cache, length: int): + if hasattr(cache, "crop"): + cache.crop(length) + return cache + return tuple((crop_tensor(key, length), crop_tensor(value, length)) for key, value in cache_to_layers(cache)) + + +def cache_to_cpu(cache) -> tuple: + return tuple( + (key.detach().cpu().contiguous(), value.detach().cpu().contiguous()) for key, value in cache_to_layers(cache) + ) + + +def original_outputs_to_hidden_and_cache(outputs, fallback_cache=None): + if hasattr(outputs, "last_hidden_state"): + return outputs.last_hidden_state, outputs.past_key_values + if isinstance(outputs, tuple): + if len(outputs) > 1: + return outputs[0], outputs[1] + if fallback_cache is not None: + return outputs[0], fallback_cache + if torch.is_tensor(outputs) and fallback_cache is not None: + return outputs, fallback_cache + raise TypeError("Original DFlash model did not return a cache. Make sure use_cache=True is supported.") + + +def run_dflash_block( + draft, + target, + block_input_ids: torch.Tensor, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + block_size: int, +) -> torch.Tensor: + target_device = block_input_ids.device + noise_embedding = target.model.embed_tokens(block_input_ids.to(target_device)).to(next(draft.parameters()).device) + draft_hidden = draft( + target_hidden=hidden_states.to(next(draft.parameters()).device), + noise_embedding=noise_embedding, + position_ids=position_ids.to(next(draft.parameters()).device), + use_cache=False, + is_causal=False, + ) + return target.lm_head(draft_hidden[:, -block_size + 1 :, :].to(target_device)) + + +def run_dflash_cached_block( + draft, + target, + block_input_ids: torch.Tensor, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + block_size: int, +) -> tuple[torch.Tensor, object]: + target_device = block_input_ids.device + draft_device = next(draft.parameters()).device + if past_key_values is None: + past_key_values = DynamicCache(config=draft.config) + noise_embedding = target.model.embed_tokens(block_input_ids.to(target_device)).to(draft_device) + outputs = draft( + target_hidden=hidden_states.to(draft_device), + noise_embedding=noise_embedding, + position_ids=position_ids.to(draft_device), + past_key_values=past_key_values, + use_cache=True, + is_causal=False, + ) + draft_hidden, past_key_values = original_outputs_to_hidden_and_cache(outputs, fallback_cache=past_key_values) + logits = target.lm_head(draft_hidden[:, -block_size + 1 :, :].to(target_device)) + return logits, past_key_values + + def main(): parser = argparse.ArgumentParser(description="Create a lightweight DFlash correctness fixture.") parser.add_argument("--draft-model", default="z-lab/Qwen3-Coder-30B-A3B-DFlash") @@ -38,7 +156,10 @@ def main(): parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"]) parser.add_argument("--device-map", default="auto") parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--num-steps", type=int, default=4) args = parser.parse_args() + if args.num_steps < 1: + raise ValueError("--num-steps must be at least 1") dtype = getattr(torch, args.dtype) tokenizer = AutoTokenizer.from_pretrained(args.target_model, trust_remote_code=True) @@ -62,49 +183,88 @@ def main(): mask_token_id = draft.config.dflash_config["mask_token_id"] target_layer_ids = draft.config.dflash_config["target_layer_ids"] - position_ids = torch.arange(input_ids.shape[1] + block_size, device=device).unsqueeze(0) + steps = [] + committed_input_ids = input_ids + committed_hidden_length = 0 + original_past_key_values = None with torch.inference_mode(): - target_output = target( - input_ids, - position_ids=position_ids[:, : input_ids.shape[1]], - use_cache=False, - logits_to_keep=1, - output_hidden_states=True, - ) - first_draft_token = sample(target_output.logits[:, -1, :], args.temperature) - - block_input_ids = torch.full((1, block_size), mask_token_id, dtype=torch.long, device=device) - block_input_ids[:, 0] = first_draft_token - - target_hidden = extract_context_feature(target_output.hidden_states, target_layer_ids).to(next(draft.parameters()).device) - noise_embedding = target.model.embed_tokens(block_input_ids.to(device)).to(next(draft.parameters()).device) - draft_position_ids = position_ids.to(next(draft.parameters()).device) - - draft_hidden = draft( - target_hidden=target_hidden, - noise_embedding=noise_embedding, - position_ids=draft_position_ids, - use_cache=False, - is_causal=False, - ) - expected_logits = target.lm_head(draft_hidden[:, -block_size + 1 :, :].to(device)) - sampled_tokens = sample(expected_logits, args.temperature) + for step_idx in range(args.num_steps): + committed_length = committed_input_ids.shape[1] + target_position_ids = torch.arange(committed_length, device=device).unsqueeze(0) + target_output = target( + committed_input_ids, + position_ids=target_position_ids, + use_cache=False, + logits_to_keep=1, + output_hidden_states=True, + ) + seed_token = sample(target_output.logits[:, -1, :], args.temperature) + + block_input_ids = torch.full((1, block_size), mask_token_id, dtype=torch.long, device=device) + block_input_ids[:, 0] = seed_token + + full_hidden_states = extract_context_feature(target_output.hidden_states, target_layer_ids) + hidden_states = full_hidden_states[:, committed_hidden_length:, :] + position_start = committed_hidden_length + position_ids = torch.arange(position_start, committed_length + block_size, device=device).unsqueeze(0) + full_position_ids = torch.arange(committed_length + block_size, device=device).unsqueeze(0) + expected_logits = run_dflash_block( + draft, + target, + block_input_ids, + full_hidden_states, + full_position_ids, + block_size, + ) + original_cached_logits, original_past_key_values = run_dflash_cached_block( + draft, + target, + block_input_ids, + hidden_states, + position_ids, + original_past_key_values, + block_size, + ) + + committed_hidden_length += hidden_states.shape[1] + original_past_key_values = crop_cache(original_past_key_values, committed_hidden_length) + steps.append( + { + "index": step_idx, + "input_ids": block_input_ids.cpu(), + "hidden_states": hidden_states.cpu(), + "position_ids": position_ids.cpu(), + "expected_logits": original_cached_logits.cpu(), + "expected_logits_full_prefix": expected_logits.cpu(), + "original_cached_logits": original_cached_logits.cpu(), + "original_committed_cache": cache_to_cpu(original_past_key_values), + "seed_token": seed_token.cpu(), + "sampled_tokens": sample(expected_logits, args.temperature).cpu(), + "expected_present_length": committed_hidden_length, + } + ) + committed_input_ids = torch.cat([committed_input_ids, seed_token[:, None]], dim=1) bundle = { - "input_ids": block_input_ids.cpu(), - "target_hidden": target_hidden.cpu(), - "position_ids": position_ids.cpu(), - "expected_logits": expected_logits.cpu(), - "sampled_tokens": sampled_tokens.cpu(), + "steps": steps, + # Keep first-step keys for quick ad-hoc inspection and older local scripts. + "input_ids": steps[0]["input_ids"], + "hidden_states": steps[0]["hidden_states"], + "position_ids": steps[0]["position_ids"], + "expected_logits": steps[0]["expected_logits"], + "sampled_tokens": steps[0]["sampled_tokens"], + "expected_present_length": steps[0]["expected_present_length"], "metadata": { "draft_model": args.draft_model, "target_model": args.target_model, "prompt": args.prompt, + "num_steps": args.num_steps, "block_size": block_size, "mask_token_id": mask_token_id, "target_layer_ids": target_layer_ids, "dtype": args.dtype, "temperature": args.temperature, + "cache_policy": "committed_prefix", }, } torch.save(bundle, args.output) From e919101ed422d15830233c4fd7e838b052034598 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 26 May 2026 05:50:16 -0700 Subject: [PATCH 07/11] Fix DFlash export to support dynamic block size (`num_assistant_tokens`) --- optimum/exporters/openvino/model_patcher.py | 6 ++-- tests/openvino/test_exporters_cli.py | 31 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 83e815c4d3..011f3c060b 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -8814,8 +8814,10 @@ def forward( **kwargs, ) if logits_to_keep is None: - logits_to_keep = self.block_size - 1 - logits = self.lm_head(outputs.last_hidden_state[:, -logits_to_keep:, :]) + hidden_states = outputs.last_hidden_state[:, 1:, :] + else: + hidden_states = outputs.last_hidden_state[:, -logits_to_keep:, :] + logits = self.lm_head(hidden_states) return CausalLMOutputWithPast(logits=logits, past_key_values=outputs.past_key_values) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 9ffaba1e15..4f5db09937 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -859,6 +859,8 @@ def _openvino_export( @unittest.skipUnless(os.environ.get("RUN_DFLASH_EXPORT_TEST"), "Set RUN_DFLASH_EXPORT_TEST=1 to run DFlash export") def test_dflash_export_smoke(self): + import openvino as ov + draft_model_id, target_model_id = DFLASH_MODELS["qwen3_coder_dflash"] with TemporaryDirectory() as tmpdir: main_export( @@ -872,6 +874,35 @@ def test_dflash_export_smoke(self): self.assertTrue((Path(tmpdir) / "openvino_model.xml").exists()) self.assertTrue((Path(tmpdir) / "openvino_model.bin").exists()) + model_path = Path(tmpdir) / "openvino_model.xml" + core = ov.Core() + dflash_model = core.read_model(model_path) + logits = next(output for output in dflash_model.outputs if "logits" in output.get_names()) + self.assertTrue(dflash_model.input("input_ids").get_partial_shape()[1].is_dynamic) + self.assertTrue(logits.get_partial_shape()[1].is_dynamic) + + for input_length in (2, 5): + with self.subTest(input_length=input_length): + dflash_model = core.read_model(model_path) + hidden_states_shape = dflash_model.input("hidden_states").get_partial_shape() + hidden_states_shape[0] = 1 + hidden_states_shape[1] = 4 + position_ids_shape = dflash_model.input("position_ids").get_partial_shape() + position_ids_shape[0] = 1 + position_ids_shape[1] = hidden_states_shape[1].get_length() + input_length + + dflash_model.reshape( + { + "input_ids": ov.PartialShape([1, input_length]), + "hidden_states": hidden_states_shape, + "position_ids": position_ids_shape, + } + ) + logits = next(output for output in dflash_model.outputs if "logits" in output.get_names()) + logits_sequence_length = logits.get_partial_shape()[1] + self.assertFalse(logits_sequence_length.is_dynamic) + self.assertEqual(logits_sequence_length.get_length(), input_length - 1) + def test_filtered_architectures(cls): if is_transformers_version("<", "4.49"): expected = {"qwen3_vl", "llama4", "qwen2_5_vl", "phi4mm"} From 644f3ab4033845ee2396694f2d09b28127792df4 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Thu, 28 May 2026 01:48:22 -0700 Subject: [PATCH 08/11] Add support for qwen3.5 hidden_states annotations and relevant tests --- optimum/exporters/openvino/model_patcher.py | 14 +- .../openvino/test_hidden_state_annotations.py | 155 ++++++++++++++++++ 2 files changed, 167 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7d364a53d1..48e72d022b 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -9812,6 +9812,9 @@ def patched_forward( text_config = self._text_config num_full_attn_layers = text_config.layer_types.count("full_attention") num_linear_attn_layers = text_config.layer_types.count("linear_attention") + output_hidden_states = getattr(self._model.config, "output_hidden_states", False) or getattr( + text_config, "output_hidden_states", False + ) use_cache = False wrapped_cache_params = None @@ -9843,19 +9846,23 @@ def patched_forward( position_ids=position_ids, past_key_values=wrapped_cache_params, use_cache=use_cache, + output_hidden_states=output_hidden_states, ) - hidden_states = outputs_lm[0] - logits = self._model.lm_head(hidden_states) + last_hidden_state = outputs_lm[0] + logits = self._model.lm_head(last_hidden_state) past_kv = outputs_lm.past_key_values + hidden_states = outputs_lm.hidden_states else: causal_lm_output = self.model_orig_forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=wrapped_cache_params, use_cache=use_cache, + output_hidden_states=output_hidden_states, ) logits = causal_lm_output.logits past_kv = causal_lm_output.past_key_values + hidden_states = causal_lm_output.hidden_states outputs = { "logits": logits, } @@ -9872,6 +9879,9 @@ def patched_forward( outputs["present_key_values"] = present_key_values + if hidden_states is not None: + outputs["hidden_states"] = hidden_states + return outputs self.patched_forward = patched_forward diff --git a/tests/openvino/test_hidden_state_annotations.py b/tests/openvino/test_hidden_state_annotations.py index 69764a523c..b5a245ee15 100644 --- a/tests/openvino/test_hidden_state_annotations.py +++ b/tests/openvino/test_hidden_state_annotations.py @@ -14,8 +14,10 @@ import json import unittest +from collections import Counter from pathlib import Path from tempfile import TemporaryDirectory +from unittest import mock import numpy as np import openvino as ov @@ -24,9 +26,11 @@ from utils_tests import MODEL_NAMES from optimum.exporters.openvino import export_from_model +from optimum.intel.utils.import_utils import is_transformers_version HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" +QWEN3_5_MOE_TRANSFORMERS_AVAILABLE = is_transformers_version(">=", "5.2.0") and is_transformers_version("<", "5.3.0") def _find_output_by_tensor_name(model, tensor_name): @@ -45,6 +49,47 @@ def _add_model_output(model, output, output_name): model.add_outputs([output]) +def _port_names(ports): + return set().union(*(port.get_names() for port in ports)) + + +def _op_type_arity_signature(model): + return Counter( + (op.get_type_name(), len(op.inputs()), len(op.outputs())) + for op in model.get_ops() + if op.get_type_name() != "Result" + ) + + +def _has_recurrent_attention_cell(model): + return any(op.get_type_name() in {"RecurrentAttentionCell", "RecurrentAttentionCellOp"} for op in model.get_ops()) + + +def _export_qwen3_5_moe_text_model( + output, annotate_hidden_states=True, stateful=True, task="text-generation-with-past" +): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["qwen3_5_moe"]) + if annotate_hidden_states: + export_from_model( + model=model, + output=output, + task=task, + preprocessors=None, + stateful=stateful, + ) + return model.config + + with mock.patch("optimum.exporters.openvino.convert._can_annotate_hidden_states", return_value=False): + export_from_model( + model=model, + output=output, + task=task, + preprocessors=None, + stateful=stateful, + ) + return model.config + + class HiddenStateAnnotationExportTest(unittest.TestCase): def test_export_hidden_state_annotations_without_extra_outputs(self): for task in ("text-generation", "text-generation-with-past"): @@ -131,3 +176,113 @@ def test_annotated_hidden_state_output_matches_pytorch(self): rtol=1e-4, atol=1e-4, ) + + @unittest.skipIf( + not QWEN3_5_MOE_TRANSFORMERS_AVAILABLE, + "Qwen3.5-MoE export requires transformers 5.2.x", + ) + def test_qwen3_5_moe_text_generation_with_past_hidden_state_annotations(self): + with TemporaryDirectory() as tmpdirname: + model_config = _export_qwen3_5_moe_text_model(Path(tmpdirname)) + + ov_model = ov.Core().read_model(Path(tmpdirname) / "openvino_model.xml") + output_names = _port_names(ov_model.outputs) + self.assertFalse(any(name.startswith("ov.hidden_states.") for name in output_names)) + + rt_info = ov_model.get_rt_info() + self.assertIn(HIDDEN_STATES_RT_INFO_KEY, rt_info) + annotation = json.loads(rt_info[HIDDEN_STATES_RT_INFO_KEY].value) + self.assertEqual(annotation["version"], 1) + self.assertEqual(len(annotation["layers"]), model_config.num_hidden_layers) + + graph_tensor_names = set() + for op in ov_model.get_ops(): + for output in op.outputs(): + graph_tensor_names.update(output.get_names()) + for tensor_name in annotation["layers"].values(): + self.assertIn(tensor_name, graph_tensor_names) + + @unittest.skipIf( + not QWEN3_5_MOE_TRANSFORMERS_AVAILABLE, + "Qwen3.5-MoE export requires transformers 5.2.x", + ) + def test_qwen3_5_moe_hidden_state_annotation_preserves_graph_signature(self): + with TemporaryDirectory() as tmpdirname: + baseline_dir = Path(tmpdirname) / "baseline" + annotated_dir = Path(tmpdirname) / "annotated" + _export_qwen3_5_moe_text_model(baseline_dir, annotate_hidden_states=False) + _export_qwen3_5_moe_text_model(annotated_dir) + + core = ov.Core() + baseline_model = core.read_model(baseline_dir / "openvino_model.xml") + annotated_model = core.read_model(annotated_dir / "openvino_model.xml") + + self.assertEqual(_port_names(baseline_model.inputs), _port_names(annotated_model.inputs)) + self.assertEqual(_port_names(baseline_model.outputs), _port_names(annotated_model.outputs)) + self.assertEqual(_op_type_arity_signature(baseline_model), _op_type_arity_signature(annotated_model)) + self.assertEqual( + _has_recurrent_attention_cell(baseline_model), + _has_recurrent_attention_cell(annotated_model), + ) + + @unittest.skipIf( + not QWEN3_5_MOE_TRANSFORMERS_AVAILABLE, + "Qwen3.5-MoE export requires transformers 5.2.x", + ) + def test_qwen3_5_moe_annotated_hidden_state_outputs_match_pytorch(self): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["qwen3_5_moe"]) + model.eval() + + with TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task="text-generation", + preprocessors=None, + stateful=False, + ) + + core = ov.Core() + ov_model = core.read_model(Path(tmpdirname) / "openvino_model.xml") + annotation = json.loads(ov_model.get_rt_info()[HIDDEN_STATES_RT_INFO_KEY].value) + layer_indices = [0, model.config.num_hidden_layers - 1] + output_names = [] + for layer_idx in layer_indices: + output_name = f"decoder_layer_{layer_idx}_hidden_state" + hidden_state_output = _find_output_by_tensor_name(ov_model, annotation["layers"][str(layer_idx)]) + _add_model_output(ov_model, hidden_state_output, output_name) + output_names.append(output_name) + + input_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + with torch.no_grad(): + torch_outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + return_dict=True, + ) + + compiled_model = core.compile_model(ov_model, "CPU") + ov_inputs = {} + for input_port in compiled_model.inputs: + input_name = input_port.get_any_name() + if input_name == "input_ids": + ov_inputs[input_name] = input_ids.numpy() + elif input_name == "attention_mask": + ov_inputs[input_name] = attention_mask.numpy() + elif input_name == "position_ids": + ov_inputs[input_name] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1) + else: + self.fail(f"Unexpected OpenVINO model input: {input_name}") + + infer_result = compiled_model(ov_inputs) + for layer_idx, output_name in zip(layer_indices, output_names): + ov_output_port = next(output for output in compiled_model.outputs if output_name in output.get_names()) + np.testing.assert_allclose( + infer_result[ov_output_port], + torch_outputs.hidden_states[layer_idx + 1].detach().numpy(), + rtol=5e-3, + atol=5e-3, + ) From e8b4dfb660df1a7995d01c28919ccf26c861bddb Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Thu, 28 May 2026 10:50:18 -0700 Subject: [PATCH 09/11] Fix dflash export where target model is a text model nested in a VLM --- optimum/exporters/openvino/model_patcher.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 48e72d022b..a0805d9aec 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -8717,11 +8717,12 @@ def _dflash_resolve_tensor_file(config: PretrainedConfig, tensor_name: str) -> s token=token, local_files_only=local_files_only, ) + except Exception: + filename = "model.safetensors" + else: with open(index_path) as f: weight_map = json.load(f)["weight_map"] filename = weight_map[tensor_name] - except Exception: - filename = "model.safetensors" return hf_hub_download( target_model, @@ -8748,6 +8749,13 @@ def _dflash_load_target_tensor(config: PretrainedConfig, tensor_names: Tuple[str raise ValueError(f"Could not load any of {tensor_names} from DFlash target model.") from last_error +_DFLASH_EMBED_TENSOR_NAMES = ( + "model.language_model.embed_tokens.weight", + "model.embed_tokens.weight", + "embed_tokens.weight", +) + + class Qwen3DFlashAttention(nn.Module): """Qwen3 attention variant used by DFlash, where draft tokens attend over target context and noise tokens.""" @@ -8949,7 +8957,7 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def _load_target_weights(self, config): - embed_weight = _dflash_load_target_tensor(config, ("model.embed_tokens.weight", "embed_tokens.weight")) + embed_weight = _dflash_load_target_tensor(config, _DFLASH_EMBED_TENSOR_NAMES) try: lm_head_weight = _dflash_load_target_tensor(config, ("lm_head.weight",)) except ValueError: From ff7c41f40bbdfb102ca11662c8b8ba6a2946d829 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Sat, 30 May 2026 11:03:43 -0700 Subject: [PATCH 10/11] Test cleanup --- tests/openvino/test_decoder.py | 111 ------- tests/openvino/test_export.py | 109 ++++++- tests/openvino/test_exporters_cli.py | 48 --- .../openvino/test_hidden_state_annotations.py | 288 ------------------ tests/openvino/utils_tests.py | 5 - .../scripts/compare_dflash_cache_semantics.py | 161 ---------- 6 files changed, 108 insertions(+), 614 deletions(-) delete mode 100644 tests/openvino/test_hidden_state_annotations.py delete mode 100644 tests/scripts/compare_dflash_cache_semantics.py diff --git a/tests/openvino/test_decoder.py b/tests/openvino/test_decoder.py index 2f809cf40e..ba45d393f5 100644 --- a/tests/openvino/test_decoder.py +++ b/tests/openvino/test_decoder.py @@ -11,7 +11,6 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES from transformers.testing_utils import slow from utils_tests import ( - DFLASH_MODELS, EAGLE3_MODELS, F32_CONFIG, MODEL_NAMES, @@ -940,116 +939,6 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair): del ov_model gc.collect() - @parameterized.expand(DFLASH_MODELS.items()) - @pytest.mark.skipif(is_transformers_version("<", "4.57"), reason="DFlash requires transformers >= 4.57") - @pytest.mark.skipif("DFLASH_DEBUG_BUNDLE" not in os.environ, reason="Set DFLASH_DEBUG_BUNDLE to run DFlash tests") - @pytest.mark.run_slow - @slow - def test_load_and_infer_stateful_with_dflash_debug_bundle(self, model_arch, model_pair): - draft_model_id, target_model_id = model_pair - bundle = torch.load(os.environ["DFLASH_DEBUG_BUNDLE"], map_location="cpu") - - ov_model_path = os.environ.get("DFLASH_OV_MODEL_DIR") - if ov_model_path: - ov_model = OVModelForCausalLM.from_pretrained( - ov_model_path, - use_cache=True, - stateful=True, - device=OPENVINO_DEVICE, - ) - else: - ov_model = OVModelForCausalLM.from_pretrained( - draft_model_id, - export=True, - trust_remote_code=True, - dflash_target_model=target_model_id, - use_cache=True, - stateful=True, - device=OPENVINO_DEVICE, - ) - self.assertTrue(ov_model.stateful) - if hasattr(ov_model, "model"): - self.assertEqual(ov_model.model.get_rt_info(["dflash", "cache_policy"]).value, "committed_prefix") - input_names = {name for model_input in ov_model.model.inputs for name in model_input.get_names()} - output_names = {name for model_output in ov_model.model.outputs for name in model_output.get_names()} - self.assertFalse(any("past_key_values" in name for name in input_names)) - self.assertFalse(any("present" in name for name in output_names)) - - steps = bundle.get("steps") - if steps is None: - steps = [ - { - "input_ids": bundle["input_ids"], - "hidden_states": bundle["hidden_states"], - "position_ids": bundle["position_ids"], - "expected_logits": bundle["expected_logits"], - } - ] - if "next_input_ids" in bundle: - steps.append( - { - "input_ids": bundle["next_input_ids"], - "hidden_states": bundle["next_hidden_states"], - "position_ids": bundle["next_position_ids"], - "expected_logits": bundle["next_expected_logits"], - } - ) - min_debug_steps = int(os.environ.get("DFLASH_MIN_DEBUG_STEPS", "3")) - self.assertGreaterEqual( - len(steps), - min_debug_steps, - "Regenerate DFLASH_DEBUG_BUNDLE with the updated multi-step extractor.", - ) - - past_key_values = None - for idx, step in enumerate(steps): - with self.subTest(step=idx): - ov_outputs = ov_model( - input_ids=step["input_ids"], - hidden_states=step["hidden_states"], - position_ids=step["position_ids"], - past_key_values=past_key_values, - ) - expected_logits = step.get( - "original_cached_logits", step.get("expected_logits_cached", step["expected_logits"]) - ) - if "expected_logits_full_prefix" in step: - torch.testing.assert_close( - step["expected_logits_full_prefix"].float(), - expected_logits.float(), - rtol=float(os.environ.get("DFLASH_PYTORCH_CACHE_RTOL", os.environ.get("DFLASH_RTOL", "5e-2"))), - atol=float(os.environ.get("DFLASH_PYTORCH_CACHE_ATOL", os.environ.get("DFLASH_ATOL", "5e-2"))), - ) - self.assertEqual(tuple(ov_outputs.logits.shape), tuple(expected_logits.shape)) - torch.testing.assert_close( - ov_outputs.logits.float(), - expected_logits.float(), - rtol=float(os.environ.get("DFLASH_RTOL", "5e-2")), - atol=float(os.environ.get("DFLASH_ATOL", "5e-2")), - ) - self.assertTrue("past_key_values" in ov_outputs) - self.assertIsInstance(ov_outputs.past_key_values, tuple) - self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) - past_key_values = ov_outputs.past_key_values - - first_step = steps[0] - reset_outputs = ov_model( - input_ids=first_step["input_ids"], - hidden_states=first_step["hidden_states"], - position_ids=first_step["position_ids"], - ) - torch.testing.assert_close( - reset_outputs.logits.float(), - first_step.get( - "original_cached_logits", first_step.get("expected_logits_cached", first_step["expected_logits"]) - ).float(), - rtol=float(os.environ.get("DFLASH_RTOL", "5e-2")), - atol=float(os.environ.get("DFLASH_ATOL", "5e-2")), - ) - - del ov_model - gc.collect() - HYBRID_ARCHITECTURES = [] if is_transformers_version(">=", "4.53"): HYBRID_ARCHITECTURES.append("granitemoehybrid") diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index fbca6ace56..120d4d5d5b 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -13,13 +13,16 @@ # limitations under the License. +import json import unittest from pathlib import Path +import numpy as np +import openvino as ov import torch from parameterized import parameterized from sentence_transformers import SentenceTransformer, models -from transformers import AutoConfig, AutoTokenizer, GenerationConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig from utils_tests import ( MODEL_NAMES, OPENVINO_DEVICE, @@ -65,6 +68,24 @@ logger = logging.get_logger() +HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" + + +def _find_output_by_tensor_name(model, tensor_name): + for op in model.get_ops(): + for output in op.outputs(): + if tensor_name in output.get_names(): + return output + raise AssertionError(f"Tensor {tensor_name} was not found in the OpenVINO graph") + + +def _add_model_output(model, output, output_name): + output.get_tensor().add_names({output_name}) + if hasattr(model, "add_output"): + model.add_output(output) + else: + model.add_outputs([output]) + class ExportModelTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = { @@ -263,6 +284,92 @@ def test_export(self, model_type: str): model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} self._openvino_export(model_type, model_kwargs=model_kwargs) + def test_export_hidden_state_annotations_without_extra_outputs(self): + for task in ("text-generation", "text-generation-with-past"): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) + with self.subTest(task=task), TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task=task, + preprocessors=None, + stateful=False, + ) + + ov_model = ov.Core().read_model(Path(tmpdirname) / "openvino_model.xml") + output_names = set().union(*(output.get_names() for output in ov_model.outputs)) + self.assertNotIn("last_hidden_state", output_names) + self.assertFalse(any(name.startswith("ov.hidden_states.") for name in output_names)) + + rt_info = ov_model.get_rt_info() + self.assertIn(HIDDEN_STATES_RT_INFO_KEY, rt_info) + annotation = json.loads(rt_info[HIDDEN_STATES_RT_INFO_KEY].value) + self.assertEqual(annotation["version"], 1) + self.assertEqual(len(annotation["layers"]), model.config.num_hidden_layers) + + graph_tensor_names = set() + for op in ov_model.get_ops(): + for output in op.outputs(): + graph_tensor_names.update(output.get_names()) + for tensor_name in annotation["layers"].values(): + self.assertIn(tensor_name, graph_tensor_names) + + def test_annotated_hidden_state_output_matches_pytorch(self): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) + model.eval() + + with TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task="text-generation", + preprocessors=None, + stateful=False, + ) + + core = ov.Core() + ov_model = core.read_model(Path(tmpdirname) / "openvino_model.xml") + annotation = json.loads(ov_model.get_rt_info()[HIDDEN_STATES_RT_INFO_KEY].value) + layer_idx = 0 + output_name = "decoder_layer_0_hidden_state" + hidden_state_output = _find_output_by_tensor_name(ov_model, annotation["layers"][str(layer_idx)]) + _add_model_output(ov_model, hidden_state_output, output_name) + + input_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + with torch.no_grad(): + torch_outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + return_dict=True, + ) + + compiled_model = core.compile_model(ov_model, "CPU") + ov_inputs = {} + for input_port in compiled_model.inputs: + input_name = input_port.get_any_name() + if input_name == "input_ids": + ov_inputs[input_name] = input_ids.numpy() + elif input_name == "attention_mask": + ov_inputs[input_name] = attention_mask.numpy() + elif input_name == "position_ids": + ov_inputs[input_name] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1) + elif input_name == "token_type_ids": + ov_inputs[input_name] = np.zeros(input_ids.shape, dtype=np.int64) + else: + self.fail(f"Unexpected OpenVINO model input: {input_name}") + + infer_result = compiled_model(ov_inputs) + ov_output_port = next(output for output in compiled_model.outputs if output_name in output.get_names()) + np.testing.assert_allclose( + infer_result[ov_output_port], + torch_outputs.hidden_states[layer_idx + 1].detach().numpy(), + rtol=1e-4, + atol=1e-4, + ) + @parameterized.expand(GENERATIVE_MODELS) def test_export_with_custom_gen_config(self, model_type): auto_model = self.SUPPORTED_ARCHITECTURES[model_type] diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 662a95c5d2..a45c3de6bc 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import os import subprocess import unittest from pathlib import Path @@ -29,7 +28,6 @@ ) from utils_tests import ( _ARCHITECTURES_TO_EXPECTED_INT8, - DFLASH_MODELS, MODEL_NAMES, OPENVINO_DEVICE, REMOTE_CODE_MODELS, @@ -878,52 +876,6 @@ def _openvino_export( model_name_or_path=model_name, output=tmpdir, task=task, model_kwargs=model_kwargs, **loading_kwargs ) - @unittest.skipUnless(os.environ.get("RUN_DFLASH_EXPORT_TEST"), "Set RUN_DFLASH_EXPORT_TEST=1 to run DFlash export") - def test_dflash_export_smoke(self): - import openvino as ov - - draft_model_id, target_model_id = DFLASH_MODELS["qwen3_coder_dflash"] - with TemporaryDirectory() as tmpdir: - main_export( - model_name_or_path=draft_model_id, - output=tmpdir, - task="text-generation", - trust_remote_code=True, - dflash_target_model=target_model_id, - convert_tokenizer=False, - ) - self.assertTrue((Path(tmpdir) / "openvino_model.xml").exists()) - self.assertTrue((Path(tmpdir) / "openvino_model.bin").exists()) - - model_path = Path(tmpdir) / "openvino_model.xml" - core = ov.Core() - dflash_model = core.read_model(model_path) - logits = next(output for output in dflash_model.outputs if "logits" in output.get_names()) - self.assertTrue(dflash_model.input("input_ids").get_partial_shape()[1].is_dynamic) - self.assertTrue(logits.get_partial_shape()[1].is_dynamic) - - for input_length in (2, 5): - with self.subTest(input_length=input_length): - dflash_model = core.read_model(model_path) - hidden_states_shape = dflash_model.input("hidden_states").get_partial_shape() - hidden_states_shape[0] = 1 - hidden_states_shape[1] = 4 - position_ids_shape = dflash_model.input("position_ids").get_partial_shape() - position_ids_shape[0] = 1 - position_ids_shape[1] = hidden_states_shape[1].get_length() + input_length - - dflash_model.reshape( - { - "input_ids": ov.PartialShape([1, input_length]), - "hidden_states": hidden_states_shape, - "position_ids": position_ids_shape, - } - ) - logits = next(output for output in dflash_model.outputs if "logits" in output.get_names()) - logits_sequence_length = logits.get_partial_shape()[1] - self.assertFalse(logits_sequence_length.is_dynamic) - self.assertEqual(logits_sequence_length.get_length(), input_length - 1) - def test_filtered_architectures(cls): if is_transformers_version("<", "4.49"): expected = {"qwen3_vl", "llama4", "qwen2_5_vl", "phi4mm"} diff --git a/tests/openvino/test_hidden_state_annotations.py b/tests/openvino/test_hidden_state_annotations.py deleted file mode 100644 index b5a245ee15..0000000000 --- a/tests/openvino/test_hidden_state_annotations.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import unittest -from collections import Counter -from pathlib import Path -from tempfile import TemporaryDirectory -from unittest import mock - -import numpy as np -import openvino as ov -import torch -from transformers import AutoModelForCausalLM -from utils_tests import MODEL_NAMES - -from optimum.exporters.openvino import export_from_model -from optimum.intel.utils.import_utils import is_transformers_version - - -HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" -QWEN3_5_MOE_TRANSFORMERS_AVAILABLE = is_transformers_version(">=", "5.2.0") and is_transformers_version("<", "5.3.0") - - -def _find_output_by_tensor_name(model, tensor_name): - for op in model.get_ops(): - for output in op.outputs(): - if tensor_name in output.get_names(): - return output - raise AssertionError(f"Tensor {tensor_name} was not found in the OpenVINO graph") - - -def _add_model_output(model, output, output_name): - output.get_tensor().add_names({output_name}) - if hasattr(model, "add_output"): - model.add_output(output) - else: - model.add_outputs([output]) - - -def _port_names(ports): - return set().union(*(port.get_names() for port in ports)) - - -def _op_type_arity_signature(model): - return Counter( - (op.get_type_name(), len(op.inputs()), len(op.outputs())) - for op in model.get_ops() - if op.get_type_name() != "Result" - ) - - -def _has_recurrent_attention_cell(model): - return any(op.get_type_name() in {"RecurrentAttentionCell", "RecurrentAttentionCellOp"} for op in model.get_ops()) - - -def _export_qwen3_5_moe_text_model( - output, annotate_hidden_states=True, stateful=True, task="text-generation-with-past" -): - model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["qwen3_5_moe"]) - if annotate_hidden_states: - export_from_model( - model=model, - output=output, - task=task, - preprocessors=None, - stateful=stateful, - ) - return model.config - - with mock.patch("optimum.exporters.openvino.convert._can_annotate_hidden_states", return_value=False): - export_from_model( - model=model, - output=output, - task=task, - preprocessors=None, - stateful=stateful, - ) - return model.config - - -class HiddenStateAnnotationExportTest(unittest.TestCase): - def test_export_hidden_state_annotations_without_extra_outputs(self): - for task in ("text-generation", "text-generation-with-past"): - model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) - with self.subTest(task=task), TemporaryDirectory() as tmpdirname: - export_from_model( - model=model, - output=Path(tmpdirname), - task=task, - preprocessors=None, - stateful=False, - ) - - ov_model = ov.Core().read_model(Path(tmpdirname) / "openvino_model.xml") - output_names = set().union(*(output.get_names() for output in ov_model.outputs)) - self.assertNotIn("last_hidden_state", output_names) - self.assertFalse(any(name.startswith("ov.hidden_states.") for name in output_names)) - - rt_info = ov_model.get_rt_info() - self.assertIn(HIDDEN_STATES_RT_INFO_KEY, rt_info) - annotation = json.loads(rt_info[HIDDEN_STATES_RT_INFO_KEY].value) - self.assertEqual(annotation["version"], 1) - self.assertEqual(len(annotation["layers"]), model.config.num_hidden_layers) - - graph_tensor_names = set() - for op in ov_model.get_ops(): - for output in op.outputs(): - graph_tensor_names.update(output.get_names()) - for tensor_name in annotation["layers"].values(): - self.assertIn(tensor_name, graph_tensor_names) - - def test_annotated_hidden_state_output_matches_pytorch(self): - model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) - model.eval() - - with TemporaryDirectory() as tmpdirname: - export_from_model( - model=model, - output=Path(tmpdirname), - task="text-generation", - preprocessors=None, - stateful=False, - ) - - core = ov.Core() - ov_model = core.read_model(Path(tmpdirname) / "openvino_model.xml") - annotation = json.loads(ov_model.get_rt_info()[HIDDEN_STATES_RT_INFO_KEY].value) - layer_idx = 0 - output_name = "decoder_layer_0_hidden_state" - hidden_state_output = _find_output_by_tensor_name(ov_model, annotation["layers"][str(layer_idx)]) - _add_model_output(ov_model, hidden_state_output, output_name) - - input_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - with torch.no_grad(): - torch_outputs = model( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - use_cache=False, - return_dict=True, - ) - - compiled_model = core.compile_model(ov_model, "CPU") - ov_inputs = {} - for input_port in compiled_model.inputs: - input_name = input_port.get_any_name() - if input_name == "input_ids": - ov_inputs[input_name] = input_ids.numpy() - elif input_name == "attention_mask": - ov_inputs[input_name] = attention_mask.numpy() - elif input_name == "position_ids": - ov_inputs[input_name] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1) - elif input_name == "token_type_ids": - ov_inputs[input_name] = np.zeros(input_ids.shape, dtype=np.int64) - else: - self.fail(f"Unexpected OpenVINO model input: {input_name}") - - infer_result = compiled_model(ov_inputs) - ov_output_port = next(output for output in compiled_model.outputs if output_name in output.get_names()) - np.testing.assert_allclose( - infer_result[ov_output_port], - torch_outputs.hidden_states[layer_idx + 1].detach().numpy(), - rtol=1e-4, - atol=1e-4, - ) - - @unittest.skipIf( - not QWEN3_5_MOE_TRANSFORMERS_AVAILABLE, - "Qwen3.5-MoE export requires transformers 5.2.x", - ) - def test_qwen3_5_moe_text_generation_with_past_hidden_state_annotations(self): - with TemporaryDirectory() as tmpdirname: - model_config = _export_qwen3_5_moe_text_model(Path(tmpdirname)) - - ov_model = ov.Core().read_model(Path(tmpdirname) / "openvino_model.xml") - output_names = _port_names(ov_model.outputs) - self.assertFalse(any(name.startswith("ov.hidden_states.") for name in output_names)) - - rt_info = ov_model.get_rt_info() - self.assertIn(HIDDEN_STATES_RT_INFO_KEY, rt_info) - annotation = json.loads(rt_info[HIDDEN_STATES_RT_INFO_KEY].value) - self.assertEqual(annotation["version"], 1) - self.assertEqual(len(annotation["layers"]), model_config.num_hidden_layers) - - graph_tensor_names = set() - for op in ov_model.get_ops(): - for output in op.outputs(): - graph_tensor_names.update(output.get_names()) - for tensor_name in annotation["layers"].values(): - self.assertIn(tensor_name, graph_tensor_names) - - @unittest.skipIf( - not QWEN3_5_MOE_TRANSFORMERS_AVAILABLE, - "Qwen3.5-MoE export requires transformers 5.2.x", - ) - def test_qwen3_5_moe_hidden_state_annotation_preserves_graph_signature(self): - with TemporaryDirectory() as tmpdirname: - baseline_dir = Path(tmpdirname) / "baseline" - annotated_dir = Path(tmpdirname) / "annotated" - _export_qwen3_5_moe_text_model(baseline_dir, annotate_hidden_states=False) - _export_qwen3_5_moe_text_model(annotated_dir) - - core = ov.Core() - baseline_model = core.read_model(baseline_dir / "openvino_model.xml") - annotated_model = core.read_model(annotated_dir / "openvino_model.xml") - - self.assertEqual(_port_names(baseline_model.inputs), _port_names(annotated_model.inputs)) - self.assertEqual(_port_names(baseline_model.outputs), _port_names(annotated_model.outputs)) - self.assertEqual(_op_type_arity_signature(baseline_model), _op_type_arity_signature(annotated_model)) - self.assertEqual( - _has_recurrent_attention_cell(baseline_model), - _has_recurrent_attention_cell(annotated_model), - ) - - @unittest.skipIf( - not QWEN3_5_MOE_TRANSFORMERS_AVAILABLE, - "Qwen3.5-MoE export requires transformers 5.2.x", - ) - def test_qwen3_5_moe_annotated_hidden_state_outputs_match_pytorch(self): - model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["qwen3_5_moe"]) - model.eval() - - with TemporaryDirectory() as tmpdirname: - export_from_model( - model=model, - output=Path(tmpdirname), - task="text-generation", - preprocessors=None, - stateful=False, - ) - - core = ov.Core() - ov_model = core.read_model(Path(tmpdirname) / "openvino_model.xml") - annotation = json.loads(ov_model.get_rt_info()[HIDDEN_STATES_RT_INFO_KEY].value) - layer_indices = [0, model.config.num_hidden_layers - 1] - output_names = [] - for layer_idx in layer_indices: - output_name = f"decoder_layer_{layer_idx}_hidden_state" - hidden_state_output = _find_output_by_tensor_name(ov_model, annotation["layers"][str(layer_idx)]) - _add_model_output(ov_model, hidden_state_output, output_name) - output_names.append(output_name) - - input_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - with torch.no_grad(): - torch_outputs = model( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - use_cache=False, - return_dict=True, - ) - - compiled_model = core.compile_model(ov_model, "CPU") - ov_inputs = {} - for input_port in compiled_model.inputs: - input_name = input_port.get_any_name() - if input_name == "input_ids": - ov_inputs[input_name] = input_ids.numpy() - elif input_name == "attention_mask": - ov_inputs[input_name] = attention_mask.numpy() - elif input_name == "position_ids": - ov_inputs[input_name] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1) - else: - self.fail(f"Unexpected OpenVINO model input: {input_name}") - - infer_result = compiled_model(ov_inputs) - for layer_idx, output_name in zip(layer_indices, output_names): - ov_output_port = next(output for output in compiled_model.outputs if output_name in output.get_names()) - np.testing.assert_allclose( - infer_result[ov_output_port], - torch_outputs.hidden_states[layer_idx + 1].detach().numpy(), - rtol=5e-3, - atol=5e-3, - ) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index d8d09cd600..8a842c3159 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -352,15 +352,11 @@ def _create_tiny_kokoro_model(): "ltx-video": "optimum-intel-internal-testing/tiny-random-ltx-video", "zamba2": "optimum-intel-internal-testing/tiny-random-zamba2", "qwen3_eagle3": "AngelSlim/Qwen3-1.7B_eagle3", - "qwen3_coder_dflash": "z-lab/Qwen3-Coder-30B-A3B-DFlash", "qwen3_vl_eagle3": "optimum-intel-internal-testing/tiny-random-qwen3-vl-eagle3", "videochat_flash_qwen": "optimum-intel-internal-testing/tiny-videochat-flash-qwen", } EAGLE3_MODELS = {"qwen3_eagle3": ("AngelSlim/Qwen3-1.7B_eagle3", "Qwen/Qwen3-1.7B")} -DFLASH_MODELS = { - "qwen3_coder_dflash": ("z-lab/Qwen3-Coder-30B-A3B-DFlash", "Qwen/Qwen3-Coder-30B-A3B-Instruct") -} # VLM-based Eagle3 draft models (AngelSlim Eagle3LlamaForCausalLM architecture). # These use Qwen3-VL MRoPE and target VLM models for speculative decoding. @@ -578,7 +574,6 @@ def _create_tiny_kokoro_model(): "minicpm3", "deepseek", "qwen3_eagle3", - "qwen3_coder_dflash", "qwen3_vl_eagle3", "qwen3_asr", "videochat_flash_qwen", diff --git a/tests/scripts/compare_dflash_cache_semantics.py b/tests/scripts/compare_dflash_cache_semantics.py deleted file mode 100644 index a561f365b7..0000000000 --- a/tests/scripts/compare_dflash_cache_semantics.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python -# Copyright 2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Compare original DFlash bundle references with the Optimum patched export model. - -The bundle is produced by extract_dflash_debug_bundle.py on a machine that can -load the full target model. It contains original DFlash logits and original -DFlash K/V caches cropped to the committed target-backed prefix. This helper -loads only the patched Optimum DFlash model, replays the same steps, and checks -that both logits and committed-prefix caches match. -""" - -import argparse - -import torch -from transformers import AutoConfig - - -def cache_to_layers(cache) -> tuple: - if hasattr(cache, "to_legacy_cache"): - return tuple(cache.to_legacy_cache()) - if hasattr(cache, "layers"): - layers = [] - for layer in cache.layers: - key = getattr(layer, "keys", None) - value = getattr(layer, "values", None) - if key is None: - key = getattr(layer, "key_cache", None) - if value is None: - value = getattr(layer, "value_cache", None) - layers.append((key, value)) - return tuple(layers) - return tuple(cache) - - -def cache_seq_length(cache) -> int: - layers = cache_to_layers(cache) - if not layers: - return 0 - return layers[0][0].shape[-2] - - -def cache_to_cpu(cache) -> tuple: - return tuple( - (key.detach().float().cpu().contiguous(), value.detach().float().cpu().contiguous()) - for key, value in cache_to_layers(cache) - ) - - -def assert_cache_close(actual_cache, expected_cache, *, step_idx: int, rtol: float, atol: float): - actual_layers = cache_to_cpu(actual_cache) - if len(actual_layers) != len(expected_cache): - raise AssertionError(f"Step {step_idx}: layer count mismatch {len(actual_layers)} != {len(expected_cache)}") - for layer_idx, ((actual_key, actual_value), (expected_key, expected_value)) in enumerate( - zip(actual_layers, expected_cache) - ): - torch.testing.assert_close( - actual_key, - expected_key.float(), - rtol=rtol, - atol=atol, - msg=f"Step {step_idx}, layer {layer_idx}: key cache mismatch", - ) - torch.testing.assert_close( - actual_value, - expected_value.float(), - rtol=rtol, - atol=atol, - msg=f"Step {step_idx}, layer {layer_idx}: value cache mismatch", - ) - - -def load_patched_dflash(draft_model: str, target_model: str, dtype: torch.dtype, device_map: str): - from optimum.exporters.openvino.__main__ import update_config_for_dflash - from optimum.exporters.openvino.model_patcher import Qwen3DFlashForCausalLM - - config = AutoConfig.from_pretrained(draft_model, trust_remote_code=True) - config = update_config_for_dflash(config, dflash_target_model=target_model) - return Qwen3DFlashForCausalLM.from_pretrained( - draft_model, - config=config, - dtype=dtype, - device_map=device_map, - trust_remote_code=True, - attn_implementation="eager", - ).eval() - - -def main(): - parser = argparse.ArgumentParser(description="Compare patched DFlash against original committed-cache bundle data.") - parser.add_argument("--bundle", required=True, help="Path to a bundle produced by extract_dflash_debug_bundle.py") - parser.add_argument("--draft-model", default=None, help="Override draft model ID/path from the bundle metadata.") - parser.add_argument("--target-model", default=None, help="Override target model ID/path from the bundle metadata.") - parser.add_argument("--dtype", default=None, choices=["float16", "bfloat16", "float32"]) - parser.add_argument("--device-map", default="auto") - parser.add_argument("--rtol", type=float, default=5e-2) - parser.add_argument("--atol", type=float, default=5e-2) - args = parser.parse_args() - - bundle = torch.load(args.bundle, map_location="cpu") - metadata = bundle.get("metadata", {}) - draft_model = args.draft_model or metadata.get("draft_model", "z-lab/Qwen3-Coder-30B-A3B-DFlash") - target_model = args.target_model or metadata.get("target_model", "Qwen/Qwen3-Coder-30B-A3B-Instruct") - dtype = getattr(torch, args.dtype or metadata.get("dtype", "bfloat16")) - steps = bundle["steps"] - - patched_draft = load_patched_dflash(draft_model, target_model, dtype, args.device_map) - patched_device = next(patched_draft.parameters()).device - patched_past_key_values = None - with torch.inference_mode(): - for step_idx, step in enumerate(steps): - if "original_committed_cache" not in step: - raise ValueError("Bundle is missing original_committed_cache. Regenerate it with the updated extractor.") - original_logits = step.get("original_cached_logits", step["expected_logits"]) - patched_outputs = patched_draft( - input_ids=step["input_ids"].to(patched_device), - hidden_states=step["hidden_states"].to(patched_device), - position_ids=step["position_ids"].to(patched_device), - past_key_values=patched_past_key_values, - use_cache=True, - ) - patched_past_key_values = patched_outputs.past_key_values - torch.testing.assert_close( - patched_outputs.logits.detach().float().cpu(), - original_logits.float(), - rtol=args.rtol, - atol=args.atol, - msg=f"Step {step_idx}: logits mismatch", - ) - assert_cache_close( - patched_past_key_values, - step["original_committed_cache"], - step_idx=step_idx, - rtol=args.rtol, - atol=args.atol, - ) - patched_length = cache_seq_length(patched_past_key_values) - if patched_length != step["expected_present_length"]: - raise AssertionError( - f"Step {step_idx}: patched cache length {patched_length} != {step['expected_present_length']}" - ) - print(f"Step {step_idx}: patched logits and committed-prefix cache match original bundle reference") - - print("DFlash committed-prefix cache semantics match the original implementation.") - - -if __name__ == "__main__": - main() From 115b3434a7eec2485d43e9bc57cc907ac20a5843 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Sat, 30 May 2026 12:20:06 -0700 Subject: [PATCH 11/11] Remove finished todo and left over testing not needed --- optimum/exporters/openvino/__main__.py | 10 - tests/scripts/extract_dflash_debug_bundle.py | 274 ------------------- 2 files changed, 284 deletions(-) delete mode 100644 tests/scripts/extract_dflash_debug_bundle.py diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 7113dadc2b..c6ddfeb860 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -208,16 +208,6 @@ def update_config_for_dflash( "AutoModel": moduler_path + "--model_patcher.Qwen3DFlashDraftModel", "AutoModelForCausalLM": moduler_path + "--model_patcher.Qwen3DFlashForCausalLM", } - # TODO ofir: remove this after implementing load time override in openvino.genai - dflash_block_size_override = os.environ.get("DFLASH_BLOCK_SIZE_OVERRIDE") - if dflash_block_size_override: - try: - block_size = int(dflash_block_size_override) - except ValueError as exc: - raise ValueError("DFLASH_BLOCK_SIZE_OVERRIDE must be an integer.") from exc - if block_size <= 1: - raise ValueError("DFLASH_BLOCK_SIZE_OVERRIDE must be greater than 1.") - config.block_size = block_size config.dflash_target_model = dflash_target_model config.dflash_target_cache_dir = cache_dir diff --git a/tests/scripts/extract_dflash_debug_bundle.py b/tests/scripts/extract_dflash_debug_bundle.py deleted file mode 100644 index fe90cea6db..0000000000 --- a/tests/scripts/extract_dflash_debug_bundle.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright 2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Create a portable DFlash correctness fixture. - -This script is standalone: run it on any machine that can load the PyTorch -target and DFlash draft models, then copy the resulting `.pt` bundle to the -machine that runs the OpenVINO export tests. - -Required packages: - - torch - - transformers >= 4.57 - - accelerate, if using --device-map auto - - safetensors - - huggingface_hub - -Example: - python tests/scripts/extract_dflash_debug_bundle.py \ - --draft-model z-lab/Qwen3-Coder-30B-A3B-DFlash \ - --target-model Qwen/Qwen3-Coder-30B-A3B-Instruct \ - --dtype float16 \ - --num-steps 4 \ - --output dflash_debug_bundle_kv.pt -""" - -import argparse - -import torch -from transformers.cache_utils import DynamicCache -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer - - -def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: - if temperature == 0: - return logits.argmax(dim=-1) - return torch.distributions.Categorical(logits=logits / temperature).sample() - - -def extract_context_feature(hidden_states: list[torch.Tensor], layer_ids: list[int]) -> torch.Tensor: - # hidden_states[0] is the embedding output, so model layer ids are offset by one. - return torch.cat([hidden_states[layer_id + 1] for layer_id in layer_ids], dim=-1) - - -def cache_to_layers(cache) -> tuple: - if hasattr(cache, "to_legacy_cache"): - return tuple(cache.to_legacy_cache()) - if hasattr(cache, "layers"): - layers = [] - for layer in cache.layers: - key = getattr(layer, "keys", None) - value = getattr(layer, "values", None) - if key is None: - key = getattr(layer, "key_cache", None) - if value is None: - value = getattr(layer, "value_cache", None) - layers.append((key, value)) - return tuple(layers) - return tuple(cache) - - -def crop_tensor(tensor: torch.Tensor, length: int) -> torch.Tensor: - return tensor[..., :length, :].contiguous() - - -def crop_cache(cache, length: int): - if hasattr(cache, "crop"): - cache.crop(length) - return cache - return tuple((crop_tensor(key, length), crop_tensor(value, length)) for key, value in cache_to_layers(cache)) - - -def cache_to_cpu(cache) -> tuple: - return tuple( - (key.detach().cpu().contiguous(), value.detach().cpu().contiguous()) for key, value in cache_to_layers(cache) - ) - - -def original_outputs_to_hidden_and_cache(outputs, fallback_cache=None): - if hasattr(outputs, "last_hidden_state"): - return outputs.last_hidden_state, outputs.past_key_values - if isinstance(outputs, tuple): - if len(outputs) > 1: - return outputs[0], outputs[1] - if fallback_cache is not None: - return outputs[0], fallback_cache - if torch.is_tensor(outputs) and fallback_cache is not None: - return outputs, fallback_cache - raise TypeError("Original DFlash model did not return a cache. Make sure use_cache=True is supported.") - - -def run_dflash_block( - draft, - target, - block_input_ids: torch.Tensor, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - block_size: int, -) -> torch.Tensor: - target_device = block_input_ids.device - noise_embedding = target.model.embed_tokens(block_input_ids.to(target_device)).to(next(draft.parameters()).device) - draft_hidden = draft( - target_hidden=hidden_states.to(next(draft.parameters()).device), - noise_embedding=noise_embedding, - position_ids=position_ids.to(next(draft.parameters()).device), - use_cache=False, - is_causal=False, - ) - return target.lm_head(draft_hidden[:, -block_size + 1 :, :].to(target_device)) - - -def run_dflash_cached_block( - draft, - target, - block_input_ids: torch.Tensor, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - past_key_values, - block_size: int, -) -> tuple[torch.Tensor, object]: - target_device = block_input_ids.device - draft_device = next(draft.parameters()).device - if past_key_values is None: - past_key_values = DynamicCache(config=draft.config) - noise_embedding = target.model.embed_tokens(block_input_ids.to(target_device)).to(draft_device) - outputs = draft( - target_hidden=hidden_states.to(draft_device), - noise_embedding=noise_embedding, - position_ids=position_ids.to(draft_device), - past_key_values=past_key_values, - use_cache=True, - is_causal=False, - ) - draft_hidden, past_key_values = original_outputs_to_hidden_and_cache(outputs, fallback_cache=past_key_values) - logits = target.lm_head(draft_hidden[:, -block_size + 1 :, :].to(target_device)) - return logits, past_key_values - - -def main(): - parser = argparse.ArgumentParser(description="Create a lightweight DFlash correctness fixture.") - parser.add_argument("--draft-model", default="z-lab/Qwen3-Coder-30B-A3B-DFlash") - parser.add_argument("--target-model", default="Qwen/Qwen3-Coder-30B-A3B-Instruct") - parser.add_argument("--prompt", default="Write a quicksort in Python.") - parser.add_argument("--output", required=True) - parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"]) - parser.add_argument("--device-map", default="auto") - parser.add_argument("--temperature", type=float, default=0.0) - parser.add_argument("--num-steps", type=int, default=4) - args = parser.parse_args() - if args.num_steps < 1: - raise ValueError("--num-steps must be at least 1") - - dtype = getattr(torch, args.dtype) - tokenizer = AutoTokenizer.from_pretrained(args.target_model, trust_remote_code=True) - target = AutoModelForCausalLM.from_pretrained( - args.target_model, - dtype=dtype, - device_map=args.device_map, - trust_remote_code=True, - ).eval() - draft = AutoModel.from_pretrained( - args.draft_model, - dtype=dtype, - device_map=args.device_map, - trust_remote_code=True, - attn_implementation="eager", - ).eval() - - device = next(target.parameters()).device - input_ids = tokenizer(args.prompt, return_tensors="pt").input_ids.to(device) - block_size = draft.config.block_size - mask_token_id = draft.config.dflash_config["mask_token_id"] - target_layer_ids = draft.config.dflash_config["target_layer_ids"] - - steps = [] - committed_input_ids = input_ids - committed_hidden_length = 0 - original_past_key_values = None - with torch.inference_mode(): - for step_idx in range(args.num_steps): - committed_length = committed_input_ids.shape[1] - target_position_ids = torch.arange(committed_length, device=device).unsqueeze(0) - target_output = target( - committed_input_ids, - position_ids=target_position_ids, - use_cache=False, - logits_to_keep=1, - output_hidden_states=True, - ) - seed_token = sample(target_output.logits[:, -1, :], args.temperature) - - block_input_ids = torch.full((1, block_size), mask_token_id, dtype=torch.long, device=device) - block_input_ids[:, 0] = seed_token - - full_hidden_states = extract_context_feature(target_output.hidden_states, target_layer_ids) - hidden_states = full_hidden_states[:, committed_hidden_length:, :] - position_start = committed_hidden_length - position_ids = torch.arange(position_start, committed_length + block_size, device=device).unsqueeze(0) - full_position_ids = torch.arange(committed_length + block_size, device=device).unsqueeze(0) - expected_logits = run_dflash_block( - draft, - target, - block_input_ids, - full_hidden_states, - full_position_ids, - block_size, - ) - original_cached_logits, original_past_key_values = run_dflash_cached_block( - draft, - target, - block_input_ids, - hidden_states, - position_ids, - original_past_key_values, - block_size, - ) - - committed_hidden_length += hidden_states.shape[1] - original_past_key_values = crop_cache(original_past_key_values, committed_hidden_length) - steps.append( - { - "index": step_idx, - "input_ids": block_input_ids.cpu(), - "hidden_states": hidden_states.cpu(), - "position_ids": position_ids.cpu(), - "expected_logits": original_cached_logits.cpu(), - "expected_logits_full_prefix": expected_logits.cpu(), - "original_cached_logits": original_cached_logits.cpu(), - "original_committed_cache": cache_to_cpu(original_past_key_values), - "seed_token": seed_token.cpu(), - "sampled_tokens": sample(expected_logits, args.temperature).cpu(), - "expected_present_length": committed_hidden_length, - } - ) - committed_input_ids = torch.cat([committed_input_ids, seed_token[:, None]], dim=1) - - bundle = { - "steps": steps, - # Keep first-step keys for quick ad-hoc inspection and older local scripts. - "input_ids": steps[0]["input_ids"], - "hidden_states": steps[0]["hidden_states"], - "position_ids": steps[0]["position_ids"], - "expected_logits": steps[0]["expected_logits"], - "sampled_tokens": steps[0]["sampled_tokens"], - "expected_present_length": steps[0]["expected_present_length"], - "metadata": { - "draft_model": args.draft_model, - "target_model": args.target_model, - "prompt": args.prompt, - "num_steps": args.num_steps, - "block_size": block_size, - "mask_token_id": mask_token_id, - "target_layer_ids": target_layer_ids, - "dtype": args.dtype, - "temperature": args.temperature, - "cache_policy": "committed_prefix", - }, - } - torch.save(bundle, args.output) - - -if __name__ == "__main__": - main()