Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3778355
feat: Added fms_mo/quant_refactor files
BrandonGroth Dec 3, 2024
dde71a2
feat: Added quant_refactor tests
BrandonGroth Dec 3, 2024
0a4f88a
fix: Added log file to ignore list
BrandonGroth Dec 3, 2024
beee057
fix: Reverted diff == 0.0 to index lookup for per_channel/group
BrandonGroth Jan 2, 2025
0a55afc
fix: Changed plural scale/zp to singular to match per_tensor
BrandonGroth Jan 2, 2025
c66b5a2
fix: Added perTok qunit and seperated Nch, Ngrp to be own variables
BrandonGroth Jan 3, 2025
c0aa6ed
fix: Changed sqQscheme to Qscheme
BrandonGroth Jan 6, 2025
2466ba4
chore: Change return hints to List[] from typing
BrandonGroth Jan 10, 2025
722e413
fix: Change List to Tuple and move default func values outside init
BrandonGroth Jan 14, 2025
dda7b3e
fix: Removed Base from Quantizer, PerTensorSTE and renamed base files
BrandonGroth Feb 13, 2025
a1c2851
fix: Fixed unknown option value for clip_vals, qscheme
BrandonGroth Feb 13, 2025
6b7b738
fix: Added _ste to per_tensor, per_channel
BrandonGroth Feb 19, 2025
689c887
fix: Fixed perCh useage in Quantizer
BrandonGroth Feb 19, 2025
9695bfe
fix: Fixed qmaxnew fixture qscheme being commented out
BrandonGroth Feb 19, 2025
ff5c375
fix: Fixed logger.error string formatting in quantizer_error
BrandonGroth Feb 19, 2025
dd1d8f7
fix: Minor changes to quantizer, qmax, sawb
BrandonGroth Feb 19, 2025
3bd9401
fix: Fixed perCh path in sawb_params_code
BrandonGroth Feb 27, 2025
fe8cd9d
feat: Added per_channel_axis
BrandonGroth Feb 27, 2025
4bc34f3
feat: Added NperGrp, axis to Qscheme
BrandonGroth Feb 27, 2025
e0d1a0a
feat: Added per_channel_axis support to per_channel_ste
BrandonGroth Feb 27, 2025
603f060
fix: Commented SAWBplusZeroperCh_new forward
BrandonGroth Feb 27, 2025
b17a25e
feat: Added perCh fixtures for sawb to conftest
BrandonGroth Mar 6, 2025
b77d0bf
fix: quantizer_error perCh changes
BrandonGroth Mar 6, 2025
3e8dbce
fix: torch_quantizer perCh changes
BrandonGroth Mar 6, 2025
bfadd63
fix: test_sawb perCh update
BrandonGroth Mar 6, 2025
33a899c
fix: num_bit_int fix for SAWBPlusZeroPerChSTE
BrandonGroth Mar 6, 2025
f02a5d6
fix: Fixed sawb_utils perCh tensor creation
BrandonGroth Mar 6, 2025
277228b
fix: Added sawb_utils neg clip_val to abs.max
BrandonGroth Mar 14, 2025
d141bcd
fix: Updates to Qscheme for perCh
BrandonGroth Mar 6, 2025
d66340c
fix: Fixed qint_bounds zero_point == 0 for perCh
BrandonGroth Mar 13, 2025
5c9c49d
fix: Removed sawb_new 16bin quantizers
BrandonGroth Mar 14, 2025
87f6907
feat: Added axis to qscheme string
BrandonGroth Mar 18, 2025
86a82ee
fix: Fixed quantizer_new perCh and updated negative clip_vals from sa…
BrandonGroth Mar 13, 2025
49ae7d9
fix: Updates for test_sawb perCh
BrandonGroth Mar 13, 2025
a852ca8
feat: Added SAWB perCh PTnative and removed 16 bins
BrandonGroth Mar 18, 2025
389a427
feat: Enabled PTnative SAWB perCh test
BrandonGroth Mar 18, 2025
99fe7c7
fix: Minor changes to SAWB perCh
BrandonGroth Mar 19, 2025
7c87e01
chore: formatted quant_refactor
BrandonGroth May 30, 2025
5d5cbe2
fix: Changed default clips to be 1-d tensors
BrandonGroth Jun 4, 2025
55c10db
fix: added zero_point in [0,nlvl] assert in asymmetric_linear_quanitz…
BrandonGroth Jul 23, 2025
9dbd861
test: Added Qmax perCh testing
BrandonGroth Jul 23, 2025
d2bb34c
feat: Added Base PerChannelSTEQmax and PTnative classes
BrandonGroth Jul 23, 2025
63396cf
feat: Implemented Base PerChannelSTEQmax classes
BrandonGroth Jul 23, 2025
454391f
feat: Implemented dequant for Qmax original perCh STEs
BrandonGroth Jul 23, 2025
fbf94e0
fix: Added recompute_clips var to Qmax for eval mode
BrandonGroth Jul 24, 2025
0173209
fix: QminmaxPerCh_PTnative patch for unstable rounding in torch.quant…
BrandonGroth Jul 29, 2025
aebedb9
fix: Replaced FloatTensor,IntTensor w/ Tensor
BrandonGroth Jul 29, 2025
759dab5
fix: Renamed refactor _new files/funcs to _rc
BrandonGroth Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ qcfg.json
configs
pytest.out

# Log file
fms_mo.log

# IDEs
.vscode/
.idea/
Expand Down
17 changes: 12 additions & 5 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,12 @@ def forward(

if istraining:
# only recalc clipvals under training mode
num_bits_int = (
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
)
SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
if num_bits in [2, 4, 8]:
sawb_code = SAWBcode_mapping[num_bits]
sawb_code = SAWBcode_mapping[num_bits_int]
clip_val, _ = sawb_params_code(
num_bits, sawb_code, input_tensor, perCh=True
)
Expand Down Expand Up @@ -549,9 +552,13 @@ def forward(
clip_val.dtype
) # NOTE return will be a fp32 tensor; function only support float()
else:
output = torch.quantize_per_channel(
input_tensor, scale, zero_point, 0, torch.qint8
).int_repr()
output = (
torch.quantize_per_channel(
input_tensor, scale, zero_point, 0, torch.qint8
)
.int_repr()
.clamp(int_l, int_u)
)
# NOTE return will be a torch.int8 tensor

return output
Expand Down Expand Up @@ -2537,7 +2544,7 @@ def asymmetric_linear_quantization_params(
return scale, zero_point


def clamp(input_tensor: torch.FloatTensor, clamp_min, clamp_max, inplace=False):
def clamp(input_tensor: torch.Tensor, clamp_min, clamp_max, inplace=False):
"""
Returns:
Clamped Torch Tensor.
Expand Down
114 changes: 114 additions & 0 deletions fms_mo/quant_refactor/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base QModel Class
"""

# Standard
# pylint: disable=keyword-arg-before-vararg
import logging

# Third Party
import torch

logger = logging.getLogger(__name__)


class Qmodel: # do not inherit nn.Module, or self.model will not show up in __dict__
"""
A wrapper for fms_mo model, mainly for user API simplification purpose.
Everything is the same as the original model, but we can add new member functions.
Make sure the naming will be unique enough so that we won't override any existing functions
in Huggingface models.
"""

def __init__(self, original_model, qcfg=None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.org_attr = dir(original_model)
self.model = original_model
if qcfg:
self.qcfg = qcfg

def __getattr__(self, name: str):
if name in self.org_attr:
logger.info(f"Trying to access self.{name}, forward to self.model.{name}")
return getattr(self.model, name)
# NOTE: this self.model is in __dict__, so it will not trigger __getattr__
# recursively.!!

def __call__(self, *args, **kwargs):
logger.info(
"Make this object callable, but actually just calling self.model.__call__()"
)
return self.model(*args, **kwargs)

def __repr__(self):
OKCYAN = "\033[96m"
ENDC = "\033[0m"
rep_txt = f"{OKCYAN}FMSMO_Qmodel_wrapper({ENDC}\n{self.model.__repr__()}{OKCYAN}){ENDC}"
return rep_txt

def to(self, tar_dev: torch.device):
"""
Demonstrate that we can override a function in original model
it should not call __getattr__(), i.e. will not see the printout from that func

Args:
tar_dev (torch.device): A new device

Returns:
Qmodel: Moved model to tar_dev
"""
logger.info(
f"Override a function in original model. moving the model to a new device {tar_dev}"
)
return self.model.to(tar_dev)

def save_model_in_pt_fmt(
self, filename: str = "model.pt", exam_inp: torch.Tensor = None
):
"""
Save entire model to file

Args:
filename (str, optional): File path to save model. Defaults to "model.pt".
exam_inp (torch.Tensor, optional): Example input for model. Defaults to None.
"""
# NOTE self.qcfg has a lot of info already, like transformers_version
# NOTE cannot save wrapped self, can only save self.model...
save_dict = {"model": self.model}
if exam_inp:
save_dict["exam_inp"] = exam_inp
torch.save(save_dict, filename)
logger.info(f"{filename} saved successfully.")

def save_statedict_in_pt_fmt(
self,
filename: str = "model.pt",
):
"""
Save the model state dict to file

Args:
filename (str, optional): File path to save model state dict. Defaults to "model.pt".
"""
torch.save(self.model.state_dict(), filename)
logger.info(f"model.state_dict() is saved to {filename} successfully.")

def run_gptq(self):
"""
Check model is supported by AutoGPTQ first
"""
return
Loading
Loading