diff --git a/lightx2v/common/modules/weight_module.py b/lightx2v/common/modules/weight_module.py index 8c07bd267..66b4b9598 100755 --- a/lightx2v/common/modules/weight_module.py +++ b/lightx2v/common/modules/weight_module.py @@ -10,6 +10,8 @@ def is_empty(self): return len(self._modules) == 0 and len(self._parameters) == 0 def add_module(self, name, module): + if hasattr(self, "config") and hasattr(module, "set_config"): + module.set_config(self.config) self._modules[name] = module setattr(self, name, module) diff --git a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py index dbe2036b4..3816872f4 100644 --- a/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py +++ b/lightx2v_platform/ops/mm/hygon_dcu/mm_weight.py @@ -1,4 +1,7 @@ +import os + import torch +import torch.nn.functional as F from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER @@ -13,6 +16,90 @@ except ImportError: IntegerQuantizer = None +try: + from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm +except ImportError: + hipblaslt_w8a8_channelwise_gemm = None + + +def _load_auto_quant_bias(module, weight_dict): + module.bias = None + module.pin_bias = None + if module.bias_name is None or module.bias_name not in weight_dict: + return + bias = weight_dict[module.bias_name] + if module.bias_force_fp32: + bias = bias.to(torch.float32) + elif hasattr(module, "infer_dtype"): + bias = bias.to(module.infer_dtype) + weight = getattr(module, "weight", None) + target_device = weight.device if weight is not None else bias.device + module.bias = bias.to(target_device) + + +def _make_weight_contiguous(module): + if hasattr(module, "weight") and module.weight is not None: + module.weight = module.weight.contiguous() + if hasattr(module, "weight_scale") and module.weight_scale is not None: + module.weight_scale = module.weight_scale.contiguous() + + +def _flatten_last_dim(input_tensor): + if input_tensor.dim() == 2: + return input_tensor, None + original_shape = input_tensor.shape + return input_tensor.reshape(-1, original_shape[-1]), original_shape[:-1] + + +def _restore_last_dim(output_tensor, prefix_shape): + if prefix_shape is None: + return output_tensor + return output_tensor.reshape(*prefix_shape, output_tensor.shape[-1]) + + +def _bias_or_none(module, out_dtype=None): + if hasattr(module, "bias") and module.bias is not None: + bias = module.bias + if out_dtype is not None and bias.dtype != out_dtype: + bias = bias.to(out_dtype) + return bias + return None + + + +def _env_flag(name, default="0"): + return os.getenv(name, default).strip().lower() in {"1", "true", "yes", "on"} + + +def _env_patterns(name, default=""): + raw = os.getenv(name, default) + return tuple(item.strip() for item in raw.replace(";", ",").split(",") if item.strip()) + + +def _matches_any(name, patterns): + return any(pattern in name for pattern in patterns) + + +def _use_selective_bf16_fallback(weight_name): + if not _env_flag("LIGHTX2V_INT8_SELECTIVE"): + return False + include = _env_patterns("LIGHTX2V_INT8_SELECTIVE_INCLUDE") + exclude = _env_patterns( + "LIGHTX2V_INT8_SELECTIVE_EXCLUDE", + "txt_branch", + ) + if include and not _matches_any(weight_name, include): + return True + return _matches_any(weight_name, exclude) + +def _require_hipblaslt_w8a8_channelwise_gemm(): + if hipblaslt_w8a8_channelwise_gemm is None: + raise RuntimeError( + "int8-vllm-hygon-dcu requires lmslim.quantize.quant_ops." + "hipblaslt_w8a8_channelwise_gemm on Hygon DCU." + ) + return hipblaslt_w8a8_channelwise_gemm + @PLATFORM_MM_WEIGHT_REGISTER("int8-vllm-hygon-dcu") class MMWeightWint8channelAint8channeldynamicVllmHygonDcu(MMWeightQuantTemplate): @@ -41,9 +128,22 @@ def __init__( self.load_func = self.load_int8_perchannel_sym self.weight_need_transpose = False self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + self.use_bf16_fallback = _use_selective_bf16_fallback(weight_name) + + def load(self, weight_dict): + super().load(weight_dict) + _make_weight_contiguous(self) def load_int8_perchannel_sym(self, weight_dict): """Load INT8 per-channel symmetric quantized weights.""" + if self.use_bf16_fallback: + if not self.config.get("weight_auto_quant", False): + raise RuntimeError("Selective BF16 fallback requires weight_auto_quant=1 so original BF16 weights are available.") + self.weight = weight_dict[self.weight_name].to(self.infer_dtype) + self.weight_scale = None + _load_auto_quant_bias(self, weight_dict) + return + if self.config.get("weight_auto_quant", False): if IntegerQuantizer is None: raise ImportError("IntegerQuantizer not available. Please ensure lightx2v.utils.quant_utils is available.") @@ -52,6 +152,7 @@ def load_int8_perchannel_sym(self, weight_dict): self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight = self.weight.to(torch.int8) self.weight_scale = self.weight_scale.to(torch.float32) + _load_auto_quant_bias(self, weight_dict) else: self.load_quantized(weight_dict) @@ -62,39 +163,51 @@ def act_quant_int8_perchannel_sym_vllm(self, x): input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) return input_tensor_quant, input_tensor_scale - def apply(self, input_tensor): - shape = (input_tensor.shape[0], self.weight.shape[1]) + def prepare_quantized_input(self, input_tensor): + if self.use_bf16_fallback: + raise RuntimeError("BF16 fallback weights do not support shared INT8 activation quantization.") + input_2d, prefix_shape = _flatten_last_dim(input_tensor) + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_2d) + return input_2d, prefix_shape, input_tensor_quant, input_tensor_scale + + def _apply_quantized_2d(self, input_2d, input_tensor_quant, input_tensor_scale, out_dtype): + m, k = input_tensor_quant.shape + n = self.weight.shape[0] + hipblaslt_gemm = _require_hipblaslt_w8a8_channelwise_gemm() + _, output_tensor = hipblaslt_gemm( + a=input_tensor_quant.contiguous(), + b=self.weight, + scale_a=input_tensor_scale.contiguous(), + scale_b=self.weight_scale, + m=m, + n=n, + k=k, + transpose_flag="NT", + out_dtype=out_dtype, + bias=_bias_or_none(self, out_dtype), + ) + output_tensor = output_tensor.reshape(-1, n).narrow(0, 0, input_2d.shape[0]) + return output_tensor + + def apply_quantized_input(self, input_tensor, quantized_input): + if self.use_bf16_fallback: + return self.apply(input_tensor) dtype = input_tensor.dtype - device = input_tensor.device - - input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) - - # Use ops.blaslt_scaled_mm from vllm for ROCm/DCU instead of torch.ops._C.cutlass_scaled_mm - if ops is not None and hasattr(ops, "blaslt_scaled_mm"): - # Ensure out_dtype is bfloat16 or float16 as required by blaslt_scaled_mm - out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 - - # Ensure input tensor is contiguous for optimal performance - input_tensor_quant = input_tensor_quant.contiguous() - - output_tensor = ops.blaslt_scaled_mm( - input_tensor_quant, - self.weight, - input_tensor_scale, - self.weight_scale, - out_dtype, - self.bias if self.bias is not None else None, - ) - - # Convert back to original dtype if needed - if output_tensor.dtype != dtype: - output_tensor = output_tensor.to(dtype) - else: - # Fallback: use manual dequantization and matmul - input_dequant = input_tensor_quant.to(dtype) * input_tensor_scale.to(dtype) - weight_dequant = self.weight.to(dtype) * self.weight_scale.to(dtype) - output_tensor = torch.matmul(input_dequant, weight_dequant) - if self.bias is not None: - output_tensor = output_tensor + self.bias + out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 + input_2d, prefix_shape, input_tensor_quant, input_tensor_scale = quantized_input + output_tensor = self._apply_quantized_2d(input_2d, input_tensor_quant, input_tensor_scale, out_dtype) - return output_tensor + if output_tensor.dtype != dtype: + output_tensor = output_tensor.to(dtype) + return _restore_last_dim(output_tensor, prefix_shape) + + def _apply_bf16(self, input_tensor): + if self.weight.dtype != input_tensor.dtype: + self.weight = self.weight.to(input_tensor.dtype) + bias = _bias_or_none(self, input_tensor.dtype) + return F.linear(input_tensor, self.weight, bias) + + def apply(self, input_tensor): + if self.use_bf16_fallback: + return self._apply_bf16(input_tensor) + return self.apply_quantized_input(input_tensor, self.prepare_quantized_input(input_tensor)) diff --git a/lightx2v_platform/ops/mm/template.py b/lightx2v_platform/ops/mm/template.py index 28f519aa2..033663e5e 100644 --- a/lightx2v_platform/ops/mm/template.py +++ b/lightx2v_platform/ops/mm/template.py @@ -70,7 +70,10 @@ def __init__( # weight load functions # ========================= def load(self, weight_dict): - self.load_quantized(weight_dict) + if self.load_func is None: + self.load_quantized(weight_dict) + else: + self.load_func(weight_dict) if self.weight_need_transpose: if hasattr(self, "weight") and self.weight is not None: self.weight = self.weight.t()