Skip to content
36 changes: 26 additions & 10 deletions fms_mo/custom_ext_kernels/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def matmul_kernel(
stride_cm,
stride_cn,
chunk_trun_bits,
max_acc_bits, # pylint: disable=unused-argument
truncate_then_accumulate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
Expand Down Expand Up @@ -212,6 +213,7 @@ def imatmul_kernel(
stride_cm,
stride_cn,
chunk_trun_bits,
max_acc_bits,
truncate_then_accumulate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
Expand All @@ -220,8 +222,8 @@ def imatmul_kernel(
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the INT matmul C = A x B that include LSB truncation. A and B should be
INT8, C should be INT32. (Pretty much the same code as float version.)
"""Kernel for computing the INT matmul D = A x B + C that include LSB truncation and MSB
clamping. A and B should be INT8, C/D should be INT32. (similar to the float version.)
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Args:
chunk_trun_bits (int): number of LSBs to truncate/round.
Expand All @@ -238,14 +240,20 @@ def imatmul_kernel(

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
## ------ prepare LSB rounding/truncation masks -------
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
## ------ prepare MSB/LSB rounding/truncation masks -------
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
# msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
acc_min = -(1 << (max_acc_bits - 1))
acc_max = -acc_min - 1
## ---------------------------------------------------------

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
Expand All @@ -256,7 +264,12 @@ def imatmul_kernel(
else:
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")

## ------ add chunky LSB rounding/masking --------
## ------ INT MSB truncation is simulated by clamping,
# "special" INT LSB truncation by right and left shift --------
if max_acc_bits < 32:
accumulator_inner = tl.maximum(
tl.minimum(accumulator_inner, acc_max), acc_min
)
if chunk_trun_bits != 0:
accumulator_inner = (accumulator_inner + round_bit) >> chunk_trun_bits
accumulator_inner = accumulator_inner << chunk_trun_bits
Expand All @@ -275,8 +288,6 @@ def imatmul_kernel(

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


Expand All @@ -300,6 +311,7 @@ def matmul_kernel_DABC(
stride_cm,
stride_cn,
chunk_trun_bits,
max_acc_bits, # pylint: disable=unused-argument
truncate_then_accumulate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
Expand Down Expand Up @@ -421,6 +433,7 @@ def tl_matmul_chunk_truncate(
activation="",
chunk_trun_bits=0,
chunk_size=16,
max_acc_bits=32,
truncate_then_accumulate=True,
cast_output_to_input_dtype=None,
):
Expand All @@ -434,6 +447,9 @@ def tl_matmul_chunk_truncate(
activation (str, optional): activation func to be fused, see relu example.
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
clamp each chunk of a*b to [-2**23-1, 2**23].
(assuming no inf when overflow)
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
c = truncate(a*b+c)
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
Expand Down Expand Up @@ -472,9 +488,9 @@ def isPowerofTwo(x):

# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
# insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
# Do not support INT8 for now.
if chunk_size == 8 and a.dtype in [
torch.float8_e4m3fn,
torch.int8,
torch.float16,
torch.bfloat16,
]:
Expand Down Expand Up @@ -515,7 +531,6 @@ def isPowerofTwo(x):
c_org_dtype = c.dtype
c = c.to(acc_dtype)
assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B."
assert acc_dtype == torch.float32, "INT truncation is not yet supported."

# 1D launch kernel where each block gets its own program.
def grid(META):
Expand Down Expand Up @@ -556,6 +571,7 @@ def grid(META):
c.stride(0),
c.stride(1),
chunk_trun_bits=chunk_trun_bits,
max_acc_bits=max_acc_bits,
truncate_then_accumulate=truncate_then_accumulate,
ACTIVATION=activation,
**kernel_config, # if using auto-tune, comment this line out.
Expand Down
75 changes: 68 additions & 7 deletions fms_mo/custom_ext_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,17 @@ def exv2_i4f16_fxinputs_abstract(


def imatmul_ops_reg(
useCUTLASS=True, mm_func=torch.matmul, AB_dtype=torch.float, D_dtype=torch.float
useINTkernel="triton",
mm_func=torch.matmul,
AB_dtype=torch.float,
D_dtype=torch.float,
):
"""This function will register a dummy Q_imatmul Op for better "graph representation".
Args:
useCUTLASS: bool. choose to use a) real INT matmul using cutlass kernel or b) "simulated"
imatmul using torch.matmul.
useINTkernel: str|bool. ["cutlass", "triton", False]. choose to use a) real INT matmul, e.g.
cutlass or triton kernel or b) "simulated" imatmul using torch.matmul.
For b), could use D_dtype to select fp16 or fp32 accumulation
mm_func: matmul func to be used when useCUTLASS is True, should be a real callable kernel
mm_func: matmul func to be used when useINTkernel is True, should be a real callable kernel
from cutlass, but for debug purpose, could use torch.matmul as well.
AB_dtype: datatype for input tensors
D_dtype: datatype for accumulation and output tensor
Expand Down Expand Up @@ -697,10 +700,10 @@ def imatmul(m1, m2):
tar_shape = tuple(m1.shape[:-1]) + (m2.shape[1],)
m1 = m1.view(re_shape)

if useCUTLASS:
if useINTkernel in ["triton", "cutlass"]:
assert (
m1.dtype == torch.int8 and m2.dtype == torch.int8
), "When using cutlass int matmul, inputs must be 2D INT8"
), "When using int matmul, inputs must be 2D and INT8."
return mm_func(m1, m2).reshape(tar_shape)

outf32_or_f16 = torch.empty(
Expand Down Expand Up @@ -759,7 +762,7 @@ def q_iaddmm_dq(bias, m1, m2, scale_i, zp_i, scale_w):
assert m2.dtype == torch.int8, f"weight tensor is of incorrect dtype {m2.dtype}"
m1 = torch.clamp((m1 / scale_i + zp_i - 128).round(), -128, 127).to(torch.int8)

if useCUTLASS:
if useINTkernel:
mm_i32 = mm_func(m1, m2)
else:
outf32_or_f16 = torch.empty(
Expand Down Expand Up @@ -856,6 +859,64 @@ def lower_qmodel_cutlass(
return mod


def lower_qmodel_triton(
model: torch.nn.Module,
use_dyn_max_act=False,
max_acc_bits=32,
num_lsb_to_truncate=0,
chunk_size=32,
):
"""
Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else.
Triton kernel can be used to:
1. test INT8 or FP8 HW performance (kernel is not optimized)
2. simulate MSB/LSB truncation effect

Args:
model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy
use_dyn_max_act: bool or int, can be False or -1 for per-token, or -2 for perCh. will use
dynamic max quantizer for activation if not False.
max_acc_bits: max bits for accumulator, typically FP32 for all FP matmuls and INT32 for all
INT matmuls. But some HW could use fewer bits to trade-off power
efficiency at the expense of higher chance of accumulation "overflow".
For example, an INT24 accumulator can only hold values ranged from -2^23 to
2^23 -1, as opposed to typical range -2^31 to -2^31 -1.
num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is
s1e8m23, if we choose to truncate 13 mantissa bits from right most side,
i.e. LSB, the resulting number will be s1e8m10, which is TF32.
chunk_size: given a matmul of (m, k) @ (k, n), the inner product will be "accumulated" along
k-dim. Since the entire matrix will be partitioned into smaller tiles when being
computed, accumulator will only add a certain num of elements in one shot. This
"chunk size" in k-dim will affect the overflow/underflow of accumulator.
"""
# Third Party
from torch.ao.quantization.utils import _parent_name

# Local
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy

for name, m in model.named_modules():
if not isinstance(m, QLinear):
continue
parent_name, module_name = _parent_name(name)
parent_mod = model.get_submodule(parent_name)
qmod = getattr(parent_mod, module_name)
setattr(
parent_mod,
module_name,
QLinearINT8Deploy.from_fms_mo(
qmod,
use_int_kernel="triton",
use_dynamic_max_act_Qfunc=use_dyn_max_act,
max_acc_bits=max_acc_bits,
truncate_lsb=num_lsb_to_truncate,
chunk_size=chunk_size,
),
)

logger.info(f"\nModel lowering with triton kernel is done.\n{model}")


### -------------------------------------------------------------
# GPTQ tensor packing functions for Exllama kernel
### -------------------------------------------------------------
Expand Down
11 changes: 6 additions & 5 deletions fms_mo/dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from fms_mo import qconfig_init, qmodel_prep
from fms_mo.fx.utils import model_size_Wb
from fms_mo.quant.ptq import (
calibration_llm_1GPU,
calibration_llm_1GPU_v2,
dq_llm,
get_act_scales,
get_act_scales_1gpu,
Expand Down Expand Up @@ -232,9 +232,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU(qcfg, model, dq_dataloader)
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
else:
model.to("cuda:0")
model.to("cuda")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
Expand Down Expand Up @@ -263,7 +263,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
test_dataset = load_from_disk(data_args.test_data_path)
test_dataset = test_dataset.with_format("torch")
elif len(pt_files) > 0:
test_dataset = torch.load(pt_files[0])
test_dataset = torch.load(pt_files[0], weights_only=False)

logger.info(f"Model for evaluation: {model}")
if qcfg["large_model"]:
Expand All @@ -272,7 +272,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
model.to(torch.device("cuda:0"))
n_samples = int(test_dataset.input_ids.shape[1] / block_size)
evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples)
ppl = evaluator.evaluate(model, block_size=block_size)
with patch_torch_bmm(qcfg):
ppl = evaluator.evaluate(model, block_size=block_size)
logger.info(f"Model perplexity: {ppl}")
logger.info("-" * 50)
logger.info("Finished evaluation")
Loading
Loading