Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions fms_mo/aiu_addons/gptq/gptq_aiu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging

# Third Party
from packaging.version import Version
import torch

# pylint: disable=unused-argument
Expand All @@ -25,6 +26,36 @@
logger = logging.getLogger(__name__)


def implement_op_decorator(op_namespace_id):
"""Version-dependent decorator for custom op implementation.
Always compare against pytorch version in current environment.
"""

torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])

def decorator(func):
if torch_version < Version("2.4"):
return torch.library.impl(op_namespace_id, "default")(func)
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)

return decorator


def register_op_decorator(op_namespace_id):
"""Version-dependent decorator for custom op registration.
Always compare against pytorch version in current environment.
"""

torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])

def decorator(func):
if torch_version < Version("2.4"):
return torch.library.impl_abstract(op_namespace_id)(func)
return torch.library.register_fake(op_namespace_id)(func)

return decorator


def register_aiu_gptq_op():
"""Register AIU-specific op to enable torch compile without graph break.
The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation.
Expand All @@ -36,17 +67,33 @@ def register_aiu_gptq_op():
):
logger.warning("AIU op has already been registered")
return

op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
torch.library.define(
op_namespace_id,
"(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor",
)
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
torch.library.define(
op_namespace_id,
"(Tensor x, Tensor qw, Tensor qzeros, "
"Tensor scales, Tensor g_idx) -> Tensor",
)

# Add implementations for the operator
@torch.library.impl(op_namespace_id, "default")
def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
# on AIU, GPTQ qw is [out_feat, in_feat]
@implement_op_decorator(op_namespace_id)
def i4f16_fxinputs_aiu(
x: torch.Tensor,
qw: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
g_idx: torch.Tensor,
) -> torch.Tensor:
"""Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a
node on the computational graph to be captured during compiling for AIU.

Instead of computing the weight decompression and matmul, this function returns
a zero tensor with the expected shape.

NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized
weights as [in_feat, out_feat]
"""

outshape = x.shape[:-1] + (qw.shape[0],)
x = x.view(-1, x.shape[-1])
output = torch.zeros(
Expand All @@ -56,8 +103,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
)
return output.view(outshape)

@torch.library.impl_abstract(op_namespace_id)
def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx):
@register_op_decorator(op_namespace_id)
def _(x, qw, qzeros, scales, g_idx):
"""OP template of I/O sizes"""

outshape = x.shape[:-1] + (qw.shape[0],)
return torch.empty(
outshape,
Expand Down
81 changes: 52 additions & 29 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,49 @@
import logging

# Third Party
from packaging.version import Version
import torch
import torch.nn.functional as F

logger = logging.getLogger(__name__)

# pylint: disable=unused-argument
# i8i8 op must be registered with specific I/O, even if not in use by the op function

# pylint: disable=not-callable
# torch.nn.functional.linear not recognized as callable
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482

logger = logging.getLogger(__name__)


def implement_op_decorator(op_namespace_id):
"""Version-dependent decorator for custom op implementation.
Always compare against pytorch version in current environment.
"""

torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])

def decorator(func):
if torch_version < Version("2.4"):
return torch.library.impl(op_namespace_id, "default")(func)
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)

return decorator


def register_op_decorator(op_namespace_id):
"""Version-dependent decorator for custom op registration.
Always compare against pytorch version in current environment.
"""

torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])

def decorator(func):
if torch_version < Version("2.4"):
return torch.library.impl_abstract(op_namespace_id)(func)
return torch.library.register_fake(op_namespace_id)(func)

return decorator


def register_aiu_i8i8_op():
"""Register AIU-specific op to enable torch compile without graph break.
Expand All @@ -41,26 +72,26 @@ def register_aiu_i8i8_op():
if hasattr(torch.ops, "fms_mo") and hasattr(torch.ops.fms_mo, "i8i8_aiu"):
logger.warning("AIU op has already been registered")
return

op_namespace_id = "fms_mo::i8i8_aiu"
torch.library.define(
op_namespace_id,
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
"str weight_quant_type, str activ_quant_type, "
"bool smoothquant) "
"-> Tensor",
)
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
torch.library.define(
op_namespace_id,
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
"str weight_quant_type, str activ_quant_type, "
"bool smoothquant) "
"-> Tensor",
)

@torch.library.impl(op_namespace_id, "default")
@implement_op_decorator(op_namespace_id)
def i8i8_aiu(
x,
weight,
bias,
qdata,
weight_quant_type,
activ_quant_type,
smoothquant,
):
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
qdata: torch.Tensor,
weight_quant_type: str,
activ_quant_type: str,
smoothquant: bool,
) -> torch.Tensor:
"""Implement addmm of X and W.
Support various quantization options for weights and activations.

Expand All @@ -86,16 +117,8 @@ def i8i8_aiu(

return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))

@torch.library.impl_abstract(op_namespace_id)
def i8i8_aiu_abstract(
x,
weight,
bias,
qdata,
weight_quant_type,
activ_quant_type,
smoothquant,
):
@register_op_decorator(op_namespace_id)
def _(x, weight, bias, qdata, weight_quant_type, activ_quant_type, smoothquant):
"""OP template of I/O sizes"""

outshape = x.size()[:-1] + (weight.size(0),)
Expand Down
Loading