Skip to content

Commit 5cbe607

Browse files
Shehrozkashifclaude
andcommitted
Fix transpose_a support in LoRA Correction: remove getattr bug and wrong reduction_axis override
- Replace `getattr(node, "transpose_a", False)` (always returned False since NNCFNode has no such attribute) with proper access via `layer_attributes.input_attributes["transpose"]`, then remove the now-unused `transpose_a_flag` and the `transpose_a` parameter from `calculate_low_rank_matrices`. - Remove the `if transpose_a and reduction_axis != -1: reduction_axis = 1` block which would have incorrectly overridden the H-axis group-quantization index (e.g. setting it to 1 for a [H, O] weight where H is at axis 0). - Revert the unrelated inlining of `backend`/`device` locals in `WCTensorStatistic._get_serialized_data` to keep the diff focused. - Fix the `process_stats` docstring to accurately describe the new `transpose_a` parameter and the two possible return layouts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3803aa6 commit 5cbe607

3 files changed

Lines changed: 8 additions & 10 deletions

File tree

src/nncf/common/tensor_statistics/statistics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,15 @@ def __eq__(self, other: Any) -> bool:
276276
return mean_values_equal
277277

278278
def _get_serialized_data(self) -> dict[str, Tensor]:
279+
backend = self.mean_values[0].backend
280+
device = self.mean_values[0].device
279281
return {
280282
self.MEAN_STAT: fns.stack(self.mean_values),
281283
self.SHAPE_STAT: fns.tensor(
282284
self.shape_values,
283-
backend=self.mean_values[0].backend,
285+
backend=backend,
284286
dtype=TensorDataType.int32,
285-
device=self.mean_values[0].device,
287+
device=device,
286288
),
287289
}
288290

src/nncf/quantization/algorithms/weight_compression/activation_stats.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ def process_stats(
2929
:param stats: An object containing statistics for the layer.
3030
:param subset_size: The number of samples for AWQ. If subset_size <= 0, all samples are used.
3131
:param act_ch_axis: The activation channel axis.
32+
:param transpose_a: When True, returns X in [SampleSize, HiddenDim] layout instead of the default
33+
[HiddenDim, SampleSize]. Used by LoRA Correction which requires samples as rows.
3234
:return: tuple of the following tensors:
33-
s - maximum channel magnitude across samples [HiddenDim]
34-
X - average channel magnitude across tokens in the sequence [HiddenDim, min(SampleSize, ~subset_size)]
35+
s - maximum channel magnitude across samples, shape [HiddenDim]
36+
X - activation matrix, shape [HiddenDim, SampleSize] normally or [SampleSize, HiddenDim] if transpose_a=True
3537
"""
3638
X = fns.stack(
3739
stats.mean_values

src/nncf/quantization/algorithms/weight_compression/lora_correction.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def calculate_adapters(
124124
layer_name = wc_params.node_with_weight.node_name
125125
layer_statistics = self._statistics[layer_name]
126126
is_debug = self._debug_interface is not None
127-
transpose_a_flag = getattr(wc_params.node_with_weight, "transpose_a", False)
128127
lora_A, lora_B, mean_noises = self.calculate_low_rank_matrices(
129128
weight,
130129
compressed_weight,
@@ -134,7 +133,6 @@ def calculate_adapters(
134133
layer_statistics,
135134
act_ch_axis,
136135
is_debug,
137-
transpose_a=transpose_a_flag,
138136
)
139137
if is_debug:
140138
self._debug_interface.add_noises(layer_name, mean_noises)
@@ -150,7 +148,6 @@ def calculate_low_rank_matrices(
150148
layer_statistics: WCTensorStatistic,
151149
act_ch_axis: int,
152150
is_debug: bool | None = False,
153-
transpose_a: bool = False,
154151
):
155152
"""
156153
Calculates low rank matrices for a given original and compressed weights.
@@ -185,9 +182,6 @@ def calculate_low_rank_matrices(
185182
else:
186183
reduction_axis = -1
187184

188-
if transpose_a and reduction_axis != -1:
189-
reduction_axis = 1
190-
191185
if mode in (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM):
192186
fq_weights = do_integer_dequantization(
193187
compressed_weight,

0 commit comments

Comments
 (0)