From 365551846166b0a93fa9ac10632b94a0a4b9aa0b Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 15 May 2025 14:47:33 +0000 Subject: [PATCH 1/7] adjust int8 triton to enable msb/lsb truncation Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 36 +++++++++++++++------ tests/triton_kernels/test_triton_mm.py | 18 ++++++++--- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 3cc15abf..a52113b3 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -101,6 +101,7 @@ def matmul_kernel( stride_cm, stride_cn, chunk_trun_bits, + max_acc_bits, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -212,6 +213,7 @@ def imatmul_kernel( stride_cm, stride_cn, chunk_trun_bits, + max_acc_bits, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -220,8 +222,8 @@ def imatmul_kernel( GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr, ): - """Kernel for computing the INT matmul C = A x B that include LSB truncation. A and B should be - INT8, C should be INT32. (Pretty much the same code as float version.) + """Kernel for computing the INT matmul D = A x B + C that include LSB truncation and MSB + clamping. A and B should be INT8, C/D should be INT32. (similar to the float version.) A has shape (M, K), B has shape (K, N) and C has shape (M, N) Args: chunk_trun_bits (int): number of LSBs to truncate/round. @@ -238,14 +240,20 @@ def imatmul_kernel( offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) - ## ------ prepare LSB rounding/truncation masks ------- + # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0) + ## ------ prepare MSB/LSB rounding/truncation masks ------- round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 - # msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB + acc_min = -(1 << (max_acc_bits - 1)) + acc_max = -acc_min - 1 ## --------------------------------------------------------- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): @@ -256,7 +264,11 @@ def imatmul_kernel( else: accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee") - ## ------ add chunky LSB rounding/masking -------- + ## ------ MSB truncation by clamp, chunky LSB truncation by rounding/masking -------- + if max_acc_bits < 32: + accumulator_inner = tl.maximum( + tl.minimum(accumulator_inner, acc_max), acc_min + ) if chunk_trun_bits != 0: accumulator_inner = (accumulator_inner + round_bit) >> chunk_trun_bits accumulator_inner = accumulator_inner << chunk_trun_bits @@ -275,8 +287,6 @@ def imatmul_kernel( offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) @@ -300,6 +310,7 @@ def matmul_kernel_DABC( stride_cm, stride_cn, chunk_trun_bits, + max_acc_bits, truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -421,6 +432,7 @@ def tl_matmul_chunk_truncate( activation="", chunk_trun_bits=0, chunk_size=16, + max_acc_bits=32, truncate_then_accumulate=True, cast_output_to_input_dtype=None, ): @@ -434,6 +446,9 @@ def tl_matmul_chunk_truncate( activation (str, optional): activation func to be fused, see relu example. chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded. chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16. + max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will + clamp each chunk of a*b to [-2**23-1, 2**23]. + (assuming no inf when overflow) truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise c = truncate(a*b+c) cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually @@ -472,9 +487,9 @@ def isPowerofTwo(x): # because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could # insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged. - # Do not support INT8 for now. if chunk_size == 8 and a.dtype in [ torch.float8_e4m3fn, + torch.int8, torch.float16, torch.bfloat16, ]: @@ -515,7 +530,7 @@ def isPowerofTwo(x): c_org_dtype = c.dtype c = c.to(acc_dtype) assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B." - assert acc_dtype == torch.float32, "INT truncation is not yet supported." + # assert acc_dtype == torch.float32, "INT truncation is not yet supported." # 1D launch kernel where each block gets its own program. def grid(META): @@ -556,6 +571,7 @@ def grid(META): c.stride(0), c.stride(1), chunk_trun_bits=chunk_trun_bits, + max_acc_bits=max_acc_bits, truncate_then_accumulate=truncate_then_accumulate, ACTIVATION=activation, **kernel_config, # if using auto-tune, comment this line out. diff --git a/tests/triton_kernels/test_triton_mm.py b/tests/triton_kernels/test_triton_mm.py index cf86e553..086ea89c 100644 --- a/tests/triton_kernels/test_triton_mm.py +++ b/tests/triton_kernels/test_triton_mm.py @@ -94,13 +94,23 @@ def test_triton_matmul_int8(mkn): torch_output = torch.matmul(a.to(torch.float), b.to(torch.float)) # cast tl_matmul results to float because torch.norm only supports float tl_output_no_trun = tl_matmul(a, b).to(torch.float) + # check LSB truncation effect tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float) + # check MSB truncation effect + # max(1 int8 * 1 int8) ~ 2^17 -> each chunk acc 32 elem, possible max ~ 2^22 + # -> truncate to 18b -> should see large err than LSB-only case + tl_output_trun_18b8b = tl_matmul(a, b, max_acc_bits=18, chunk_trun_bits=8).to( + torch.float + ) - diff_no_trun = torch_output - tl_output_no_trun - diff_trun_8b = torch_output - tl_output_trun_8b + ref = torch.norm(torch_output) + rel_err_no_trun = torch.norm(torch_output - tl_output_no_trun) / ref + rel_err_trun_8b = torch.norm(torch_output - tl_output_trun_8b) / ref + rel_err_trun_18b8b = torch.norm(torch_output - tl_output_trun_18b8b) / ref - assert torch.norm(diff_no_trun) / torch.norm(torch_output) < 1e-5 - assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-2 + assert rel_err_no_trun < 1e-5 + assert rel_err_trun_8b < 1e-2 + assert rel_err_trun_18b8b < 1e-2 @pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)]) From 8ffad1bc530652193753acd7664098268ff84526 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 15 May 2025 15:22:40 +0000 Subject: [PATCH 2/7] minor adjustment Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 4 ++-- tests/triton_kernels/test_triton_mm.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index a52113b3..8ecadd06 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -264,7 +264,8 @@ def imatmul_kernel( else: accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee") - ## ------ MSB truncation by clamp, chunky LSB truncation by rounding/masking -------- + ## ------ INT MSB truncation is simulated by clamping, + # "special" INT LSB truncation by right and left shift -------- if max_acc_bits < 32: accumulator_inner = tl.maximum( tl.minimum(accumulator_inner, acc_max), acc_min @@ -530,7 +531,6 @@ def isPowerofTwo(x): c_org_dtype = c.dtype c = c.to(acc_dtype) assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B." - # assert acc_dtype == torch.float32, "INT truncation is not yet supported." # 1D launch kernel where each block gets its own program. def grid(META): diff --git a/tests/triton_kernels/test_triton_mm.py b/tests/triton_kernels/test_triton_mm.py index 086ea89c..07328888 100644 --- a/tests/triton_kernels/test_triton_mm.py +++ b/tests/triton_kernels/test_triton_mm.py @@ -94,23 +94,29 @@ def test_triton_matmul_int8(mkn): torch_output = torch.matmul(a.to(torch.float), b.to(torch.float)) # cast tl_matmul results to float because torch.norm only supports float tl_output_no_trun = tl_matmul(a, b).to(torch.float) - # check LSB truncation effect + # check LSB truncation effect (underflow) tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float) - # check MSB truncation effect - # max(1 int8 * 1 int8) ~ 2^17 -> each chunk acc 32 elem, possible max ~ 2^22 - # -> truncate to 18b -> should see large err than LSB-only case + # check MSB truncation effect (overflow) + # max(1 int8 * 1 int8) ~ 2^14 -> each chunk acc 32 elem only, achievable max ~ 2^19 + # -> truncate to 18b -> should see slightly large err than LSB-only case tl_output_trun_18b8b = tl_matmul(a, b, max_acc_bits=18, chunk_trun_bits=8).to( torch.float ) + # use larger chunk size to accumulate more elem, MSB truncation (overflow) issue should worsen + tl_output_trun_18b8b_128 = tl_matmul( + a, b, max_acc_bits=18, chunk_trun_bits=8, chunk_size=min(128, k) + ).to(torch.float) ref = torch.norm(torch_output) rel_err_no_trun = torch.norm(torch_output - tl_output_no_trun) / ref rel_err_trun_8b = torch.norm(torch_output - tl_output_trun_8b) / ref rel_err_trun_18b8b = torch.norm(torch_output - tl_output_trun_18b8b) / ref + rel_err_trun_18b8b_128 = torch.norm(torch_output - tl_output_trun_18b8b_128) / ref assert rel_err_no_trun < 1e-5 assert rel_err_trun_8b < 1e-2 assert rel_err_trun_18b8b < 1e-2 + assert rel_err_trun_18b8b_128 >= rel_err_trun_18b8b @pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)]) From 8b570a744b04eb8415885100abd154f7c3911dad Mon Sep 17 00:00:00 2001 From: cliu-us Date: Thu, 15 May 2025 15:32:29 +0000 Subject: [PATCH 3/7] linting Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/triton_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 8ecadd06..24fbde89 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -101,7 +101,7 @@ def matmul_kernel( stride_cm, stride_cn, chunk_trun_bits, - max_acc_bits, + max_acc_bits, # pylint: disable=unused-argument truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -311,7 +311,7 @@ def matmul_kernel_DABC( stride_cm, stride_cn, chunk_trun_bits, - max_acc_bits, + max_acc_bits, # pylint: disable=unused-argument truncate_then_accumulate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, From 8c7a4e83f6de5c7b2cbd81c44bcd8b45cee16558 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 16 May 2025 04:06:50 +0000 Subject: [PATCH 4/7] add dynamic act quantizer option (something like pertokenmax) for QLinearINT8Deploy Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/utils.py | 17 +++--- fms_mo/modules/linear.py | 88 ++++++++++++++++++------------ 2 files changed, 62 insertions(+), 43 deletions(-) diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index b3c60c4d..4560adab 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -633,14 +633,17 @@ def exv2_i4f16_fxinputs_abstract( def imatmul_ops_reg( - useCUTLASS=True, mm_func=torch.matmul, AB_dtype=torch.float, D_dtype=torch.float + useINTkernel="triton", + mm_func=torch.matmul, + AB_dtype=torch.float, + D_dtype=torch.float, ): """This function will register a dummy Q_imatmul Op for better "graph representation". Args: - useCUTLASS: bool. choose to use a) real INT matmul using cutlass kernel or b) "simulated" - imatmul using torch.matmul. + useINTkernel: str|bool. ["cutlass", "triton", False]. choose to use a) real INT matmul, e.g. + cutlass or triton kernel or b) "simulated" imatmul using torch.matmul. For b), could use D_dtype to select fp16 or fp32 accumulation - mm_func: matmul func to be used when useCUTLASS is True, should be a real callable kernel + mm_func: matmul func to be used when useINTkernel is True, should be a real callable kernel from cutlass, but for debug purpose, could use torch.matmul as well. AB_dtype: datatype for input tensors D_dtype: datatype for accumulation and output tensor @@ -697,10 +700,10 @@ def imatmul(m1, m2): tar_shape = tuple(m1.shape[:-1]) + (m2.shape[1],) m1 = m1.view(re_shape) - if useCUTLASS: + if useINTkernel: assert ( m1.dtype == torch.int8 and m2.dtype == torch.int8 - ), "When using cutlass int matmul, inputs must be 2D INT8" + ), "When using int matmul, inputs must be 2D and INT8." return mm_func(m1, m2).reshape(tar_shape) outf32_or_f16 = torch.empty( @@ -759,7 +762,7 @@ def q_iaddmm_dq(bias, m1, m2, scale_i, zp_i, scale_w): assert m2.dtype == torch.int8, f"weight tensor is of incorrect dtype {m2.dtype}" m1 = torch.clamp((m1 / scale_i + zp_i - 128).round(), -128, 127).to(torch.int8) - if useCUTLASS: + if useINTkernel: mm_i32 = mm_func(m1, m2) else: outf32_or_f16 = torch.empty( diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 26b383c6..42cf7b52 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -742,7 +742,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): for a_or_w in ["num_bits_feature", "num_bits_weight"] ), "Please check nbits setting!" - target_device = kwargs.get( + tar_dev = kwargs.get( "target_device", kwargs.get("device", next(fms_mo_qlinear.parameters()).device), ) @@ -751,7 +751,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): fms_mo_qlinear.in_features, fms_mo_qlinear.out_features, bias=fms_mo_qlinear.bias is not None, - device=target_device, + device=tar_dev, ) # Make sure to register an Op for integer matmul, could be real INT matmul or emulation qcfg = getattr(fms_mo_qlinear, "qcfg", {}) @@ -759,6 +759,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): "use_int_kernel", qcfg.get("use_int_kernel", "cutlass") ) qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False) + qlin_int.useDynMaxQfunc = kwargs.get("use_dynamic_max_act_Qfunc", False) qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32) qlin_int.accminmax = ( -(1 << (qlin_int.max_acc_bits - 1)), @@ -773,34 +774,48 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): with torch.no_grad(): Qa = fms_mo_qlinear.quantize_feature Qw = fms_mo_qlinear.quantize_weight - a_cv, a_cvn = Qa.clip_val.item(), Qa.clip_valn.item() w_cv = Qw.clip_val.item() + if qlin_int.useDynMaxQfunc in [-1, -2]: # [-1, -2] indicates reduce_dim + # dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc + Qa.register_buffer("clip_val", torch.tensor(8.0, device=tar_dev)) + Qa.register_buffer("clip_valn", torch.tensor(-8.0, device=tar_dev)) + a_cv, a_cvn = Qa.clip_val.item(), Qa.clip_valn.item() + # Store original cv_a and cv_w (in python floats, not tensors), and sq scales + # for later use (probably not necessary) + qlin_int.cvs = [a_cv, a_cvn, w_cv] # NOTE: Keep w transposed to prevent confusion Qw.dequantize = False - w_int8 = Qw( - fms_mo_qlinear.weight.float() - ) # Qw.clipval should have been updated after this + # trigger Qw.clipval re-calc for SAWB (if needed) + w_int8 = Qw(fms_mo_qlinear.weight.float()) qlin_int.weight = nn.Parameter( w_int8.to(torch.int8), requires_grad=False ) # NOTE: may need INT W stored as FP in some cases - if qlin_int.usePTnativeQfunc: + if qlin_int.useDynMaxQfunc in [-1, -2]: + input_scale = torch.tensor(1.0, device=tar_dev) + input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev) + w_scale = torch.tensor( + [w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev + ) + elif qlin_int.usePTnativeQfunc: input_scale = torch.tensor( - [(a_cv - a_cvn) / (2**qlin_int.nbits_a - 1)], device=target_device + [(a_cv - a_cvn) / (2**qlin_int.nbits_a - 1)], device=tar_dev ) input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int) - w_scale = torch.tensor([w_cv * 2 / (2**qlin_int.nbits_w - 2)]) + w_scale = torch.tensor( + [w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev + ) else: # fms_mo formula is a bit different from conventional PT formula quant_scale = (2**qlin_int.nbits_a - 1) / torch.tensor( - [a_cv - a_cvn], device=target_device + [a_cv - a_cvn], device=tar_dev ) quant_stepsize = 1.0 / quant_scale quant_zero_point = torch.round(a_cvn * quant_scale) input_scale = quant_stepsize input_zero_point = -quant_zero_point quant_w_scale = (2**qlin_int.nbits_a - 2) / torch.tensor( - [w_cv * 2], device=target_device + [w_cv * 2], device=tar_dev ) w_scale = 1.0 / quant_w_scale qlin_int.register_buffer("quant_scale", quant_scale) @@ -812,9 +827,6 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): qlin_int.register_buffer("input_zp", input_zero_point) qlin_int.register_buffer("w_scale", w_scale) qlin_int.register_buffer("w_zp", w_zp) - # Store original cv_a and cv_w (in python floats, not tensors), and sq scales - # for later verification - qlin_int.cvs = [Qa.clip_val.item(), Qa.clip_valn.item(), Qw.clip_val.item()] corr_term = ( (input_zero_point - 128) @@ -836,17 +848,14 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): qlin_int.register_buffer("bias", -corr_term.to(fms_mo_w_dtype)) qlin_int.org_model_has_bias = False - qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach()) - qlin_int.register_buffer( - "Qa_clip_valn", Qa.clip_valn.detach() - ) # TODO: case for PACT? - qlin_int.register_buffer( - "Qw_clip_val", Qw.clip_val.detach() - ) # asym W quantizer may have clipvaln + # redundant variables to be cleaned up + # qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach()) + # qlin_int.register_buffer("Qa_clip_valn", Qa.clip_valn.detach()) + # qlin_int.register_buffer("Qw_clip_val", Qw.clip_val.detach()) qlin_int.set_matmul_op() - return qlin_int.to(target_device) + return qlin_int.to(tar_dev) @classmethod def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs): @@ -988,25 +997,15 @@ def qa_raw_qfunc(self, x): """ Quantizes the input tensor x to 8-bit integer values using raw formula, slower if not torch.compiled - - Args: - x (Tensor): Input tensor to be quantized. - - Returns: - Tensor: Quantized tensor with values in the range [-128, 127]. """ x = torch.clamp((x / self.input_scale + self.input_zp - 128).round(), -128, 127) return x.to(torch.int8) def qa_fmo_mo_qfunc(self, x): """ - Quantizes the input tensor x to 8-bit integer values. - - Args: - x (Tensor): Input tensor to be quantized. - - Returns: - Tensor: Quantized tensor with values in the range [-128, 127]. + Quantizes the input tensor x to 8-bit integer values. Note that old fms-mo formula clamps + before rounds, as opposed to typical torch formula that rounds before clamps. + (See qa_raw_qfunc() above.) """ x = ( torch.round( @@ -1017,6 +1016,21 @@ def qa_fmo_mo_qfunc(self, x): ) return x.to(torch.int8) + def qa_dynamic_max_qfunc(self, x): + """ + Symmetric dynamic quantizer, same as QDynMax, which allows per-token or per-channel. + This quantizer will not use self.input_scale but instead will update it every time. + NOTE + 1. self.input_scale.shape should be (x.shape[-2], ) if reduce_dim == -1 and (, x.shape[-1]) + for reduce_dim == -2. + 2. input_scale should be be broadcasted correctly together with W_scale (e.g. if per-Ch) at + final output step, i.e. imm_out*(a_scale*w_scale)*... + """ + amax = x.abs().max(dim=self.useDynMaxQfunc, keepdim=True)[0] + levels = 2 ** (self.nbits_a - 1) - 1 + self.input_scale = amax.clamp(min=1e-5).div(levels) + return torch.round(x / self.input_scale).to(torch.int8) + def iaddmm_int(self, bias, m1, m2): """ Performs integer matrix multiplication with optional addition of a bias term. @@ -1034,7 +1048,9 @@ def iaddmm_int(self, bias, m1, m2): The result of the integer matrix multiplication with the bias added. """ - if self.usePTnativeQfunc: + if self.useDynMaxQfunc in [-1, -2]: + m1 = self.qa_dynamic_max_qfunc(m1) + elif self.usePTnativeQfunc: m1 = self.qa_raw_qfunc(m1) else: m1 = self.qa_fmo_mo_qfunc(m1) From 8e3f16afbbc1d3fd5ea0c0fc632f4067f17c9c62 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Sun, 18 May 2025 18:45:58 +0000 Subject: [PATCH 5/7] add dynamic symmetric activation option to QLinearINT8Deploy Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/utils.py | 2 +- fms_mo/modules/linear.py | 65 ++++++++++++++++-------------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 4560adab..47ef7956 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -700,7 +700,7 @@ def imatmul(m1, m2): tar_shape = tuple(m1.shape[:-1]) + (m2.shape[1],) m1 = m1.view(re_shape) - if useINTkernel: + if useINTkernel in ["triton", "cutlass"]: assert ( m1.dtype == torch.int8 and m2.dtype == torch.int8 ), "When using int matmul, inputs must be 2D and INT8." diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 42cf7b52..1bb7e305 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -760,6 +760,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): ) qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False) qlin_int.useDynMaxQfunc = kwargs.get("use_dynamic_max_act_Qfunc", False) + qlin_int.useSymAct = "sym" in fms_mo_qlinear.qa_mode qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32) qlin_int.accminmax = ( -(1 << (qlin_int.max_acc_bits - 1)), @@ -770,6 +771,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): qlin_int.acc_dtype = torch.float16 qlin_int.nbits_a = fms_mo_qlinear.num_bits_feature # only support INT8 for now qlin_int.nbits_w = fms_mo_qlinear.num_bits_weight + w_levels = 2**qlin_int.nbits_w - 2 + a_levels = 2**qlin_int.nbits_a - 1 - qlin_int.useSymAct with torch.no_grad(): Qa = fms_mo_qlinear.quantize_feature @@ -794,29 +797,19 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): if qlin_int.useDynMaxQfunc in [-1, -2]: input_scale = torch.tensor(1.0, device=tar_dev) input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev) - w_scale = torch.tensor( - [w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev - ) + w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev) elif qlin_int.usePTnativeQfunc: - input_scale = torch.tensor( - [(a_cv - a_cvn) / (2**qlin_int.nbits_a - 1)], device=tar_dev - ) + input_scale = torch.tensor([(a_cv - a_cvn) / a_levels], device=tar_dev) input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int) - w_scale = torch.tensor( - [w_cv * 2 / (2**qlin_int.nbits_w - 2)], device=tar_dev - ) + w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev) else: # fms_mo formula is a bit different from conventional PT formula - quant_scale = (2**qlin_int.nbits_a - 1) / torch.tensor( - [a_cv - a_cvn], device=tar_dev - ) + quant_scale = a_levels / torch.tensor([a_cv - a_cvn], device=tar_dev) quant_stepsize = 1.0 / quant_scale quant_zero_point = torch.round(a_cvn * quant_scale) input_scale = quant_stepsize input_zero_point = -quant_zero_point - quant_w_scale = (2**qlin_int.nbits_a - 2) / torch.tensor( - [w_cv * 2], device=tar_dev - ) + quant_w_scale = w_levels / torch.tensor([w_cv * 2], device=tar_dev) w_scale = 1.0 / quant_w_scale qlin_int.register_buffer("quant_scale", quant_scale) qlin_int.register_buffer("quant_stepsize", quant_stepsize) @@ -829,7 +822,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): qlin_int.register_buffer("w_zp", w_zp) corr_term = ( - (input_zero_point - 128) + (input_zero_point - 128 + qlin_int.useSymAct) * (w_int8.sum(dim=1)) * w_scale.float() * input_scale.float() @@ -975,7 +968,7 @@ def qa_pt_qfunc_wrapped(self, x): Tensor: Quantized tensor with values in the range [-128, 127]. """ return torch.ops.fms_mo.q_per_t_sym( - x.float(), self.input_scale, self.input_zp - 128 + x.float(), self.input_scale, self.input_zp - 128 + self.useSymAct ) def qa_pt_quant_func(self, x): @@ -990,7 +983,10 @@ def qa_pt_quant_func(self, x): Tensor: Quantized tensor with values in the range [-128, 127]. """ return torch.quantize_per_tensor( - x.float(), self.input_scale, self.input_zp - 128, torch.qint8 + x.float(), + self.input_scale, + self.input_zp - 128 + self.useSymAct, + torch.qint8, ).int_repr() def qa_raw_qfunc(self, x): @@ -998,7 +994,11 @@ def qa_raw_qfunc(self, x): Quantizes the input tensor x to 8-bit integer values using raw formula, slower if not torch.compiled """ - x = torch.clamp((x / self.input_scale + self.input_zp - 128).round(), -128, 127) + x = torch.clamp( + (x / self.input_scale + self.input_zp - 128 + self.useSymAct).round(), + -128, + 127, + ) return x.to(torch.int8) def qa_fmo_mo_qfunc(self, x): @@ -1007,13 +1007,10 @@ def qa_fmo_mo_qfunc(self, x): before rounds, as opposed to typical torch formula that rounds before clamps. (See qa_raw_qfunc() above.) """ - x = ( - torch.round( - x.clamp(self.cvs[1], self.cvs[0]) / self.quant_stepsize - - self.quant_zero_point - ) - - 128 - ) + x = torch.round( + x.clamp(self.cvs[1], self.cvs[0]) / self.quant_stepsize + - self.quant_zero_point + ) - (128 - self.useSymAct) return x.to(torch.int8) def qa_dynamic_max_qfunc(self, x): @@ -1060,7 +1057,9 @@ def iaddmm_int(self, bias, m1, m2): Nchunk = len(idx) idx.append(m1.shape[1]) accumulator = torch.zeros( - (m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device + (m1.shape[0], m2.shape[1]), + dtype=torch.int, + device=m1.device, # cast float16 if needed ) trun_scale = 1 if self.truncate_lsb > 0: @@ -1080,7 +1079,7 @@ def iaddmm_int(self, bias, m1, m2): # could cast to smaller data type to further simulate HW behavior, for example, # if HW truncates 8b from both sides of i32 accumulator, the remaining data can # be cast to i16 to be more realistic. pay attention to overflow handling - accumulator += imm_out.to(torch.float16) + accumulator += imm_out # .to(torch.float16) if needed return ( accumulator @@ -1107,8 +1106,14 @@ def iaddmm_FP(self, bias, m1, m2): Returns: Tensor: the result of the matrix multiplication with addition of bias """ - m2 = m2.to(m1.dtype) - return torch.addmm(bias, m1, m2) + if self.useDynMaxQfunc in [-1, -2]: + m1 = self.qa_dynamic_max_qfunc(m1) + elif self.usePTnativeQfunc: + m1 = self.qa_raw_qfunc(m1) + else: + m1 = self.qa_fmo_mo_qfunc(m1) + + return torch.matmul(m1 * self.input_scale, m2 * self.w_scale) + bias def set_matmul_op(self): """ From 362d52128b0536b0bb9800961d6d280347525fd6 Mon Sep 17 00:00:00 2001 From: cliu-us Date: Wed, 21 May 2025 14:35:27 +0000 Subject: [PATCH 6/7] fix dq issue with llama3-70b on single gpu Signed-off-by: cliu-us --- fms_mo/dq.py | 6 +- fms_mo/prep.py | 19 +++++- fms_mo/quant/ptq.py | 129 +++++++++++++++++++++++++++++++------ fms_mo/utils/eval_utils.py | 31 +++++---- 4 files changed, 148 insertions(+), 37 deletions(-) diff --git a/fms_mo/dq.py b/fms_mo/dq.py index 15e80a1a..d56e5a90 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -38,7 +38,7 @@ from fms_mo import qconfig_init, qmodel_prep from fms_mo.fx.utils import model_size_Wb from fms_mo.quant.ptq import ( - calibration_llm_1GPU, + calibration_llm_1GPU_v2, dq_llm, get_act_scales, get_act_scales_1gpu, @@ -224,9 +224,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): if qcfg["qmodel_calibration_new"] > 0: logger.info("Starting to calibrate activation clip_val") if qcfg["large_model"]: - calibration_llm_1GPU(qcfg, model, dq_dataloader) + calibration_llm_1GPU_v2(qcfg, model, dq_dataloader) else: - model.to("cuda:0") + model.to("cuda") pbar = tqdm( dq_dataloader, desc=" calibration after applying smoothq scale and before inference", diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 2fa55756..cc4ce7bc 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -177,7 +177,10 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): is mappable, create a Qmodule and return, otherwise, return the original module. In the future, Qmodules need to have a .from_torch() or .from_nn() classmethod, and then this function will be greatly simplified. - NOTE: This func will check qskip_layer_name before creating the Qmodule + NOTE: + 1. This func will check qskip_layer_name before creating the Qmodule + 2. Qmodule will be created on "meta device" as a placeholder, which will skip params init and + mem alloc, as weights and bias will be reassigned to module.weight/.bias right after Args: module (nn.Module): the module which Qmodule will be based on @@ -216,7 +219,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): if hasattr(module, "__constants__"): base_params = {k: getattr(module, k) for k in module.__constants__} base_params["bias"] = module.bias is not None - base_params["device"] = next(module.parameters()).device # usually cuda + base_params["device"] = "meta" module_output = module @@ -499,8 +502,17 @@ def q_any_net_5(model: nn.Module, qcfg: dict, verbose: bool = False): """ # Third Party from torch.ao.quantization.utils import _parent_name + from tqdm import tqdm + + total_modules = len(list(model.named_modules())) + pbar = tqdm( + model.named_modules(), + total=total_modules, + desc="Mapping modules to target Qmodules.", + ) + for name, module in pbar: + pbar.set_description(f"processing {name}") - for name, module in model.named_modules(): parent_module_name, curr_mod_name = _parent_name(name) new_module = make_quant_module(module, name, qcfg) parent_module = model.get_submodule(parent_module_name) @@ -525,6 +537,7 @@ def q_any_net_5(model: nn.Module, qcfg: dict, verbose: bool = False): if verbose: logger.info(f"Swap ({name}) from {type(module)} to {type(new_module)}") + pbar.close() return model diff --git a/fms_mo/quant/ptq.py b/fms_mo/quant/ptq.py index 482058e6..dca1a9ef 100644 --- a/fms_mo/quant/ptq.py +++ b/fms_mo/quant/ptq.py @@ -419,15 +419,25 @@ class PTQHookRecInOutLMv2(nn.Module): leave the special handling, e.g. reshape/cat/shuffling...etc, for later """ - def __init__(self, qcfg, name=None, cls2rec=(nn.Conv2d,), recInOnly=False): + def __init__( + self, + qcfg, + name=None, + cls2rec=(nn.Conv2d, nn.Linear), + recInOnly=False, + stop_after_rec=False, + cache_dev="cuda", + ): super().__init__() self.name = name self.qcfg = qcfg self.cls2rec = cls2rec self.rec_input_only = recInOnly self.num_valid_input = -1 + self.stop_after_rec = stop_after_rec + self.cache_dev = cache_dev - def __call__(self, mod, inputs, output): + def __call__(self, mod, inputs, *args, **_kwargs): # make sure this module/block's ptqmode is not 'q_out' submods = [m for m in mod.modules() if isinstance(m, self.cls2rec)] if any(sm.ptqmode == "q_out" for sm in submods): @@ -448,7 +458,7 @@ def __call__(self, mod, inputs, output): # check available GPU memory, cache on GPU if possible: GPUmem_available, _GPUmem_total = torch.cuda.mem_get_info() # 1 block for SQUAD/BERT 500 batches*12/batch = ~10G - if GPUmem_available / 1e9 > 20: + if self.cache_dev == "cuda" and GPUmem_available / 1e9 > 20: cache_device = "cuda" else: cache_device = "cpu" @@ -461,13 +471,15 @@ def __call__(self, mod, inputs, output): ) # output could be a tuple of a single tensor or simply a tensor ? - assert isinstance(output, (torch.Tensor, tuple)) - if not self.rec_input_only: + if not self.rec_input_only and "output" in args: + output = args["output"] + assert isinstance(output, (torch.Tensor, tuple)) self.qcfg["cached_output"].append( output[0].detach().to(cache_device) if isinstance(output, tuple) else output.detach().to(cache_device) ) + assert not self.stop_after_rec # this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module @@ -2021,7 +2033,7 @@ def get_blocks(model, model_type=None): "llama": ( "model.layers", "model.embed_tokens", - None, + "model.rotary_emb", None, "model.norm", "lm_head", @@ -2111,7 +2123,9 @@ def cache_block0_inputs( model, dloader, qcfg, blocks, emb=None, emb_pos=None, emb_ln=None, dev="cpu" ): """ - To cache the input to the first transformer block. + To cache the input to the first transformer block. Basically a "forward_pre_hook" + NOTE, change caching from tensor to list to allow varying input length, slightly + increase memeory due to mask and alibi. """ emb = emb.to(dev) if emb_pos is not None: @@ -2119,12 +2133,6 @@ def cache_block0_inputs( if emb_ln is not None: emb_ln = emb_ln.to(dev) blocks[0] = blocks[0].to(dev) - # NOTE, change caching from tensor to list to allow varying input length, slightly - # increase memeory due to mask and alibi. - qcfg["cached_block0_input"] = [] - qcfg["cache_id"] = 0 - qcfg["cached_mask"] = [] - qcfg["cached_alibi"] = [] # move block0 to GPU and excuting fwd() until finish block0 if "fms" in qcfg["model_type"]: qcfg["kw_to_cache"] = { @@ -2142,9 +2150,16 @@ def cache_block0_inputs( } blocks[0] = RunModule(blocks[0], qcfg) + # clear up old cache, if exists. + qcfg["cached_block0_input"] = [] + qcfg["cache_id"] = 0 + for kw in qcfg["kw_to_cache"].values(): + if kw in qcfg: + qcfg[kw] = [] + if isinstance(dloader, torch.utils.data.DataLoader): pbar = tqdm( - dloader, desc="Phase 0: PTQ caching block0 input", total=qcfg["ptq_nbatch"] + dloader, desc="Phase 0: Caching block0 inputs", total=qcfg["ptq_nbatch"] ) for data_mb, _ in zip(pbar, range(qcfg["ptq_nbatch"])): try: @@ -2310,9 +2325,8 @@ def freeze_layers(m, layer_list): @torch.no_grad() def calibration_llm_1GPU(qcfg, model, dloader): - """ - calibration for large models that can not fit the whole model on 1 GPU. - """ + """Calibration for large models that can not fit on 1 GPU.""" + model.train() dev = "cuda" qcfg["batch_size"] = 1 @@ -2365,6 +2379,83 @@ def calibration_llm_1GPU(qcfg, model, dloader): logger.info("All blocks are calibrated") +@torch.no_grad() +def calibration_llm_1GPU_v2(qcfg, model, dloader): + """ + Improved version of Calibration for large language models that can not fit on 1 GPU with new + (built-in) calibration mechanism. + NOTE: + 1. Calibration only, NO update to weights! + 2. Rely on a alternative "pre fwd hook" to cache all possible inputs. + 3. As calibration usually cache a small number of data only, no need to move each batch back and + forth between GPU and CPU. + """ + + model.train() + dev = "cuda" + qcfg["batch_size"] = 1 + qcfg["dtype"] = next(iter(model.parameters())).dtype + qcfg["n_samples"] = min(qcfg["ptq_nbatch"], qcfg["qmodel_calibration_new"]) + + assert "model_type" in qcfg, "Unknown model type. please check before proceed." + assert isinstance( + dloader, torch.utils.data.DataLoader + ), "Please provide a valid dataloader." + # --- Phase 0 cache the inputs of the block0--- + model.config.use_cache = False + blocks, emb, emb_pos, emb_ln, _, _ = get_blocks(model, qcfg["model_type"]) + + cache_block0_inputs( + model, + dloader, + qcfg, + blocks, + emb=emb, + emb_pos=emb_pos, + emb_ln=emb_ln, + dev="cpu", + ) + logger.info("Done, caching inputs to block0 for calibration") + + # --- Phase 1 --- compute blocks and last linear layer + pbar = tqdm( + blocks, desc="Phase 1: Calibration for each block", position=0, leave=True + ) + qcfg["cached_input"] = [ + inp.clone().detach().to(dev) for inp in qcfg["cached_block0_input"] + ] + kw_to_use = { + kw_org: kw_new + for kw_org, kw_new in qcfg["kw_to_cache"].items() + if len(qcfg[kw_new]) == len(qcfg["cached_input"]) + } + for _num_block, m in enumerate(pbar): + m.to(dev) + for i in tqdm( + range(qcfg["n_samples"]), desc="number of samples", position=1, leave=False + ): + if qcfg["cached_alibi"]: + cached_inp_prev_lay = qcfg["cached_input"][i].unsqueeze(0).to(dev) + data_mb = { + "attention_mask": qcfg["cached_mask"][i].unsqueeze(0).to(dev), + "alibi": qcfg["cached_alibi"][i].unsqueeze(0).to(dev), + } + else: + cached_inp_prev_lay = qcfg["cached_input"][i] + data_mb = { + kw_org: move_to(qcfg[kw_new][i], dev) + for kw_org, kw_new in kw_to_use.items() + } + + with patch_torch_bmm(qcfg): + qcfg["cached_input"][i] = m(cached_inp_prev_lay, **data_mb)[0] + + m.cpu() + torch.cuda.empty_cache() + + logger.info("All blocks are calibrated") + + @torch.no_grad() def activation_stats(name, tensor, act_scales): # TODO if 'QBmm' in name: reshape the tensor. @@ -2498,8 +2589,8 @@ def get_act_scales_1gpu(model, dloader, qcfg): assert "model_type" in qcfg, "Unknown model type. please check before proceed." assert ( - qcfg["loader_len"] == qcfg["ptq_nbatch"] - ), "set batch_size=1 and PTQ samples== Nbatches" + qcfg["loader_len"] >= qcfg["ptq_nbatch"] + ), "Please make sure dataloader has enough data needed for PTQ (ie. check qcfg['ptq_nbatch'])." # --- Phase 0 cache the inputs of the block0--- blocks, emb, emb_pos, emb_ln, _, _ = get_blocks(model, qcfg["model_type"]) cache_block0_inputs( diff --git a/fms_mo/utils/eval_utils.py b/fms_mo/utils/eval_utils.py index 498d8472..6c309507 100644 --- a/fms_mo/utils/eval_utils.py +++ b/fms_mo/utils/eval_utils.py @@ -26,7 +26,7 @@ # Local from fms_mo.quant.ptq import cache_block0_inputs, get_blocks -from fms_mo.utils.utils import patch_torch_bmm +from fms_mo.utils.utils import move_to, patch_torch_bmm logger = logging.getLogger(__name__) @@ -35,11 +35,13 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): # pylint: disable=unused-argument """ Evaluate causal LLM with 1GPU, return perplexity - Note: currently taking test_dataset as dict (instead of dataloader) - Used for models that cannot fit into a 1 GPU. + Note: + 1. currently taking test_dataset as dict (instead of dataloader) + 2. Used for models that cannot fit into a 1 GPU. Will need to move modules back and forth. + 3. Keep hid_state on device to reduce uncessary data transfer. """ model.eval() - dev = "cuda:0" # cuda:0 is used for PTQ + dev = "cuda" qcfg["batch_size"] = 1 # for dataloading, always use batch_size of 1 qcfg["dtype"] = next(iter(model.parameters())).dtype seq_len = qcfg["seq_len"] @@ -63,7 +65,14 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): # # Phase 1: compute blocks and last linear layer pbar = tqdm(blocks, desc="evaluation: compute blocks") - qcfg["cached_input"] = [inp.clone().detach() for inp in qcfg["cached_block0_input"]] + qcfg["cached_input"] = [ + inp.clone().detach().to(dev) for inp in qcfg["cached_block0_input"] + ] + kw_to_use = { + kw_org: kw_new + for kw_org, kw_new in qcfg["kw_to_cache"].items() + if len(qcfg[kw_new]) == len(qcfg["cached_input"]) + } for block_id, m in enumerate(pbar): # pylint: disable=unused-variable m.to(dev) for i in range(qcfg["n_samples"]): @@ -74,16 +83,14 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): # "alibi": qcfg["cached_alibi"][i].unsqueeze(0).to(dev), } else: - cached_inp_prev_lay = qcfg["cached_input"][i].to(dev) + cached_inp_prev_lay = qcfg["cached_input"][i] data_mb = { - "attention_mask": qcfg["cached_mask"][i].to(dev) - if len(qcfg["cached_mask"]) > 0 - else None, - "position_ids": qcfg["position_ids"][i].to(dev), + kw_org: move_to(qcfg[kw_new][i], dev) + for kw_org, kw_new in kw_to_use.items() } - with torch.no_grad(), patch_torch_bmm(qcfg): - qcfg["cached_input"][i] = m(cached_inp_prev_lay, **data_mb)[0].cpu() + with patch_torch_bmm(qcfg): + qcfg["cached_input"][i] = m(cached_inp_prev_lay, **data_mb)[0] m.cpu() torch.cuda.empty_cache() From 553c7a6733ebbe8e51841b663a96ab17c034028b Mon Sep 17 00:00:00 2001 From: cliu-us Date: Fri, 23 May 2025 16:22:16 +0000 Subject: [PATCH 7/7] 1. temp enables Qmax.dequant=False, bmgroth will officially enable it later, 2. add util func to lower qmodel to triton kernel, 3. additional fix for dq, e.g. torch.load Signed-off-by: cliu-us --- fms_mo/custom_ext_kernels/utils.py | 58 ++++++++++++++++++++++++++++++ fms_mo/dq.py | 14 ++++---- fms_mo/modules/linear.py | 46 +++++++++++++++--------- fms_mo/quant/quantizers.py | 7 ++-- fms_mo/training_args.py | 3 ++ 5 files changed, 103 insertions(+), 25 deletions(-) diff --git a/fms_mo/custom_ext_kernels/utils.py b/fms_mo/custom_ext_kernels/utils.py index 47ef7956..76898443 100644 --- a/fms_mo/custom_ext_kernels/utils.py +++ b/fms_mo/custom_ext_kernels/utils.py @@ -859,6 +859,64 @@ def lower_qmodel_cutlass( return mod +def lower_qmodel_triton( + model: torch.nn.Module, + use_dyn_max_act=False, + max_acc_bits=32, + num_lsb_to_truncate=0, + chunk_size=32, +): + """ + Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else. + Triton kernel can be used to: + 1. test INT8 or FP8 HW performance (kernel is not optimized) + 2. simulate MSB/LSB truncation effect + + Args: + model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy + use_dyn_max_act: bool or int, can be False or -1 for per-token, or -2 for perCh. will use + dynamic max quantizer for activation if not False. + max_acc_bits: max bits for accumulator, typically FP32 for all FP matmuls and INT32 for all + INT matmuls. But some HW could use fewer bits to trade-off power + efficiency at the expense of higher chance of accumulation "overflow". + For example, an INT24 accumulator can only hold values ranged from -2^23 to + 2^23 -1, as opposed to typical range -2^31 to -2^31 -1. + num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is + s1e8m23, if we choose to truncate 13 mantissa bits from right most side, + i.e. LSB, the resulting number will be s1e8m10, which is TF32. + chunk_size: given a matmul of (m, k) @ (k, n), the inner product will be "accumulated" along + k-dim. Since the entire matrix will be partitioned into smaller tiles when being + computed, accumulator will only add a certain num of elements in one shot. This + "chunk size" in k-dim will affect the overflow/underflow of accumulator. + """ + # Third Party + from torch.ao.quantization.utils import _parent_name + + # Local + from fms_mo.modules.linear import QLinear, QLinearINT8Deploy + + for name, m in model.named_modules(): + if not isinstance(m, QLinear): + continue + parent_name, module_name = _parent_name(name) + parent_mod = model.get_submodule(parent_name) + qmod = getattr(parent_mod, module_name) + setattr( + parent_mod, + module_name, + QLinearINT8Deploy.from_fms_mo( + qmod, + use_int_kernel="triton", + use_dynamic_max_act_Qfunc=use_dyn_max_act, + max_acc_bits=max_acc_bits, + truncate_lsb=num_lsb_to_truncate, + chunk_size=chunk_size, + ), + ) + + logger.info(f"\nModel lowering with triton kernel is done.\n{model}") + + ### ------------------------------------------------------------- # GPTQ tensor packing functions for Exllama kernel ### ------------------------------------------------------------- diff --git a/fms_mo/dq.py b/fms_mo/dq.py index d56e5a90..34587fb3 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -172,7 +172,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["seq_len"] = block_size qcfg["model"] = model_args.model_name_or_path - qcfg["smoothq"] = True + qcfg["smoothq"] = fms_mo_args.smoothq_alpha != -1 qcfg["plotsvg"] = False calibration_dataset = load_from_disk(data_args.training_data_path) @@ -217,9 +217,10 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): save_fname="dq", ) logger.info(f"Quantized model {model}") - logger.info("Starting to apply smooth scale") - dq_llm(model, act_scales, qcfg) - logger.info("Finished applying smooth scale") + if qcfg["smoothq"]: + logger.info("Starting to apply smooth scale") + dq_llm(model, act_scales, qcfg) + logger.info("Finished applying smooth scale") logger.info("==" * 20) if qcfg["qmodel_calibration_new"] > 0: logger.info("Starting to calibrate activation clip_val") @@ -249,7 +250,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): test_dataset = load_from_disk(data_args.test_data_path) test_dataset = test_dataset.with_format("torch") elif len(pt_files) > 0: - test_dataset = torch.load(pt_files[0]) + test_dataset = torch.load(pt_files[0], weights_only=False) logger.info(f"Model for evaluation: {model}") if qcfg["large_model"]: @@ -258,7 +259,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): model.to(torch.device("cuda:0")) n_samples = int(test_dataset.input_ids.shape[1] / block_size) evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples) - ppl = evaluator.evaluate(model, block_size=block_size) + with patch_torch_bmm(qcfg): + ppl = evaluator.evaluate(model, block_size=block_size) logger.info(f"Model perplexity: {ppl}") logger.info("-" * 50) logger.info("Finished evaluation") diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 1bb7e305..7f14c17d 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -29,6 +29,7 @@ # Local from fms_mo.custom_ext_kernels.utils import pack_vectorized from fms_mo.quant.quantizers import ( + SAWB, HardPrune, Qbypass, Qdynamic, @@ -751,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): fms_mo_qlinear.in_features, fms_mo_qlinear.out_features, bias=fms_mo_qlinear.bias is not None, - device=tar_dev, + device="meta", # init on tar_dev is unnecessary ) # Make sure to register an Op for integer matmul, could be real INT matmul or emulation qcfg = getattr(fms_mo_qlinear, "qcfg", {}) @@ -777,31 +778,26 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): with torch.no_grad(): Qa = fms_mo_qlinear.quantize_feature Qw = fms_mo_qlinear.quantize_weight - w_cv = Qw.clip_val.item() + w_cv = Qw.clip_val if qlin_int.useDynMaxQfunc in [-1, -2]: # [-1, -2] indicates reduce_dim # dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc Qa.register_buffer("clip_val", torch.tensor(8.0, device=tar_dev)) Qa.register_buffer("clip_valn", torch.tensor(-8.0, device=tar_dev)) - a_cv, a_cvn = Qa.clip_val.item(), Qa.clip_valn.item() - # Store original cv_a and cv_w (in python floats, not tensors), and sq scales - # for later use (probably not necessary) - qlin_int.cvs = [a_cv, a_cvn, w_cv] - # NOTE: Keep w transposed to prevent confusion - Qw.dequantize = False - # trigger Qw.clipval re-calc for SAWB (if needed) - w_int8 = Qw(fms_mo_qlinear.weight.float()) - qlin_int.weight = nn.Parameter( - w_int8.to(torch.int8), requires_grad=False - ) # NOTE: may need INT W stored as FP in some cases + a_cv = Qa.clip_val + a_cvn = Qa.clip_valn + # Store original cv_a and cv_w in python floats (instead of tensors) will be more + # accurate, but not compatible for per-ch and per-token. + qlin_int.cvs = [a_cv, a_cvn, w_cv] # TODO remove the need of this. + # may need to trigger Qw.clipval re-calc for SAWB here, (if needed?) if qlin_int.useDynMaxQfunc in [-1, -2]: input_scale = torch.tensor(1.0, device=tar_dev) input_zero_point = torch.tensor(128, dtype=torch.int, device=tar_dev) - w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev) + w_scale = w_cv * 2 / w_levels elif qlin_int.usePTnativeQfunc: input_scale = torch.tensor([(a_cv - a_cvn) / a_levels], device=tar_dev) input_zero_point = torch.round(-a_cvn / input_scale).to(torch.int) - w_scale = torch.tensor([w_cv * 2 / w_levels], device=tar_dev) + w_scale = w_cv * 2 / w_levels else: # fms_mo formula is a bit different from conventional PT formula quant_scale = a_levels / torch.tensor([a_cv - a_cvn], device=tar_dev) @@ -809,7 +805,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): quant_zero_point = torch.round(a_cvn * quant_scale) input_scale = quant_stepsize input_zero_point = -quant_zero_point - quant_w_scale = w_levels / torch.tensor([w_cv * 2], device=tar_dev) + quant_w_scale = w_levels / (w_cv * 2) w_scale = 1.0 / quant_w_scale qlin_int.register_buffer("quant_scale", quant_scale) qlin_int.register_buffer("quant_stepsize", quant_stepsize) @@ -821,6 +817,21 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): qlin_int.register_buffer("w_scale", w_scale) qlin_int.register_buffer("w_zp", w_zp) + # NOTE: + # 1. Keep W transposed to prevent confusion, hence (W.t()/scale).t() + # 2. only a few quantizer have .dequantize working correctly + if isinstance(Qw, SAWB): + Qw.dequantize = False + w_int8 = Qw(fms_mo_qlinear.weight.float()) + else: + w_int8 = ( + torch.round(fms_mo_qlinear.weight.t() / w_scale) + .clamp(-w_levels / 2, w_levels / 2) + .t() + ) + + qlin_int.weight = nn.Parameter(w_int8.to(torch.int8), requires_grad=False) + corr_term = ( (input_zero_point - 128 + qlin_int.useSymAct) * (w_int8.sum(dim=1)) @@ -836,8 +847,11 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): (fms_mo_qlinear.bias - corr_term).to(fms_mo_w_dtype), requires_grad=False, ) + qlin_int.org_model_has_bias = True else: + delattr(qlin_int, "bias") + # even if bias is None, reg_buffer() is still unhappy about it qlin_int.register_buffer("bias", -corr_term.to(fms_mo_w_dtype)) qlin_int.org_model_has_bias = False diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index d3340c94..74a83b26 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -3183,9 +3183,7 @@ class QmaxPerChSTE(torch.autograd.Function): """ @staticmethod - def forward( - ctx, input_tensor, num_bits, _dequantize, inplace, cv, _cvn, align_zero - ): + def forward(ctx, input_tensor, num_bits, dequantize, inplace, cv, _cvn, align_zero): if inplace: ctx.mark_dirty(input_tensor) scale = (2**num_bits - 2) if align_zero else (2**num_bits - 1) @@ -3206,6 +3204,9 @@ def forward( quant_min=int_l, quant_max=int_u, ).to(input_tensor.dtype) + + if not dequantize: + return (output.t() / scale).t() return output @staticmethod diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index 46011e59..627e4803 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -173,6 +173,9 @@ class FMSMOArguments(TypeChecker): default=2048, metadata={"help": "input sequence length after tokenization"} ) eval_ppl: bool = field(default=False) + aiu_sim_triton: bool = field( + default=False, metadata={"help": ("AIU simulation with triton kernel")} + ) @dataclass