feat: add Hygon DCU INT8 hipBLASLt GEMM path#1199
Conversation
(cherry picked from commit a3a1a1f870b768929d8ca073f0c74added572087)
Add reusable quantized-input helpers for Hygon DCU W8A8 dynamic activation GEMMs, and support selective BF16 fallback for configured INT8 weights. (cherry picked from commit 58dab25b69c41c6ec9a24df0fe584ca93534eacc)
There was a problem hiding this comment.
Code Review
This pull request introduces a selective BF16 fallback mechanism and integrates hipblaslt_w8a8_channelwise_gemm for Hygon DCU in the quantization pipeline, along with helper functions for managing weights, biases, and environment flags. It also updates module loading logic to support custom load functions. The review feedback highlights opportunities to optimize performance and robustness, specifically by caching the casted weight tensor in the BF16 fallback path to avoid redundant casting, removing redundant .contiguous() calls on already contiguous weights during GEMM execution, and safely checking for the existence of module.weight to prevent potential AttributeErrors when loading auto-quantized biases.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _apply_bf16(self, input_tensor): | ||
| weight = self.weight | ||
| if weight.dtype != input_tensor.dtype: | ||
| weight = weight.to(input_tensor.dtype) | ||
| bias = _bias_or_none(self, input_tensor.dtype) | ||
| return F.linear(input_tensor, weight, bias) |
There was a problem hiding this comment.
Casting a large weight tensor on every single forward pass is extremely inefficient and will cause significant GPU memory churn and latency overhead. We should cache the casted weight back to self.weight so that subsequent forward passes can reuse it directly.
| def _apply_bf16(self, input_tensor): | |
| weight = self.weight | |
| if weight.dtype != input_tensor.dtype: | |
| weight = weight.to(input_tensor.dtype) | |
| bias = _bias_or_none(self, input_tensor.dtype) | |
| return F.linear(input_tensor, weight, bias) | |
| def _apply_bf16(self, input_tensor): | |
| if self.weight.dtype != input_tensor.dtype: | |
| self.weight = self.weight.to(input_tensor.dtype) | |
| bias = _bias_or_none(self, input_tensor.dtype) | |
| return F.linear(input_tensor, self.weight, bias) |
| bias = bias.to(torch.float32) | ||
| elif hasattr(module, "infer_dtype"): | ||
| bias = bias.to(module.infer_dtype) | ||
| module.bias = bias.to(module.weight.device) |
There was a problem hiding this comment.
To prevent a potential AttributeError if module.weight is None or not yet initialized, we should safely check for its existence and fallback to the bias tensor's own device.
| module.bias = bias.to(module.weight.device) | |
| device = module.weight.device if getattr(module, "weight", None) is not None else bias.device | |
| module.bias = bias.to(device) |
| _, output_tensor = hipblaslt_gemm( | ||
| a=input_tensor_quant.contiguous(), | ||
| b=self.weight.contiguous(), | ||
| scale_a=input_tensor_scale.contiguous(), | ||
| scale_b=self.weight_scale.contiguous(), |
There was a problem hiding this comment.
Since self.weight and self.weight_scale are already made contiguous during the load() phase (via _make_weight_contiguous), calling .contiguous() on them during every forward pass is redundant and adds unnecessary overhead.
| _, output_tensor = hipblaslt_gemm( | |
| a=input_tensor_quant.contiguous(), | |
| b=self.weight.contiguous(), | |
| scale_a=input_tensor_scale.contiguous(), | |
| scale_b=self.weight_scale.contiguous(), | |
| _, output_tensor = hipblaslt_gemm( | |
| a=input_tensor_quant.contiguous(), | |
| b=self.weight, | |
| scale_a=input_tensor_scale.contiguous(), | |
| scale_b=self.weight_scale, |
Summary
int8-vllm-hygon-dcuMM weight backend using Hygon DCU hipBLASLt W8A8 channelwise GEMMWhy
This enables faster INT8 GEMM execution on Hygon DCU while keeping the path gated by the quantization scheme and clear dependency checks.
Validation
ModelTC/LightX2V:main(89dfa833)git diff --checkpassed for the PR branch