Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lightx2v/common/modules/weight_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def is_empty(self):
return len(self._modules) == 0 and len(self._parameters) == 0

def add_module(self, name, module):
if hasattr(self, "config") and hasattr(module, "set_config"):
module.set_config(self.config)
self._modules[name] = module
setattr(self, name, module)

Expand Down
181 changes: 147 additions & 34 deletions lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os

import torch
import torch.nn.functional as F

from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate
from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER
Expand All @@ -13,6 +16,90 @@
except ImportError:
IntegerQuantizer = None

try:
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
except ImportError:
hipblaslt_w8a8_channelwise_gemm = None


def _load_auto_quant_bias(module, weight_dict):
module.bias = None
module.pin_bias = None
if module.bias_name is None or module.bias_name not in weight_dict:
return
bias = weight_dict[module.bias_name]
if module.bias_force_fp32:
bias = bias.to(torch.float32)
elif hasattr(module, "infer_dtype"):
bias = bias.to(module.infer_dtype)
weight = getattr(module, "weight", None)
target_device = weight.device if weight is not None else bias.device
module.bias = bias.to(target_device)


def _make_weight_contiguous(module):
if hasattr(module, "weight") and module.weight is not None:
module.weight = module.weight.contiguous()
if hasattr(module, "weight_scale") and module.weight_scale is not None:
module.weight_scale = module.weight_scale.contiguous()


def _flatten_last_dim(input_tensor):
if input_tensor.dim() == 2:
return input_tensor, None
original_shape = input_tensor.shape
return input_tensor.reshape(-1, original_shape[-1]), original_shape[:-1]


def _restore_last_dim(output_tensor, prefix_shape):
if prefix_shape is None:
return output_tensor
return output_tensor.reshape(*prefix_shape, output_tensor.shape[-1])


def _bias_or_none(module, out_dtype=None):
if hasattr(module, "bias") and module.bias is not None:
bias = module.bias
if out_dtype is not None and bias.dtype != out_dtype:
bias = bias.to(out_dtype)
return bias
return None



def _env_flag(name, default="0"):
return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"}


def _env_patterns(name, default=""):
raw = os.getenv(name, default)
return tuple(item.strip() for item in raw.replace(";", ",").split(",") if item.strip())


def _matches_any(name, patterns):
return any(pattern in name for pattern in patterns)


def _use_selective_bf16_fallback(weight_name):
if not _env_flag("LIGHTX2V_INT8_SELECTIVE"):
return False
include = _env_patterns("LIGHTX2V_INT8_SELECTIVE_INCLUDE")
exclude = _env_patterns(
"LIGHTX2V_INT8_SELECTIVE_EXCLUDE",
"txt_branch",
)
if include and not _matches_any(weight_name, include):
return True
return _matches_any(weight_name, exclude)

def _require_hipblaslt_w8a8_channelwise_gemm():
if hipblaslt_w8a8_channelwise_gemm is None:
raise RuntimeError(
"int8-vllm-hygon-dcu requires lmslim.quantize.quant_ops."
"hipblaslt_w8a8_channelwise_gemm on Hygon DCU."
)
return hipblaslt_w8a8_channelwise_gemm


@PLATFORM_MM_WEIGHT_REGISTER("int8-vllm-hygon-dcu")
class MMWeightWint8channelAint8channeldynamicVllmHygonDcu(MMWeightQuantTemplate):
Expand Down Expand Up @@ -41,9 +128,22 @@ def __init__(
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
self.use_bf16_fallback = _use_selective_bf16_fallback(weight_name)

def load(self, weight_dict):
super().load(weight_dict)
_make_weight_contiguous(self)

def load_int8_perchannel_sym(self, weight_dict):
"""Load INT8 per-channel symmetric quantized weights."""
if self.use_bf16_fallback:
if not self.config.get("weight_auto_quant", False):
raise RuntimeError("Selective BF16 fallback requires weight_auto_quant=1 so original BF16 weights are available.")
self.weight = weight_dict[self.weight_name].to(self.infer_dtype)
self.weight_scale = None
_load_auto_quant_bias(self, weight_dict)
return

if self.config.get("weight_auto_quant", False):
if IntegerQuantizer is None:
raise ImportError("IntegerQuantizer not available. Please ensure lightx2v.utils.quant_utils is available.")
Expand All @@ -52,6 +152,7 @@ def load_int8_perchannel_sym(self, weight_dict):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
_load_auto_quant_bias(self, weight_dict)
else:
self.load_quantized(weight_dict)

Expand All @@ -62,39 +163,51 @@ def act_quant_int8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale

def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
def prepare_quantized_input(self, input_tensor):
if self.use_bf16_fallback:
raise RuntimeError("BF16 fallback weights do not support shared INT8 activation quantization.")
input_2d, prefix_shape = _flatten_last_dim(input_tensor)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_2d)
return input_2d, prefix_shape, input_tensor_quant, input_tensor_scale

def _apply_quantized_2d(self, input_2d, input_tensor_quant, input_tensor_scale, out_dtype):
m, k = input_tensor_quant.shape
n = self.weight.shape[0]
hipblaslt_gemm = _require_hipblaslt_w8a8_channelwise_gemm()
_, output_tensor = hipblaslt_gemm(
a=input_tensor_quant.contiguous(),
b=self.weight,
scale_a=input_tensor_scale.contiguous(),
scale_b=self.weight_scale,
m=m,
n=n,
k=k,
transpose_flag="NT",
out_dtype=out_dtype,
bias=_bias_or_none(self, out_dtype),
)
output_tensor = output_tensor.reshape(-1, n).narrow(0, 0, input_2d.shape[0])
return output_tensor

def apply_quantized_input(self, input_tensor, quantized_input):
if self.use_bf16_fallback:
return self.apply(input_tensor)
dtype = input_tensor.dtype
device = input_tensor.device

input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)

# Use ops.blaslt_scaled_mm from vllm for ROCm/DCU instead of torch.ops._C.cutlass_scaled_mm
if ops is not None and hasattr(ops, "blaslt_scaled_mm"):
# Ensure out_dtype is bfloat16 or float16 as required by blaslt_scaled_mm
out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16

# Ensure input tensor is contiguous for optimal performance
input_tensor_quant = input_tensor_quant.contiguous()

output_tensor = ops.blaslt_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
out_dtype,
self.bias if self.bias is not None else None,
)

# Convert back to original dtype if needed
if output_tensor.dtype != dtype:
output_tensor = output_tensor.to(dtype)
else:
# Fallback: use manual dequantization and matmul
input_dequant = input_tensor_quant.to(dtype) * input_tensor_scale.to(dtype)
weight_dequant = self.weight.to(dtype) * self.weight_scale.to(dtype)
output_tensor = torch.matmul(input_dequant, weight_dequant)
if self.bias is not None:
output_tensor = output_tensor + self.bias
out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
input_2d, prefix_shape, input_tensor_quant, input_tensor_scale = quantized_input
output_tensor = self._apply_quantized_2d(input_2d, input_tensor_quant, input_tensor_scale, out_dtype)

return output_tensor
if output_tensor.dtype != dtype:
output_tensor = output_tensor.to(dtype)
return _restore_last_dim(output_tensor, prefix_shape)

def _apply_bf16(self, input_tensor):
if self.weight.dtype != input_tensor.dtype:
self.weight = self.weight.to(input_tensor.dtype)
bias = _bias_or_none(self, input_tensor.dtype)
return F.linear(input_tensor, self.weight, bias)

def apply(self, input_tensor):
if self.use_bf16_fallback:
return self._apply_bf16(input_tensor)
return self.apply_quantized_input(input_tensor, self.prepare_quantized_input(input_tensor))
5 changes: 4 additions & 1 deletion lightx2v_platform/ops/mm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def __init__(
# weight load functions
# =========================
def load(self, weight_dict):
self.load_quantized(weight_dict)
if self.load_func is None:
self.load_quantized(weight_dict)
else:
self.load_func(weight_dict)
if self.weight_need_transpose:
if hasattr(self, "weight") and self.weight is not None:
self.weight = self.weight.t()
Expand Down