diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index bc63bb760f..acfe970e04 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -294,6 +294,15 @@ def parse_args_openvino(parser: "ArgumentParser"): type=json.loads, help=("Any kwargs passed to the model forward, or used to customize the export for a given model."), ) + optional_group.add_argument( + "--dflash-target-model", + type=str, + default=None, + help=( + "Target model ID or local path used when exporting DFlash draft models. Only the target token " + "embedding and lm_head weights are loaded." + ), + ) def no_compression_parameter_provided(args): @@ -479,6 +488,7 @@ def run(self): library_name=library_name, variant=self.args.variant, model_kwargs=self.args.model_kwargs, + dflash_target_model=self.args.dflash_target_model, # **input_shapes, ) if apply_main_quantize: diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 1fb3f5d43c..c6ddfeb860 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -189,6 +189,35 @@ def update_config_for_eagle3(config): return config +def update_config_for_dflash( + config, + dflash_target_model: Optional[str], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + revision: str = "main", + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, +): + if dflash_target_model is None: + raise ValueError("Exporting DFlash draft models requires `dflash_target_model` / `--dflash-target-model`.") + + moduler_name = "optimum.exporters.openvino.model_patcher" + spec = importlib.util.find_spec(moduler_name) + if spec and spec.origin: + moduler_path = os.path.dirname(spec.origin) + config.auto_map = { + "AutoModel": moduler_path + "--model_patcher.Qwen3DFlashDraftModel", + "AutoModelForCausalLM": moduler_path + "--model_patcher.Qwen3DFlashForCausalLM", + } + + config.dflash_target_model = dflash_target_model + config.dflash_target_cache_dir = cache_dir + config.dflash_target_revision = revision + config.dflash_target_token = token + config.dflash_target_local_files_only = local_files_only + config.tie_word_embeddings = False + return config + + def infer_library_name( model_name_or_path: str, subfolder: str = "", @@ -235,6 +264,7 @@ def main_export( library_name: Optional[str] = None, model_loading_kwargs: Optional[Dict[str, Any]] = None, variant: Optional[str] = None, + dflash_target_model: Optional[str] = None, **kwargs_shapes, ): """ @@ -358,6 +388,15 @@ def main_export( _eagle3_archs = {"LlamaForCausalLMEagle3", "Eagle3LlamaForCausalLM"} if isinstance(archs, list) and len(archs) > 0 and archs[0] in _eagle3_archs: loading_kwargs["config"] = update_config_for_eagle3(config) + elif isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel": + loading_kwargs["config"] = update_config_for_dflash( + config, + dflash_target_model=dflash_target_model, + cache_dir=cache_dir, + revision=revision, + token=token, + local_files_only=local_files_only, + ) # mxfp4 quantized model will be dequantized to bf16 if quant_method == "mxfp4" and is_transformers_version(">=", "4.55"): diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index a1a928501d..278af7a522 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -16,6 +16,7 @@ import functools import gc import inspect +import json import logging import os from pathlib import Path @@ -83,6 +84,9 @@ logger = logging.getLogger(__name__) +HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" +HIDDEN_STATE_TENSOR_NAME_TEMPLATE = "ov.hidden_states.decoder_layer_{}" + if is_torch_available(): import torch.nn as nn from transformers.modeling_utils import PreTrainedModel @@ -138,12 +142,94 @@ def _save_model( if getattr(config, "eagle3", False): model = _add_eagle3_mode_to_rt_info(model) + if getattr(config, "dflash", False): + model = _add_dflash_mode_to_rt_info(model, config._config) save_model(model, path, compress_to_fp16) del model gc.collect() +def _can_annotate_hidden_states(model, config: "OnnxConfig") -> bool: + if "text-generation" not in getattr(config, "task", ""): + return False + + model_config = getattr(model, "config", None) + if model_config is None or not hasattr(model_config, "num_hidden_layers"): + return False + if not hasattr(model_config, "output_hidden_states"): + return False + + can_record_outputs = getattr(model, "_can_record_outputs", None) + if isinstance(can_record_outputs, dict) and not can_record_outputs.get("hidden_states", False): + return False + + return True + + +def _annotate_hidden_state_outputs(model: Model, public_output_names: List[str], num_hidden_layers: int): + public_output_names = set(public_output_names) + temporary_results = [] + for result in model.get_results(): + result_names = set(result.output(0).get_names()) + if result_names & public_output_names: + continue + temporary_results.append(result) + + if not temporary_results: + return + + # HF hidden_states usually contains embeddings first, then one tensor after each decoder block. + offset = 1 if len(temporary_results) >= num_hidden_layers + 1 else 0 + hidden_layer_results = temporary_results[offset : offset + num_hidden_layers] + if len(hidden_layer_results) == num_hidden_layers: + layers = {} + for layer_idx, result in enumerate(hidden_layer_results): + tensor_name = HIDDEN_STATE_TENSOR_NAME_TEMPLATE.format(layer_idx) + result.input_value(0).get_tensor().add_names({tensor_name}) + layers[str(layer_idx)] = tensor_name + + model.set_rt_info( + json.dumps({"version": 1, "layers": layers}), + HIDDEN_STATES_RT_INFO_KEY, + ) + else: + logger.debug( + "Skipping hidden-state annotation: expected at least %s hidden-state outputs, got %s.", + num_hidden_layers, + len(temporary_results), + ) + + for result in temporary_results: + model.remove_result(result) + model.validate_nodes_and_infer_types() + + +def _enable_hidden_states_in_config_outputs(config: "OnnxConfig", num_hidden_layers: int): + original_class = config.__class__ + + class ConfigWithHiddenStateOutputs(original_class): + @property + def outputs(self): + outputs = original_class.outputs.fget(self).copy() + for idx in range(num_hidden_layers + 1): + outputs[f"hidden_states.{idx}"] = {0: "batch_size", 1: "sequence_length"} + return outputs + + config.__class__ = ConfigWithHiddenStateOutputs + return original_class + + +def _flatten_hidden_states_outputs(outputs: Dict): + hidden_states = outputs.pop("hidden_states", None) + if hidden_states is None: + return outputs + + for idx, hidden_state in enumerate(hidden_states): + outputs[f"hidden_states.{idx}"] = hidden_state + return outputs + + def export( model: Union["PreTrainedModel", "ModelMixin", "DiffusionPipeline"], config: "OnnxConfig", @@ -382,6 +468,17 @@ def export_pytorch( dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs) dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs) + output_names = list(config.outputs.keys()) + annotate_hidden_states = _can_annotate_hidden_states(model, config) + original_output_hidden_states = None + original_config_class = None + if annotate_hidden_states: + original_output_hidden_states = model.config.output_hidden_states + model.config.output_hidden_states = True + if hasattr(config, "_config") and hasattr(config._config, "output_hidden_states"): + config._config.output_hidden_states = True + original_config_class = _enable_hidden_states_in_config_outputs(config, model.config.num_hidden_layers) + # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching, # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output # To handle it, additional wrapper on patcher forward applied. @@ -406,6 +503,8 @@ def ts_patched_forward(*args, **kwargs): input_dict = dict(zip(keys, tuple_input)) kwargs[input_name] = input_dict outputs = patched_forward(**kwargs) + if annotate_hidden_states: + outputs = _flatten_hidden_states_outputs(outputs) return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()]) patcher.patched_forward = ts_patched_forward @@ -449,9 +548,14 @@ def ts_patched_forward(*args, **kwargs): extension=conversion_extensions, ) + if annotate_hidden_states: + config.__class__ = original_config_class + model.config.output_hidden_states = original_output_hidden_states + if hasattr(config, "_config") and hasattr(config._config, "output_hidden_states"): + config._config.output_hidden_states = original_output_hidden_states + ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation? - output_names = list(config.outputs.keys()) for idx, out_tensor in enumerate(ov_model.outputs): if idx < len(output_names): out_tensor.get_tensor().set_names({output_names[idx]}) @@ -464,6 +568,9 @@ def ts_patched_forward(*args, **kwargs): if stateful: patch_stateful(model.config, ov_model) + if annotate_hidden_states: + _annotate_hidden_state_outputs(ov_model, output_names, model.config.num_hidden_layers) + library_name = _infer_library_from_model_or_model_class(model=model, library_name=library_name) _save_model( @@ -979,6 +1086,25 @@ def _add_eagle3_mode_to_rt_info(model: Model): return model +def _add_dflash_mode_to_rt_info(model: Model, hf_config): + """ + Add DFlash metadata. + """ + try: + model.set_rt_info("True", ["dflash_mode"]) + model.set_rt_info(str(getattr(hf_config, "block_size", "")), ["dflash", "block_size"]) + model.set_rt_info("committed_prefix", ["dflash", "cache_policy"]) + dflash_config = getattr(hf_config, "dflash_config", {}) + if "mask_token_id" in dflash_config: + model.set_rt_info(str(dflash_config["mask_token_id"]), ["dflash", "mask_token_id"]) + if "target_layer_ids" in dflash_config: + model.set_rt_info(",".join(map(str, dflash_config["target_layer_ids"])), ["dflash", "target_layer_ids"]) + except Exception: + pass + + return model + + def _add_version_info_to_model(model: Model, library_name: Optional[str] = None): """ Add dependency versions to OpenVINO model diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 19d954b62a..2d23520079 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -421,6 +421,55 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): _MODEL_PATCHER = Qwen2MoEPatcher +class DFlashDummyInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("input_ids", "hidden_states", "position_ids") + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + self.batch_size = batch_size + self.context_length = sequence_length + self.hidden_size = normalized_config.hidden_size + config = normalized_config.config + self.block_size = getattr(config, "block_size", sequence_length) + dflash_config = getattr(config, "dflash_config", {}) + self.num_target_layers = len(dflash_config.get("target_layer_ids", [])) + self.vocab_size = getattr(config, "vocab_size", normalized_config.vocab_size) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "input_ids": + return self.random_int_tensor( + [self.batch_size, self.block_size], + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ) + if input_name == "hidden_states": + return self.random_float_tensor( + [self.batch_size, self.context_length, self.hidden_size * self.num_target_layers], + framework=framework, + dtype=float_dtype, + ) + if input_name == "position_ids": + position_length = self.context_length + self.block_size + if framework == "pt": + return torch.arange(position_length, dtype=DTYPE_MAPPER.pt(int_dtype)).unsqueeze(0).repeat( + self.batch_size, 1 + ) + return self.random_int_tensor( + [self.batch_size, position_length], + max_value=position_length, + framework=framework, + dtype=int_dtype, + ) + raise ValueError(f"Unsupported DFlash input name: {input_name}") + + @register_in_tasks_manager( "qwen3", *[ @@ -439,8 +488,40 @@ class Qwen3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig _MODEL_PATCHER = OVDecoderModelPatcher + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + preprocessors: list[Any] | None = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + preprocessors=preprocessors, + ) + archs = getattr(config, "architectures", None) + self.dflash = isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel" + if self.dflash: + self.DUMMY_INPUT_GENERATOR_CLASSES = (DFlashDummyInputGenerator, GemmaDummyPastKeyValuesGenerator) + self.MIN_TRANSFORMERS_VERSION = "4.57.0" + @property def inputs(self) -> Dict[str, Dict[int, str]]: + if self.dflash: + common_inputs = super().inputs + common_inputs.pop("attention_mask", None) + common_inputs["input_ids"] = {0: "batch_size", 1: "block_size"} + common_inputs["hidden_states"] = {0: "batch_size", 1: "context_length", 2: "target_hidden_size"} + common_inputs["position_ids"] = {0: "batch_size", 1: "position_sequence_length"} + return common_inputs if self.task in ["feature-extraction"]: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, @@ -450,6 +531,22 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs return common_inputs + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self.dflash: + common_outputs = super().outputs + common_outputs["logits"] = {0: "batch_size", 1: "draft_sequence_length"} + return common_outputs + return super().outputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + super().add_past_key_values(inputs_or_outputs, direction) + if self.dflash and direction == "outputs": + for axes in inputs_or_outputs.values(): + for axis, name in axes.items(): + if name == "past_sequence_length + sequence_length": + axes[axis] = "past_sequence_length + context_length" + class DummyQwen3VLLMInputGenerator(DummyTextInputGenerator): SUPPORTED_INPUT_NAMES = ( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 6fcc9de4de..a0805d9aec 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -14,11 +14,13 @@ import functools import inspect +import json import logging import logging as log import math import types from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -31,6 +33,7 @@ BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, + CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.models.llama.configuration_llama import LlamaConfig @@ -71,6 +74,34 @@ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock if is_transformers_version(">=", "4.54"): from transformers.masking_utils import create_causal_mask +if is_transformers_version(">=", "4.57"): + from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Config, + Qwen3MLP, + Qwen3PreTrainedModel, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + eager_attention_forward as qwen3_eager_attention_forward, + rotate_half as qwen3_rotate_half, + ) +else: + Qwen3Config = PretrainedConfig + Qwen3PreTrainedModel = PreTrainedModel + + class _UnavailableQwen3Module(nn.Module): + def __init__(self, *args, **kwargs): + raise ImportError("DFlash export requires transformers >= 4.57.") + + Qwen3MLP = _UnavailableQwen3Module + Qwen3RMSNorm = _UnavailableQwen3Module + Qwen3RotaryEmbedding = _UnavailableQwen3Module + + def qwen3_eager_attention_forward(*args, **kwargs): + raise ImportError("DFlash export requires transformers >= 4.57.") + + def qwen3_rotate_half(*args, **kwargs): + raise ImportError("DFlash export requires transformers >= 4.57.") + if is_transformers_version(">=", "4.56"): import transformers.masking_utils if is_transformers_version(">=", "4.57"): @@ -8644,6 +8675,343 @@ def forward( ) +def _dflash_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (qwen3_rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (qwen3_rotate_half(k) * sin) + return q_embed, k_embed + + +def _dflash_resolve_tensor_file(config: PretrainedConfig, tensor_name: str) -> str: + target_model = getattr(config, "dflash_target_model", None) + if target_model is None: + raise ValueError("DFlash logits export requires `dflash_target_model` to load target weights.") + + target_path = Path(target_model) + if target_path.is_dir(): + index_path = target_path / "model.safetensors.index.json" + if index_path.exists(): + with index_path.open() as f: + weight_map = json.load(f)["weight_map"] + return str(target_path / weight_map[tensor_name]) + single_file = target_path / "model.safetensors" + if single_file.exists(): + return str(single_file) + raise FileNotFoundError(f"Could not find safetensors weights in {target_path}.") + + from huggingface_hub import hf_hub_download + + cache_dir = getattr(config, "dflash_target_cache_dir", None) + revision = getattr(config, "dflash_target_revision", "main") + token = getattr(config, "dflash_target_token", None) + local_files_only = getattr(config, "dflash_target_local_files_only", False) + + try: + index_path = hf_hub_download( + target_model, + "model.safetensors.index.json", + cache_dir=cache_dir, + revision=revision, + token=token, + local_files_only=local_files_only, + ) + except Exception: + filename = "model.safetensors" + else: + with open(index_path) as f: + weight_map = json.load(f)["weight_map"] + filename = weight_map[tensor_name] + + return hf_hub_download( + target_model, + filename, + cache_dir=cache_dir, + revision=revision, + token=token, + local_files_only=local_files_only, + ) + + +def _dflash_load_target_tensor(config: PretrainedConfig, tensor_names: Tuple[str, ...]) -> torch.Tensor: + from safetensors import safe_open + + last_error = None + for tensor_name in tensor_names: + try: + tensor_file = _dflash_resolve_tensor_file(config, tensor_name) + with safe_open(tensor_file, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + return f.get_tensor(tensor_name) + except Exception as error: + last_error = error + raise ValueError(f"Could not load any of {tensor_names} from DFlash target model.") from last_error + + +_DFLASH_EMBED_TENSOR_NAMES = ( + "model.language_model.embed_tokens.weight", + "model.embed_tokens.weight", + "embed_tokens.weight", +) + + +class Qwen3DFlashAttention(nn.Module): + """Qwen3 attention variant used by DFlash, where draft tokens attend over target context and noise tokens.""" + + def __init__(self, config: "Qwen3Config", layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) + query_states = self.q_norm(query_states).transpose(1, 2) + + kv_hidden_states = torch.cat([target_hidden, hidden_states], dim=1) + key_states = self.k_proj(kv_hidden_states).view(bsz, ctx_len + q_len, -1, self.head_dim) + value_states = self.v_proj(kv_hidden_states).view(bsz, ctx_len + q_len, -1, self.head_dim) + key_states = self.k_norm(key_states).transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = _dflash_apply_rotary_pos_emb(query_states, key_states, cos, sin) + target_key_states, block_key_states = key_states.split([ctx_len, q_len], dim=2) + target_value_states, block_value_states = value_states.split([ctx_len, q_len], dim=2) + + if past_key_values is not None: + # Persist only committed target-prefix K/V. The speculative block + # stays local to this inference, so rejection never requires cache trim. + target_cache_position = cache_position[:ctx_len] if cache_position is not None else None + cache_kwargs = {"sin": sin[:, :ctx_len], "cos": cos[:, :ctx_len], "cache_position": target_cache_position} + target_key_states, target_value_states = past_key_values.update( + target_key_states, + target_value_states, + self.layer_idx, + cache_kwargs, + ) + + key_states = torch.cat([target_key_states, block_key_states], dim=2) + value_states = torch.cat([target_value_states, block_value_states], dim=2) + + attn_output, attn_weights = qwen3_eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlashDecoderLayer(nn.Module): + def __init__(self, config: "Qwen3Config", layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + target_hidden: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3DFlashDraftModel(Qwen3PreTrainedModel): + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlashDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.layers = nn.ModuleList( + [Qwen3DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + dflash_config = getattr(config, "dflash_config", {}) + self.target_layer_ids = dflash_config.get("target_layer_ids", []) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.fc = nn.Linear(len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.mask_token_id = dflash_config.get("mask_token_id", None) + self.post_init() + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + noise_states = noise_embedding + target_hidden = hidden_states.to(noise_states.dtype) + target_hidden = self.hidden_norm(self.fc(target_hidden)) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + elif not isinstance(past_key_values, Cache): + if is_transformers_version("<", "5"): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCache(past_key_values) + if use_cache and cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + target_hidden.shape[1] + noise_states.shape[1], + device=noise_states.device, + ) + position_embeddings = self.rotary_emb(noise_states, position_ids) + for layer in self.layers: + noise_states = layer( + hidden_states=noise_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + return BaseModelOutputWithPast( + last_hidden_state=self.norm(noise_states), + past_key_values=past_key_values if use_cache else None, + ) + + +class Qwen3DFlashForCausalLM(Qwen3DFlashDraftModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _keys_to_ignore_on_load_missing = [r"embed_tokens.weight", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def _load_target_weights(self, config): + embed_weight = _dflash_load_target_tensor(config, _DFLASH_EMBED_TENSOR_NAMES) + try: + lm_head_weight = _dflash_load_target_tensor(config, ("lm_head.weight",)) + except ValueError: + lm_head_weight = embed_weight + + target_dtype = self.fc.weight.dtype + embed_weight = embed_weight.to(target_dtype) + lm_head_weight = lm_head_weight.to(target_dtype) + + self.embed_tokens.weight = nn.Parameter(embed_weight, requires_grad=False) + self.lm_head.weight = nn.Parameter(lm_head_weight, requires_grad=False) + + @classmethod + def from_pretrained(cls, *model_args, **kwargs): + output_loading_info = kwargs.get("output_loading_info", False) + result = super().from_pretrained(*model_args, **kwargs) + + if output_loading_info: + model, loading_info = result + model._load_target_weights(model.config) + return model, loading_info + + result._load_target_weights(result.config) + return result + + def forward( + self, + input_ids: torch.LongTensor, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + logits_to_keep: Optional[int] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + noise_embedding = self.embed_tokens(input_ids) + outputs = super().forward( + hidden_states=hidden_states, + noise_embedding=noise_embedding, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + if logits_to_keep is None: + hidden_states = outputs.last_hidden_state[:, 1:, :] + else: + hidden_states = outputs.last_hidden_state[:, -logits_to_keep:, :] + logits = self.lm_head(hidden_states) + return CausalLMOutputWithPast(logits=logits, past_key_values=outputs.past_key_values) + + # Patched implementation of the gated delta rule in recurrent form. # Adapted from: # https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L522 @@ -9452,6 +9820,9 @@ def patched_forward( text_config = self._text_config num_full_attn_layers = text_config.layer_types.count("full_attention") num_linear_attn_layers = text_config.layer_types.count("linear_attention") + output_hidden_states = getattr(self._model.config, "output_hidden_states", False) or getattr( + text_config, "output_hidden_states", False + ) use_cache = False wrapped_cache_params = None @@ -9483,19 +9854,23 @@ def patched_forward( position_ids=position_ids, past_key_values=wrapped_cache_params, use_cache=use_cache, + output_hidden_states=output_hidden_states, ) - hidden_states = outputs_lm[0] - logits = self._model.lm_head(hidden_states) + last_hidden_state = outputs_lm[0] + logits = self._model.lm_head(last_hidden_state) past_kv = outputs_lm.past_key_values + hidden_states = outputs_lm.hidden_states else: causal_lm_output = self.model_orig_forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=wrapped_cache_params, use_cache=use_cache, + output_hidden_states=output_hidden_states, ) logits = causal_lm_output.logits past_kv = causal_lm_output.past_key_values + hidden_states = causal_lm_output.hidden_states outputs = { "logits": logits, } @@ -9512,6 +9887,9 @@ def patched_forward( outputs["present_key_values"] = present_key_values + if hidden_states is not None: + outputs["hidden_states"] = hidden_states + return outputs self.patched_forward = patched_forward diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 69f7b83ee0..a66b0c80ff 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -347,6 +347,7 @@ def _export( model_loading_kwargs["torch_dtype"] = torch_dtype variant = kwargs.pop("variant", None) + dflash_target_model = kwargs.pop("dflash_target_model", None) main_export( model_name_or_path=model_id, @@ -364,6 +365,7 @@ def _export( model_loading_kwargs=model_loading_kwargs, library_name=cls._library_name, variant=variant, + dflash_target_model=dflash_target_model, ) if config.model_type == "phi3" and config.max_position_embeddings != getattr( @@ -553,7 +555,10 @@ def prepare_inputs( else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 - if past_key_values: + dflash_hidden_states = kwargs.get("hidden_states", None) + if past_key_values and dflash_hidden_states is not None and "hidden_states" in self.input_names: + position_ids = position_ids[:, -(input_ids.shape[1] + dflash_hidden_states.shape[1]) :] + elif past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids @@ -568,7 +573,12 @@ def prepare_inputs( if "hidden_states" in self.input_names: hidden_states = kwargs.get("hidden_states", None) if hidden_states is None: - hs_shape = (batch_size, input_ids.shape[1], self.config.hidden_size * 3) + dflash_config = getattr(self.config, "dflash_config", None) + if dflash_config is not None: + raise ValueError("DFlash draft models require `hidden_states` to be passed to the forward call.") + else: + hidden_size = self.config.hidden_size * 3 + hs_shape = (batch_size, input_ids.shape[1], hidden_size) hidden_states = torch.zeros(hs_shape, device=self.device, dtype=torch.float32) inputs["hidden_states"] = hidden_states diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index fbca6ace56..120d4d5d5b 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -13,13 +13,16 @@ # limitations under the License. +import json import unittest from pathlib import Path +import numpy as np +import openvino as ov import torch from parameterized import parameterized from sentence_transformers import SentenceTransformer, models -from transformers import AutoConfig, AutoTokenizer, GenerationConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig from utils_tests import ( MODEL_NAMES, OPENVINO_DEVICE, @@ -65,6 +68,24 @@ logger = logging.get_logger() +HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers" + + +def _find_output_by_tensor_name(model, tensor_name): + for op in model.get_ops(): + for output in op.outputs(): + if tensor_name in output.get_names(): + return output + raise AssertionError(f"Tensor {tensor_name} was not found in the OpenVINO graph") + + +def _add_model_output(model, output, output_name): + output.get_tensor().add_names({output_name}) + if hasattr(model, "add_output"): + model.add_output(output) + else: + model.add_outputs([output]) + class ExportModelTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = { @@ -263,6 +284,92 @@ def test_export(self, model_type: str): model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} self._openvino_export(model_type, model_kwargs=model_kwargs) + def test_export_hidden_state_annotations_without_extra_outputs(self): + for task in ("text-generation", "text-generation-with-past"): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) + with self.subTest(task=task), TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task=task, + preprocessors=None, + stateful=False, + ) + + ov_model = ov.Core().read_model(Path(tmpdirname) / "openvino_model.xml") + output_names = set().union(*(output.get_names() for output in ov_model.outputs)) + self.assertNotIn("last_hidden_state", output_names) + self.assertFalse(any(name.startswith("ov.hidden_states.") for name in output_names)) + + rt_info = ov_model.get_rt_info() + self.assertIn(HIDDEN_STATES_RT_INFO_KEY, rt_info) + annotation = json.loads(rt_info[HIDDEN_STATES_RT_INFO_KEY].value) + self.assertEqual(annotation["version"], 1) + self.assertEqual(len(annotation["layers"]), model.config.num_hidden_layers) + + graph_tensor_names = set() + for op in ov_model.get_ops(): + for output in op.outputs(): + graph_tensor_names.update(output.get_names()) + for tensor_name in annotation["layers"].values(): + self.assertIn(tensor_name, graph_tensor_names) + + def test_annotated_hidden_state_output_matches_pytorch(self): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"]) + model.eval() + + with TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task="text-generation", + preprocessors=None, + stateful=False, + ) + + core = ov.Core() + ov_model = core.read_model(Path(tmpdirname) / "openvino_model.xml") + annotation = json.loads(ov_model.get_rt_info()[HIDDEN_STATES_RT_INFO_KEY].value) + layer_idx = 0 + output_name = "decoder_layer_0_hidden_state" + hidden_state_output = _find_output_by_tensor_name(ov_model, annotation["layers"][str(layer_idx)]) + _add_model_output(ov_model, hidden_state_output, output_name) + + input_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + with torch.no_grad(): + torch_outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + return_dict=True, + ) + + compiled_model = core.compile_model(ov_model, "CPU") + ov_inputs = {} + for input_port in compiled_model.inputs: + input_name = input_port.get_any_name() + if input_name == "input_ids": + ov_inputs[input_name] = input_ids.numpy() + elif input_name == "attention_mask": + ov_inputs[input_name] = attention_mask.numpy() + elif input_name == "position_ids": + ov_inputs[input_name] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1) + elif input_name == "token_type_ids": + ov_inputs[input_name] = np.zeros(input_ids.shape, dtype=np.int64) + else: + self.fail(f"Unexpected OpenVINO model input: {input_name}") + + infer_result = compiled_model(ov_inputs) + ov_output_port = next(output for output in compiled_model.outputs if output_name in output.get_names()) + np.testing.assert_allclose( + infer_result[ov_output_port], + torch_outputs.hidden_states[layer_idx + 1].detach().numpy(), + rtol=1e-4, + atol=1e-4, + ) + @parameterized.expand(GENERATIVE_MODELS) def test_export_with_custom_gen_config(self, model_type): auto_model = self.SUPPORTED_ARCHITECTURES[model_type]