Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,15 @@ def parse_args_openvino(parser: "ArgumentParser"):
type=json.loads,
help=("Any kwargs passed to the model forward, or used to customize the export for a given model."),
)
optional_group.add_argument(
"--dflash-target-model",
type=str,
default=None,
help=(
"Target model ID or local path used when exporting DFlash draft models. Only the target token "
"embedding and lm_head weights are loaded."
),
)


def no_compression_parameter_provided(args):
Expand Down Expand Up @@ -479,6 +488,7 @@ def run(self):
library_name=library_name,
variant=self.args.variant,
model_kwargs=self.args.model_kwargs,
dflash_target_model=self.args.dflash_target_model,
# **input_shapes,
)
if apply_main_quantize:
Expand Down
39 changes: 39 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,35 @@ def update_config_for_eagle3(config):
return config


def update_config_for_dflash(
config,
dflash_target_model: Optional[str],
cache_dir: str = HUGGINGFACE_HUB_CACHE,
revision: str = "main",
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
):
if dflash_target_model is None:
raise ValueError("Exporting DFlash draft models requires `dflash_target_model` / `--dflash-target-model`.")

moduler_name = "optimum.exporters.openvino.model_patcher"
spec = importlib.util.find_spec(moduler_name)
if spec and spec.origin:
moduler_path = os.path.dirname(spec.origin)
config.auto_map = {
"AutoModel": moduler_path + "--model_patcher.Qwen3DFlashDraftModel",
"AutoModelForCausalLM": moduler_path + "--model_patcher.Qwen3DFlashForCausalLM",
}

config.dflash_target_model = dflash_target_model
config.dflash_target_cache_dir = cache_dir
config.dflash_target_revision = revision
config.dflash_target_token = token
config.dflash_target_local_files_only = local_files_only
config.tie_word_embeddings = False
return config


def infer_library_name(
model_name_or_path: str,
subfolder: str = "",
Expand Down Expand Up @@ -235,6 +264,7 @@ def main_export(
library_name: Optional[str] = None,
model_loading_kwargs: Optional[Dict[str, Any]] = None,
variant: Optional[str] = None,
dflash_target_model: Optional[str] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -358,6 +388,15 @@ def main_export(
_eagle3_archs = {"LlamaForCausalLMEagle3", "Eagle3LlamaForCausalLM"}
if isinstance(archs, list) and len(archs) > 0 and archs[0] in _eagle3_archs:
loading_kwargs["config"] = update_config_for_eagle3(config)
elif isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel":
loading_kwargs["config"] = update_config_for_dflash(
config,
dflash_target_model=dflash_target_model,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
)

# mxfp4 quantized model will be dequantized to bf16
if quant_method == "mxfp4" and is_transformers_version(">=", "4.55"):
Expand Down
128 changes: 127 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import gc
import inspect
import json
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -83,6 +84,9 @@

logger = logging.getLogger(__name__)

HIDDEN_STATES_RT_INFO_KEY = "hidden_states_decoder_layers"
HIDDEN_STATE_TENSOR_NAME_TEMPLATE = "ov.hidden_states.decoder_layer_{}"

if is_torch_available():
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -138,12 +142,94 @@ def _save_model(

if getattr(config, "eagle3", False):
model = _add_eagle3_mode_to_rt_info(model)
if getattr(config, "dflash", False):
model = _add_dflash_mode_to_rt_info(model, config._config)

save_model(model, path, compress_to_fp16)
del model
gc.collect()


def _can_annotate_hidden_states(model, config: "OnnxConfig") -> bool:
if "text-generation" not in getattr(config, "task", ""):
return False

model_config = getattr(model, "config", None)
if model_config is None or not hasattr(model_config, "num_hidden_layers"):
return False
if not hasattr(model_config, "output_hidden_states"):
return False

can_record_outputs = getattr(model, "_can_record_outputs", None)
if isinstance(can_record_outputs, dict) and not can_record_outputs.get("hidden_states", False):
return False

return True


def _annotate_hidden_state_outputs(model: Model, public_output_names: List[str], num_hidden_layers: int):
public_output_names = set(public_output_names)
temporary_results = []
for result in model.get_results():
result_names = set(result.output(0).get_names())
if result_names & public_output_names:
continue
temporary_results.append(result)

if not temporary_results:
return

# HF hidden_states usually contains embeddings first, then one tensor after each decoder block.
offset = 1 if len(temporary_results) >= num_hidden_layers + 1 else 0
hidden_layer_results = temporary_results[offset : offset + num_hidden_layers]
if len(hidden_layer_results) == num_hidden_layers:
layers = {}
for layer_idx, result in enumerate(hidden_layer_results):
tensor_name = HIDDEN_STATE_TENSOR_NAME_TEMPLATE.format(layer_idx)
result.input_value(0).get_tensor().add_names({tensor_name})
layers[str(layer_idx)] = tensor_name

model.set_rt_info(
json.dumps({"version": 1, "layers": layers}),
HIDDEN_STATES_RT_INFO_KEY,
)
else:
logger.debug(
"Skipping hidden-state annotation: expected at least %s hidden-state outputs, got %s.",
num_hidden_layers,
len(temporary_results),
)

for result in temporary_results:
model.remove_result(result)
model.validate_nodes_and_infer_types()


def _enable_hidden_states_in_config_outputs(config: "OnnxConfig", num_hidden_layers: int):
original_class = config.__class__

class ConfigWithHiddenStateOutputs(original_class):
@property
def outputs(self):
outputs = original_class.outputs.fget(self).copy()
for idx in range(num_hidden_layers + 1):
outputs[f"hidden_states.{idx}"] = {0: "batch_size", 1: "sequence_length"}
return outputs

config.__class__ = ConfigWithHiddenStateOutputs
return original_class


def _flatten_hidden_states_outputs(outputs: Dict):
hidden_states = outputs.pop("hidden_states", None)
if hidden_states is None:
return outputs

for idx, hidden_state in enumerate(hidden_states):
outputs[f"hidden_states.{idx}"] = hidden_state
return outputs


def export(
model: Union["PreTrainedModel", "ModelMixin", "DiffusionPipeline"],
config: "OnnxConfig",
Expand Down Expand Up @@ -382,6 +468,17 @@ def export_pytorch(

dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)
output_names = list(config.outputs.keys())
annotate_hidden_states = _can_annotate_hidden_states(model, config)
original_output_hidden_states = None
original_config_class = None
if annotate_hidden_states:
original_output_hidden_states = model.config.output_hidden_states
model.config.output_hidden_states = True
if hasattr(config, "_config") and hasattr(config._config, "output_hidden_states"):
config._config.output_hidden_states = True
original_config_class = _enable_hidden_states_in_config_outputs(config, model.config.num_hidden_layers)

# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
Expand All @@ -406,6 +503,8 @@ def ts_patched_forward(*args, **kwargs):
input_dict = dict(zip(keys, tuple_input))
kwargs[input_name] = input_dict
outputs = patched_forward(**kwargs)
if annotate_hidden_states:
outputs = _flatten_hidden_states_outputs(outputs)
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])

patcher.patched_forward = ts_patched_forward
Expand Down Expand Up @@ -449,9 +548,14 @@ def ts_patched_forward(*args, **kwargs):
extension=conversion_extensions,
)

if annotate_hidden_states:
config.__class__ = original_config_class
model.config.output_hidden_states = original_output_hidden_states
if hasattr(config, "_config") and hasattr(config._config, "output_hidden_states"):
config._config.output_hidden_states = original_output_hidden_states

ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?

output_names = list(config.outputs.keys())
for idx, out_tensor in enumerate(ov_model.outputs):
if idx < len(output_names):
out_tensor.get_tensor().set_names({output_names[idx]})
Expand All @@ -464,6 +568,9 @@ def ts_patched_forward(*args, **kwargs):
if stateful:
patch_stateful(model.config, ov_model)

if annotate_hidden_states:
_annotate_hidden_state_outputs(ov_model, output_names, model.config.num_hidden_layers)

library_name = _infer_library_from_model_or_model_class(model=model, library_name=library_name)

_save_model(
Expand Down Expand Up @@ -979,6 +1086,25 @@ def _add_eagle3_mode_to_rt_info(model: Model):
return model


def _add_dflash_mode_to_rt_info(model: Model, hf_config):
"""
Add DFlash metadata.
"""
try:
model.set_rt_info("True", ["dflash_mode"])
model.set_rt_info(str(getattr(hf_config, "block_size", "")), ["dflash", "block_size"])
model.set_rt_info("committed_prefix", ["dflash", "cache_policy"])
dflash_config = getattr(hf_config, "dflash_config", {})
if "mask_token_id" in dflash_config:
model.set_rt_info(str(dflash_config["mask_token_id"]), ["dflash", "mask_token_id"])
if "target_layer_ids" in dflash_config:
model.set_rt_info(",".join(map(str, dflash_config["target_layer_ids"])), ["dflash", "target_layer_ids"])
except Exception:
pass

return model


def _add_version_info_to_model(model: Model, library_name: Optional[str] = None):
"""
Add dependency versions to OpenVINO model
Expand Down
97 changes: 97 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,55 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
_MODEL_PATCHER = Qwen2MoEPatcher


class DFlashDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("input_ids", "hidden_states", "position_ids")

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.batch_size = batch_size
self.context_length = sequence_length
self.hidden_size = normalized_config.hidden_size
config = normalized_config.config
self.block_size = getattr(config, "block_size", sequence_length)
dflash_config = getattr(config, "dflash_config", {})
self.num_target_layers = len(dflash_config.get("target_layer_ids", []))
self.vocab_size = getattr(config, "vocab_size", normalized_config.vocab_size)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "input_ids":
return self.random_int_tensor(
[self.batch_size, self.block_size],
max_value=self.vocab_size,
framework=framework,
dtype=int_dtype,
)
if input_name == "hidden_states":
return self.random_float_tensor(
[self.batch_size, self.context_length, self.hidden_size * self.num_target_layers],
framework=framework,
dtype=float_dtype,
)
if input_name == "position_ids":
position_length = self.context_length + self.block_size
if framework == "pt":
return torch.arange(position_length, dtype=DTYPE_MAPPER.pt(int_dtype)).unsqueeze(0).repeat(
self.batch_size, 1
)
return self.random_int_tensor(
[self.batch_size, position_length],
max_value=position_length,
framework=framework,
dtype=int_dtype,
)
raise ValueError(f"Unsupported DFlash input name: {input_name}")


@register_in_tasks_manager(
"qwen3",
*[
Expand All @@ -439,8 +488,40 @@ class Qwen3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
_MODEL_PATCHER = OVDecoderModelPatcher

def __init__(
self,
config: PretrainedConfig,
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: list[Any] | None = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
)
archs = getattr(config, "architectures", None)
self.dflash = isinstance(archs, list) and len(archs) > 0 and archs[0] == "DFlashDraftModel"
if self.dflash:
self.DUMMY_INPUT_GENERATOR_CLASSES = (DFlashDummyInputGenerator, GemmaDummyPastKeyValuesGenerator)
self.MIN_TRANSFORMERS_VERSION = "4.57.0"

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.dflash:
common_inputs = super().inputs
common_inputs.pop("attention_mask", None)
common_inputs["input_ids"] = {0: "batch_size", 1: "block_size"}
common_inputs["hidden_states"] = {0: "batch_size", 1: "context_length", 2: "target_hidden_size"}
common_inputs["position_ids"] = {0: "batch_size", 1: "position_sequence_length"}
return common_inputs
if self.task in ["feature-extraction"]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
Expand All @@ -450,6 +531,22 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.dflash:
common_outputs = super().outputs
common_outputs["logits"] = {0: "batch_size", 1: "draft_sequence_length"}
return common_outputs
return super().outputs

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
super().add_past_key_values(inputs_or_outputs, direction)
if self.dflash and direction == "outputs":
for axes in inputs_or_outputs.values():
for axis, name in axes.items():
if name == "past_sequence_length + sequence_length":
axes[axis] = "past_sequence_length + context_length"


class DummyQwen3VLLMInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
Expand Down
Loading