diff --git a/fms_mo/aiu_addons/gptq/gptq_aiu_op.py b/fms_mo/aiu_addons/gptq/gptq_aiu_op.py index d958f38e..b9199f6b 100644 --- a/fms_mo/aiu_addons/gptq/gptq_aiu_op.py +++ b/fms_mo/aiu_addons/gptq/gptq_aiu_op.py @@ -17,6 +17,7 @@ import logging # Third Party +from packaging.version import Version import torch # pylint: disable=unused-argument @@ -25,6 +26,36 @@ logger = logging.getLogger(__name__) +def implement_op_decorator(op_namespace_id): + """Version-dependent decorator for custom op implementation. + Always compare against pytorch version in current environment. + """ + + torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) + + def decorator(func): + if torch_version < Version("2.4"): + return torch.library.impl(op_namespace_id, "default")(func) + return torch.library.custom_op(op_namespace_id, mutates_args=())(func) + + return decorator + + +def register_op_decorator(op_namespace_id): + """Version-dependent decorator for custom op registration. + Always compare against pytorch version in current environment. + """ + + torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) + + def decorator(func): + if torch_version < Version("2.4"): + return torch.library.impl_abstract(op_namespace_id)(func) + return torch.library.register_fake(op_namespace_id)(func) + + return decorator + + def register_aiu_gptq_op(): """Register AIU-specific op to enable torch compile without graph break. The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation. @@ -36,17 +67,33 @@ def register_aiu_gptq_op(): ): logger.warning("AIU op has already been registered") return - op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu" - torch.library.define( - op_namespace_id, - "(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor", - ) + if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"): + torch.library.define( + op_namespace_id, + "(Tensor x, Tensor qw, Tensor qzeros, " + "Tensor scales, Tensor g_idx) -> Tensor", + ) # Add implementations for the operator - @torch.library.impl(op_namespace_id, "default") - def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx): - # on AIU, GPTQ qw is [out_feat, in_feat] + @implement_op_decorator(op_namespace_id) + def i4f16_fxinputs_aiu( + x: torch.Tensor, + qw: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + g_idx: torch.Tensor, + ) -> torch.Tensor: + """Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a + node on the computational graph to be captured during compiling for AIU. + + Instead of computing the weight decompression and matmul, this function returns + a zero tensor with the expected shape. + + NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized + weights as [in_feat, out_feat] + """ + outshape = x.shape[:-1] + (qw.shape[0],) x = x.view(-1, x.shape[-1]) output = torch.zeros( @@ -56,8 +103,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx): ) return output.view(outshape) - @torch.library.impl_abstract(op_namespace_id) - def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx): + @register_op_decorator(op_namespace_id) + def _(x, qw, qzeros, scales, g_idx): + """OP template of I/O sizes""" + outshape = x.shape[:-1] + (qw.shape[0],) return torch.empty( outshape, diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py index 0d38836c..41aa896f 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py @@ -17,11 +17,10 @@ import logging # Third Party +from packaging.version import Version import torch import torch.nn.functional as F -logger = logging.getLogger(__name__) - # pylint: disable=unused-argument # i8i8 op must be registered with specific I/O, even if not in use by the op function @@ -29,6 +28,38 @@ # torch.nn.functional.linear not recognized as callable # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 +logger = logging.getLogger(__name__) + + +def implement_op_decorator(op_namespace_id): + """Version-dependent decorator for custom op implementation. + Always compare against pytorch version in current environment. + """ + + torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) + + def decorator(func): + if torch_version < Version("2.4"): + return torch.library.impl(op_namespace_id, "default")(func) + return torch.library.custom_op(op_namespace_id, mutates_args=())(func) + + return decorator + + +def register_op_decorator(op_namespace_id): + """Version-dependent decorator for custom op registration. + Always compare against pytorch version in current environment. + """ + + torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) + + def decorator(func): + if torch_version < Version("2.4"): + return torch.library.impl_abstract(op_namespace_id)(func) + return torch.library.register_fake(op_namespace_id)(func) + + return decorator + def register_aiu_i8i8_op(): """Register AIU-specific op to enable torch compile without graph break. @@ -41,26 +72,26 @@ def register_aiu_i8i8_op(): if hasattr(torch.ops, "fms_mo") and hasattr(torch.ops.fms_mo, "i8i8_aiu"): logger.warning("AIU op has already been registered") return - op_namespace_id = "fms_mo::i8i8_aiu" - torch.library.define( - op_namespace_id, - "(Tensor x, Tensor weight, Tensor bias, Tensor qdata, " - "str weight_quant_type, str activ_quant_type, " - "bool smoothquant) " - "-> Tensor", - ) + if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"): + torch.library.define( + op_namespace_id, + "(Tensor x, Tensor weight, Tensor bias, Tensor qdata, " + "str weight_quant_type, str activ_quant_type, " + "bool smoothquant) " + "-> Tensor", + ) - @torch.library.impl(op_namespace_id, "default") + @implement_op_decorator(op_namespace_id) def i8i8_aiu( - x, - weight, - bias, - qdata, - weight_quant_type, - activ_quant_type, - smoothquant, - ): + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + qdata: torch.Tensor, + weight_quant_type: str, + activ_quant_type: str, + smoothquant: bool, + ) -> torch.Tensor: """Implement addmm of X and W. Support various quantization options for weights and activations. @@ -86,16 +117,8 @@ def i8i8_aiu( return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype)) - @torch.library.impl_abstract(op_namespace_id) - def i8i8_aiu_abstract( - x, - weight, - bias, - qdata, - weight_quant_type, - activ_quant_type, - smoothquant, - ): + @register_op_decorator(op_namespace_id) + def _(x, weight, bias, qdata, weight_quant_type, activ_quant_type, smoothquant): """OP template of I/O sizes""" outshape = x.size()[:-1] + (weight.size(0),)