Skip to content

Commit 5d2c2fe

Browse files
authored
Fix inc quantization bugs (microsoft#2174)
## Describe your changes - Unify the logic when bits input is either int or enum. - Use IntEnum by default since IntEnum is available since python 3.4 ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent fa04308 commit 5d2c2fe

5 files changed

Lines changed: 15 additions & 18 deletions

File tree

olive/cli/capture_onnx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import argparse
66
from argparse import ArgumentParser
77
from copy import deepcopy
8+
from enum import IntEnum
89

910
from olive.cli.base import (
1011
BaseOliveCLICommand,
@@ -15,10 +16,10 @@
1516
get_input_model_config,
1617
update_shared_cache_options,
1718
)
18-
from olive.common.utils import IntEnumBase, set_nested_dict_value
19+
from olive.common.utils import set_nested_dict_value
1920

2021

21-
class ModelBuilderAccuracyLevel(IntEnumBase):
22+
class ModelBuilderAccuracyLevel(IntEnum):
2223
fp32 = 1
2324
fp16 = 2
2425
bf16 = 3

olive/common/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,18 @@
2525

2626

2727
if sys.version_info >= (3, 11):
28-
from enum import IntEnum, StrEnum
28+
from enum import StrEnum
2929

3030
class StrEnumBase(StrEnum):
3131
pass
3232

33-
class IntEnumBase(IntEnum):
34-
pass
35-
3633
else:
3734
from enum import Enum
3835

3936
class StrEnumBase(str, Enum):
4037
def __str__(self) -> str:
4138
return self.value
4239

43-
class IntEnumBase(int, Enum):
44-
pass
45-
4640

4741
def run_subprocess(cmd, env=None, cwd=None, check=False):
4842
logger.debug("Running command: %s", cmd)

olive/constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5+
from enum import IntEnum
6+
57
from olive.common.config_utils import CaseInsensitiveEnum
6-
from olive.common.utils import IntEnumBase, StrEnumBase
8+
from olive.common.utils import StrEnumBase
79

810
MSFT_DOMAIN = "com.microsoft"
911

@@ -52,7 +54,7 @@ class Precision(StrEnumBase):
5254
BF16 = "bf16"
5355

5456

55-
class PrecisionBits(IntEnumBase):
57+
class PrecisionBits(IntEnum):
5658
BITS4 = 4
5759
BITS8 = 8
5860
BITS16 = 16
@@ -103,7 +105,7 @@ class OpType(StrEnumBase):
103105
Constant = "Constant"
104106

105107

106-
class AccuracyLevel(IntEnumBase):
108+
class AccuracyLevel(IntEnum):
107109
unset = 0
108110
fp32 = 1
109111
fp16 = 2

olive/passes/onnx/inc_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,10 @@ def _set_tuning_config(self, run_config):
440440
def _set_woq_config(self, run_config):
441441
# set weight only quantization config for INC API
442442
weight_only_config = run_config["weight_only_config"]
443-
bits = weight_only_config.get("bits", PrecisionBits.BITS4).value
443+
bits = int(weight_only_config.get("bits") or PrecisionBits.BITS4)
444444
group_size = weight_only_config.get("group_size", 32)
445445
scheme = weight_only_config.get("scheme", "asym")
446-
algo = (weight_only_config.get("algorithm") or QuantAlgorithm.RTN).value.upper()
446+
algo = (weight_only_config.get("algorithm") or QuantAlgorithm.RTN.value).upper()
447447
return {"bits": bits, "group_size": group_size, "scheme": scheme, "algorithm": algo}
448448

449449
def _run_for_config(

olive/passes/onnx/model_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import copy
88
import json
99
import logging
10+
from enum import IntEnum
1011
from pathlib import Path
1112
from typing import Any, ClassVar, Union
1213

1314
import onnx
1415
import transformers
1516
from packaging import version
1617

17-
from olive.common.utils import IntEnumBase
1818
from olive.constants import Precision
1919
from olive.hardware.accelerator import AcceleratorSpec, Device
2020
from olive.hardware.constants import ExecutionProvider
@@ -33,14 +33,14 @@ class ModelBuilder(Pass):
3333
See https://github.com/microsoft/onnxruntime-genai
3434
"""
3535

36-
class BlockSize(IntEnumBase):
36+
class BlockSize(IntEnum):
3737
B16 = 16
3838
B32 = 32
3939
B64 = 64
4040
B128 = 128
4141
B256 = 256
4242

43-
class AccuracyLevel(IntEnumBase):
43+
class AccuracyLevel(IntEnum):
4444
fp32 = 1
4545
fp16 = 2
4646
bf16 = 3
@@ -239,7 +239,7 @@ def _run_for_config(
239239

240240
extra_args.update(
241241
{
242-
key: value.value if isinstance(value, IntEnumBase) else value
242+
key: value.value if isinstance(value, IntEnum) else value
243243
for key, value in config.dict().items()
244244
if value is not None and key not in {"precision", "metadata_only", "search", "extra_options"}
245245
}

0 commit comments

Comments
 (0)