From 709a45c36218e3bf17f518cdbd211abc16c72057 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 9 Jun 2026 21:11:12 +0200 Subject: [PATCH 1/2] [Torch] Lora kernels compilation --- .../torch/quantization/quantize_functions.py | 35 +++++++++ src/nncf/torch/quantization/reference.py | 72 +++++++++++-------- src/nncf/torch/utils.py | 4 ++ 3 files changed, 81 insertions(+), 30 deletions(-) diff --git a/src/nncf/torch/quantization/quantize_functions.py b/src/nncf/torch/quantization/quantize_functions.py index 047b376164a..3ee66a0e35e 100644 --- a/src/nncf/torch/quantization/quantize_functions.py +++ b/src/nncf/torch/quantization/quantize_functions.py @@ -19,6 +19,7 @@ from nncf.torch.quantization.extensions import QuantizedFunctionsCPU from nncf.torch.quantization.extensions import QuantizedFunctionsCUDA from nncf.torch.quantization.reference import ReferenceQuantizedFunctions as RQ +from nncf.torch.utils import CompilationWrapper from nncf.torch.utils import add_ov_domain @@ -302,6 +303,32 @@ def asymmetric_quantize_lora( ) if skip: return input_ + return _asymmetric_quantize_lora( + input_, + input_shape, + A, + B, + input_low_, + input_range_, + level_low, + level_high, + levels, + eps, + ) + + +def _asymmetric_quantize_lora( + input_, + input_shape, + A, + B, + input_low_, + input_range_, + level_low, + level_high, + levels, + eps, +): input_range_safe = abs(input_range_) + eps input_low, input_range = TuneRange.apply(input_low_, input_range_safe, levels) input_ = (input_ + B @ A).type(input_.dtype) # input(float16) + lora(bfloat16) = float32, need a cast to float16 @@ -334,6 +361,10 @@ def symmetric_quantize_lora(input_, input_shape, A, B, scale, level_low, level_h ) if skip: return input_ + return _symmetric_quantize_lora(input_, input_shape, A, B, scale, level_low, level_high, levels, eps) + + +def _symmetric_quantize_lora(input_, input_shape, A, B, scale, level_low, level_high, levels, eps): scale_safe = torch.where(torch.abs(scale) < eps, eps, scale) input_ = (input_ + B @ A).type(input_.dtype) # input(float16) + lora(bfloat16) = float32, need a cast to float16 return QuantizeSymmetricTorch.apply( @@ -471,3 +502,7 @@ def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor: """ t = unpack_uint4(packed_tensor) return t.type(torch.int8) - 8 + + +_asymmetric_quantize_lora = CompilationWrapper(_asymmetric_quantize_lora) +_symmetric_quantize_lora = CompilationWrapper(_symmetric_quantize_lora) diff --git a/src/nncf/torch/quantization/reference.py b/src/nncf/torch/quantization/reference.py index cf719bb62f7..49611e45fd2 100644 --- a/src/nncf/torch/quantization/reference.py +++ b/src/nncf/torch/quantization/reference.py @@ -21,34 +21,6 @@ GeneralizedTensor = TypeVar("GeneralizedTensor", torch.Tensor, np.ndarray) -def fp32_accum_wrapper(func): - def wrapper(tensor_to_sum, ret_tensor): - half = tensor_to_sum.dtype == np.float16 - if half: - tensor_to_sum = tensor_to_sum.astype(np.float32) - retval = func(tensor_to_sum, ret_tensor) - if half: - retval = retval.astype(np.float16) - return retval - - return wrapper - - -@fp32_accum_wrapper -def sum_like(tensor_to_sum, ref_tensor): - """Warning: may modify tensor_to_sum""" - if ref_tensor.size == 1: - return tensor_to_sum.sum() - - for dim, size in enumerate(ref_tensor.shape): - if size == 1: - if isinstance(tensor_to_sum, np.ndarray): - tensor_to_sum = tensor_to_sum.sum(dim, keepdims=True) - else: - tensor_to_sum = tensor_to_sum.sum(dim, keepdim=True) - return tensor_to_sum - - class ReferenceBackendType(Enum): NUMPY = "numpy" TORCH = "torch" @@ -79,6 +51,46 @@ def _reciprocal(self, tensor: GeneralizedTensor) -> GeneralizedTensor: return np.reciprocal(tensor) return torch.reciprocal(tensor) + def _sum_like(self, tensor_to_sum: GeneralizedTensor, ref_tensor: GeneralizedTensor): + """Warning: may modify tensor_to_sum""" + if self.backend is np: + half = tensor_to_sum.dtype == np.float16 + if half: + tensor_to_sum = tensor_to_sum.astype(np.float32) + retval = self._sum_like_fp32(tensor_to_sum, ref_tensor) + if half: + retval = retval.astype(np.float16) + return retval + + half = tensor_to_sum.dtype == torch.float16 + if half: + tensor_to_sum = tensor_to_sum.type(torch.float32) + retval = self._sum_like_fp32(tensor_to_sum, ref_tensor) + if half: + retval = retval.type(torch.float16) + return retval + + def _sum_like_fp32(self, tensor_to_sum: GeneralizedTensor, ref_tensor: GeneralizedTensor): + """Warning: may modify tensor_to_sum""" + if self.backend is np: + n_elements = ref_tensor.size + if n_elements == 1: + return tensor_to_sum.sum() + + for dim, size in enumerate(ref_tensor.shape): + if size == 1: + tensor_to_sum = tensor_to_sum.sum(dim, keepdims=True) + return tensor_to_sum + + n_elements = ref_tensor.numel() + if n_elements == 1: + return tensor_to_sum.sum() + + for dim, size in enumerate(ref_tensor.shape): + if size == 1: + tensor_to_sum = tensor_to_sum.sum(dim, keepdim=True) + return tensor_to_sum + def forward( self, input_: GeneralizedTensor, input_low: GeneralizedTensor, input_range: GeneralizedTensor, levels: int ) -> GeneralizedTensor: @@ -114,12 +126,12 @@ def backward( output = self.forward(input_, input_low, input_range, levels) err = (output - input_) * self._reciprocal(input_range * range_sign) grad_range = grad_output * (err * mask_in + range_sign * (level_low / level_high) * mask_lo + mask_hi) - grad_range = sum_like(grad_range, input_range) + grad_range = self._sum_like(grad_range, input_range) grad_input = grad_output * mask_in grad_low = grad_output * (mask_hi + mask_lo) - grad_low = sum_like(grad_low, input_low) + grad_low = self._sum_like(grad_low, input_low) return [grad_input, grad_low, grad_range] def tune_range( diff --git a/src/nncf/torch/utils.py b/src/nncf/torch/utils.py index 1f924f542e4..1c2727a67e9 100644 --- a/src/nncf/torch/utils.py +++ b/src/nncf/torch/utils.py @@ -171,6 +171,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: :return: Result of the function call. """ + # Prevent nested compilation + if torch.compiler.is_compiling(): + return self._func(*args, **kwargs) + if self._compiled_func is None: try: self._compiled_func = torch.compile(self._func) From ad6e3cf3ca2100eeb719a766fa759f0cb133980f Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 16 Jun 2026 18:26:23 +0200 Subject: [PATCH 2/2] Fix precommit --- src/nncf/torch/quantization/reference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nncf/torch/quantization/reference.py b/src/nncf/torch/quantization/reference.py index 49611e45fd2..f9c08529618 100644 --- a/src/nncf/torch/quantization/reference.py +++ b/src/nncf/torch/quantization/reference.py @@ -75,7 +75,7 @@ def _sum_like_fp32(self, tensor_to_sum: GeneralizedTensor, ref_tensor: Generaliz if self.backend is np: n_elements = ref_tensor.size if n_elements == 1: - return tensor_to_sum.sum() + return tensor_to_sum.sum().reshape(ref_tensor.shape) for dim, size in enumerate(ref_tensor.shape): if size == 1: @@ -84,7 +84,7 @@ def _sum_like_fp32(self, tensor_to_sum: GeneralizedTensor, ref_tensor: Generaliz n_elements = ref_tensor.numel() if n_elements == 1: - return tensor_to_sum.sum() + return tensor_to_sum.sum().reshape(ref_tensor.shape) for dim, size in enumerate(ref_tensor.shape): if size == 1: