Skip to content

Commit 8caee1c

Browse files
jambaykCopilot
andauthored
Fix optimize CLI to set system EP and device (microsoft#2418)
## Describe your changes Fix the `olive optimize` CLI to properly configure the system execution provider and device in the generated workflow config. - `_update_system_config` now creates the `local_system` with the specified execution provider and optional device. Previously only the QNN AOT case was handled, leaving the system config empty. - When model builder is used as the exporter, the `OnnxFloatToFloat16` pass is skipped since model builder already produces the model in fp16. - `test_optimize_cli_pass_list` now verifies the system EP and device are correctly set in the generated config for all test cases. ## Checklist before requesting a review - [x] Add unit tests for this change. - [x] Make sure all tests can pass. - [ ] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [x] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - `olive optimize` now correctly sets the target system execution provider and device. ## (Optional) Issue link --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 40d860e commit 8caee1c

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

olive/cli/optimize.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,19 @@ def _update_system_config(self, config: dict[str, Any]):
294294
"""Update system configuration based on provider and device."""
295295
provider = ExecutionProvider(self.args.provider)
296296

297+
accelerator = {"execution_providers": [provider.value]}
298+
if self.args.device:
299+
accelerator["device"] = self.args.device
300+
if self.args.memory is not None:
301+
accelerator["memory"] = self.args.memory
302+
303+
config["systems"]["local_system"] = {
304+
"type": "LocalSystem",
305+
"accelerators": [accelerator],
306+
}
307+
308+
config["target"] = "local_system"
309+
297310
if provider == ExecutionProvider.QNNExecutionProvider and self.args.enable_aot:
298311
config["systems"]["qnn_system"] = {
299312
"type": "PythonEnvironment",
@@ -622,7 +635,7 @@ def _get_onnx_blockwise_rtn_quantization_pass_config(self) -> dict[str, Any]:
622635
def _enable_onnx_float_to_float16_pass(self) -> bool:
623636
"""Return true if condition to add OnnxFloatToFloat16 pass is met."""
624637
precision = Precision(self.args.precision)
625-
return precision == Precision.FP16
638+
return precision == Precision.FP16 and not self.enable_model_builder
626639

627640
def _get_onnx_float_to_float16_pass_config(self) -> dict[str, Any]:
628641
"""Return pass dictionary for OnnxFloatToFloat16 pass."""

test/cli/test_cli.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,8 @@ def test_optimize_cli_pass_list(mock_repo_exists, mock_run, tmp_path):
493493
# setup
494494
output_dir = "output_dir"
495495

496+
# Each entry: [command, args, expected_passes, expected_device, expected_ep]
497+
# expected_device is None when --device is not specified (olive infers it at runtime)
496498
test_list = [
497499
[
498500
"optimize",
@@ -504,6 +506,8 @@ def test_optimize_cli_pass_list(mock_repo_exists, mock_run, tmp_path):
504506
"QuaRot, Gptq, CaptureSplitInfo, ModelBuilder, MatMulNBitsToQDQ, GraphSurgeries, "
505507
"OnnxStaticQuantization, SplitModel, StaticLLM"
506508
),
509+
None,
510+
"QNNExecutionProvider",
507511
],
508512
[
509513
"optimize",
@@ -515,16 +519,22 @@ def test_optimize_cli_pass_list(mock_repo_exists, mock_run, tmp_path):
515519
"QuaRot, Gptq, CaptureSplitInfo, ModelBuilder, MatMulNBitsToQDQ, GraphSurgeries, "
516520
"OnnxStaticQuantization, VitisAIAddMetaData, SplitModel, StaticLLM"
517521
),
522+
None,
523+
"VitisAIExecutionProvider",
518524
],
519525
[
520526
"optimize",
521527
"--precision int4 --act_precision int16 --provider OpenVINOExecutionProvider --device gpu",
522528
"OpenVINOOptimumConversion, OpenVINOIoUpdate, OpenVINOEncapsulation",
529+
"gpu",
530+
"OpenVINOExecutionProvider",
523531
],
524532
[
525533
"optimize",
526534
"-t text-classification --precision int8 --exporter torchscript_exporter",
527535
"OnnxConversion, OnnxPeepholeOptimizer, OrtTransformersOptimization, OnnxStaticQuantization",
536+
None,
537+
"CPUExecutionProvider",
528538
],
529539
[
530540
"optimize",
@@ -536,11 +546,22 @@ def test_optimize_cli_pass_list(mock_repo_exists, mock_run, tmp_path):
536546
"OnnxConversion, DynamicToFixedShape, OnnxPeepholeOptimizer, OrtTransformersOptimization, "
537547
"OnnxStaticQuantization, StaticLLM"
538548
),
549+
"npu",
550+
"QNNExecutionProvider",
539551
],
540552
[
541553
"optimize",
542554
"-t text-classification --precision fp16 --exporter torchscript_exporter --provider CUDAExecutionProvider",
543555
"OnnxConversion, OnnxPeepholeOptimizer, OrtTransformersOptimization, OnnxFloatToFloat16",
556+
None,
557+
"CUDAExecutionProvider",
558+
],
559+
[
560+
"optimize",
561+
"--precision fp16 --provider CUDAExecutionProvider",
562+
"ModelBuilder",
563+
None,
564+
"CUDAExecutionProvider",
544565
],
545566
[
546567
"optimize",
@@ -549,6 +570,8 @@ def test_optimize_cli_pass_list(mock_repo_exists, mock_run, tmp_path):
549570
" NvTensorRTRTXExecutionProvider --device gpu"
550571
),
551572
"OnnxConversion, OnnxPeepholeOptimizer, OnnxFloatToFloat16",
573+
"gpu",
574+
"NvTensorRTRTXExecutionProvider",
552575
],
553576
]
554577

@@ -577,6 +600,20 @@ def test_optimize_cli_pass_list(mock_repo_exists, mock_run, tmp_path):
577600

578601
assert pass_list == [item.strip() for item in t[2].split(",")]
579602

603+
# Verify system config has correct device and execution provider
604+
accelerator = data["systems"]["local_system"]["accelerators"][0]
605+
expected_device = t[3]
606+
expected_ep = t[4]
607+
if expected_device is None:
608+
assert "device" not in accelerator, f"Expected no device but got '{accelerator.get('device')}'"
609+
else:
610+
assert accelerator["device"] == expected_device, (
611+
f"Expected device '{expected_device}' but got '{accelerator.get('device')}'"
612+
)
613+
assert accelerator["execution_providers"] == [expected_ep], (
614+
f"Expected EP '{expected_ep}' but got '{accelerator['execution_providers']}'"
615+
)
616+
580617

581618
@patch("olive.workflows.run")
582619
@patch("huggingface_hub.repo_exists", return_value=True)

0 commit comments

Comments
 (0)