-
Notifications
You must be signed in to change notification settings - Fork 20
feat: mx integration #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
chichun-charlie-liu
merged 47 commits into
foundation-model-stack:main
from
chichun-charlie-liu:mx_impl
May 28, 2025
Merged
feat: mx integration #110
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
53f9857
toy exam under examples/MX
chichun-charlie-liu a59010e
include mx Linear class directly for convenience
chichun-charlie-liu b6b1d4f
minor clean up
chichun-charlie-liu 027eb2f
[pyproject] Added MX optional-dependency
BrandonGroth 02a9c2d
[qcfg] Changed mapping init and added mx path in qconfig_init
BrandonGroth 2bd90e9
[QLinear] Added QLinearMX stub
BrandonGroth cc681ee
[qcfg] Changed qcfg[mapping] structure
BrandonGroth 62962a3
[mx] Updates to ffn_tmp.py
BrandonGroth a79ebbf
clean-up, now simply use qa_mode and qw_mode to trigger mx
chichun-charlie-liu 8bf30ac
first mx test with "dq" using qa_/qw_mode=mx_fp8_e4m3
chichun-charlie-liu ac03469
add QbmmMX support
chichun-charlie-liu 7f6ea36
draft example readme for mx
chichun-charlie-liu 613e523
adjust mapping def, allow qbmm to use non-mx while linear is using mx…
chichun-charlie-liu 9ea77da
bug fix, QBmmMX init was calling QBmmMX init which tries to get a non…
chichun-charlie-liu 91c20cd
feat: Added mx_specs to qconfig_init and check_config
BrandonGroth 8b1d026
feat: Added skippable mx_specs unit tests
BrandonGroth a161b9a
fix: Added QLinearMX and QBmmMX to quantized_modules
BrandonGroth 9ffa0b1
chore: format qconfig_utils
BrandonGroth a9a57b3
test: minor fixes to test_mx and supporting functions
BrandonGroth 764d7bc
Merge pull request #1 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu 45dd501
Merge branch 'main' into mx_impl
chichun-charlie-liu 9bab835
lint and ruff fix
chichun-charlie-liu 40a5c0c
chore: Added license to ffn_tmp.py and other minor fixes
BrandonGroth 9ba1a42
feat: Added install_patches.py for fms_mo patch installs
BrandonGroth ac0d2ed
feat: Added patches directory and microxcaling.patch
BrandonGroth 9f3713e
feat: Added README.md for patching instructions
BrandonGroth 352e77d
doc: Added instructions to create a patch file
BrandonGroth 303454c
Merge pull request #2 from BrandonGroth/mx_impl_patch
chichun-charlie-liu d7dfc27
rename mx example py, markdown update is still WIP
chichun-charlie-liu 3446413
markdown update is still WIP
chichun-charlie-liu dd17104
fix bug in config check as QBmmMX is optional
chichun-charlie-liu 6e4431c
feat: Added qmodel_prep mx_specs hook for default config with MX sett…
BrandonGroth 83d2b06
test: Added MX test for default config + qmodel_prep
BrandonGroth d05401f
fix: Added MX vocabulary for spellcheck util
BrandonGroth 7daae3b
chore: Fixed MX vocab errors
BrandonGroth 5694a5a
fix: Changed check_config check for MX mapping
BrandonGroth 308e5e2
Merge pull request #3 from BrandonGroth/mx_impl_brandon
chichun-charlie-liu e5eb0ed
doc: Updated MX examples/readme for patch install
BrandonGroth ab060ca
update mx readme.md, fix simple_mx_exam.py together with a few minor …
chichun-charlie-liu 9b9b777
Merge branch 'main' into mx_impl
chichun-charlie-liu 2cb9aca
add a microscaling ref in readme
chichun-charlie-liu 7e7a0fa
fix: Patched qconfig_init and get_unwanted_defaults QBmm
BrandonGroth f220b17
fix: Updated spellcheck list for examples/mx readme
BrandonGroth 026a3ea
fix: Fixed additional Qbmm mapping in test_saveconfig
BrandonGroth b00287f
minor updates on README
chichun-charlie-liu b5f0eaa
chore: Cleanup of extra comments
BrandonGroth 4e0de3f
adjust gitignore to allow changes in py and md files
chichun-charlie-liu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # `microscaling` Examples Using a Toy Model and Direct Quantization (DQ) | ||
| Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits. More details about microscaling can be found in [this OCP document](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). | ||
|
|
||
| Here, we provide two simple examples of using MX format in `fms-mo`. | ||
|
|
||
| > [!NOTE] | ||
| It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected. | ||
|
|
||
|
|
||
| ## Requirements | ||
| - [FMS Model Optimizer requirements](../../README.md#requirements) | ||
| - Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git). | ||
| > [!TIP] | ||
| > `FMS-Model-Optimizer` and `microxcaling` have clashing dependency requirements for `PyTorch` packages. We have created a patching solution to resolve this, run the following in command line: | ||
| ``` bash | ||
| python3 ../install_patches.py | ||
| ``` | ||
| This patching file will either download the repo for you, or look for an already installed version in `$HOME` or the current working directory, then install the patch. | ||
| For more information, see `patches/README.md`. | ||
|
|
||
| ## QuickStart | ||
|
|
||
| ### Example 1 | ||
| First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow | ||
|
|
||
| ```bash | ||
| >>> python simple_mx_example.py | ||
| ``` | ||
|
|
||
| Comparison between different formats, including the first 3 elements from output tensors and the norm compared to FP32 reference, is shown below. | ||
|
|
||
| | dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|<sub>2</sub> | | ||
| |:-----------|---------------:|---------------:|---------------:|------------------------:| | ||
| | fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 | | ||
| | fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 | | ||
| | fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 | | ||
| | mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 | | ||
| | mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 | | ||
| | mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 | | ||
| | mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 | | ||
| | mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 | | ||
|
|
||
|
|
||
| ### Example 2 | ||
| The second example is the same as the [DQ example](../DQ_SQ/README.md), except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We only demonstrate `mxfp8` and `mxfp4` here, but MXINT8, MXFP8, MXFP6, MXFP4 are also available for weights, activations, and/or KV-cache. | ||
|
|
||
| **1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below. | ||
|
|
||
| ```python | ||
| from transformers import AutoTokenizer | ||
| from fms_mo.utils.calib_data import get_tokenized_data | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True) | ||
| num_samples = 128 | ||
| seq_len = 2048 | ||
| get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data') | ||
| ``` | ||
| > [!NOTE] | ||
| > - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting. | ||
| > - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test` | ||
| > - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead | ||
|
|
||
| **2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below. | ||
| ```bash | ||
| python -m fms_mo.run_quant \ | ||
| --model_name_or_path "meta-llama/Meta-Llama-3-8B" \ | ||
| --training_data_path data_train \ | ||
| --test_data_path data_test \ | ||
| --torch_dtype "float16" \ | ||
| --quant_method dq \ | ||
| --nbits_w 8 \ | ||
| --nbits_a 8 \ | ||
| --nbits_kvcache 32 \ | ||
| --qa_mode "mx_fp8_e4m3"\ | ||
| --qw_mode "mx_fp8_e4m3" \ | ||
| --output_dir "dq_test" \ | ||
| --eval_ppl | ||
| ``` | ||
| > [!NOTE] | ||
| > To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples. | ||
|
|
||
| **3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`. | ||
|
|
||
|
|
||
| ## Example Test Results | ||
| The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below: | ||
|
|
||
| | Model |Type |QA |QW |DQ |SQ |Perplexity| | ||
| |:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:| | ||
| |`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 | | ||
| | |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 | | ||
| | |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 | | ||
| | |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 | | ||
|
|
||
|
|
||
| > [!NOTE] | ||
| > SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details. | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| # Copyright The FMS Model Optimizer Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Simple example using a toy model to demo how to trigger mx in fms-mo.""" | ||
|
|
||
| # Third Party | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| class ResidualMLP(torch.nn.Module): | ||
| def __init__(self, hidden_size, device="cuda"): | ||
| super(ResidualMLP, self).__init__() | ||
|
|
||
| self.layernorm = torch.nn.LayerNorm(hidden_size, device=device) | ||
| self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device) | ||
| self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device) | ||
| self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device) | ||
| # add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped | ||
|
|
||
| def forward(self, inputs): | ||
| norm_outputs = self.layernorm(inputs) | ||
|
|
||
| # MLP | ||
| proj_outputs = self.dense_4h(norm_outputs) | ||
| proj_outputs = F.gelu(proj_outputs) | ||
| mlp_outputs = self.dense_h(proj_outputs) | ||
| mlp_outputs = self.dummy(mlp_outputs) | ||
|
|
||
| # Residual Connection | ||
| outputs = inputs + mlp_outputs | ||
|
|
||
| return outputs | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # Third Party | ||
| from tabulate import tabulate | ||
|
|
||
| # Local | ||
| from fms_mo import qconfig_init, qmodel_prep | ||
|
|
||
| HIDDEN_DIM = 128 | ||
| x = np.random.randn(16, HIDDEN_DIM) | ||
| x = torch.tensor(x, dtype=torch.float32, device="cuda") | ||
| results = { | ||
| "dtype": [], | ||
| "output[0, 0]": [], | ||
| "output[0, 1]": [], | ||
| "output[0, 2]": [], | ||
| "||ref - out_dtype||_2": [], | ||
| } | ||
|
|
||
| # --- Test 0. Run MLP as is | ||
| model = ResidualMLP(HIDDEN_DIM) | ||
| with torch.no_grad(): | ||
| out = model(x) | ||
| results["dtype"].append("fp32") | ||
| results["output[0, 0]"].append(out[0, 0]) | ||
| results["output[0, 1]"].append(out[0, 1]) | ||
| results["output[0, 2]"].append(out[0, 2]) | ||
| results["||ref - out_dtype||_2"].append(0) | ||
| print(model) | ||
|
|
||
| # --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear | ||
| qcfg = qconfig_init() | ||
| qcfg["nbits_a"] = 8 | ||
| qcfg["nbits_w"] = 8 | ||
| qmodel_prep(model, x, qcfg) | ||
| with torch.no_grad(): | ||
| out_dtype = model(x) | ||
| results["dtype"].append("fmsmo_int8") | ||
| results["output[0, 0]"].append(out_dtype[0, 0]) | ||
| results["output[0, 1]"].append(out_dtype[0, 1]) | ||
| results["output[0, 2]"].append(out_dtype[0, 2]) | ||
| results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item()) | ||
| print(model) | ||
|
|
||
| qcfg["nbits_a"] = 4 | ||
| qcfg["nbits_w"] = 4 | ||
| model = ResidualMLP(HIDDEN_DIM) | ||
| qmodel_prep(model, x, qcfg) | ||
| with torch.no_grad(): | ||
| out_dtype = model(x) | ||
| results["dtype"].append("fmsmo_int4") | ||
| results["output[0, 0]"].append(out_dtype[0, 0]) | ||
| results["output[0, 1]"].append(out_dtype[0, 1]) | ||
| results["output[0, 2]"].append(out_dtype[0, 2]) | ||
| results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item()) | ||
| print(model) | ||
|
|
||
| # --- Test 2. now change mapping to MX | ||
| # NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode, | ||
| # qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically | ||
|
|
||
| for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]: | ||
| qcfg["qw_mode"] = f"mx_{dtype_to_test}" | ||
| qcfg["qa_mode"] = f"mx_{dtype_to_test}" | ||
| model = ResidualMLP(HIDDEN_DIM) # fresh model | ||
| qmodel_prep(model, x, qcfg) | ||
| with torch.no_grad(): | ||
| out_dtype = model(x) | ||
| results["dtype"].append(f"mx{dtype_to_test}") | ||
| results["output[0, 0]"].append(out_dtype[0, 0]) | ||
| results["output[0, 1]"].append(out_dtype[0, 1]) | ||
| results["output[0, 2]"].append(out_dtype[0, 2]) | ||
| results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item()) | ||
| print(model) | ||
|
|
||
| print(tabulate(results, headers="keys", tablefmt="pipe", floatfmt=".4f")) | ||
|
|
||
| print("DONE!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1218,7 +1218,7 @@ def call_seq_hook(mod, *_args, **_kwargs): | |
| # b) qbmm creation and attaching to model | ||
| if qcfg.get("QBmm"): # see Note 4 | ||
| # Local | ||
| from fms_mo.modules import QBmm | ||
| QBmm = qcfg["mapping"]["matmul_or_bmm"] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be lower case variable name?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a class pointer disguised as a variable and used as |
||
|
|
||
| qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][ | ||
| "which2patch_contextmanager" | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.