Skip to content

Commit 98dc4bb

Browse files
committed
Add mobius to capture onnx graph cli
1 parent af26cc0 commit 98dc4bb

2 files changed

Lines changed: 127 additions & 2 deletions

File tree

olive/cli/capture_onnx.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ def register_subcommand(parser: ArgumentParser):
113113
action="store_true",
114114
help="Whether to use Model Builder to capture ONNX model.",
115115
)
116+
mb_group.add_argument(
117+
"--use_mobius_builder",
118+
action="store_true",
119+
help=(
120+
"Whether to use MobiusBuilder (mobius-ai) to capture ONNX model. "
121+
"Supports multi-component multimodal models (VLMs). "
122+
"Requires 'pip install mobius-ai'. "
123+
"Mutually exclusive with --use_model_builder and --use_dynamo_exporter."
124+
),
125+
)
116126
mb_group.add_argument(
117127
"--precision",
118128
type=str,
@@ -197,8 +207,14 @@ def _get_run_config(self, tempdir: str) -> dict:
197207
is_diffusers_model = input_model_config["type"].lower() == "diffusersmodel"
198208

199209
# whether model is in fp16 or bf16 (currently not supported by CPU EP)
200-
is_fp16_or_bf16 = (not self.args.use_model_builder and self.args.torch_dtype == "float16") or (
201-
self.args.use_model_builder and self.args.precision in ("fp16", "bf16")
210+
is_fp16_or_bf16 = (
211+
(
212+
not self.args.use_model_builder
213+
and not self.args.use_mobius_builder
214+
and self.args.torch_dtype == "float16"
215+
)
216+
or (self.args.use_model_builder and self.args.precision in ("fp16", "bf16"))
217+
or (self.args.use_mobius_builder and self.args.precision in ("fp16", "bf16"))
202218
)
203219
to_replace = [
204220
("input_model", input_model_config),
@@ -213,6 +229,7 @@ def _get_run_config(self, tempdir: str) -> dict:
213229

214230
if is_diffusers_model:
215231
del config["passes"]["m"]
232+
del config["passes"]["b"]
216233
to_replace.extend(
217234
[
218235
(
@@ -223,8 +240,30 @@ def _get_run_config(self, tempdir: str) -> dict:
223240
(("passes", "c", "target_opset"), self.args.target_opset),
224241
]
225242
)
243+
elif self.args.use_mobius_builder:
244+
if self.args.use_model_builder or self.args.use_dynamo_exporter:
245+
raise ValueError(
246+
"--use_mobius_builder cannot be combined with --use_model_builder or --use_dynamo_exporter."
247+
)
248+
if self.args.precision not in ("fp32", "fp16", "bf16"):
249+
raise ValueError(
250+
f"MobiusBuilder supports precisions fp32/fp16/bf16; got '{self.args.precision}'. "
251+
"For INT4, capture in fp32/fp16/bf16 first and run a quantization pass afterwards."
252+
)
253+
del config["passes"]["c"]
254+
del config["passes"]["m"]
255+
to_replace.extend(
256+
[
257+
(("passes", "b", "precision"), self.args.precision),
258+
(
259+
("passes", "b", "runtime"),
260+
"ort-genai" if self.args.use_ort_genai else "none",
261+
),
262+
]
263+
)
226264
elif self.args.use_model_builder:
227265
del config["passes"]["c"]
266+
del config["passes"]["b"]
228267
to_replace.extend(
229268
[
230269
(("passes", "m", "precision"), self.args.precision),
@@ -245,6 +284,7 @@ def _get_run_config(self, tempdir: str) -> dict:
245284
if self.args.int4_accuracy_level is not None:
246285
to_replace.append((("passes", "m", "int4_accuracy_level"), self.args.int4_accuracy_level))
247286
else:
287+
del config["passes"]["b"]
248288
to_replace.extend(
249289
[
250290
(
@@ -300,6 +340,7 @@ def _get_run_config(self, tempdir: str) -> dict:
300340
"type": "OnnxConversion",
301341
},
302342
"m": {"type": "ModelBuilder", "metadata_only": False},
343+
"b": {"type": "MobiusBuilder"},
303344
"f": {"type": "DynamicToFixedShape"},
304345
},
305346
"host": "local_system",

test/cli/test_cli.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,90 @@ def test_capture_onnx_command_fix_shape(_, mock_run, use_model_builder, tmp_path
321321
assert mock_run.call_count == 1
322322

323323

324+
@patch("olive.workflows.run")
325+
@patch("huggingface_hub.repo_exists", return_value=True)
326+
@pytest.mark.parametrize(
327+
("precision", "use_ort_genai"),
328+
[
329+
("fp16", True),
330+
("fp32", False),
331+
("bf16", True),
332+
],
333+
)
334+
def test_capture_onnx_command_use_mobius_builder(_, mock_run, precision, use_ort_genai, tmp_path):
335+
# setup
336+
output_dir = tmp_path / "output_dir"
337+
model_id = "dummy-model-id"
338+
command_args = [
339+
"capture-onnx-graph",
340+
"-m",
341+
model_id,
342+
"-o",
343+
str(output_dir),
344+
"--use_mobius_builder",
345+
"--precision",
346+
precision,
347+
]
348+
if use_ort_genai:
349+
command_args.append("--use_ort_genai")
350+
351+
# execute
352+
cli_main(command_args)
353+
354+
config = mock_run.call_args[0][0]
355+
assert config["input_model"]["model_path"] == model_id
356+
# MobiusBuilder ("b") is the only conversion pass; "c" (OnnxConversion) and "m" (ModelBuilder) are removed.
357+
assert "b" in config["passes"]
358+
assert "c" not in config["passes"]
359+
assert "m" not in config["passes"]
360+
assert config["passes"]["b"]["type"] == "MobiusBuilder"
361+
assert config["passes"]["b"]["precision"] == precision
362+
assert config["passes"]["b"]["runtime"] == ("ort-genai" if use_ort_genai else "none")
363+
assert mock_run.call_count == 1
364+
365+
366+
@patch("olive.workflows.run")
367+
@patch("huggingface_hub.repo_exists", return_value=True)
368+
def test_capture_onnx_command_use_mobius_builder_rejects_int4(_, __, tmp_path):
369+
# setup
370+
output_dir = tmp_path / "output_dir"
371+
command_args = [
372+
"capture-onnx-graph",
373+
"-m",
374+
"dummy-model-id",
375+
"-o",
376+
str(output_dir),
377+
"--use_mobius_builder",
378+
"--precision",
379+
"int4",
380+
]
381+
382+
# execute / verify
383+
with pytest.raises(ValueError, match="MobiusBuilder supports precisions fp32/fp16/bf16"):
384+
cli_main(command_args)
385+
386+
387+
@patch("olive.workflows.run")
388+
@patch("huggingface_hub.repo_exists", return_value=True)
389+
@pytest.mark.parametrize("conflicting_flag", ["--use_model_builder", "--use_dynamo_exporter"])
390+
def test_capture_onnx_command_use_mobius_builder_rejects_conflicts(_, __, conflicting_flag, tmp_path):
391+
# setup
392+
output_dir = tmp_path / "output_dir"
393+
command_args = [
394+
"capture-onnx-graph",
395+
"-m",
396+
"dummy-model-id",
397+
"-o",
398+
str(output_dir),
399+
"--use_mobius_builder",
400+
conflicting_flag,
401+
]
402+
403+
# execute / verify
404+
with pytest.raises(ValueError, match="cannot be combined"):
405+
cli_main(command_args)
406+
407+
324408
@patch("olive.cli.shared_cache.AzureContainerClientFactory")
325409
def test_shared_cache_command(mock_AzureContainerClientFactory):
326410
# setup

0 commit comments

Comments
 (0)