From 197934a3862bb2a118c53124eb7f5f0a1b170015 Mon Sep 17 00:00:00 2001 From: zhenggf Date: Thu, 25 Jun 2026 17:56:41 +0800 Subject: [PATCH 1/4] Enable-Hygon-DCU-INT8-hipBLASLt-GEMM (cherry picked from commit a3a1a1f870b768929d8ca073f0c74added572087) --- lightx2v/common/modules/weight_module.py | 2 + .../ops/mm/hygon_dcu/mm_weight.py | 118 +++++++++++++----- lightx2v_platform/ops/mm/template.py | 5 +- 3 files changed, 90 insertions(+), 35 deletions(-) diff --git a/lightx2v/common/modules/weight_module.py b/lightx2v/common/modules/weight_module.py index 8c07bd267..66b4b9598 100755 --- a/lightx2v/common/modules/weight_module.py +++ b/lightx2v/common/modules/weight_module.py @@ -10,6 +10,8 @@ def is_empty(self): return len(self._modules) == 0 and len(self._parameters) == 0 def add_module(self, name, module): + if hasattr(self, "config") and hasattr(module, "set_config"): + module.set_config(self.config) self._modules[name] = module setattr(self, name, module) diff --git a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py index dbe2036b4..ca97b08b6 100644 --- a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py +++ b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py @@ -13,6 +13,62 @@ except ImportError: IntegerQuantizer = None +try: + from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm +except ImportError: + hipblaslt_w8a8_channelwise_gemm = None + + +def _load_auto_quant_bias(module, weight_dict): + module.bias = None + module.pin_bias = None + if module.bias_name is None or module.bias_name not in weight_dict: + return + bias = weight_dict[module.bias_name] + if module.bias_force_fp32: + bias = bias.to(torch.float32) + elif hasattr(module, "infer_dtype"): + bias = bias.to(module.infer_dtype) + module.bias = bias.to(module.weight.device) + + +def _make_weight_contiguous(module): + if hasattr(module, "weight") and module.weight is not None: + module.weight = module.weight.contiguous() + if hasattr(module, "weight_scale") and module.weight_scale is not None: + module.weight_scale = module.weight_scale.contiguous() + + +def _flatten_last_dim(input_tensor): + if input_tensor.dim() == 2: + return input_tensor, None + original_shape = input_tensor.shape + return input_tensor.reshape(-1, original_shape[-1]), original_shape[:-1] + + +def _restore_last_dim(output_tensor, prefix_shape): + if prefix_shape is None: + return output_tensor + return output_tensor.reshape(*prefix_shape, output_tensor.shape[-1]) + + +def _bias_or_none(module, out_dtype=None): + if hasattr(module, "bias") and module.bias is not None: + bias = module.bias + if out_dtype is not None and bias.dtype != out_dtype: + bias = bias.to(out_dtype) + return bias + return None + + +def _require_hipblaslt_w8a8_channelwise_gemm(): + if hipblaslt_w8a8_channelwise_gemm is None: + raise RuntimeError( + "int8-vllm-hygon-dcu requires lmslim.quantize.quant_ops." + "hipblaslt_w8a8_channelwise_gemm on Hygon DCU." + ) + return hipblaslt_w8a8_channelwise_gemm + @PLATFORM_MM_WEIGHT_REGISTER("int8-vllm-hygon-dcu") class MMWeightWint8channelAint8channeldynamicVllmHygonDcu(MMWeightQuantTemplate): @@ -42,6 +98,10 @@ def __init__( self.weight_need_transpose = False self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + def load(self, weight_dict): + super().load(weight_dict) + _make_weight_contiguous(self) + def load_int8_perchannel_sym(self, weight_dict): """Load INT8 per-channel symmetric quantized weights.""" if self.config.get("weight_auto_quant", False): @@ -52,6 +112,7 @@ def load_int8_perchannel_sym(self, weight_dict): self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight = self.weight.to(torch.int8) self.weight_scale = self.weight_scale.to(torch.float32) + _load_auto_quant_bias(self, weight_dict) else: self.load_quantized(weight_dict) @@ -63,38 +124,27 @@ def act_quant_int8_perchannel_sym_vllm(self, x): return input_tensor_quant, input_tensor_scale def apply(self, input_tensor): - shape = (input_tensor.shape[0], self.weight.shape[1]) dtype = input_tensor.dtype - device = input_tensor.device - - input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) - - # Use ops.blaslt_scaled_mm from vllm for ROCm/DCU instead of torch.ops._C.cutlass_scaled_mm - if ops is not None and hasattr(ops, "blaslt_scaled_mm"): - # Ensure out_dtype is bfloat16 or float16 as required by blaslt_scaled_mm - out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 - - # Ensure input tensor is contiguous for optimal performance - input_tensor_quant = input_tensor_quant.contiguous() - - output_tensor = ops.blaslt_scaled_mm( - input_tensor_quant, - self.weight, - input_tensor_scale, - self.weight_scale, - out_dtype, - self.bias if self.bias is not None else None, - ) - - # Convert back to original dtype if needed - if output_tensor.dtype != dtype: - output_tensor = output_tensor.to(dtype) - else: - # Fallback: use manual dequantization and matmul - input_dequant = input_tensor_quant.to(dtype) * input_tensor_scale.to(dtype) - weight_dequant = self.weight.to(dtype) * self.weight_scale.to(dtype) - output_tensor = torch.matmul(input_dequant, weight_dequant) - if self.bias is not None: - output_tensor = output_tensor + self.bias - - return output_tensor + input_2d, prefix_shape = _flatten_last_dim(input_tensor) + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_2d) + out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 + m, k = input_tensor_quant.shape + n = self.weight.shape[0] + hipblaslt_gemm = _require_hipblaslt_w8a8_channelwise_gemm() + _, output_tensor = hipblaslt_gemm( + a=input_tensor_quant.contiguous(), + b=self.weight.contiguous(), + scale_a=input_tensor_scale.contiguous(), + scale_b=self.weight_scale.contiguous(), + m=m, + n=n, + k=k, + transpose_flag="NT", + out_dtype=out_dtype, + bias=_bias_or_none(self, out_dtype), + ) + output_tensor = output_tensor.reshape(-1, n).narrow(0, 0, m) + + if output_tensor.dtype != dtype: + output_tensor = output_tensor.to(dtype) + return _restore_last_dim(output_tensor, prefix_shape) diff --git a/lightx2v_platform/ops/mm/template.py b/lightx2v_platform/ops/mm/template.py index 28f519aa2..033663e5e 100644 --- a/lightx2v_platform/ops/mm/template.py +++ b/lightx2v_platform/ops/mm/template.py @@ -70,7 +70,10 @@ def __init__( # weight load functions # ========================= def load(self, weight_dict): - self.load_quantized(weight_dict) + if self.load_func is None: + self.load_quantized(weight_dict) + else: + self.load_func(weight_dict) if self.weight_need_transpose: if hasattr(self, "weight") and self.weight is not None: self.weight = self.weight.t() From 0cc460505960e7a273c33a86692bdf0fcd70882b Mon Sep 17 00:00:00 2001 From: zhenggf Date: Mon, 29 Jun 2026 09:32:09 +0800 Subject: [PATCH 2/4] hygon-dcu: add shared int8 activation helpers Add reusable quantized-input helpers for Hygon DCU W8A8 dynamic activation GEMMs, and support selective BF16 fallback for configured INT8 weights. (cherry picked from commit 58dab25b69c41c6ec9a24df0fe584ca93534eacc) --- .../ops/mm/hygon_dcu/mm_weight.py | 70 +++++++++++++++++-- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py index ca97b08b6..209d0ee2e 100644 --- a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py +++ b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py @@ -1,4 +1,7 @@ +import os + import torch +import torch.nn.functional as F from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER @@ -61,6 +64,32 @@ def _bias_or_none(module, out_dtype=None): return None + +def _env_flag(name, default="0"): + return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"} + + +def _env_patterns(name, default=""): + raw = os.getenv(name, default) + return tuple(item.strip() for item in raw.replace(";", ",").split(",") if item.strip()) + + +def _matches_any(name, patterns): + return any(pattern in name for pattern in patterns) + + +def _use_selective_bf16_fallback(weight_name): + if not _env_flag("LIGHTX2V_INT8_SELECTIVE"): + return False + include = _env_patterns("LIGHTX2V_INT8_SELECTIVE_INCLUDE") + exclude = _env_patterns( + "LIGHTX2V_INT8_SELECTIVE_EXCLUDE", + "txt_branch", + ) + if include and not _matches_any(weight_name, include): + return True + return _matches_any(weight_name, exclude) + def _require_hipblaslt_w8a8_channelwise_gemm(): if hipblaslt_w8a8_channelwise_gemm is None: raise RuntimeError( @@ -97,6 +126,7 @@ def __init__( self.load_func = self.load_int8_perchannel_sym self.weight_need_transpose = False self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + self.use_bf16_fallback = _use_selective_bf16_fallback(weight_name) def load(self, weight_dict): super().load(weight_dict) @@ -104,6 +134,14 @@ def load(self, weight_dict): def load_int8_perchannel_sym(self, weight_dict): """Load INT8 per-channel symmetric quantized weights.""" + if self.use_bf16_fallback: + if not self.config.get("weight_auto_quant", False): + raise RuntimeError("Selective BF16 fallback requires weight_auto_quant=1 so original BF16 weights are available.") + self.weight = weight_dict[self.weight_name].to(self.infer_dtype) + self.weight_scale = None + _load_auto_quant_bias(self, weight_dict) + return + if self.config.get("weight_auto_quant", False): if IntegerQuantizer is None: raise ImportError("IntegerQuantizer not available. Please ensure lightx2v.utils.quant_utils is available.") @@ -123,11 +161,14 @@ def act_quant_int8_perchannel_sym_vllm(self, x): input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) return input_tensor_quant, input_tensor_scale - def apply(self, input_tensor): - dtype = input_tensor.dtype + def prepare_quantized_input(self, input_tensor): + if self.use_bf16_fallback: + raise RuntimeError("BF16 fallback weights do not support shared INT8 activation quantization.") input_2d, prefix_shape = _flatten_last_dim(input_tensor) input_tensor_quant, input_tensor_scale = self.act_quant_func(input_2d) - out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 + return input_2d, prefix_shape, input_tensor_quant, input_tensor_scale + + def _apply_quantized_2d(self, input_2d, input_tensor_quant, input_tensor_scale, out_dtype): m, k = input_tensor_quant.shape n = self.weight.shape[0] hipblaslt_gemm = _require_hipblaslt_w8a8_channelwise_gemm() @@ -143,8 +184,29 @@ def apply(self, input_tensor): out_dtype=out_dtype, bias=_bias_or_none(self, out_dtype), ) - output_tensor = output_tensor.reshape(-1, n).narrow(0, 0, m) + output_tensor = output_tensor.reshape(-1, n).narrow(0, 0, input_2d.shape[0]) + return output_tensor + + def apply_quantized_input(self, input_tensor, quantized_input): + if self.use_bf16_fallback: + return self.apply(input_tensor) + dtype = input_tensor.dtype + out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 + input_2d, prefix_shape, input_tensor_quant, input_tensor_scale = quantized_input + output_tensor = self._apply_quantized_2d(input_2d, input_tensor_quant, input_tensor_scale, out_dtype) if output_tensor.dtype != dtype: output_tensor = output_tensor.to(dtype) return _restore_last_dim(output_tensor, prefix_shape) + + def _apply_bf16(self, input_tensor): + weight = self.weight + if weight.dtype != input_tensor.dtype: + weight = weight.to(input_tensor.dtype) + bias = _bias_or_none(self, input_tensor.dtype) + return F.linear(input_tensor, weight, bias) + + def apply(self, input_tensor): + if self.use_bf16_fallback: + return self._apply_bf16(input_tensor) + return self.apply_quantized_input(input_tensor, self.prepare_quantized_input(input_tensor)) From 1213895e0d2c5087f706c0f3af8194a1fa62decf Mon Sep 17 00:00:00 2001 From: zhenggf Date: Tue, 30 Jun 2026 14:31:18 +0800 Subject: [PATCH 3/4] fix: tighten Hygon INT8 GEMM fallback handling --- lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py index 209d0ee2e..3816872f4 100644 --- a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py +++ b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py @@ -32,7 +32,9 @@ def _load_auto_quant_bias(module, weight_dict): bias = bias.to(torch.float32) elif hasattr(module, "infer_dtype"): bias = bias.to(module.infer_dtype) - module.bias = bias.to(module.weight.device) + weight = getattr(module, "weight", None) + target_device = weight.device if weight is not None else bias.device + module.bias = bias.to(target_device) def _make_weight_contiguous(module): @@ -174,9 +176,9 @@ def _apply_quantized_2d(self, input_2d, input_tensor_quant, input_tensor_scale, hipblaslt_gemm = _require_hipblaslt_w8a8_channelwise_gemm() _, output_tensor = hipblaslt_gemm( a=input_tensor_quant.contiguous(), - b=self.weight.contiguous(), + b=self.weight, scale_a=input_tensor_scale.contiguous(), - scale_b=self.weight_scale.contiguous(), + scale_b=self.weight_scale, m=m, n=n, k=k, @@ -200,11 +202,10 @@ def apply_quantized_input(self, input_tensor, quantized_input): return _restore_last_dim(output_tensor, prefix_shape) def _apply_bf16(self, input_tensor): - weight = self.weight - if weight.dtype != input_tensor.dtype: - weight = weight.to(input_tensor.dtype) + if self.weight.dtype != input_tensor.dtype: + self.weight = self.weight.to(input_tensor.dtype) bias = _bias_or_none(self, input_tensor.dtype) - return F.linear(input_tensor, weight, bias) + return F.linear(input_tensor, self.weight, bias) def apply(self, input_tensor): if self.use_bf16_fallback: From d8b5456d69ec0a0409c67214de0b8d7d6b9850a5 Mon Sep 17 00:00:00 2001 From: zhenggf Date: Wed, 1 Jul 2026 14:59:32 +0800 Subject: [PATCH 4/4] style: format Hygon INT8 GEMM changes --- lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py index 3816872f4..ed481c060 100644 --- a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py +++ b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py @@ -66,7 +66,6 @@ def _bias_or_none(module, out_dtype=None): return None - def _env_flag(name, default="0"): return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"} @@ -92,12 +91,10 @@ def _use_selective_bf16_fallback(weight_name): return True return _matches_any(weight_name, exclude) + def _require_hipblaslt_w8a8_channelwise_gemm(): if hipblaslt_w8a8_channelwise_gemm is None: - raise RuntimeError( - "int8-vllm-hygon-dcu requires lmslim.quantize.quant_ops." - "hipblaslt_w8a8_channelwise_gemm on Hygon DCU." - ) + raise RuntimeError("int8-vllm-hygon-dcu requires lmslim.quantize.quant_ops.hipblaslt_w8a8_channelwise_gemm on Hygon DCU.") return hipblaslt_w8a8_channelwise_gemm