Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
53f9857
toy exam under examples/MX
chichun-charlie-liu Mar 27, 2025
a59010e
include mx Linear class directly for convenience
chichun-charlie-liu Mar 27, 2025
b6b1d4f
minor clean up
chichun-charlie-liu Mar 28, 2025
027eb2f
[pyproject] Added MX optional-dependency
BrandonGroth Mar 26, 2025
02a9c2d
[qcfg] Changed mapping init and added mx path in qconfig_init
BrandonGroth Mar 27, 2025
2bd90e9
[QLinear] Added QLinearMX stub
BrandonGroth Mar 27, 2025
cc681ee
[qcfg] Changed qcfg[mapping] structure
BrandonGroth Mar 28, 2025
62962a3
[mx] Updates to ffn_tmp.py
BrandonGroth Mar 28, 2025
a79ebbf
clean-up, now simply use qa_mode and qw_mode to trigger mx
chichun-charlie-liu Mar 29, 2025
8bf30ac
first mx test with "dq" using qa_/qw_mode=mx_fp8_e4m3
chichun-charlie-liu Mar 30, 2025
ac03469
add QbmmMX support
chichun-charlie-liu Apr 1, 2025
7f6ea36
draft example readme for mx
chichun-charlie-liu Apr 1, 2025
613e523
adjust mapping def, allow qbmm to use non-mx while linear is using mx…
chichun-charlie-liu Apr 1, 2025
9ea77da
bug fix, QBmmMX init was calling QBmmMX init which tries to get a non…
chichun-charlie-liu Apr 1, 2025
91c20cd
feat: Added mx_specs to qconfig_init and check_config
BrandonGroth Apr 22, 2025
8b1d026
feat: Added skippable mx_specs unit tests
BrandonGroth Apr 22, 2025
a161b9a
fix: Added QLinearMX and QBmmMX to quantized_modules
BrandonGroth Apr 29, 2025
9ffa0b1
chore: format qconfig_utils
BrandonGroth Apr 29, 2025
a9a57b3
test: minor fixes to test_mx and supporting functions
BrandonGroth Apr 29, 2025
764d7bc
Merge pull request #1 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu Apr 29, 2025
45dd501
Merge branch 'main' into mx_impl
chichun-charlie-liu Apr 29, 2025
9bab835
lint and ruff fix
chichun-charlie-liu Apr 29, 2025
40a5c0c
chore: Added license to ffn_tmp.py and other minor fixes
BrandonGroth May 5, 2025
9ba1a42
feat: Added install_patches.py for fms_mo patch installs
BrandonGroth May 2, 2025
ac0d2ed
feat: Added patches directory and microxcaling.patch
BrandonGroth May 2, 2025
9f3713e
feat: Added README.md for patching instructions
BrandonGroth May 2, 2025
352e77d
doc: Added instructions to create a patch file
BrandonGroth May 6, 2025
303454c
Merge pull request #2 from BrandonGroth/mx_impl_patch
chichun-charlie-liu May 6, 2025
d7dfc27
rename mx example py, markdown update is still WIP
chichun-charlie-liu May 14, 2025
3446413
markdown update is still WIP
chichun-charlie-liu May 14, 2025
dd17104
fix bug in config check as QBmmMX is optional
chichun-charlie-liu May 21, 2025
6e4431c
feat: Added qmodel_prep mx_specs hook for default config with MX sett…
BrandonGroth May 21, 2025
83d2b06
test: Added MX test for default config + qmodel_prep
BrandonGroth May 21, 2025
d05401f
fix: Added MX vocabulary for spellcheck util
BrandonGroth May 21, 2025
7daae3b
chore: Fixed MX vocab errors
BrandonGroth May 21, 2025
5694a5a
fix: Changed check_config check for MX mapping
BrandonGroth May 22, 2025
308e5e2
Merge pull request #3 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu May 22, 2025
e5eb0ed
doc: Updated MX examples/readme for patch install
BrandonGroth May 23, 2025
ab060ca
update mx readme.md, fix simple_mx_exam.py together with a few minor …
chichun-charlie-liu May 27, 2025
9b9b777
Merge branch 'main' into mx_impl
chichun-charlie-liu May 27, 2025
2cb9aca
add a microscaling ref in readme
chichun-charlie-liu May 27, 2025
7e7a0fa
fix: Patched qconfig_init and get_unwanted_defaults QBmm
BrandonGroth May 27, 2025
f220b17
fix: Updated spellcheck list for examples/mx readme
BrandonGroth May 27, 2025
026a3ea
fix: Fixed additional Qbmm mapping in test_saveconfig
BrandonGroth May 27, 2025
b00287f
minor updates on README
chichun-charlie-liu May 27, 2025
b5f0eaa
chore: Cleanup of extra comments
BrandonGroth May 28, 2025
4e0de3f
adjust gitignore to allow changes in py and md files
chichun-charlie-liu May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions examples/MX/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Direct Quantization (DQ) Using `microscaling`
This is the same example as in the [DQ](../DQ_SQ/README.md) folder, except using [microscaling](https://arxiv.org/abs/2310.10537) format.

Here, we provide an example of direct quantization. In this case, we demonstrate DQ of `llama3-8b` model into MXINT8, MXFP8, MXFP6, MXFP4 for weights, activations, and/or KV-cache. Note that `MXFP8` is a different format compared to typical PyTorch FP8s (e4m3 or e5m2), see our other [FP8 example](../FP8_QUANT/README.md). Mainly all the `mx` format are not natively supported by Hopper yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated, no "speed up" should be expected.

## Requirements
- [FMS Model Optimizer requirements](../../README.md#requirements)
- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git).
> [!TIP]
> After git clone BEFORE installation, first update `pyproject.toml` to remove the version constraints in [dependencies]. (simply comment those three lines out.) Then use `pip install -e .` as usual.
Comment thread
BrandonGroth marked this conversation as resolved.
Outdated

## QuickStart

**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.

```python
from transformers import AutoTokenizer
from fms_mo.utils.calib_data import get_tokenized_data

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
num_samples = 128
seq_len = 2048
get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data')
```
> [!NOTE]
> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting.
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead

**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below.
```bash
python -m fms_mo.run_quant \
--model_name_or_path "meta-llama/Meta-Llama-3-8B" \
--training_data_path data_train \
--test_data_path data_test \
--torch_dtype "float16" \
--quant_method dq \
--nbits_w 8 \
--nbits_a 8 \
--nbits_kvcache 32 \
--qa_mode "mx_fp8_e4m3"\
--qw_mode "mx_fp8_e4m3" \
--output_dir "dq_test" \
--eval_ppl
```
> [!TIP]
> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples.

**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`.

# *TO BE UPDATED BELOW THIS LINE*


## Example Test Results
The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below:

| Model |Type |QA |QW |DQ |SQ |Perplexity|
|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:|
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.21 |
| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 |

## Code Walk-through

**1. KV caching**

In large language models (LLMs), key/value pairs are frequently cached during token generation, a process known as KV caching, to prevent redundant computations due to the autoregressive nature of token generation. However, the size of the KV cache increases with both batch size and context length, which can slow down model inference due to the need to access a large amount of data in memory. Quantizing the KV cache effectively reduces this memory bandwidth limitation, improving inference speed. To study the quantization behavior of KV cache, we can simply set the `nbits_kvcache` argument to 8-bit, then the KV cache will be quantized together with weights and activations. In addition, the `bmm1_qm1_mode`, `bmm1_qm2_mode`, and `bmm2_qm2_mode` [arguments](../../fms_mo/training_args.py) must be set to the same quantizer mode as `qa_mode`. **NOTE**: `bmm2_qm1_mode` should be kept as `minmax`.

The effect of setting the `nbits_kvcache` to 8 and its relevant code sections are:

- Enables eager attention for the quantization of attention operations, including KV cache.
```python
# For attention or kv-cache quantization, need to use eager attention
attn_bits = [fms_mo_args.nbits_bmm1, fms_mo_args.nbits_bmm2, fms_mo_args.nbits_kvcache]
if any(attn_bits) != 32:
attn_implementation = "eager"
else:
attn_implementation = None
```
- Enables Dynamo for quantized model preparation. We use PyTorch's Dynamo tracer to identify the bmm and KV cache inside the attention block.
```python
if any(x != 32 for x in attn_bits):
logger.info("Quantize attention bmms or kvcache, use dynamo for prep")
use_layer_name_pattern_matching = False
qcfg["qlayer_name_pattern"] = []
assert (
qcfg["qlayer_name_pattern"] == []
), "ensure nothing in qlayer_name_pattern when use dynamo"
use_dynamo = True
else:
logger.info("Do not quantize attention bmms")
use_layer_name_pattern_matching = True
use_dynamo = False
```

**2. Define quantization config** including quantizers and hyperparameters. Here we simply use the default [dq recipe](../../fms_mo/recipies/dq.json).

```python
qcfg = qconfig_init(recipe="dq",args=fms_mo_args)
```

**3. Obtain activation scales for SmoothQuant (SQ)**

``` python
# For loading or creating smoothquant scale.
act_scale_directory = "./act_scales"
if not os.path.exists(act_scale_directory):
os.makedirs(act_scale_directory)

if qcfg["act_scale_path"] is not None:
act_scales = torch.load(qcfg["act_scale_path"], map_location="cpu")
else:
logger.info("Generate activation scales")
if qcfg["large_model"]:
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
else:
act_scales = get_act_scales(model, dq_dataloader, qcfg)
scale_file = f"{act_scale_directory}/{qcfg['model'].replace('/', '-')}" + ".pt"
torch.save(act_scales, scale_file)
```

**4. Prepare the quantized model and attach activation scales** to quantized modules

```python
qmodel_prep(
model,
dq_dataloader,
qcfg,
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
use_dynamo=use_dynamo,
dev=dev,
save_fname='test'
)

dq_llm(model, act_scales, qcfg)
```

**5. Perform direct quantization** by calibrating quantizers (clip_vals)

``` python
if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU(qcfg, model, calibration_dataset)
else:
model.to("cuda:0")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
total=qcfg["qmodel_calibration_new"],
)
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
data_mb = prepare_input(model.device, data_mb)
with patch_torch_bmm(qcfg):
model(**data_mb)

logger.info(f"Saving quantized model and tokenizer to {output_dir}")
model.save_pretrained(output_dir, use_safetensors=True)
tokenizer.save_pretrained(output_dir)
```

**6. Check perplexity** (simple method to evaluate the model quality)

``` python
if fms_mo_args.eval_ppl:
logger.info(f"Model for evaluation: {model}")
if qcfg["large_model"]:
eval_llm_1GPU(qcfg, model, test_dataset)
else:
model.to(torch.device("cuda:0"))
n_samples = int(test_dataset.input_ids.shape[1] / block_size)
evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples)
ppl = evaluator.evaluate(model, block_size=block_size)
logger.info(f"Model perplexity: {ppl}")
logger.info("-" * 50)
logger.info("Finished evaluation")
```
106 changes: 106 additions & 0 deletions examples/MX/ffn_tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Third Party
Comment thread
BrandonGroth marked this conversation as resolved.
Outdated

# from mx import Linear as Linear_mx # Need to amend mx's Linear class
Comment thread
BrandonGroth marked this conversation as resolved.
Outdated
# Third Party
import numpy as np
import torch
import torch.nn.functional as F


class ResidualMLP(torch.nn.Module):
def __init__(self, hidden_size, device="cuda"):
Comment thread
BrandonGroth marked this conversation as resolved.
super(ResidualMLP, self).__init__()

self.layernorm = torch.nn.LayerNorm(hidden_size, device=device)
self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device)
self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device)
self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device)
# add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped

def forward(self, inputs):
norm_outputs = self.layernorm(inputs)

# MLP
proj_outputs = self.dense_4h(norm_outputs)
proj_outputs = F.gelu(proj_outputs)
mlp_outputs = self.dense_h(proj_outputs)
mlp_outputs = self.dummy(mlp_outputs)

# Residual Connection
outputs = inputs + mlp_outputs

return outputs


if __name__ == "__main__":
# Add config arguments
# parser = argparse.ArgumentParser()
# parser.add_argument("--hidden_size", default=128)
# parser.add_argument("--device", default='cuda')
# args = parser.parse_args()
# Standard
from functools import partial

# Third Party
from mx import MxSpecs
from tabulate import tabulate

# Local
from fms_mo import qconfig_init, qmodel_prep

x = np.random.randn(16, 128)
x = torch.tensor(x, dtype=torch.float32, device="cuda")
results = {"dtype": [], "output[0, :5]": [], "||ref - out_dtype||_2": []}

# --- Test 0. Run MLP as is
mlp = ResidualMLP(128)
# mlp.to("cuda")
with torch.no_grad():
out = mlp(x)
results["dtype"].append("fp32")
results["output[0, :5]"].append(out[0, :5].tolist())
results["||ref - out_dtype||_2"].append("-")
print(mlp)

# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
qcfg = qconfig_init()
qcfg["nbits_a"] = 8
qcfg["nbits_w"] = 8
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fms_int8")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
# print(model)

qcfg["nbits_a"] = 4
qcfg["nbits_w"] = 4
mlp = ResidualMLP(128)
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fms_int4")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

# --- Test 2. now change mapping to MX
# NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode,
# qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically

for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
mlp = ResidualMLP(128) # fresh model
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append(f"mx{dtype_to_test}")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)

print(tabulate(results, headers="keys"))

print("DONE!")
63 changes: 36 additions & 27 deletions fms_mo/dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
)
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU

# from fms_mo.utils.import_utils import available_packages
Comment thread
BrandonGroth marked this conversation as resolved.
Outdated
from fms_mo.utils.utils import patch_torch_bmm, prepare_input

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -158,21 +160,21 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
config_quantize_smooth_layers(qcfg)

if any(x != 32 for x in attn_bits):
logger.info("Quantize attention bmms or kvcache, use dynamo for prep")
logger.info("Quantize attention bmms or kvcache, will use dynamo for prep")
use_layer_name_pattern_matching = False
qcfg["qlayer_name_pattern"] = []
assert (
qcfg["qlayer_name_pattern"] == []
), "ensure nothing in qlayer_name_pattern when use dynamo"
use_dynamo = True
else:
logger.info("Do not quantize attention bmms")
logger.info("Attention bmms will not be quantized.")
use_layer_name_pattern_matching = True
use_dynamo = False

qcfg["seq_len"] = block_size
qcfg["model"] = model_args.model_name_or_path
qcfg["smoothq"] = True
qcfg["smoothq"] = "mx_specs" not in qcfg # NOTE no SQ for mx for now
qcfg["plotsvg"] = False

calibration_dataset = load_from_disk(data_args.training_data_path)
Expand All @@ -186,27 +188,31 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
)

# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
if qcfg.get("act_scale_path", None):
# user provided a scale file (or a dir)
scale_file_or_dir = Path(qcfg["act_scale_path"])
if scale_file_or_dir.is_dir():
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
elif scale_file_or_dir.is_file():
scale_file = scale_file_or_dir

if not scale_file.parent.exists():
scale_file.parent.mkdir(exist_ok=False)

if scale_file.exists():
act_scales = torch.load(scale_file, map_location=getattr(model, "device", dev))
else:
logger.info("Generate activation scales")
if qcfg["large_model"]:
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
if qcfg["smoothq"]:
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
if qcfg.get("act_scale_path", None):
# user provided a scale file (or a dir)
scale_file_or_dir = Path(qcfg["act_scale_path"])
if scale_file_or_dir.is_dir():
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
elif scale_file_or_dir.is_file():
scale_file = scale_file_or_dir

if not scale_file.parent.exists():
scale_file.parent.mkdir(exist_ok=False)

if scale_file.exists():
act_scales = torch.load(
scale_file, map_location=getattr(model, "device", dev)
)
else:
act_scales = get_act_scales(model, dq_dataloader, qcfg)
torch.save(act_scales, scale_file)
logger.info("Generate activation scales")
if qcfg["large_model"]:
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
else:
act_scales = get_act_scales(model, dq_dataloader, qcfg)
torch.save(act_scales, scale_file)

qmodel_prep(
model,
dq_dataloader,
Expand All @@ -217,10 +223,13 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
save_fname="dq",
)
logger.info(f"Quantized model {model}")
logger.info("Starting to apply smooth scale")
dq_llm(model, act_scales, qcfg)
logger.info("Finished applying smooth scale")
logger.info("==" * 20)

if qcfg["smoothq"]:
logger.info("Starting to apply smooth scale")
dq_llm(model, act_scales, qcfg)
logger.info("Finished applying smooth scale")
logger.info("==" * 20)

if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
Expand Down
2 changes: 1 addition & 1 deletion fms_mo/fx/dynamo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def call_seq_hook(mod, *_args, **_kwargs):
# b) qbmm creation and attaching to model
if qcfg.get("QBmm"): # see Note 4
# Local
from fms_mo.modules import QBmm
QBmm = qcfg["mapping"]["matmul_or_bmm"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be lower case variable name?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a class pointer disguised as a variable and used as QBmm(args) below. I think it is fine as is.


qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
"which2patch_contextmanager"
Expand Down
Loading
Loading