Skip to content

Commit 7421f11

Browse files
Enable AIMET in olive quantize CLI (microsoft#2187)
## Describe your changes Adds support for AIMET algorithms to quantize CLI ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link Signed-off-by: Michael Tuttle <mtuttle@qti.qualcomm.com>
1 parent e3ff856 commit 7421f11

4 files changed

Lines changed: 24 additions & 6 deletions

File tree

olive/cli/quantize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ImplName(StrEnumBase):
3636
QUAROT = "quarot"
3737
AWQ = "awq"
3838
AUTOGPTQ = "autogptq"
39+
AIMET = "aimet"
3940

4041

4142
class QuantizeCommand(BaseOliveCLICommand):
@@ -165,6 +166,12 @@ def _get_passes_dict(self, pass_list):
165166
"bits": precision_bits_from_precision(self.args.precision),
166167
},
167168
"MatMulNBitsToQDQ": {},
169+
"AimetQuantization": {
170+
"precision": self.args.precision,
171+
"activation_type": self.args.act_precision,
172+
"data_config": "default_data_config",
173+
"techniques": [{"name": self.args.algorithm}],
174+
},
168175
}
169176

170177
passes_dict = {}
@@ -251,4 +258,5 @@ def run(self):
251258
{"impl_name": ImplName.OLIVE, "pass_type": "OnnxHqqQuantization"},
252259
{"impl_name": ImplName.OLIVE, "pass_type": "OnnxBlockWiseRtnQuantization"},
253260
{"impl_name": ImplName.INC, "pass_type": "IncStaticQuantization"},
261+
{"impl_name": ImplName.AIMET, "pass_type": "AimetQuantization"},
254262
]

olive/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class QuantAlgorithm(CaseInsensitiveEnum):
6666
RTN = "rtn"
6767
SPINQUANT = "spinquant"
6868
QUAROT = "quarot"
69+
LPBQ = "lpbq"
70+
SEQMSE = "seqmse"
71+
ADAROUND = "adaround"
6972

7073

7174
class QuantEncoding(StrEnumBase):

olive/olive_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
"supported_providers": [ "*" ],
1414
"supported_accelerators": [ "*" ],
1515
"supported_precisions": [ "int4", "int8", "int16" ],
16-
"supported_algorithms": [ ],
17-
"supported_quantization_encodings": [ ],
16+
"supported_algorithms": [ "lpbq", "seqmse", "adaround" ],
17+
"supported_quantization_encodings": [ "qdq" ],
1818
"extra_dependencies": [ "aimet-onnx" ],
1919
"dataset": "dataset_required"
2020
},

test/cli/test_cli.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,35 +324,42 @@ def test_shared_cache_delete_all_with_confirmation(mock_AzureContainerClientFact
324324
mock_factory_instance.delete_all.assert_called_once()
325325

326326

327-
@pytest.mark.parametrize("algorithm_name", ["awq", "gptq"])
327+
@pytest.mark.parametrize("algorithm_name", ["awq", "gptq", "lpbq", "seqmse", "adaround"])
328328
@patch("olive.workflows.run")
329329
@patch("huggingface_hub.repo_exists")
330330
def test_quantize_command(mock_repo_exists, mock_run, algorithm_name, tmp_path):
331+
from test.utils import ONNX_MODEL_PATH
332+
331333
# setup
332334
output_dir = tmp_path / "output_dir"
333335

334336
# setup
335337
command_args = [
336338
"quantize",
337-
"-m",
338-
"dummy_model",
339339
"--algorithm",
340340
algorithm_name,
341341
"-o",
342342
str(output_dir),
343343
]
344344

345+
model_name = "dummy_model"
345346
if algorithm_name == "gptq":
346347
command_args += ["-d", "dummy_dataset"]
347348
command_args += ["--implementation", "autogptq"]
348349
if algorithm_name == "awq":
349350
command_args += ["--implementation", "awq"]
351+
if algorithm_name in {"lpbq", "seqmse", "adaround"}:
352+
model_name = str(ONNX_MODEL_PATH)
353+
command_args += ["-d", "dummy_dataset"]
354+
command_args += ["--implementation", "aimet"]
355+
356+
command_args += ["-m", model_name]
350357

351358
# execute
352359
cli_main(command_args)
353360

354361
config = mock_run.call_args[0][0]
355-
assert config["input_model"]["model_path"] == "dummy_model"
362+
assert config["input_model"]["model_path"] == model_name
356363
assert mock_run.call_count == 1
357364

358365

0 commit comments

Comments
 (0)