|
2 | 2 | # Copyright (c) Intel Corporation. All rights reserved. |
3 | 3 | # Licensed under the MIT License. |
4 | 4 | # -------------------------------------------------------------------------- |
| 5 | +import logging |
| 6 | +import numbers |
5 | 7 | import os |
| 8 | +from collections.abc import Mapping, MutableMapping |
6 | 9 | from pathlib import Path |
7 | | -from typing import ClassVar, Union |
| 10 | +from typing import Any, ClassVar, Union |
8 | 11 |
|
9 | 12 | import onnx.helper as helper |
10 | 13 | from onnx import TensorProto, save |
|
15 | 18 | from olive.passes import Pass |
16 | 19 | from olive.passes.pass_config import BasePassConfig, PassConfigParam |
17 | 20 |
|
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
18 | 23 |
|
19 | 24 | class OpenVINOEncapsulation(Pass): |
20 | 25 | """Encapsulates OpenVINO models with onnx context nodes.""" |
@@ -97,6 +102,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon |
97 | 102 | required=False, |
98 | 103 | description=("Reuse cache of previous passes to reduce storage footprint."), |
99 | 104 | ), |
| 105 | + "genai_config_override": PassConfigParam( |
| 106 | + type_=dict, |
| 107 | + default_value=None, |
| 108 | + required=False, |
| 109 | + description=("Configuration overrides for genai_config.json generation. "), |
| 110 | + ), |
100 | 111 | } |
101 | 112 |
|
102 | 113 | def _run_for_config( |
@@ -252,6 +263,46 @@ def extract_shape_list(shape, config, prefix: str = "input_0_") -> list: |
252 | 263 | return shape_list |
253 | 264 |
|
254 | 265 |
|
| 266 | +def _compatible_type(default_val: Any, new_val: Any) -> bool: |
| 267 | + """Loose type check: allow ints for floats, bool as bool, etc.""" |
| 268 | + if default_val is None: |
| 269 | + return True |
| 270 | + if isinstance(default_val, bool): |
| 271 | + return isinstance(new_val, bool) |
| 272 | + if isinstance(default_val, numbers.Real) and not isinstance(default_val, bool): |
| 273 | + return isinstance(new_val, numbers.Real) and not isinstance(new_val, bool) |
| 274 | + if isinstance(default_val, str): |
| 275 | + return isinstance(new_val, str) |
| 276 | + if isinstance(default_val, (list, tuple)): |
| 277 | + return isinstance(new_val, (list, tuple)) |
| 278 | + if isinstance(default_val, Mapping): |
| 279 | + return isinstance(new_val, Mapping) |
| 280 | + return True # fall back to permissive |
| 281 | + |
| 282 | + |
| 283 | +def apply_genai_overrides( |
| 284 | + defaults: MutableMapping[str, Any], overrides: Mapping[str, Any], *, path: str = "" |
| 285 | +) -> MutableMapping[str, Any]: |
| 286 | + """Recursively merge `overrides` into `defaults`.""" |
| 287 | + for k, v in overrides.items(): |
| 288 | + here = f"{path}.{k}" if path else k |
| 289 | + if k not in defaults: |
| 290 | + continue |
| 291 | + |
| 292 | + dv = defaults[k] |
| 293 | + |
| 294 | + # Recurse for dicts |
| 295 | + if isinstance(dv, Mapping) and isinstance(v, Mapping): |
| 296 | + apply_genai_overrides(dv, v, path=here) |
| 297 | + continue |
| 298 | + |
| 299 | + # Replace lists/tuples and scalars |
| 300 | + if not _compatible_type(dv, v): |
| 301 | + logger.warning("Type mismatch at %s", here) |
| 302 | + defaults[k] = v |
| 303 | + return defaults |
| 304 | + |
| 305 | + |
255 | 306 | def create_genai_config(model_name: str, output_path: str, config: type[BasePassConfig]) -> None: |
256 | 307 | """Generate the genai_config.json from the model config files. |
257 | 308 |
|
@@ -371,6 +422,9 @@ def create_genai_config(model_name: str, output_path: str, config: type[BasePass |
371 | 422 |
|
372 | 423 | genai_config["search"]["max_length"] = src_config.get("max_position_embeddings", -1) |
373 | 424 |
|
| 425 | + if isinstance(config.genai_config_override, dict): |
| 426 | + apply_genai_overrides(genai_config, config.genai_config_override) |
| 427 | + |
374 | 428 | # Step 2: Write to JSON file |
375 | 429 | output_genai_config = Path(output_path) / "genai_config.json" |
376 | 430 | with open(output_genai_config, "w") as f: |
|
0 commit comments