Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
53f9857
toy exam under examples/MX
chichun-charlie-liu Mar 27, 2025
a59010e
include mx Linear class directly for convenience
chichun-charlie-liu Mar 27, 2025
b6b1d4f
minor clean up
chichun-charlie-liu Mar 28, 2025
027eb2f
[pyproject] Added MX optional-dependency
BrandonGroth Mar 26, 2025
02a9c2d
[qcfg] Changed mapping init and added mx path in qconfig_init
BrandonGroth Mar 27, 2025
2bd90e9
[QLinear] Added QLinearMX stub
BrandonGroth Mar 27, 2025
cc681ee
[qcfg] Changed qcfg[mapping] structure
BrandonGroth Mar 28, 2025
62962a3
[mx] Updates to ffn_tmp.py
BrandonGroth Mar 28, 2025
a79ebbf
clean-up, now simply use qa_mode and qw_mode to trigger mx
chichun-charlie-liu Mar 29, 2025
8bf30ac
first mx test with "dq" using qa_/qw_mode=mx_fp8_e4m3
chichun-charlie-liu Mar 30, 2025
ac03469
add QbmmMX support
chichun-charlie-liu Apr 1, 2025
7f6ea36
draft example readme for mx
chichun-charlie-liu Apr 1, 2025
613e523
adjust mapping def, allow qbmm to use non-mx while linear is using mx…
chichun-charlie-liu Apr 1, 2025
9ea77da
bug fix, QBmmMX init was calling QBmmMX init which tries to get a non…
chichun-charlie-liu Apr 1, 2025
91c20cd
feat: Added mx_specs to qconfig_init and check_config
BrandonGroth Apr 22, 2025
8b1d026
feat: Added skippable mx_specs unit tests
BrandonGroth Apr 22, 2025
a161b9a
fix: Added QLinearMX and QBmmMX to quantized_modules
BrandonGroth Apr 29, 2025
9ffa0b1
chore: format qconfig_utils
BrandonGroth Apr 29, 2025
a9a57b3
test: minor fixes to test_mx and supporting functions
BrandonGroth Apr 29, 2025
764d7bc
Merge pull request #1 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu Apr 29, 2025
45dd501
Merge branch 'main' into mx_impl
chichun-charlie-liu Apr 29, 2025
9bab835
lint and ruff fix
chichun-charlie-liu Apr 29, 2025
40a5c0c
chore: Added license to ffn_tmp.py and other minor fixes
BrandonGroth May 5, 2025
9ba1a42
feat: Added install_patches.py for fms_mo patch installs
BrandonGroth May 2, 2025
ac0d2ed
feat: Added patches directory and microxcaling.patch
BrandonGroth May 2, 2025
9f3713e
feat: Added README.md for patching instructions
BrandonGroth May 2, 2025
352e77d
doc: Added instructions to create a patch file
BrandonGroth May 6, 2025
303454c
Merge pull request #2 from BrandonGroth/mx_impl_patch
chichun-charlie-liu May 6, 2025
d7dfc27
rename mx example py, markdown update is still WIP
chichun-charlie-liu May 14, 2025
3446413
markdown update is still WIP
chichun-charlie-liu May 14, 2025
dd17104
fix bug in config check as QBmmMX is optional
chichun-charlie-liu May 21, 2025
6e4431c
feat: Added qmodel_prep mx_specs hook for default config with MX sett…
BrandonGroth May 21, 2025
83d2b06
test: Added MX test for default config + qmodel_prep
BrandonGroth May 21, 2025
d05401f
fix: Added MX vocabulary for spellcheck util
BrandonGroth May 21, 2025
7daae3b
chore: Fixed MX vocab errors
BrandonGroth May 21, 2025
5694a5a
fix: Changed check_config check for MX mapping
BrandonGroth May 22, 2025
308e5e2
Merge pull request #3 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu May 22, 2025
e5eb0ed
doc: Updated MX examples/readme for patch install
BrandonGroth May 23, 2025
ab060ca
update mx readme.md, fix simple_mx_exam.py together with a few minor …
chichun-charlie-liu May 27, 2025
9b9b777
Merge branch 'main' into mx_impl
chichun-charlie-liu May 27, 2025
2cb9aca
add a microscaling ref in readme
chichun-charlie-liu May 27, 2025
7e7a0fa
fix: Patched qconfig_init and get_unwanted_defaults QBmm
BrandonGroth May 27, 2025
f220b17
fix: Updated spellcheck list for examples/mx readme
BrandonGroth May 27, 2025
026a3ea
fix: Fixed additional Qbmm mapping in test_saveconfig
BrandonGroth May 27, 2025
b00287f
minor updates on README
chichun-charlie-liu May 27, 2025
b5f0eaa
chore: Cleanup of extra comments
BrandonGroth May 28, 2025
4e0de3f
adjust gitignore to allow changes in py and md files
chichun-charlie-liu May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,9 @@ fms_mo.log
data*_train/
data*_test/
act_scales/
examples/
examples/**/*.json
examples/**/*.safetensors
examples/**/*.log
examples/**/*.sh
examples/**/*.pt
examples/**/*.arrow
12 changes: 12 additions & 0 deletions .spellcheck-en-custom.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ dequantization
dq
DQ
dev
dtype
eval
fms
fmsmo
fp
FP
FP8Arguments
Expand Down Expand Up @@ -125,3 +127,13 @@ venv
vllm
xs
zp
microxcaling
Microscaling
microscaling
MX
mx
MXINT
mxint
MXFP
mxfp
OCP
98 changes: 98 additions & 0 deletions examples/MX/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ)
Comment thread
chichun-charlie-liu marked this conversation as resolved.
Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits. More details about microscaling can be found in [this OCP document](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).

Here, we provide two simple examples of using MX format in `fms-mo`.

> [!NOTE]
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected.


## Requirements
- [FMS Model Optimizer requirements](../../README.md#requirements)
- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git).
> [!TIP]
> `FMS-Model-Optimizer` and `microxcaling` have clashing dependency requirements for `PyTorch` packages. We have created a patching solution to resolve this, run the following in command line:
``` bash
python3 ../install_patches.py
```
This patching file will either download the repo for you, or look for an already installed version in `$HOME` or the current working directory, then install the patch.
For more information, see `patches/README.md`.

## QuickStart

### Example 1
First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow

```bash
>>> python simple_mx_example.py
```

Comparison between different formats, including the first 3 elements from output tensors and the norm compared to FP32 reference, is shown below.

| dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|<sub>2</sub> |
|:-----------|---------------:|---------------:|---------------:|------------------------:|
| fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 |
| fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 |
| fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 |
| mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 |
| mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 |
| mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 |
| mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 |
| mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 |


### Example 2
The second example is the same as the [DQ example](../DQ_SQ/README.md), except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We only demonstrate `mxfp8` and `mxfp4` here, but MXINT8, MXFP8, MXFP6, MXFP4 are also available for weights, activations, and/or KV-cache.

**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.

```python
from transformers import AutoTokenizer
from fms_mo.utils.calib_data import get_tokenized_data

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
num_samples = 128
seq_len = 2048
get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data')
```
> [!NOTE]
> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting.
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead

**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below.
```bash
python -m fms_mo.run_quant \
--model_name_or_path "meta-llama/Meta-Llama-3-8B" \
--training_data_path data_train \
--test_data_path data_test \
--torch_dtype "float16" \
--quant_method dq \
--nbits_w 8 \
--nbits_a 8 \
--nbits_kvcache 32 \
--qa_mode "mx_fp8_e4m3"\
--qw_mode "mx_fp8_e4m3" \
--output_dir "dq_test" \
--eval_ppl
```
> [!NOTE]
> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples.

**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`.


## Example Test Results
The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below:

| Model |Type |QA |QW |DQ |SQ |Perplexity|
|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:|
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 |
| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 |
| |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 |
| |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 |


> [!NOTE]
> SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details.

123 changes: 123 additions & 0 deletions examples/MX/simple_mx_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple example using a toy model to demo how to trigger mx in fms-mo."""

# Third Party
import numpy as np
import torch
import torch.nn.functional as F


class ResidualMLP(torch.nn.Module):
def __init__(self, hidden_size, device="cuda"):
super(ResidualMLP, self).__init__()

self.layernorm = torch.nn.LayerNorm(hidden_size, device=device)
self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device)
self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device)
self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device)
# add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped

def forward(self, inputs):
norm_outputs = self.layernorm(inputs)

# MLP
proj_outputs = self.dense_4h(norm_outputs)
proj_outputs = F.gelu(proj_outputs)
mlp_outputs = self.dense_h(proj_outputs)
mlp_outputs = self.dummy(mlp_outputs)

# Residual Connection
outputs = inputs + mlp_outputs

return outputs


if __name__ == "__main__":
# Third Party
from tabulate import tabulate

# Local
from fms_mo import qconfig_init, qmodel_prep

HIDDEN_DIM = 128
x = np.random.randn(16, HIDDEN_DIM)
x = torch.tensor(x, dtype=torch.float32, device="cuda")
results = {
"dtype": [],
"output[0, 0]": [],
"output[0, 1]": [],
"output[0, 2]": [],
"||ref - out_dtype||_2": [],
}

# --- Test 0. Run MLP as is
model = ResidualMLP(HIDDEN_DIM)
with torch.no_grad():
out = model(x)
results["dtype"].append("fp32")
results["output[0, 0]"].append(out[0, 0])
results["output[0, 1]"].append(out[0, 1])
results["output[0, 2]"].append(out[0, 2])
results["||ref - out_dtype||_2"].append(0)
print(model)

# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
qcfg = qconfig_init()
qcfg["nbits_a"] = 8
qcfg["nbits_w"] = 8
qmodel_prep(model, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fmsmo_int8")
results["output[0, 0]"].append(out_dtype[0, 0])
results["output[0, 1]"].append(out_dtype[0, 1])
results["output[0, 2]"].append(out_dtype[0, 2])
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

qcfg["nbits_a"] = 4
qcfg["nbits_w"] = 4
model = ResidualMLP(HIDDEN_DIM)
qmodel_prep(model, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fmsmo_int4")
results["output[0, 0]"].append(out_dtype[0, 0])
results["output[0, 1]"].append(out_dtype[0, 1])
results["output[0, 2]"].append(out_dtype[0, 2])
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

# --- Test 2. now change mapping to MX
# NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode,
# qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically

for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
model = ResidualMLP(HIDDEN_DIM) # fresh model
qmodel_prep(model, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append(f"mx{dtype_to_test}")
results["output[0, 0]"].append(out_dtype[0, 0])
results["output[0, 1]"].append(out_dtype[0, 1])
results["output[0, 2]"].append(out_dtype[0, 2])
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

print(tabulate(results, headers="keys", tablefmt="pipe", floatfmt=".4f"))

print("DONE!")
55 changes: 28 additions & 27 deletions fms_mo/dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,21 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
config_quantize_smooth_layers(qcfg)

if any(x != 32 for x in attn_bits):
logger.info("Quantize attention bmms or kvcache, use dynamo for prep")
logger.info("Quantize attention bmms or kvcache, will use dynamo for prep")
use_layer_name_pattern_matching = False
qcfg["qlayer_name_pattern"] = []
assert (
qcfg["qlayer_name_pattern"] == []
), "ensure nothing in qlayer_name_pattern when use dynamo"
use_dynamo = True
else:
logger.info("Do not quantize attention bmms")
logger.info("Attention bmms will not be quantized.")
use_layer_name_pattern_matching = True
use_dynamo = False

qcfg["seq_len"] = block_size
qcfg["model"] = model_args.model_name_or_path
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg
qcfg["plotsvg"] = False

calibration_dataset = load_from_disk(data_args.training_data_path)
Expand All @@ -187,31 +187,32 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
)

# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
if qcfg.get("act_scale_path", None):
# user provided a scale file (or a dir)
scale_file_or_dir = Path(qcfg["act_scale_path"])
if scale_file_or_dir.is_dir():
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
elif scale_file_or_dir.is_file():
scale_file = scale_file_or_dir

if not scale_file.parent.exists():
scale_file.parent.mkdir(exist_ok=False)

if scale_file.exists():
act_scales = torch.load(
scale_file,
map_location=getattr(model, "device", dev),
weights_only=True,
)
else:
logger.info("Generate activation scales")
if qcfg["large_model"]:
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
if qcfg["smoothq"]:
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
if qcfg.get("act_scale_path", None):
# user provided a scale file (or a dir)
scale_file_or_dir = Path(qcfg["act_scale_path"])
if scale_file_or_dir.is_dir():
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
elif scale_file_or_dir.is_file():
scale_file = scale_file_or_dir

if not scale_file.parent.exists():
scale_file.parent.mkdir(exist_ok=False)

if scale_file.exists():
act_scales = torch.load(
scale_file, map_location=getattr(model, "device", dev)
)

else:
act_scales = get_act_scales(model, dq_dataloader, qcfg)
torch.save(act_scales, scale_file)
logger.info("Generate activation scales")
if qcfg["large_model"]:
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
else:
act_scales = get_act_scales(model, dq_dataloader, qcfg)
torch.save(act_scales, scale_file)

qmodel_prep(
model,
dq_dataloader,
Expand Down
2 changes: 1 addition & 1 deletion fms_mo/fx/dynamo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def call_seq_hook(mod, *_args, **_kwargs):
# b) qbmm creation and attaching to model
if qcfg.get("QBmm"): # see Note 4
# Local
from fms_mo.modules import QBmm
QBmm = qcfg["mapping"]["matmul_or_bmm"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be lower case variable name?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a class pointer disguised as a variable and used as QBmm(args) below. I think it is fine as is.


qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
"which2patch_contextmanager"
Expand Down
6 changes: 0 additions & 6 deletions fms_mo/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,6 @@ def lower_qmodel_to_ext_kernels(
QLinearExllamaV2,
)

qclass_accepted = []
for map_dict in qcfg["mapping"].values():
qclass_accepted.append(map_dict["to"])
qclass_accepted.append(map_dict.get("otherwise", None))

mod2swap = {
n: m
for n, m in mod.named_modules()
Expand Down Expand Up @@ -498,7 +493,6 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
)
),
)

if show_details:
logger_or_print(df_summary_weights.to_markdown())

Expand Down
Loading
Loading