Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
53f9857
toy exam under examples/MX
chichun-charlie-liu Mar 27, 2025
a59010e
include mx Linear class directly for convenience
chichun-charlie-liu Mar 27, 2025
b6b1d4f
minor clean up
chichun-charlie-liu Mar 28, 2025
027eb2f
[pyproject] Added MX optional-dependency
BrandonGroth Mar 26, 2025
02a9c2d
[qcfg] Changed mapping init and added mx path in qconfig_init
BrandonGroth Mar 27, 2025
2bd90e9
[QLinear] Added QLinearMX stub
BrandonGroth Mar 27, 2025
cc681ee
[qcfg] Changed qcfg[mapping] structure
BrandonGroth Mar 28, 2025
62962a3
[mx] Updates to ffn_tmp.py
BrandonGroth Mar 28, 2025
a79ebbf
clean-up, now simply use qa_mode and qw_mode to trigger mx
chichun-charlie-liu Mar 29, 2025
8bf30ac
first mx test with "dq" using qa_/qw_mode=mx_fp8_e4m3
chichun-charlie-liu Mar 30, 2025
ac03469
add QbmmMX support
chichun-charlie-liu Apr 1, 2025
7f6ea36
draft example readme for mx
chichun-charlie-liu Apr 1, 2025
613e523
adjust mapping def, allow qbmm to use non-mx while linear is using mx…
chichun-charlie-liu Apr 1, 2025
9ea77da
bug fix, QBmmMX init was calling QBmmMX init which tries to get a non…
chichun-charlie-liu Apr 1, 2025
91c20cd
feat: Added mx_specs to qconfig_init and check_config
BrandonGroth Apr 22, 2025
8b1d026
feat: Added skippable mx_specs unit tests
BrandonGroth Apr 22, 2025
a161b9a
fix: Added QLinearMX and QBmmMX to quantized_modules
BrandonGroth Apr 29, 2025
9ffa0b1
chore: format qconfig_utils
BrandonGroth Apr 29, 2025
a9a57b3
test: minor fixes to test_mx and supporting functions
BrandonGroth Apr 29, 2025
764d7bc
Merge pull request #1 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu Apr 29, 2025
45dd501
Merge branch 'main' into mx_impl
chichun-charlie-liu Apr 29, 2025
9bab835
lint and ruff fix
chichun-charlie-liu Apr 29, 2025
40a5c0c
chore: Added license to ffn_tmp.py and other minor fixes
BrandonGroth May 5, 2025
9ba1a42
feat: Added install_patches.py for fms_mo patch installs
BrandonGroth May 2, 2025
ac0d2ed
feat: Added patches directory and microxcaling.patch
BrandonGroth May 2, 2025
9f3713e
feat: Added README.md for patching instructions
BrandonGroth May 2, 2025
352e77d
doc: Added instructions to create a patch file
BrandonGroth May 6, 2025
303454c
Merge pull request #2 from BrandonGroth/mx_impl_patch
chichun-charlie-liu May 6, 2025
d7dfc27
rename mx example py, markdown update is still WIP
chichun-charlie-liu May 14, 2025
3446413
markdown update is still WIP
chichun-charlie-liu May 14, 2025
dd17104
fix bug in config check as QBmmMX is optional
chichun-charlie-liu May 21, 2025
6e4431c
feat: Added qmodel_prep mx_specs hook for default config with MX sett…
BrandonGroth May 21, 2025
83d2b06
test: Added MX test for default config + qmodel_prep
BrandonGroth May 21, 2025
d05401f
fix: Added MX vocabulary for spellcheck util
BrandonGroth May 21, 2025
7daae3b
chore: Fixed MX vocab errors
BrandonGroth May 21, 2025
5694a5a
fix: Changed check_config check for MX mapping
BrandonGroth May 22, 2025
308e5e2
Merge pull request #3 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu May 22, 2025
e5eb0ed
doc: Updated MX examples/readme for patch install
BrandonGroth May 23, 2025
ab060ca
update mx readme.md, fix simple_mx_exam.py together with a few minor …
chichun-charlie-liu May 27, 2025
9b9b777
Merge branch 'main' into mx_impl
chichun-charlie-liu May 27, 2025
2cb9aca
add a microscaling ref in readme
chichun-charlie-liu May 27, 2025
7e7a0fa
fix: Patched qconfig_init and get_unwanted_defaults QBmm
BrandonGroth May 27, 2025
f220b17
fix: Updated spellcheck list for examples/mx readme
BrandonGroth May 27, 2025
026a3ea
fix: Fixed additional Qbmm mapping in test_saveconfig
BrandonGroth May 27, 2025
b00287f
minor updates on README
chichun-charlie-liu May 27, 2025
b5f0eaa
chore: Cleanup of extra comments
BrandonGroth May 28, 2025
4e0de3f
adjust gitignore to allow changes in py and md files
chichun-charlie-liu May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .spellcheck-en-custom.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,7 @@ venv
vllm
xs
zp

microxcaling
MX
MXINT
MXFP
190 changes: 190 additions & 0 deletions examples/MX/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ)
Comment thread
chichun-charlie-liu marked this conversation as resolved.
Here, we provide two simple examples of using MX format in `fms-mo`.
"MX format", such as `MXFP8`, is a different format compared to typical IEEE formats, e.g. PyTorch FP8s (`e4m3` or `e5m2`, see our other [FP8 example](../FP8_QUANT/README.md).) Mainly all the `mx` format are group-based where each member of the group is using the specified format, e.g. FP8 for MXFP8 while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design.
> [!NOTE]
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated, i.e. no real "speed up" should be expected.


## Requirements
- [FMS Model Optimizer requirements](../../README.md#requirements)
- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git).
> [!TIP]
> After git clone BEFORE installation, first update `pyproject.toml` to remove the version constraints in [dependencies]. (simply comment those three lines out.) Then use `pip install -e .` as usual.
Comment thread
BrandonGroth marked this conversation as resolved.
Outdated

## QuickStart

First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow

```bash
>>> python simple_mx_example.py
```
Expected output includes:
```bash

```

The second example is the same as in the [DQ](../DQ_SQ/README.md) folder, except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We demonstrate the effect of MXINT8, MXFP8, MXFP6, MXFP4 for weights, activations, and/or KV-cache.

**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.

```python
from transformers import AutoTokenizer
from fms_mo.utils.calib_data import get_tokenized_data

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
num_samples = 128
seq_len = 2048
get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data')
```
> [!NOTE]
> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting.
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead

**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below.
```bash
python -m fms_mo.run_quant \
--model_name_or_path "meta-llama/Meta-Llama-3-8B" \
--training_data_path data_train \
--test_data_path data_test \
--torch_dtype "float16" \
--quant_method dq \
--nbits_w 8 \
--nbits_a 8 \
--nbits_kvcache 32 \
--qa_mode "mx_fp8_e4m3"\
--qw_mode "mx_fp8_e4m3" \
--output_dir "dq_test" \
--eval_ppl
```
> [!TIP]
> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples.

**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`.

# *TO BE UPDATED BELOW THIS LINE*


## Example Test Results
The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below:

| Model |Type |QA |QW |DQ |SQ |Perplexity|
|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:|
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.21 |
| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 |

## Code Walk-through

**1. KV caching**

In large language models (LLMs), key/value pairs are frequently cached during token generation, a process known as KV caching, to prevent redundant computations due to the autoregressive nature of token generation. However, the size of the KV cache increases with both batch size and context length, which can slow down model inference due to the need to access a large amount of data in memory. Quantizing the KV cache effectively reduces this memory bandwidth limitation, improving inference speed. To study the quantization behavior of KV cache, we can simply set the `nbits_kvcache` argument to 8-bit, then the KV cache will be quantized together with weights and activations. In addition, the `bmm1_qm1_mode`, `bmm1_qm2_mode`, and `bmm2_qm2_mode` [arguments](../../fms_mo/training_args.py) must be set to the same quantizer mode as `qa_mode`. **NOTE**: `bmm2_qm1_mode` should be kept as `minmax`.

The effect of setting the `nbits_kvcache` to 8 and its relevant code sections are:

- Enables eager attention for the quantization of attention operations, including KV cache.
```python
# For attention or kv-cache quantization, need to use eager attention
attn_bits = [fms_mo_args.nbits_bmm1, fms_mo_args.nbits_bmm2, fms_mo_args.nbits_kvcache]
if any(attn_bits) != 32:
attn_implementation = "eager"
else:
attn_implementation = None
```
- Enables Dynamo for quantized model preparation. We use PyTorch's Dynamo tracer to identify the bmm and KV cache inside the attention block.
```python
if any(x != 32 for x in attn_bits):
logger.info("Quantize attention bmms or kvcache, use dynamo for prep")
use_layer_name_pattern_matching = False
qcfg["qlayer_name_pattern"] = []
assert (
qcfg["qlayer_name_pattern"] == []
), "ensure nothing in qlayer_name_pattern when use dynamo"
use_dynamo = True
else:
logger.info("Do not quantize attention bmms")
use_layer_name_pattern_matching = True
use_dynamo = False
```

**2. Define quantization config** including quantizers and hyperparameters. Here we simply use the default [dq recipe](../../fms_mo/recipies/dq.json).

```python
qcfg = qconfig_init(recipe="dq",args=fms_mo_args)
```

**3. Obtain activation scales for SmoothQuant (SQ)**

``` python
# For loading or creating smoothquant scale.
act_scale_directory = "./act_scales"
if not os.path.exists(act_scale_directory):
os.makedirs(act_scale_directory)

if qcfg["act_scale_path"] is not None:
act_scales = torch.load(qcfg["act_scale_path"], map_location="cpu")
else:
logger.info("Generate activation scales")
if qcfg["large_model"]:
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
else:
act_scales = get_act_scales(model, dq_dataloader, qcfg)
scale_file = f"{act_scale_directory}/{qcfg['model'].replace('/', '-')}" + ".pt"
torch.save(act_scales, scale_file)
```

**4. Prepare the quantized model and attach activation scales** to quantized modules

```python
qmodel_prep(
model,
dq_dataloader,
qcfg,
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
use_dynamo=use_dynamo,
dev=dev,
save_fname='test'
)

dq_llm(model, act_scales, qcfg)
```

**5. Perform direct quantization** by calibrating quantizers (clip_vals)

``` python
if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU(qcfg, model, calibration_dataset)
else:
model.to("cuda:0")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
total=qcfg["qmodel_calibration_new"],
)
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
data_mb = prepare_input(model.device, data_mb)
with patch_torch_bmm(qcfg):
model(**data_mb)

logger.info(f"Saving quantized model and tokenizer to {output_dir}")
model.save_pretrained(output_dir, use_safetensors=True)
tokenizer.save_pretrained(output_dir)
```

**6. Check perplexity** (simple method to evaluate the model quality)

``` python
if fms_mo_args.eval_ppl:
logger.info(f"Model for evaluation: {model}")
if qcfg["large_model"]:
eval_llm_1GPU(qcfg, model, test_dataset)
else:
model.to(torch.device("cuda:0"))
n_samples = int(test_dataset.input_ids.shape[1] / block_size)
evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples)
ppl = evaluator.evaluate(model, block_size=block_size)
logger.info(f"Model perplexity: {ppl}")
logger.info("-" * 50)
logger.info("Finished evaluation")
```
110 changes: 110 additions & 0 deletions examples/MX/simple_mx_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple example using a toy model to demo how to trigger mx in fms-mo."""

# Third Party
import numpy as np
import torch
import torch.nn.functional as F


class ResidualMLP(torch.nn.Module):
def __init__(self, hidden_size, device="cuda"):
super(ResidualMLP, self).__init__()

self.layernorm = torch.nn.LayerNorm(hidden_size, device=device)
self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device)
self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device)
self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device)
# add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped

def forward(self, inputs):
norm_outputs = self.layernorm(inputs)

# MLP
proj_outputs = self.dense_4h(norm_outputs)
proj_outputs = F.gelu(proj_outputs)
mlp_outputs = self.dense_h(proj_outputs)
mlp_outputs = self.dummy(mlp_outputs)

# Residual Connection
outputs = inputs + mlp_outputs

return outputs


if __name__ == "__main__":
# Third Party
from tabulate import tabulate

# Local
from fms_mo import qconfig_init, qmodel_prep

HIDDEN_DIM = 128
x = np.random.randn(16, HIDDEN_DIM)
x = torch.tensor(x, dtype=torch.float32, device="cuda")
results = {"dtype": [], "output[0, :5]": [], "||ref - out_dtype||_2": []}

# --- Test 0. Run MLP as is
mlp = ResidualMLP(HIDDEN_DIM)
# mlp.to("cuda")
with torch.no_grad():
out = mlp(x)
results["dtype"].append("fp32")
results["output[0, :5]"].append(out[0, :5].tolist())
results["||ref - out_dtype||_2"].append("-")
print(mlp)

# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
qcfg = qconfig_init()
qcfg["nbits_a"] = 8
qcfg["nbits_w"] = 8
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fms_int8")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
# print(model)

qcfg["nbits_a"] = 4
qcfg["nbits_w"] = 4
mlp = ResidualMLP(HIDDEN_DIM)
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fms_int4")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

# --- Test 2. now change mapping to MX
# NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode,
# qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically

for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
mlp = ResidualMLP(HIDDEN_DIM) # fresh model
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append(f"mx{dtype_to_test}")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

print(tabulate(results, headers="keys"))

print("DONE!")
Loading