diff --git a/.gitignore b/.gitignore index 7598aa41..7c2ef35d 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,9 @@ fms_mo.log data*_train/ data*_test/ act_scales/ -examples/ +examples/**/*.json +examples/**/*.safetensors +examples/**/*.log +examples/**/*.sh +examples/**/*.pt +examples/**/*.arrow diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index 205ec924..50ba3276 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -27,8 +27,10 @@ dequantization dq DQ dev +dtype eval fms +fmsmo fp FP FP8Arguments @@ -125,3 +127,13 @@ venv vllm xs zp +microxcaling +Microscaling +microscaling +MX +mx +MXINT +mxint +MXFP +mxfp +OCP diff --git a/examples/MX/README.md b/examples/MX/README.md new file mode 100644 index 00000000..9dd094b7 --- /dev/null +++ b/examples/MX/README.md @@ -0,0 +1,98 @@ +# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ) +Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits. More details about microscaling can be found in [this OCP document](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). + +Here, we provide two simple examples of using MX format in `fms-mo`. + +> [!NOTE] +It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected. + + +## Requirements +- [FMS Model Optimizer requirements](../../README.md#requirements) +- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git). +> [!TIP] +> `FMS-Model-Optimizer` and `microxcaling` have clashing dependency requirements for `PyTorch` packages. We have created a patching solution to resolve this, run the following in command line: +``` bash +python3 ../install_patches.py +``` +This patching file will either download the repo for you, or look for an already installed version in `$HOME` or the current working directory, then install the patch. +For more information, see `patches/README.md`. + +## QuickStart + +### Example 1 +First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow + +```bash +>>> python simple_mx_example.py +``` + +Comparison between different formats, including the first 3 elements from output tensors and the norm compared to FP32 reference, is shown below. + +| dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|2 | +|:-----------|---------------:|---------------:|---------------:|------------------------:| +| fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 | +| fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 | +| fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 | +| mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 | +| mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 | +| mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 | +| mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 | +| mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 | + + +### Example 2 +The second example is the same as the [DQ example](../DQ_SQ/README.md), except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We only demonstrate `mxfp8` and `mxfp4` here, but MXINT8, MXFP8, MXFP6, MXFP4 are also available for weights, activations, and/or KV-cache. + +**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below. + +```python +from transformers import AutoTokenizer +from fms_mo.utils.calib_data import get_tokenized_data + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True) +num_samples = 128 +seq_len = 2048 +get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data') +``` +> [!NOTE] +> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting. +> - Tokenized data will be saved in `_train` and `_test` +> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead + +**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below. +```bash +python -m fms_mo.run_quant \ + --model_name_or_path "meta-llama/Meta-Llama-3-8B" \ + --training_data_path data_train \ + --test_data_path data_test \ + --torch_dtype "float16" \ + --quant_method dq \ + --nbits_w 8 \ + --nbits_a 8 \ + --nbits_kvcache 32 \ + --qa_mode "mx_fp8_e4m3"\ + --qw_mode "mx_fp8_e4m3" \ + --output_dir "dq_test" \ + --eval_ppl +``` +> [!NOTE] +> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples. + +**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`. + + +## Example Test Results +The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below: + +| Model |Type |QA |QW |DQ |SQ |Perplexity| +|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:| +|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 | +| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 | +| |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 | +| |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 | + + +> [!NOTE] +> SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details. + diff --git a/examples/MX/simple_mx_example.py b/examples/MX/simple_mx_example.py new file mode 100644 index 00000000..ebf4a44d --- /dev/null +++ b/examples/MX/simple_mx_example.py @@ -0,0 +1,123 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple example using a toy model to demo how to trigger mx in fms-mo.""" + +# Third Party +import numpy as np +import torch +import torch.nn.functional as F + + +class ResidualMLP(torch.nn.Module): + def __init__(self, hidden_size, device="cuda"): + super(ResidualMLP, self).__init__() + + self.layernorm = torch.nn.LayerNorm(hidden_size, device=device) + self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device) + self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device) + self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device) + # add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped + + def forward(self, inputs): + norm_outputs = self.layernorm(inputs) + + # MLP + proj_outputs = self.dense_4h(norm_outputs) + proj_outputs = F.gelu(proj_outputs) + mlp_outputs = self.dense_h(proj_outputs) + mlp_outputs = self.dummy(mlp_outputs) + + # Residual Connection + outputs = inputs + mlp_outputs + + return outputs + + +if __name__ == "__main__": + # Third Party + from tabulate import tabulate + + # Local + from fms_mo import qconfig_init, qmodel_prep + + HIDDEN_DIM = 128 + x = np.random.randn(16, HIDDEN_DIM) + x = torch.tensor(x, dtype=torch.float32, device="cuda") + results = { + "dtype": [], + "output[0, 0]": [], + "output[0, 1]": [], + "output[0, 2]": [], + "||ref - out_dtype||_2": [], + } + + # --- Test 0. Run MLP as is + model = ResidualMLP(HIDDEN_DIM) + with torch.no_grad(): + out = model(x) + results["dtype"].append("fp32") + results["output[0, 0]"].append(out[0, 0]) + results["output[0, 1]"].append(out[0, 1]) + results["output[0, 2]"].append(out[0, 2]) + results["||ref - out_dtype||_2"].append(0) + print(model) + + # --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear + qcfg = qconfig_init() + qcfg["nbits_a"] = 8 + qcfg["nbits_w"] = 8 + qmodel_prep(model, x, qcfg) + with torch.no_grad(): + out_dtype = model(x) + results["dtype"].append("fmsmo_int8") + results["output[0, 0]"].append(out_dtype[0, 0]) + results["output[0, 1]"].append(out_dtype[0, 1]) + results["output[0, 2]"].append(out_dtype[0, 2]) + results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item()) + print(model) + + qcfg["nbits_a"] = 4 + qcfg["nbits_w"] = 4 + model = ResidualMLP(HIDDEN_DIM) + qmodel_prep(model, x, qcfg) + with torch.no_grad(): + out_dtype = model(x) + results["dtype"].append("fmsmo_int4") + results["output[0, 0]"].append(out_dtype[0, 0]) + results["output[0, 1]"].append(out_dtype[0, 1]) + results["output[0, 2]"].append(out_dtype[0, 2]) + results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item()) + print(model) + + # --- Test 2. now change mapping to MX + # NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode, + # qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically + + for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]: + qcfg["qw_mode"] = f"mx_{dtype_to_test}" + qcfg["qa_mode"] = f"mx_{dtype_to_test}" + model = ResidualMLP(HIDDEN_DIM) # fresh model + qmodel_prep(model, x, qcfg) + with torch.no_grad(): + out_dtype = model(x) + results["dtype"].append(f"mx{dtype_to_test}") + results["output[0, 0]"].append(out_dtype[0, 0]) + results["output[0, 1]"].append(out_dtype[0, 1]) + results["output[0, 2]"].append(out_dtype[0, 2]) + results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item()) + print(model) + + print(tabulate(results, headers="keys", tablefmt="pipe", floatfmt=".4f")) + + print("DONE!") diff --git a/fms_mo/dq.py b/fms_mo/dq.py index 6763a6d5..ff2720d4 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -159,7 +159,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): config_quantize_smooth_layers(qcfg) if any(x != 32 for x in attn_bits): - logger.info("Quantize attention bmms or kvcache, use dynamo for prep") + logger.info("Quantize attention bmms or kvcache, will use dynamo for prep") use_layer_name_pattern_matching = False qcfg["qlayer_name_pattern"] = [] assert ( @@ -167,13 +167,13 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ), "ensure nothing in qlayer_name_pattern when use dynamo" use_dynamo = True else: - logger.info("Do not quantize attention bmms") + logger.info("Attention bmms will not be quantized.") use_layer_name_pattern_matching = True use_dynamo = False qcfg["seq_len"] = block_size qcfg["model"] = model_args.model_name_or_path - qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 + qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg qcfg["plotsvg"] = False calibration_dataset = load_from_disk(data_args.training_data_path) @@ -187,31 +187,32 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") - if qcfg.get("act_scale_path", None): - # user provided a scale file (or a dir) - scale_file_or_dir = Path(qcfg["act_scale_path"]) - if scale_file_or_dir.is_dir(): - scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt" - elif scale_file_or_dir.is_file(): - scale_file = scale_file_or_dir - - if not scale_file.parent.exists(): - scale_file.parent.mkdir(exist_ok=False) - - if scale_file.exists(): - act_scales = torch.load( - scale_file, - map_location=getattr(model, "device", dev), - weights_only=True, - ) - else: - logger.info("Generate activation scales") - if qcfg["large_model"]: - act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg) + if qcfg["smoothq"]: + scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") + if qcfg.get("act_scale_path", None): + # user provided a scale file (or a dir) + scale_file_or_dir = Path(qcfg["act_scale_path"]) + if scale_file_or_dir.is_dir(): + scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt" + elif scale_file_or_dir.is_file(): + scale_file = scale_file_or_dir + + if not scale_file.parent.exists(): + scale_file.parent.mkdir(exist_ok=False) + + if scale_file.exists(): + act_scales = torch.load( + scale_file, map_location=getattr(model, "device", dev) + ) + else: - act_scales = get_act_scales(model, dq_dataloader, qcfg) - torch.save(act_scales, scale_file) + logger.info("Generate activation scales") + if qcfg["large_model"]: + act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg) + else: + act_scales = get_act_scales(model, dq_dataloader, qcfg) + torch.save(act_scales, scale_file) + qmodel_prep( model, dq_dataloader, diff --git a/fms_mo/fx/dynamo_utils.py b/fms_mo/fx/dynamo_utils.py index de684da4..71bc069b 100644 --- a/fms_mo/fx/dynamo_utils.py +++ b/fms_mo/fx/dynamo_utils.py @@ -1218,7 +1218,7 @@ def call_seq_hook(mod, *_args, **_kwargs): # b) qbmm creation and attaching to model if qcfg.get("QBmm"): # see Note 4 # Local - from fms_mo.modules import QBmm + QBmm = qcfg["mapping"]["matmul_or_bmm"] qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][ "which2patch_contextmanager" diff --git a/fms_mo/fx/utils.py b/fms_mo/fx/utils.py index 295da41d..302d9924 100644 --- a/fms_mo/fx/utils.py +++ b/fms_mo/fx/utils.py @@ -172,11 +172,6 @@ def lower_qmodel_to_ext_kernels( QLinearExllamaV2, ) - qclass_accepted = [] - for map_dict in qcfg["mapping"].values(): - qclass_accepted.append(map_dict["to"]) - qclass_accepted.append(map_dict.get("otherwise", None)) - mod2swap = { n: m for n, m in mod.named_modules() @@ -498,7 +493,6 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False): ) ), ) - if show_details: logger_or_print(df_summary_weights.to_markdown()) diff --git a/fms_mo/modules/bmm.py b/fms_mo/modules/bmm.py index aa16deff..ff4f026f 100644 --- a/fms_mo/modules/bmm.py +++ b/fms_mo/modules/bmm.py @@ -20,6 +20,7 @@ # Local from fms_mo.quant.quantizers import Qbypass, Qdynamic, get_activation_quantizer +from fms_mo.utils.import_utils import available_packages class QBmm(nn.Module): @@ -245,6 +246,280 @@ def __repr__(self): ) +if available_packages["mx"]: + # pylint: disable = import-error + # Third Party + from mx.elemwise_ops import quantize_elemwise_op + from mx.matmul_precision import set_matmul_precision + from mx.mx_ops import quantize_mx_op + from mx.specs import apply_mx_specs, get_backwards_mx_specs, mx_assert_test + import numpy as np + + class MatMulFunction(torch.autograd.Function): + """Modified from mx.matmul.MatMulFunc. Try to control m1/m2 in a compatible way with fms-mo. + Matches functionality of torch.matmul. Attempts to broadcast + outmost dims if in1 and in2 have the same number of dims. + in1: (..., out_rows, features) + in2: (..., features, out_cols) + out: (..., out_rows, out_cols) + Otherwise, it expects the following shapes: + in1: (..., out_rows, features) + in2: (features, out_cols) + out: (..., out_rows, out_cols) + """ + + @staticmethod + def forward(ctx, in1, in2, bias, mx_specs, qm1_mode, qm2_mode, replaceBmm): + ctx.qm1_mode = qm1_mode + ctx.qm2_mode = qm2_mode + ctx.replaceBmm = replaceBmm + + qin1_elem_format = qm1_mode + qin2_elem_format = qm2_mode + + bf_in1 = quantize_elemwise_op( + in1, mx_specs=mx_specs, round=mx_specs["round_output"] + ) + bf_in2 = quantize_elemwise_op( + in2, mx_specs=mx_specs, round=mx_specs["round_output"] + ) + + if bias is not None: + bf_bias = quantize_elemwise_op( + bias, mx_specs=mx_specs, round=mx_specs["round_weight"] + ) + + ctx.bias_shape = list(bias.shape) + else: + bf_bias = None + ctx.bias_shape = None + + if mx_specs["quantize_backprop"]: + ctx.save_for_backward(bf_in1, bf_in2) + else: + ctx.save_for_backward(in1, in2) + + # quantize along the dot product dimension + qin1 = quantize_mx_op( + bf_in1, + mx_specs, + elem_format=qin1_elem_format, + axes=[-1], + round=mx_specs["round_mx_output"], + ) + qin2 = quantize_mx_op( + bf_in2, + mx_specs, + elem_format=qin2_elem_format, + axes=[-2], + round=mx_specs["round_mx_output"], + ) + + with set_matmul_precision(qin1, qin2, qin1_elem_format, qin2_elem_format): + # If the matmul Op to be replaced is torch.bmm => call torch.matmul(). Or if the Op + # to be replaced is torch.matmul => call torch.bmm instead. (or inf loop will occur) + if replaceBmm: + out = torch.matmul(qin1, qin2) + else: + # BMM only take 3d tensors, reshape 4d if needed + org_batch_header = qin1.shape[:2] + m1_reshape_to = [-1, qin1.shape[-2], qin1.shape[-1]] + m2_reshape_to = [-1, qin2.shape[-2], qin2.shape[-1]] + if len(qin1.shape) > 3: + qin1 = qin1.reshape(m1_reshape_to) + if len(qin2.shape) > 3: + qin2 = qin2.reshape(m2_reshape_to) + out = torch.bmm(qin1, qin2) + out_reshape_to = [*org_batch_header, *out.shape[1:]] + out = out.reshape(out_reshape_to) + ctx.reshape_3d = [m1_reshape_to, m2_reshape_to] + + out = quantize_elemwise_op( + out, mx_specs=mx_specs, round=mx_specs["round_output"] + ) + + if bias is not None: + out = out + bf_bias + out = quantize_elemwise_op( + out, mx_specs=mx_specs, round=mx_specs["round_output"] + ) + + ctx.mx_specs = get_backwards_mx_specs(mx_specs) + return out + + @staticmethod + def backward(ctx, grad_out): + """ + For a matmul in "wa" mode, the fwd and bwd matmuls configs are: + FWD wt x act: w x a + BWD wt x grad: w x a + BWD act x grad: a x a <-- no mixed precision! + """ + qin1_elem_format = ctx.qm1_mode + qin2_elem_format = ctx.qm2_mode + + in1, in2 = ctx.saved_tensors + + grad_out = quantize_elemwise_op( + grad_out, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], + ) + + ##################################################### + # perform madtile operation for grad_in1, grad_in2 + ##################################################### + qin1 = quantize_mx_op( + in1, + ctx.mx_specs, + elem_format=qin1_elem_format, + axes=[-2], + round=ctx.mx_specs["round_mx_input_grad_input"], + ) + qin2 = quantize_mx_op( + in2, + ctx.mx_specs, + elem_format=qin2_elem_format, + axes=[-1], + round=ctx.mx_specs["round_mx_input_grad_input"], + ) + + # quantize along out_cols + qgrad_out1 = quantize_mx_op( + grad_out, + ctx.mx_specs, + elem_format=ctx.mx_specs["a_elem_format_bp_os"], + axes=[-1], + round=ctx.mx_specs["round_mx_grad_output_grad_input"], + ) + # quantize along out_rows + qgrad_out2 = quantize_mx_op( + grad_out, + ctx.mx_specs, + elem_format=ctx.mx_specs["a_elem_format_bp_os"], + axes=[-2], + round=ctx.mx_specs["round_mx_grad_output_grad_input"], + ) + + m1_org_shape = qin1.shape + m2_org_shape = qin2.shape + out_reshape_to = [-1, grad_out.shape[-2], grad_out.shape[-1]] + if ctx.replaceBmm: + # simple case, no reshape needed + torch_matmul = torch.matmul + else: + torch_matmul = torch.bmm + m1_reshape_to, m2_reshape_to = ctx.reshape_3d + # reshape every tensors to 3d so that bmm can work on them + # NOTE input shapes could be 3d 4d or even 2d, focus on dim -1 and -2 + # m1 = [..., out_r, feat] + # m2 = [..., feat, out_c] + # out and grad_out after bmm should be [..., out_r, out_c] + # In backward + # grad_in1 = grad_out@m2.t(-1, -2) = [..., out_r, out_c]@[..., out_c, feat] + # = [..., out_r, feat], i.e. shape of m1 + # grad_in2 = m1.t(-1, -2)@grad_out = [..., feat, out_r]@[..., out_r, out_c] + # = [..., feat, out_c], i.e. shape of m2 + qin1 = qin1.reshape(m1_reshape_to) + qin2 = qin2.reshape(m2_reshape_to) + qgrad_out1 = qgrad_out1.reshape(out_reshape_to) + qgrad_out2 = qgrad_out1.reshape(out_reshape_to) + + # compute grad_in1 and grad_in2 + with set_matmul_precision( + qgrad_out1, qin2, ctx.mx_specs["a_elem_format_bp_os"], qin2_elem_format + ): + grad_in1 = torch_matmul(qgrad_out1, qin2.transpose(-1, -2)) + grad_in1 = grad_in1.to(m1_org_shape) + + with set_matmul_precision( + qin1, qgrad_out2, qin1_elem_format, ctx.mx_specs["a_elem_format_bp_os"] + ): + grad_in2 = torch_matmul(qin1.transpose(-1, -2), qgrad_out2) + grad_in2 = grad_in2.to(m2_org_shape) + + # element-wise quantize for grad_in1 and grad_in2 + grad_in1 = quantize_elemwise_op( + grad_in1, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], + ) + grad_in2 = quantize_elemwise_op( + grad_in2, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], + ) + + ##################################################### + # Compute grad_bias + ##################################################### + if ctx.bias_shape is None: + grad_bias = None + else: + inner_size = grad_out.shape[-1] + assert np.prod(ctx.bias_shape) == inner_size + grad_bias = grad_out.reshape(-1, inner_size).sum(0) + grad_bias = grad_bias.reshape(ctx.bias_shape) + + grad_bias = quantize_elemwise_op( + grad_bias, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_weight"], + ) + + return (grad_in1, grad_in2, grad_bias, None, None, None, None) + + class QBmmMX(QBmm): + """Wrapper for MX. Modified based on mx.matmul.""" + + def __init__( + self, + **kwargs, + ): + # NOTE: 1. qm1_mode, qm2_mode will be saved by super().init() + # 2. "mx_" prefix in qmX_mode should stil be in place in kwargs. will remove the + # prefix when stored in class properties + mx_specs = kwargs.pop("mx_specs") + mx_assert_test(mx_specs) + self.mx_none = mx_specs is None + + self.mx_specs = apply_mx_specs(mx_specs) + # always init super as 32b, to avoid QBmm trying to get_w_quantizer of non-exist "mx_xx" + # quantizers. Actual bits is still controlled by MX (based on qm1_ and qm2_mode) + kwargs["num_bits_m1"] = 32 + kwargs["num_bits_m2"] = 32 + kwargs["nbits_kvcache"] = 32 + super().__init__(**kwargs) + + self.qm1_mode = self.qm1_mode.replace("mx_", "") + self.qm2_mode = self.qm2_mode.replace("mx_", "") + + def apply_mx_specs(self, mx_specs): + """Borrow from mx without changes.""" + mx_assert_test(mx_specs) + self.mx_none = mx_specs is None + self.mx_specs = apply_mx_specs(mx_specs) + + def forward(self, m1, m2, bias=None): + """Fwd/Bwd pass of the quantized bmm module using mx. Backward should work, too. + Can be used for addmm() is bias is provided. + """ + mx_specs = self.mx_specs + mx_assert_test(mx_specs) + mx_specs = apply_mx_specs(mx_specs) + + return MatMulFunction.apply( + m1, m2, bias, mx_specs, self.qm1_mode, self.qm2_mode, self.replaceBmm + ) + + def __repr__(self) -> str: + repr_str = ( + f"{self.__class__.__name__}(m1={self.qm1_mode},m2={self.qm2_mode}," + f"blk_size={self.mx_specs['block_size']})" + ) + return repr_str + + # ------------------------------------------------------------------------------ # ----- The following wrappers are for torch FX CPU lowering only (FBGEMM) ----- # ----- NOTE: do not use them directly in QAT, backward is not defined ----- @@ -473,11 +748,14 @@ def forward(self, m1: torch.Tensor, m2: torch.Tensor) -> torch.Tensor: return x.to(m1.dtype) +# KEEP THIS AT END OF FILE - classes must be declared QBmm_modules = ( QBmm, QMatmulDebug, QBmmINT8Deploy, ) +if available_packages["mx"]: + QBmm_modules += (QBmmMX,) def isinstance_qbmm(module): diff --git a/fms_mo/modules/conv.py b/fms_mo/modules/conv.py index 61268277..d231f9e2 100644 --- a/fms_mo/modules/conv.py +++ b/fms_mo/modules/conv.py @@ -1515,6 +1515,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output.to(input_dtype) +# KEEP THIS AT END OF FILE - classes must be declared QConv2d_modules = ( QConv2d, DetQConv2d, diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 7f14c17d..e8197598 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -1804,29 +1804,6 @@ def forward(self, x, force_cuda=False): "QLinearExv1WI4AF16 and QLinearExv2WI4AF16 wrappers will not be available." ) -QLinear_modules = ( - QLinear, - QLinearFPout, - QLinearDebug, - QLinearW4A32Debug, - QLinearINT8Deploy, - QLinearCublasI8I32NT, - QLinearCutlassI8I32NT, -) - - -def isinstance_qlinear(module): - """ - Checks if the given module is one of the available quantized linear classes. - - Args: - module (nn.Module): The module to check. - - Returns: - bool: True if the module is a quantized linear class, False otherwise. - """ - return isinstance(module, QLinear_modules) - class LinearFuncFPxFwdBwd(torch.autograd.Function): """Linear function using FP24 accumulation, experimental only. @@ -2001,3 +1978,147 @@ def extra_repr(self) -> str: f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, " f"trun_bits={self.trun_bits},fp8_dyn={self.fp8_dyn},chunk_size={self.chunk_size}" ) + + +if available_packages["mx"]: + # Third Party + # pylint: disable = import-error + from mx.elemwise_ops import quantize_elemwise_op + from mx.linear import linear as mx_linear + from mx.specs import apply_mx_specs, mx_assert_test + + # import mx # defaults to import all classes + + mx_specs_default = { + "w_elem_format": "fp8_e4m3", + "a_elem_format": "fp8_e4m3", + "block_size": 32, + "bfloat": 16, + "custom_cuda": True, + # For quantization-aware finetuning, do backward pass in FP32 + "quantize_backprop": False, + } + + class QLinearMX(torch.nn.Linear): + """Modified from mx.linear class. Only mildly changed init() and add extra_repr. + 1. Add **kwargs to receive extra (unused) params passed from qmodel_prep + 2. pass device to super.init, i.e. nn.Linear's + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + mx_specs=None, + name=None, + **kwargs, + ): + mx_assert_test(mx_specs) + self.mx_none = mx_specs is None + + self.name = name + self.prequantized_weights = False + self.mx_specs = apply_mx_specs(mx_specs) + super().__init__( + in_features, out_features, bias, device=kwargs.get("device", "cuda") + ) + + def apply_mx_specs(self, mx_specs): + """Unchanged.""" + mx_assert_test(mx_specs) + self.mx_none = mx_specs is None + self.mx_specs = apply_mx_specs(mx_specs) + + def append_name(self, postfix): + """Unchanged.""" + self.name += postfix + + def prequantize_weights(self): + """Unchanged.""" + # Can't prequantize if not using bfloat weights + if self.mx_none: + return + + assert ( + self.mx_specs["round"] == "even" + ), "Bfloat round should be 'even' for prequantizing weights." + assert ( + torch.cuda.is_bf16_supported() + ), "Current device does not support bfloat16" + assert self.mx_specs[ + "bfloat_subnorms" + ], "Bfloat_subnorms should be set to True for prequantizing weights." + assert ( + self.mx_specs["bfloat"] == 16 + ), "Only Bfloat16 is supported for prequantizing weights." + + with torch.no_grad(): + self.weight.data = quantize_elemwise_op( + self.weight.data, + mx_specs=self.mx_specs, + round=self.mx_specs["round_weight"], + ).to(torch.bfloat16) + + if self.bias is not None: + self.bias.data = quantize_elemwise_op( + self.bias.data, + mx_specs=self.mx_specs, + round=self.mx_specs["round_weight"], + ).to(torch.bfloat16) + + self.prequantized_weights = True + + def forward(self, inputs): + """Unchanged.""" + if self.mx_none: + return super().forward(inputs) + + if self.prequantized_weights: + assert ( + not self.training + ), "Cannot use prequantized weights when training!" + + return mx_linear( + input=inputs, + weight=self.weight, + bias=self.bias, + mx_specs=self.mx_specs, + prequantized_weights=self.prequantized_weights, + name=self.name, + ) + + def extra_repr(self) -> str: + repr_str = ( + f"in={self.in_features},out={self.out_features}," + f"w_fmt={self.mx_specs['w_elem_format']},a_fmt={self.mx_specs['a_elem_format']}," + f"blk_size={self.mx_specs['block_size']}" + ) + return repr_str + + +# KEEP THIS AT END OF FILE - classes must be declared +QLinear_modules = ( + QLinear, + QLinearFPout, + QLinearDebug, + QLinearW4A32Debug, + QLinearINT8Deploy, + QLinearCublasI8I32NT, + QLinearCutlassI8I32NT, +) +if available_packages["mx"]: + QLinear_modules += (QLinearMX,) + + +def isinstance_qlinear(module): + """ + Checks if the given module is one of the available quantized linear classes. + + Args: + module (nn.Module): The module to check. + + Returns: + bool: True if the module is a quantized linear class, False otherwise. + """ + return isinstance(module, QLinear_modules) diff --git a/fms_mo/modules/lstm.py b/fms_mo/modules/lstm.py index ff1418c6..2770a7e8 100644 --- a/fms_mo/modules/lstm.py +++ b/fms_mo/modules/lstm.py @@ -508,6 +508,7 @@ def __repr__(self): return expr +# KEEP THIS AT END OF FILE - classes must be declared QLSTM_modules = (QLSTM,) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index cc4ce7bc..42e40b79 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -28,7 +28,8 @@ from fms_mo.calib import qmodel_calib from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules from fms_mo.quant.quantizers import Qbypass -from fms_mo.utils.qconfig_utils import check_config, qconfig_save +from fms_mo.utils.import_utils import available_packages +from fms_mo.utils.qconfig_utils import check_config, qconfig_save, set_mx_specs from fms_mo.utils.utils import prepare_inputs # import numpy as np # only used in experimental func @@ -192,6 +193,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): nn.Module: quantized module """ mapping = qcfg.get("mapping") + mappable_classes = [cls for cls in mapping.keys() if not isinstance(cls, str)] # if mapping is not defined, qmodel_prep should raise alarm before entering QAnyNet4 qdw = qcfg.get("qdw", False) nbits_a = qcfg.get("nbits_a", 32) @@ -199,6 +201,17 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): qa_mode = qcfg.get("qa_mode", "pact+") qw_mode = qcfg.get("qw_mode", "sawb+") + # Check if MX has been set outside of qconfig_init without mx_specs being created + if ( + available_packages["mx"] + and "mx_specs" not in qcfg + and ( + (qcfg["qa_mode"].startswith("mx_") and qcfg["qw_mode"].startswith("mx_")) + or any(key.startswith("mx_") for key in qcfg.keys()) + ) + ): + set_mx_specs(qcfg, use_mx=True) + # check if on "black list" (need to be exact match), can be skipped or quantized those # to slightly higher "default" precision, or use qspecial_layers to have fine control if curr_full_name in qcfg["qskip_layer_name"]: @@ -214,7 +227,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): qw_mode = qdict.get("qw_mode", qw_mode) # NOTE: if any item is not defined, use current default - if isinstance(module, tuple(mapping.keys())): + if isinstance(module, tuple(mappable_classes)): base_params = {} if hasattr(module, "__constants__"): base_params = {k: getattr(module, k) for k in module.__constants__} @@ -223,6 +236,14 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): module_output = module + # If (W,A) is (32,8) or (8,32), one nbits = None ; Do not quantize in this case + if nbits_a is None or nbits_w is None: + if verbose: + logger.info( + f"Skip quantization of {curr_full_name} - nbits_a or nbits_w is None" + ) + return module_output + # For nn.Conv2d if isinstance(module, nn.Conv2d): if ( @@ -234,14 +255,13 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): "Otherwise please create an equivalen QConv wrapper and change qcfg['mapping']." ) - if mapping[nn.Conv2d]["from"] is None: - return module_output # set from to None means no swap for this type - - QConv = ( - mapping[nn.Conv2d]["to"] - if isinstance(module, mapping[nn.Conv2d]["from"]) - else mapping[nn.Conv2d]["otherwise"] - ) + QConv = mapping.get(nn.Conv2d, None) + if QConv is None: + if verbose: + logger.info( + f"Skip quantization of {curr_full_name} - mapping of Conv2d is None" + ) + return module_output # None means no swap for this type base_params.pop( "output_padding" @@ -313,11 +333,13 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): "Otherwise please create an equivalen QConvT wrapper and change qcfg['mapping']." ) - QConvT = ( - mapping[nn.ConvTranspose2d]["to"] - if isinstance(module, mapping[nn.ConvTranspose2d]["from"]) - else mapping[nn.ConvTranspose2d]["otherwise"] - ) + QConvT = mapping.get(nn.ConvTranspose2d, None) + if QConvT is None: + if verbose: + logger.info( + f"Skip quantization of {curr_full_name} - mapping of ConvTranspose2d is None" + ) + return module_output # None means no swap for this type if base_params["padding_mode"] != "zeros": logger.warning( @@ -368,20 +390,20 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): # For nn.Linear elif isinstance(module, nn.Linear): - QLin = ( - None - if mapping[nn.Linear]["from"] is None - else ( - mapping[nn.Linear]["to"] - if isinstance(module, mapping[nn.Linear]["from"]) - else mapping[nn.Linear]["otherwise"] + if module.__class__ != nn.Linear: + logger.warning( + f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." + "Please make sure it doesn't wrap BN and activ func." + "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." ) - ) - if not QLin or nbits_a is None or nbits_w is None: - # no swapping + + QLin = mapping.get(nn.Linear, None) + if QLin is None: if verbose: - logger.info(f"Skip quantization of {curr_full_name}") - return module_output + logger.info( + f"Skip quantization of {curr_full_name} - mapping of Linear is None" + ) + return module_output # None means no swap for this type module_output = QLin( **base_params, @@ -452,36 +474,36 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): # For nn.LSTM elif isinstance(module, nn.LSTM): - Qlstm = ( - None - if mapping[nn.LSTM]["from"] is None - else ( - mapping[nn.LSTM]["to"] - if isinstance(module, mapping[nn.LSTM]["from"]) - else mapping[nn.LSTM]["otherwise"] + if module.__class__ != nn.LSTM: + logger.warning( + f"{curr_full_name} {type(module)} seems to be a wrapper of LSTM." + "Please make sure it doesn't wrap BN and activ func." + "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." ) - ) - if Qlstm: - if nbits_a is None or nbits_w is None: - if verbose: - logger.info(f"Skip quantization of {curr_full_name}") - else: # if globallayerID in qcfg["Qlist"]: - # 2nd safety check, check if on "white list", swap only if globalID is on Qlist - module_output = Qlstm( - **base_params, - num_bits_weight=qcfg["nbits_w_lstm"], - qw_mode=qcfg["qw_mode_lstm"], - num_bits_input=qcfg["nbits_i_lstm"], - qi_mode=qcfg.get("qi_mode_lstm", qcfg["qa_mode_lstm"]), - num_bits_hidden=qcfg["nbits_h_lstm"], - qh_mode=qcfg.get("qh_mode_lstm", qcfg["qa_mode_lstm"]), - align_zero=qcfg["align_zero"], - qcfg=qcfg, + + Qlstm = mapping.get(nn.LSTM, None) + if Qlstm is None: + if verbose: + logger.info( + f"Skip quantization of {curr_full_name} - mapping of LSTM is None" ) - for k, v in module.named_parameters(): - if getattr(module, k, None): - setattr(module_output, k, v) - module_output._all_weights = module._all_weights + return module_output # None means no swap for this type + + module_output = Qlstm( + **base_params, + num_bits_weight=qcfg["nbits_w_lstm"], + qw_mode=qcfg["qw_mode_lstm"], + num_bits_input=qcfg["nbits_i_lstm"], + qi_mode=qcfg.get("qi_mode_lstm", qcfg["qa_mode_lstm"]), + num_bits_hidden=qcfg["nbits_h_lstm"], + qh_mode=qcfg.get("qh_mode_lstm", qcfg["qa_mode_lstm"]), + align_zero=qcfg["align_zero"], + qcfg=qcfg, + ) + for k, v in module.named_parameters(): + if getattr(module, k, None): + setattr(module_output, k, v) + module_output._all_weights = module._all_weights return module_output @@ -681,10 +703,13 @@ def qmodel_prep( import re qskip_layer_name, QsinglesidedConvs = [], [] + mappable_classes = [ + cls for cls in qcfg["mapping"].keys() if not isinstance(cls, str) + ] mappable_layers = [ n for n, m in model.named_modules() - if isinstance(m, tuple(qcfg["mapping"].keys())) + if isinstance(m, tuple(mappable_classes)) ] qskip_layer_name = set(mappable_layers) diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index a5084e69..bb13afff 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -25,6 +25,7 @@ "exllama_kernels", "exllamav2_kernels", "llmcompressor", + "mx", "matplotlib", "graphviz", "pygraphviz", diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index 6a9a27eb..1a292062 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -19,6 +19,7 @@ from importlib.metadata import version from pathlib import Path from typing import Any, Union +import argparse import json import logging import os @@ -30,7 +31,8 @@ import torch # Local -from fms_mo.modules import QLSTM, QConv2d, QConvTranspose2d, QLinear +from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear +from fms_mo.utils.import_utils import available_packages # import numpy as np # only used in experimental func @@ -230,7 +232,7 @@ def get_recipe(recipe: str, subdir: str = None) -> Union[list, dict]: return temp_data -def qconfig_init(recipe: str = None, args: Any = None) -> dict: +def qconfig_init(recipe: str = None, args: Any = None, use_mx: bool = False) -> dict: """Three possible ways to create qcfg: 1. create a default qcfg 2. load from a json @@ -274,17 +276,12 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict: qcfg = {} # 1. create a dict with default values qcfg["mapping"] = { - nn.Conv2d: {"from": nn.Conv2d, "to": QConv2d, "otherwise": QConv2d}, - nn.ConvTranspose2d: { - "from": nn.ConvTranspose2d, - "to": QConvTranspose2d, - "otherwise": QConvTranspose2d, - }, - nn.Linear: {"from": nn.Linear, "to": QLinear, "otherwise": QLinear}, - nn.LSTM: {"from": nn.LSTM, "to": QLSTM, "otherwise": QLSTM}, + nn.Linear: QLinear, + nn.Conv2d: QConv2d, + nn.ConvTranspose2d: QConvTranspose2d, + nn.LSTM: QLSTM, + "matmul_or_bmm": QBmm, } - # TODO: This could be further simplified. e.g. mapping["from class"] = "to class" - # "otherwise" is rarely used, and redundant "from" in the output dict qcfg["pkg_versions"] = get_pkg_info() @@ -417,9 +414,215 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict: " Default or values from json of the same key will be overwritten." ) + # 4. Check if mapping must change for MX library + # For now, simply use qa_mode or qw_mode to trigger, e.g. "mx_fp4_e2m1" -> "fp4_e2m1" + # user may create qcfg without "mx_fpxxx" then manually changes qw_mode/qa_mode to "mx_fpxxx" + # => need to check again at the beginning of qmodel_prep(), i.e. in check_config() + set_mx_specs(qcfg, args, use_mx) + return qcfg +def set_mx_specs( + config: dict, + args: argparse.ArgumentParser = None, + use_mx: bool = False, +): + """ + Set mx_specs dict in quantized config to be used for MX quantization. + Will use fms_mo default values for variables when none are given. + + Options available: + 1. Pass job args to create mx_specs. Must have --a_elem_format and --w_elem_format set. + 2. Consume a premade mx_specs dict from quantized config if present. + 3. Consume quantized config variables prefixed with "mx_". + + Options 2 and 3 are mutually exclusive with preference for Option 2 if both are given. + + Args: + config (dict): Quantization config dict + args (argparse.ArgumentParser, optional): Job arg parser. Defaults to None. + use_mx (bool): Create default mx_specs when qcfg or args aren't present. + Defaults to False. + """ + mx_prefix = "mx_" + + # MX lib defaults for these values are None, 0, nearest, max, or bool + fms_defaults = get_mx_specs_defaults() + + # Already have a mx_specs saved in config + use_mx_specs_config = "mx_specs" in config + + # Check for any "mx_" vars set in config + use_mx_config = any(key.startswith(mx_prefix) for key in config.keys()) + + # Check args for any mx_specs vars + use_mx_args = args is not None and any( + hasattr(args, key) + for key, _ in fms_defaults.items() + if key != "block_size" + # some items are not unique to mx, add names here if needed + ) + + # Lastly, check for BMM consistency to enable QBmmMX + fms_bmm_modes = [ + config["bmm1_qm1_mode"].startswith(mx_prefix), + config["bmm1_qm2_mode"].startswith(mx_prefix), + config["bmm2_qm1_mode"].startswith(mx_prefix), + config["bmm2_qm2_mode"].startswith(mx_prefix), + ] + # If any mx bmm set, they all must be set for QBmmMX ; will be checked in check_config + use_fms_bmm_modes = all(fms_bmm_modes) + + use_mx = ( + use_mx + or use_mx_specs_config + or use_mx_config + or use_mx_args + or use_fms_bmm_modes + ) + + if use_mx: + # If "mapping" has been removed from qcfg -> chk_cfg is being called by save_config() at + # the end of qmodel_prep() => don't need to update anything. + # NOTE: If "mx_" qa_/qw_mode was used through args, the "mx_" prefix would have been removed + # already in chk_cfg() => "use_mx" flag will be False. Keep in mind that THE ONLY WAY TO + # TRIGGER REFRESH of mx_specs AFTER qconfig_init() is to manually set qa_/qw_mode to + # "mx_something"! + + if available_packages["mx"]: + # Standard + from functools import partial + + # Third Party + # pylint: disable = import-error + import mx + + # Local + from fms_mo.modules.bmm import QBmmMX + from fms_mo.modules.linear import QLinearMX + + # Create a MxSpecs object based on input args and overwrite w/ qcfg if provided + mx_specs = mx.get_mx_specs(args) if use_mx_args else mx.MxSpecs() + + # Ensure fms defaults are set assuming job args haven't already changed them + for key, val in fms_defaults.items(): + if mx_specs[key] in [None, 0, False, True, "nearest", "max"]: + mx_specs[key] = val + + # Use config["mx_specs"] settings + if use_mx_specs_config: + mx_specs.update(config["mx_specs"]) + + # Use qcfg mx equivalents + else: + # k_elem_format special case - in q_modes + if config["qw_mode"].startswith(mx_prefix): + mx_specs["w_elem_format"] = config["qw_mode"].replace(mx_prefix, "") + if config["qa_mode"].startswith(mx_prefix): + mx_specs["a_elem_format"] = config["qa_mode"].replace(mx_prefix, "") + + for mx_var, _ in fms_defaults.items(): + fms_var = "mx_" + mx_var + # Only update if its in config; default values already set + if fms_var in config: + mx_specs[mx_var] = config.get(fms_var) + + # Only 1 variable that has "mx_" prefix from MX lib + mx_var = "mx_flush_fp32_subnorms" + if mx_var in config: + mx_specs[mx_var] = config.get(mx_var) + + # Many mx_spec vars are synched with other vars -- may have changed now + mx_specs = mx.finalize_mx_specs(mx_specs) + + # Save finalized mx_spec to config + config["mx_specs"] = mx_specs.data + + # Update mapping for torch.nn and matmul_or_bmm to MX variants + # QLinearMX will be used, but QBmmMX requires bmm specifically + config["mapping"][nn.Linear] = partial( + QLinearMX, mx_specs=config["mx_specs"] + ) + # config["mapping"][nn.Conv2d] = partial( + # QConv2dMX, mx_specs=config["mx_specs"] + # ) + # config["mapping"][nn.ConvTranspose2d] = partial( + # QConvTranspose2dMX, mx_specs=config["mx_specs"] + # ) + if use_fms_bmm_modes: # all bmm_modes are "mx_" prefixed + config["mapping"]["matmul_or_bmm"] = partial( + QBmmMX, mx_specs=config["mx_specs"] + ) + + else: + logger.info("MX variables provided, but MX package is not installed") + + +def is_nvcc_installed(): + """ + Check whether we can call on the NVIDIA CUDA Compiler from the OS level + + Returns: + bool: If nvcc is found and callable at the OS level + """ + # Standard + import subprocess + + try: + # Run the nvcc command to check if it's installed + subprocess.check_output("nvcc --version", shell=True, stderr=subprocess.STDOUT) + logger.info("nvcc is installed and callable") + return True + except subprocess.CalledProcessError: + logger.info("nvcc is installed, but there was an issue running nvcc.") + return False + except FileNotFoundError: + logger.info("nvcc is not installed on the system.") + return False + + +def get_mx_specs_defaults(): + """ + Get key,value pairs for mx_specs defaults for fms_mo + + Returns: + dict: fms_mo defaults of mx_specs + """ + return { + "w_elem_format": "fp8_e4m3", + "a_elem_format": "fp8_e4m3", + "w_elem_format_bp": "fp8_e4m3", + "a_elem_format_bp": "fp8_e4m3", + "a_elem_format_bp_ex": "fp8_e4m3", + "a_elem_format_bp_os": "fp8_e4m3", + "shared_exp_method": "max", + "scale_bits": 8, + "block_size": 32, # this item is not unique to mx + "bfloat": 16, # bfloat and fp cannot be set at the same time + "fp": 0, + "round": "nearest", + "round_m": "nearest", + "round_weight": "nearest", + "round_output": "nearest", + "round_grad_weight": "nearest", + "round_grad_input": "nearest", + "round_mx_output": "nearest", + "round_mx_input_grad_input": "nearest", + "round_mx_weight_grad_input": "nearest", + "round_mx_grad_output_grad_input": "nearest", + "round_mx_input_grad_weight": "nearest", + "round_mx_grad_output_grad_weight": "nearest", + "quantize_backprop": True, + "bfloat_subnorms": True, + "mx_flush_fp32_subnorms": False, + "softmax_exp2": False, + "vec_use_exp2": False, + "vec_use_recip": False, + "custom_cuda": torch.cuda.is_available() and is_nvcc_installed(), + } + + def has_non_serializable_object(anything: Any) -> bool: """ Generalized recursive function looking for any non-serializable Python object @@ -518,14 +721,11 @@ def get_unwanted_defaults() -> dict: ( "mapping", { - nn.Conv2d: {"from": nn.Conv2d, "to": QConv2d, "otherwise": QConv2d}, - nn.ConvTranspose2d: { - "from": nn.ConvTranspose2d, - "to": QConvTranspose2d, - "otherwise": QConvTranspose2d, - }, - nn.Linear: {"from": nn.Linear, "to": QLinear, "otherwise": QLinear}, - nn.LSTM: {"from": nn.LSTM, "to": QLSTM, "otherwise": QLSTM}, + nn.Conv2d: QConv2d, + nn.ConvTranspose2d: QConvTranspose2d, + nn.Linear: QLinear, + nn.LSTM: QLSTM, + "matmul_or_bmm": QBmm, }, ), ("checkQerr_frequency", False), @@ -827,30 +1027,50 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: "bmm2_qm2_mode", ] + # mx related modes for config: + mx_spec_config_modes = [ + "mx_fp8_e5m2", + "mx_fp8_e4m3", + "mx_fp4_e2m1", + "mx_fp4", + "mx_int8", + "mx_int4", + "mx_fp16", + "mx_float16", + "mx_bf16", + "mx_bfloat16", + ] + # Check each for correct ranges for qa_mode_str in qa_modes_str: qa_mode = config.get(qa_mode_str, "pact+") - if not qa_mode in qa_mode_settings: + if not qa_mode in (qa_mode_settings + mx_spec_config_modes): raise ValueError( f"{qa_mode_str} = {qa_mode} is not set to one of the following: " - f"{qa_mode_settings}" + f"{qa_mode_settings + mx_spec_config_modes}" ) for qw_mode_str in qw_modes_str: qw_mode = config.get(qw_mode_str, "sawb+") - if not qw_mode in qw_mode_settings: + if not qw_mode in (qw_mode_settings + mx_spec_config_modes): raise ValueError( f"{qw_mode_str} = {qw_mode} is not set to one of the following: " - f"{qw_mode_settings}" + f"{qw_mode_settings + mx_spec_config_modes}" ) + bmm_mode_consistency = 0 # all or none when using mx for bmm_mode_str in bmm_modes_str: bmm_mode = config.get(bmm_mode_str, "pactsym+") - if not bmm_mode in bmm_mode_settings: + bmm_mode_consistency += bmm_mode.startswith("mx_") + # mx_specs doesn't have 4 individual bmmX_qmY_modes, it re-uses w and a fmt instead. + # We will keep them in qcfg (with "mx_" prefix NOT removed). + if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes): raise ValueError( f"{bmm_mode_str} = {bmm_mode} is not set to one of the following: " - f"{bmm_mode_settings}" + f"{bmm_mode_settings + mx_spec_config_modes}" ) + if bmm_mode_consistency != 0 and bmm_mode_consistency != len(bmm_modes_str): + raise ValueError("bmmX_qmY_modes inconsistent! Should be all mx or no mx.") # Check mode calibration and initialization values calib_init_settings = ["percentile", "pact", "sawb", "max"] @@ -887,7 +1107,9 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: boolean_var = config.get( boolean_var_str, False ) # assume default = False is not specified - if not isinstance(boolean_var, bool): + # Note: bool is a subclass of int, so we can't rely on isinstance + # pylint: disable = unidiomatic-typecheck + if type(boolean_var) is not bool: raise ValueError(f"{boolean_var_str} = {boolean_var} is not a boolean") default_config = config_defaults() @@ -1021,3 +1243,147 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: smoothq_act_scale_path = config.get("smoothq_act_scale_path", None) if smoothq_act_scale_path and not smoothq_act_scale_path.endswith(".pt"): raise ValueError(f"{smoothq_act_scale_path=} is not a .pt checkpoint") + + # Check MX-related variables in mx_specs + mx_specs = config.get("mx_specs", None) + if mx_specs: + # mx related modes for config: + mx_spec_modes = [ + "fp8_e5m2", + "fp8_e4m3", + "fp4_e2m1", + "fp4", + "int8", + "int4", + "fp16", + "float16", + "bf16", + "bfloat16", + ] + + mx_specs_format_var_strs = { + "w_elem_format", + "a_elem_format", + "w_elem_format_bp", + "a_elem_format_bp", + "a_elem_format_bp_ex", + "a_elem_format_bp_os", + } + + for format_var_str in mx_specs_format_var_strs: + format_var = mx_specs[format_var_str] + if not isinstance(format_var, str): + raise ValueError( + f"mx_specs[{format_var_str}] = {format_var} is not a string" + ) + if format_var not in mx_spec_modes: + raise ValueError( + f"mx_specs[{format_var_str}] = {format_var} is not in one of the following: " + f"{mx_spec_modes}" + ) + + mx_spec_int_var_str_defaults = [ + ("scale_bits", 8), + ("block_size", 32), + ("bfloat", 16), + ] + mx_spec_int_var_values = {2, 4, 6, 8, 16, 32} + + for integer_var_str, integer_var_default in mx_spec_int_var_str_defaults: + integer_var = mx_specs.get(integer_var_str, integer_var_default) + # Check if integer was given as float (1.0 when it should be 1) + if isinstance(integer_var, float) and integer_var.is_integer(): + mx_specs[integer_var_str] = int(integer_var) + integer_var = int(integer_var) + if not isinstance(integer_var, int): + raise ValueError( + f"mx_specs[{integer_var_str}] = {integer_var} is not an integer" + ) + if integer_var not in mx_spec_int_var_values: + raise ValueError( + f"mx_specs[{integer_var_str}] = {integer_var} must be an integer in " + f"{mx_spec_int_var_values}" + ) + + mx_spec_bool_var_strs = { + "mx_flush_fp32_subnorms", + "bfloat_subnorms", + "quantize_backprop", + "softmax_exp2", + "vec_use_exp2", + "vec_use_recip", + "custom_cuda", + } + for boolean_var_str in mx_spec_bool_var_strs: + # assume default = False is not specified + boolean_var = mx_specs.get(boolean_var_str, False) + # Note: bool is a subclass of int, so we can't rely on isinstance + # pylint: disable = unidiomatic-typecheck + if type(boolean_var) is not bool: + raise ValueError( + f"mx_specs[{boolean_var_str}] = {boolean_var} is not a boolean" + ) + + mx_spec_exp_var_strs = { + "shared_exp_method", + } + mx_spec_exp_var_values = {"max", None} + for exp_var_str in mx_spec_exp_var_strs: + exp_var = mx_specs.get(exp_var_str, "max") + if not isinstance(exp_var, str): + raise ValueError(f"mx_specs[{exp_var_str}] = {exp_var} is not a string") + if exp_var not in mx_spec_exp_var_values: + raise ValueError( + f"mx_specs[{exp_var_str}] = {exp_var} is not in " + f"{mx_spec_exp_var_values}" + ) + + mx_spec_round_var_strs = { + "round", + "round_m", + "round_weight", + "round_output", + "round_grad_weight", + "round_grad_input", + "round_mx_output", + "round_mx_input_grad_input", + "round_mx_weight_grad_input", + "round_mx_grad_output_grad_input", + "round_mx_input_grad_weight", + "round_mx_grad_output_grad_weight", + } + mx_spec_round_var_values = {"nearest", "floor"} + for round_var_str in mx_spec_round_var_strs: + round_var = mx_specs.get(round_var_str, "nearest") + if not isinstance(round_var, str): + raise ValueError( + f"mx_specs[{round_var_str}] = {round_var} is not a string" + ) + if round_var not in mx_spec_round_var_values: + raise ValueError( + f"mx_specs[{round_var_str}] = {round_var} is not in" + f"{mx_spec_round_var_values}" + ) + + # If mapping is defined, check for MX classes + if available_packages["mx"]: + # Local + from fms_mo.modules.bmm import QBmmMX + from fms_mo.modules.linear import QLinearMX + + mapping = config.get("mapping", None) + + # partial was used to wrap QLinearMX, will be an instance of partial + # 1. can use .func pointer to find the original class + # 2. QBmm is optional, could be partial(QBmmMX,) or QBmm + if mapping is not None: + if mapping[nn.Linear].func is not QLinearMX: + raise ValueError("MX mapping for nn.Linear is not QLinearMX") + + qbmm_map = mapping["matmul_or_bmm"] + if bmm_mode_consistency > 0: + if getattr(qbmm_map, "func", None) is not QBmmMX: + raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX") + else: + if qbmm_map is not QBmm: + raise ValueError("Mapping for matmul_or_bmm is not QBmm") diff --git a/fms_mo/utils/torchscript_utils.py b/fms_mo/utils/torchscript_utils.py index 39025b8d..87c50763 100644 --- a/fms_mo/utils/torchscript_utils.py +++ b/fms_mo/utils/torchscript_utils.py @@ -28,7 +28,6 @@ import torch # Local -from fms_mo.modules import QBmm from fms_mo.quant.quantizers import transformers_prepare_input from fms_mo.utils.import_utils import available_packages from fms_mo.utils.utils import move_to, patch_torch_bmm, prepare_data_4_fwd @@ -1543,6 +1542,7 @@ def model_analyzer_ts( # ['QBmm'] is determined by nbits_bmm[1,2], if using QBertSelfAttn instead of func swapping, # it could still be True # ['which2patch'] == 'off' will forcefully turn off this searching and QBmm attaching + QBmm = quant_config["mapping"]["matmul_or_bmm"] find_single_sided_bmm(reconstructed_graph) # After search, flag "isActOutUnidir" and "isActOutBounded" will be available diff --git a/fms_mo/utils/utils.py b/fms_mo/utils/utils.py index 38e2a1db..8e3947cd 100644 --- a/fms_mo/utils/utils.py +++ b/fms_mo/utils/utils.py @@ -217,6 +217,7 @@ def prepare_input( if isinstance(data, torch.Tensor): kwargs = {"device": device} return data.to(**kwargs) + logger.warning( "data input to prepare_input must be Dict, " "Tuple, List or torch.Tensor and currently is", diff --git a/install_patches.py b/install_patches.py new file mode 100644 index 00000000..76ebcfe5 --- /dev/null +++ b/install_patches.py @@ -0,0 +1,91 @@ +# Standard +import os +import subprocess + +dependencies_with_patch = { + "microxcaling": "https://github.com/microsoft/microxcaling.git", +} + + +def install_with_patch( + pkg_name: str, + repo_url: str, + patch_file: str, + home_dir: str = None, +) -> None: + """ + Install a dependency with a patch file + + Args: + pkg_name (str): Name of package being installed + repo_url (str): Github repo URL + patch_file (str): Patch file in patches/ + home_dir (str): Home directory with fms-model-optimizer and other packages. + Defaults to None. + """ + # We want to git clone the repo to $HOME/repo_name + if home_dir is None: + home_dir = os.path.expanduser("~") + + # Get fms_mo directory in home_dir + cwd = os.getcwd() + + # Get patch file location from fms-model-optimizer + patch_file = os.path.join(cwd, "patches", patch_file) + if not os.path.exists(patch_file): + raise FileNotFoundError(f"Can't find {pkg_name} patch file in {cwd}/patches") + + # Check to see if package exists in cwd or home_dir + pkg_path_cwd = os.path.join(cwd, pkg_name) + pkg_path_home = os.path.join(home_dir, pkg_name) + pkg_exists_cwd = os.path.exists(pkg_path_cwd) + pkg_exists_home = os.path.exists(pkg_path_home) + + # If pkg already exists in cwd or home_dir, skip clone + if pkg_exists_cwd: + pkg_dir = pkg_path_cwd + print(f"Directory {pkg_dir} already exists. Skipping download.") + elif pkg_exists_home: + pkg_dir = pkg_path_home + print(f"Directory {pkg_dir} already exists. Skipping download.") + else: + # Clone repo to home directory + pkg_dir = pkg_path_home + subprocess.run(["git", "clone", repo_url], cwd=home_dir, check=True) + + # Apply patch and pip install package + try: + subprocess.run(["git", "apply", "--check", patch_file], cwd=pkg_dir, check=True) + subprocess.run(["git", "apply", patch_file], cwd=pkg_dir, check=True) + print( + f"FMS Model Optimizer patch for {pkg_name} applied. Installing package now." + ) + subprocess.run(["pip", "install", "."], cwd=pkg_dir, check=True) + + except subprocess.CalledProcessError as e: + print( + f"FMS Model Optimizer patch for {pkg_name} is already installed " + f"or an error has occured: \n{e}" + ) + + +def install_dependencies_with_patch() -> None: + """ + Script to install depenencies that requires a patch prior to pip install. + + To execute, use `python install_patches.py`. + + Requirements: + 1. The patch file is named .patch + 2. Patch file must be located in fms-model-optimizer/patches + """ + for pkg, repo_url in dependencies_with_patch.items(): + install_with_patch( + pkg_name=pkg, + repo_url=repo_url, + patch_file=pkg + ".patch", + ) + + +if __name__ == "__main__": + install_dependencies_with_patch() diff --git a/patches/README.md b/patches/README.md new file mode 100644 index 00000000..fcd8a139 --- /dev/null +++ b/patches/README.md @@ -0,0 +1,28 @@ +## Patching Third Party Dependencies +Some dependencies clash with the current FMS Model Optimizer environment and we need to apply a patch. +To do this, we have provided a script in `fms-model-optimizer` named `install_patches.py`. + +To run this script: +``` +python3 install_patches.py +``` + +The following optional packages require a patch: +* `microxcaling`: Uses outdated versions of PyTorch-related packages + +## Making a Patch File +To make a git diff patch file, first make your desired changes to the repository. Then run +``` +git diff > .patch +``` +Packages may include files that differ by white spaces even if you didn't change them. +To address this, add `--ignore-all-spaces` to the `git diff` command. + +To test the patch file, copy the `.patch` file to `fms-model-optimizer/patches`. +Next add a new entry to the `install_patches.py` dictionary called `dependencies_with_patch` with the package name and repo URL: +``` +dependencies_with_patch = { + : , # for .patch +} +``` +Lastly, run the python command for `install_patches.py`. \ No newline at end of file diff --git a/patches/microxcaling.patch b/patches/microxcaling.patch new file mode 100644 index 00000000..8139f3b2 --- /dev/null +++ b/patches/microxcaling.patch @@ -0,0 +1,17 @@ +diff --git a/pyproject.toml b/pyproject.toml +index e80053e..b4ec100 100644 +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -5,9 +5,9 @@ description = 'The Microsoft MX floating point library' + readme = "README.md" + requires-python = ">=3.8" + dependencies = [ +- "torch==2.2.0", +- "torchvision==0.16", +- "torchaudio==2.1.0" ++ "torch", ++ "torchvision", ++ "torchaudio" + ] + license = { file = "LICENSE" } + keywords = ["mx", "floating point", "math", "mathematics", "machine learning", "deep learning", "artificial intelligence", "ai", "ml", "dl", "torch", "torchvision", "torchaudio"] diff --git a/pyproject.toml b/pyproject.toml index 00e88a17..440d0be5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,13 +38,14 @@ dependencies = [ "pandas", "safetensors", "ibm-fms>=0.0.8", -"pkginfo>1.10" +"pkginfo>1.10", ] [project.optional-dependencies] dev = ["pre-commit>=3.0.4,<5.0"] fp8 = ["llmcompressor"] gptq = ["auto_gptq>0.4.2", "optimum>=1.15.0"] +mx = ["microxcaling>=1.1"] visualize = ["matplotlib", "graphviz", "pygraphviz"] flash-attn = ["flash-attn>=2.5.3,<3.0"] opt = ["fms-model-optimizer[fp8, gptq]"] diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 47c42fc0..3f715a8c 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -25,13 +25,16 @@ from torchvision.io import read_image from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16 from transformers import BertModel, BertTokenizer +import numpy as np import pytest import torch +import torch.nn.functional as F # Local # fms_mo imports from fms_mo import qconfig_init -from fms_mo.modules import QLSTM, QConv2d, QConvTranspose2d, QLinear +from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear +from fms_mo.utils.qconfig_utils import get_mx_specs_defaults, set_mx_specs ######################## # check_config Fixtures # @@ -310,6 +313,44 @@ def not_which2patch_contextmanager_settings(): return ["torch.vmm", "torch.natnul", "None"] +@pytest.fixture(scope="session") +def bad_mx_specs_settings(): + """ + Get list of invalid mx_spec key,value pairs + + Returns: + list: invalid mx_spec list + """ + return [ + ("w_elem_format", "fp8_e5m3"), + ("a_elem_format", "fp8_m4e3"), + ("scale_bits", False), + ("block_size", "32"), + ("bfloat", [16]), + ("round", "bankers"), + ("custom_cuda", "yes"), + ] + + +@pytest.fixture(scope="session") +def bad_mx_config_settings(): + """ + Get list of invalid mx config key,value pairs for config and mx_specs + + Returns: + list: invalid mx_spec list + """ + return [ + ("qw_mode", "w_elem_format", "mx_fp8_e5m3", "fp8_e5m3"), + ("qa_mode", "a_elem_format", "mx_fp8_m4e3", "fp8_m4e3"), + ("mx_scale_bits", "scale_bits", False, False), + ("mx_block_size", "block_size", "32", "32"), + ("mx_bfloat", "bfloat", {16}, {16}), + ("mx_round", "round", "bankers", "bankers"), + ("mx_custom_cuda", "custom_cuda", "yes", "yes"), + ] + + ################################ # Toy Model Classes + Fixtures # ################################ @@ -765,6 +806,7 @@ def num_bits_weight_fp16(request): # Note: All configs require deepcopy, as we will be modifying them for various tests default_config_params = [qconfig_init()] +mx_config_params = [qconfig_init(use_mx=True)] @pytest.fixture(scope="function", params=default_config_params) @@ -782,6 +824,59 @@ def config_fp32(request): return deepcopy(qconfig) +@pytest.fixture(scope="function", params=default_config_params) +def config_fp32_mx(request): + """ + Create fp32 qconfig w/ mx_specs vars set in qconfig. + + Args: + request (dict): qconfig_init + + Returns: + dict: qconfig_init + """ + qconfig = deepcopy(request.param) + mx_specs = get_mx_specs_defaults() + + # Set config vars prefixed w/ "mx_" + for key, val in mx_specs.items(): + qconfig["mx_" + key] = val + + # Only 1 variable that has "mx_" prefix from MX lib + qconfig["mx_flush_fp32_subnorms"] = qconfig["mx_mx_flush_fp32_subnorms"] + del qconfig["mx_mx_flush_fp32_subnorms"] + + # Move x_elem_format to q_modes and delete mx_x_elem_format + # Needs prefix settings to avoid collision w/ fms_mo modes + qconfig["qa_mode"] = "mx_" + qconfig["mx_a_elem_format"] + qconfig["qw_mode"] = "mx_" + qconfig["mx_w_elem_format"] + del qconfig["mx_a_elem_format"] + del qconfig["mx_w_elem_format"] + + return qconfig + + +@pytest.fixture(scope="function", params=mx_config_params) +def config_fp32_mx_specs(request): + """ + Create fp32 qconfig w/ mx_specs. + + + Args: + request (dict): qconfig_init + + Returns: + dict: qconfig_init + """ + qconfig = deepcopy(request.param) + qconfig["mx_specs"] = get_mx_specs_defaults() + + # Set mx_specs as if we ran qconfig_init + set_mx_specs(qconfig) + + return qconfig + + @pytest.fixture(scope="function", params=default_config_params) def config_fp16(request): """ @@ -980,22 +1075,11 @@ def wanted_pair(request): ( "mapping", { - torch.nn.Conv2d: { - "from": torch.nn.Conv2d, - "to": QConv2d, - "otherwise": QConv2d, - }, - torch.nn.ConvTranspose2d: { - "from": torch.nn.ConvTranspose2d, - "to": QConvTranspose2d, - "otherwise": QConvTranspose2d, - }, - torch.nn.Linear: { - "from": torch.nn.Linear, - "to": QLinear, - "otherwise": QLinear, - }, - torch.nn.LSTM: {"from": torch.nn.LSTM, "to": QLSTM, "otherwise": QLSTM}, + torch.nn.Conv2d: QConv2d, + torch.nn.ConvTranspose2d: QConvTranspose2d, + torch.nn.Linear: QLinear, + torch.nn.LSTM: QLSTM, + "matmul_or_bmm": QBmm, }, ), ("checkQerr_frequency", False), @@ -1135,3 +1219,83 @@ def model_bert_eager(): return BertModel.from_pretrained( "google-bert/bert-base-uncased", torchscript=True, attn_implementation="eager" ) + + +# MX reference class for quantization +if torch.cuda.is_available(): + + class ResidualMLP(torch.nn.Module): + """ + Test Linear model for MX library + """ + + def __init__(self, hidden_size, device="cuda"): + super().__init__() + + self.layernorm = torch.nn.LayerNorm(hidden_size, device=device) + self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device) + self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device) + self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device) + # add a dummy layer because by default we skip 1st/last, + # if there are only 2 layers, all will be skipped + + def forward(self, inputs): + """ + Forward function for Residual MLP + + Args: + inputs (torch.tensor): Input tensor + + Returns: + torch.tensor: Output tensor + """ + norm_outputs = self.layernorm(inputs) + + # MLP + proj_outputs = self.dense_4h(norm_outputs) + # pylint: disable=not-callable + proj_outputs = F.gelu(proj_outputs) + mlp_outputs = self.dense_h(proj_outputs) + mlp_outputs = self.dummy(mlp_outputs) + + # Residual Connection + outputs = inputs + mlp_outputs + + return outputs + + +mx_format_params = ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"] + + +@pytest.fixture(scope="session", params=mx_format_params) +def mx_format(request): + """ + Get a MX element format to test + + Returns: + str: MX element format name + """ + return request.param + + +@pytest.fixture(scope="function") +def input_residualMLP(): + """ + Get a random input for a residual MLP model + + Returns: + torch.FloatTensor: Random 16x128 tensor + """ + x = np.random.randn(16, 128) + return torch.tensor(x, dtype=torch.float32, device="cuda") + + +@pytest.fixture(scope="function") +def model_residualMLP(): + """ + Get a ResidualMLP model + + Returns: + torch.nn.Module: _description_ + """ + return ResidualMLP(128) diff --git a/tests/models/test_model_utils.py b/tests/models/test_model_utils.py index 1d86c055..3be33686 100644 --- a/tests/models/test_model_utils.py +++ b/tests/models/test_model_utils.py @@ -26,9 +26,9 @@ import torch # Local -from fms_mo.modules.bmm import QBmm -from fms_mo.modules.conv import DetQConv2d, QConv2d, QConv2dPTQ, QConv2dPTQv2 -from fms_mo.modules.linear import QLinear +from fms_mo.modules.conv import isinstance_qconv2d +from fms_mo.modules.linear import isinstance_qlinear +from fms_mo.prep import quantized_modules from fms_mo.utils.qconfig_utils import serialize_config logger = logging.getLogger(__name__) @@ -37,11 +37,6 @@ # Helper Functions # #################### -qconv2d_nodes = (QConv2d, QConv2dPTQ, QConv2dPTQv2, DetQConv2d) -qlinear_nodes = QLinear - -quantized_nodes = (QConv2d, QConv2dPTQ, QConv2dPTQv2, QLinear, DetQConv2d) - def is_qconv2d(node: torch.nn.Module): """ @@ -53,7 +48,7 @@ def is_qconv2d(node: torch.nn.Module): Returns: bool: If node is in qconv2d_nodes """ - return isinstance(node, qconv2d_nodes) + return isinstance_qconv2d(node) def is_qlinear(node: torch.nn.Module): @@ -66,7 +61,7 @@ def is_qlinear(node: torch.nn.Module): Returns: bool: If node is in qlinear_nodes """ - return isinstance(node, qlinear_nodes) + return isinstance_qlinear(node) def is_quantized_layer(node: torch.nn.Module): @@ -79,7 +74,7 @@ def is_quantized_layer(node: torch.nn.Module): Returns: bool: If node is in quantized_nodes """ - return isinstance(node, quantized_nodes) + return isinstance(node, quantized_modules) ######################### @@ -100,7 +95,7 @@ def count_qmodules(model: torch.nn.Module): """ torch_modules, fms_qmodules = [], [] for n, m in model.named_modules(): - if isinstance(m, (QConv2d, QLinear, QBmm)): + if is_quantized_layer(m): fms_qmodules.append((n, m)) elif isinstance(m, (Conv2d, Linear)): torch_modules.append((n, m)) diff --git a/tests/models/test_mx.py b/tests/models/test_mx.py new file mode 100644 index 00000000..fce79463 --- /dev/null +++ b/tests/models/test_mx.py @@ -0,0 +1,163 @@ +# Third Party +import pytest +import torch + +# Local +from fms_mo import qmodel_prep +from fms_mo.utils.import_utils import available_packages +from fms_mo.utils.qconfig_utils import check_config, set_mx_specs +from tests.models.test_model_utils import delete_config, qmodule_error + +if available_packages["mx"]: + # Local + # pylint: disable=ungrouped-imports + from fms_mo.modules.bmm import QBmmMX + from fms_mo.modules.linear import QLinearMX + + mx_qmodules = [ + QLinearMX, + QBmmMX, + ] + +@pytest.mark.skipif( + not available_packages["mx"], + reason="Skipping mx_specs error test; No package found", +) +def test_config_mx_specs_error( + model_residualMLP: torch.nn.Module, + config_fp32_mx_specs: dict, + bad_mx_specs_settings: list, +): + """ + Check that mx_specs throw ValueError when presented with bad key,value pair + + Args: + model_residualMLP (torch.nn.Module): Single fp32 model + config_fp32_mx_specs (dict): Config for fp32 quantization w/ mx_specs + bad_mx_specs_settings (list): + List of invalid values for mx_specs + """ + model_dtype = next(model_residualMLP.parameters()).dtype + + assert "mx_specs" in config_fp32_mx_specs + mx_specs_temp = config_fp32_mx_specs.get("mx_specs") + + for key,bad_val in bad_mx_specs_settings: + # Every time we change the value, we must reset mx_specs + config_fp32_mx_specs["mx_specs"][key] = bad_val + set_mx_specs(config_fp32_mx_specs) + + with pytest.raises(ValueError): + check_config(config_fp32_mx_specs, model_dtype) + + # Reset to saved value + config_fp32_mx_specs["mx_specs"] = mx_specs_temp + +@pytest.mark.skipif( + not available_packages["mx"], + reason="Skipping mx_specs error test; No package found", +) +def test_config_mx_error( + model_residualMLP: torch.nn.Module, + config_fp32_mx: dict, + bad_mx_config_settings: list, +): + """ + Check that mx_specs throw ValueError when presented with bad key,value pair + + Args: + model_residualMLP (torch.nn.Module): Single fp32 model + config_fp32_mx (dict): Config for fp32 quantization w/ mx_specs + bad_mx_specs_settings (list): + List of invalid values for mx_specs + """ + model_dtype = next(model_residualMLP.parameters()).dtype + + assert "mx_specs" not in config_fp32_mx + + for config_key, mx_specs_key, config_bad_val, mx_specs_bad_val in bad_mx_config_settings: + # Second check config w/ "mx_" prefix + mx_temp = config_fp32_mx[config_key] + + # Need to reset qcfg["mx_specs"] w/ bad val + config_fp32_mx[config_key] = config_bad_val + + set_mx_specs(config_fp32_mx) + assert "mx_specs" in config_fp32_mx + assert config_fp32_mx["mx_specs"][mx_specs_key] == mx_specs_bad_val + + with pytest.raises(ValueError): + check_config(config_fp32_mx, model_dtype) + + # Reset value and delete mx_specs + config_fp32_mx[config_key] = mx_temp + del config_fp32_mx["mx_specs"] + + +@pytest.mark.skipif( + not torch.cuda.is_available() + or not available_packages["mx"], + reason="Skipped because CUDA or MX library was not available", +) +def test_residualMLP( + model_residualMLP: torch.nn.Module, + input_residualMLP: torch.FloatTensor, + config_fp32_mx_specs: dict, + mx_format: str, +): + """ + Test residualMLP for qmodel_prep + + Args: + model_residualMLP (torch.nn.Module): Single fp32 model. + input_residualMLP (torch.FloatTensor): Random 16x128 tensor. + config_fp32_mx_specs (dict): Config for fp32 quantization w/ mx_specs. + mx_format (str): MX format for quantization. + """ + # Remove any saved qcfg.json + delete_config() + + config_fp32_mx_specs["mx_specs"]["w_elem_format"] = mx_format + config_fp32_mx_specs["mx_specs"]["a_elem_format"] = mx_format + set_mx_specs(config_fp32_mx_specs) + + qmodel_prep(model_residualMLP, input_residualMLP, config_fp32_mx_specs, use_dynamo=True) + qmodule_error(model_residualMLP, 2, 1) + + # One layer should be QLinearMX + found_qmodule_mx = False + for _, module in model_residualMLP.named_modules(): + if any( isinstance(module, qmodule_mx) for qmodule_mx in mx_qmodules ): + found_qmodule_mx = True + # Check that the desired mx format was propagated to class + assert module.mx_specs["w_elem_format"] == mx_format + assert module.mx_specs["a_elem_format"] == mx_format + + assert found_qmodule_mx + + +@pytest.mark.skipif( + not available_packages["mx"], + reason="Skipping mx_specs error test; No package found", +) +def test_mx_specs_after_qconfig_init( + model_residualMLP: torch.nn.Module, + input_residualMLP: torch.FloatTensor, + config_fp32: dict, +): + """ + Test if a default config w/ MX qmodes trigger setting mx_specs inside qmodel_prep + + Args: + model_residualMLP (torch.nn.Module): Single fp32 model. + input_residualMLP (torch.FloatTensor): Random 16x128 tensor. + config_fp32 (dict): Config w/ fp32 settings. + """ + config_fp32["qa_mode"] = "mx_fp8_e5m2" + config_fp32["qw_mode"] = "mx_fp8_e5m2" + + assert "mx_specs" not in config_fp32 + + qmodel_prep(model_residualMLP, input_residualMLP, config_fp32, use_dynamo=True) + + assert "mx_specs" in config_fp32