Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions olive/passes/qairt/encapsulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
MAX_GENIE_CONTEXT_LENGTH = 4096


def _deep_merge(base: dict, overrides: dict) -> dict:
"""Recursively merge *overrides* into *base*, returning a new dict.

Nested dicts are merged rather than replaced, so only the keys present in
*overrides* are changed; all other keys from *base* are preserved.
"""
result = dict(base)
for k, v in overrides.items():
if k in result and isinstance(result[k], dict) and isinstance(v, dict):
result[k] = _deep_merge(result[k], v)
else:
result[k] = v
return result


class QairtEncapsulation(Pass):
"""Encapsulates a QAIRT DLC model with an onnx protobuf."""

Expand All @@ -49,6 +64,34 @@
required=False,
description="Opset name and version to be added in the generated context model",
),
"genie_overrides": PassConfigParam(
type_=dict,
default_value=None,
required=False,
description=(
"Deep-merged into the GenAIConfig before the Genie DLC is produced. "
"Use Python field names (underscores). Nested dicts are merged recursively — "
"only the specified keys are overridden; all other GenAIBuilder defaults are "
"preserved. Any field on GenAIConfig is valid: kv_dim, rope_theta, n_heads, "
"n_layer, n_embd, allow_async_init, enable_graph_switching, "
"positional_encoding (nested dict), etc. Note: top-level rope_theta and "
"rope_scaling are not forwarded by the Genie factory — use "
"positional_encoding.rope_theta to override RoPE theta in the DLC."
),
),
"backend_extensions_override": PassConfigParam(
type_=dict,
default_value=None,
required=False,
description=(
"Deep-merged into the backend extensions config before the Genie DLC is "
"produced. Use the raw JSON key names (hyphens) as they appear in "
"backend_extensions.json. Nested dicts are merged recursively — only the "
"specified keys are overridden; all other backend extension defaults set "
"by the builder are preserved. If the container has no existing backend "
"extensions config, the override is used as the entire config."
),
),
}

def _run_for_config(
Expand All @@ -64,10 +107,10 @@
raise ImportError(
"Failed to import QAIRT GenAIBuilder API - please install olive-ai[qairt] to use QAIRT passes."
"If already installed, please run `qairt-vm -i` for help troubleshooting issues."
) from exc

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.

from qairt import __sdk_version__ as sdk_version

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
if Version(sdk_version) < Version("2.45.0"):
raise OSError("QairtGenAIBuilder pass is unsupported for QAIRT versions < 2.45.0")

Expand All @@ -76,6 +119,19 @@

container: qairt_genai.LLMContainer = qairt_genai.LLMContainer.load(model.model_path)

if config.genie_overrides:
gen_ai_cfg = container._gen_ai_config
current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True)
merged = _deep_merge(current, config.genie_overrides)
container._gen_ai_config = gen_ai_cfg.model_validate(merged)
logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys()))

if config.backend_extensions_override:
container._backend_extensions_config = _deep_merge(

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _backend_extensions_config of a client class (protected-access)
See protected-access.
container._backend_extensions_config or {}, config.backend_extensions_override

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _backend_extensions_config of a client class (protected-access)
See protected-access.
)
logger.info("Applied backend_extensions_override: %s", list(config.backend_extensions_override.keys()))

# Input/Output metadata
container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])]
container.outputs = [("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab_size"])]
Expand Down
Loading
Loading