Skip to content

Commit f9537b1

Browse files
authored
Add option to override genai_config.json property OV models (microsoft#2191)
## Describe your changes Adds capability in OV encapsulation pass to parse optional configs to use as override in the resulting genai_config.json ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent 7421f11 commit f9537b1

1 file changed

Lines changed: 55 additions & 1 deletion

File tree

olive/passes/openvino/encapsulation.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
# Copyright (c) Intel Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5+
import logging
6+
import numbers
57
import os
8+
from collections.abc import Mapping, MutableMapping
69
from pathlib import Path
7-
from typing import ClassVar, Union
10+
from typing import Any, ClassVar, Union
811

912
import onnx.helper as helper
1013
from onnx import TensorProto, save
@@ -15,6 +18,8 @@
1518
from olive.passes import Pass
1619
from olive.passes.pass_config import BasePassConfig, PassConfigParam
1720

21+
logger = logging.getLogger(__name__)
22+
1823

1924
class OpenVINOEncapsulation(Pass):
2025
"""Encapsulates OpenVINO models with onnx context nodes."""
@@ -97,6 +102,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
97102
required=False,
98103
description=("Reuse cache of previous passes to reduce storage footprint."),
99104
),
105+
"genai_config_override": PassConfigParam(
106+
type_=dict,
107+
default_value=None,
108+
required=False,
109+
description=("Configuration overrides for genai_config.json generation. "),
110+
),
100111
}
101112

102113
def _run_for_config(
@@ -252,6 +263,46 @@ def extract_shape_list(shape, config, prefix: str = "input_0_") -> list:
252263
return shape_list
253264

254265

266+
def _compatible_type(default_val: Any, new_val: Any) -> bool:
267+
"""Loose type check: allow ints for floats, bool as bool, etc."""
268+
if default_val is None:
269+
return True
270+
if isinstance(default_val, bool):
271+
return isinstance(new_val, bool)
272+
if isinstance(default_val, numbers.Real) and not isinstance(default_val, bool):
273+
return isinstance(new_val, numbers.Real) and not isinstance(new_val, bool)
274+
if isinstance(default_val, str):
275+
return isinstance(new_val, str)
276+
if isinstance(default_val, (list, tuple)):
277+
return isinstance(new_val, (list, tuple))
278+
if isinstance(default_val, Mapping):
279+
return isinstance(new_val, Mapping)
280+
return True # fall back to permissive
281+
282+
283+
def apply_genai_overrides(
284+
defaults: MutableMapping[str, Any], overrides: Mapping[str, Any], *, path: str = ""
285+
) -> MutableMapping[str, Any]:
286+
"""Recursively merge `overrides` into `defaults`."""
287+
for k, v in overrides.items():
288+
here = f"{path}.{k}" if path else k
289+
if k not in defaults:
290+
continue
291+
292+
dv = defaults[k]
293+
294+
# Recurse for dicts
295+
if isinstance(dv, Mapping) and isinstance(v, Mapping):
296+
apply_genai_overrides(dv, v, path=here)
297+
continue
298+
299+
# Replace lists/tuples and scalars
300+
if not _compatible_type(dv, v):
301+
logger.warning("Type mismatch at %s", here)
302+
defaults[k] = v
303+
return defaults
304+
305+
255306
def create_genai_config(model_name: str, output_path: str, config: type[BasePassConfig]) -> None:
256307
"""Generate the genai_config.json from the model config files.
257308
@@ -371,6 +422,9 @@ def create_genai_config(model_name: str, output_path: str, config: type[BasePass
371422

372423
genai_config["search"]["max_length"] = src_config.get("max_position_embeddings", -1)
373424

425+
if isinstance(config.genai_config_override, dict):
426+
apply_genai_overrides(genai_config, config.genai_config_override)
427+
374428
# Step 2: Write to JSON file
375429
output_genai_config = Path(output_path) / "genai_config.json"
376430
with open(output_genai_config, "w") as f:

0 commit comments

Comments
 (0)