Skip to content

Commit 67b1544

Browse files
committed
Extend LoRA adapter tests with transpose scenarios and skip unsupported cases
1 parent c6632b4 commit 67b1544

6 files changed

Lines changed: 251 additions & 46 deletions

File tree

src/nncf/common/tensor_statistics/statistics.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,13 @@ def __eq__(self, other: Any) -> bool:
276276
return mean_values_equal
277277

278278
def _get_serialized_data(self) -> dict[str, Tensor]:
279-
backend = self.mean_values[0].backend
280-
device = self.mean_values[0].device
281279
return {
282280
self.MEAN_STAT: fns.stack(self.mean_values),
283281
self.SHAPE_STAT: fns.tensor(
284282
self.shape_values,
285-
backend=backend,
283+
backend=self.mean_values[0].backend,
286284
dtype=TensorDataType.int32,
287-
device=device,
285+
device=self.mean_values[0].device,
288286
),
289287
}
290288

src/nncf/quantization/algorithms/weight_compression/activation_stats.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from nncf.tensor import functions as fns
1818

1919

20-
def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int = -1) -> tuple[Tensor, Tensor]:
20+
def process_stats(
21+
stats: WCTensorStatistic,
22+
subset_size: int,
23+
act_ch_axis: int = -1,
24+
transpose_a: bool = False,
25+
) -> tuple[Tensor, Tensor]:
2126
"""
2227
A function for processing activations. Shared between AWQ, Scale Estimation and LoRA Correction algorithms.
2328
@@ -37,8 +42,13 @@ def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int =
3742
axes = list(range(1, len(X.shape))) + [0]
3843
X_full = fns.transpose(X, axes=axes)
3944

40-
# The sample dimension is always the last axis after transpose
41-
sample_axis = -1
45+
if transpose_a:
46+
axes = list(range(len(X_full.shape)))
47+
axes[-1], axes[-2] = axes[-2], axes[-1]
48+
X_full = fns.transpose(X_full, axes=axes)
49+
50+
# The sample dimension is axis -1 by default, but moves to -2 if transpose_a is True
51+
sample_axis = -2 if transpose_a else -1
4252

4353
# Prevent high memory and time consumption by sampling
4454
if X_full.shape[sample_axis] > subset_size and subset_size > 0:
@@ -47,11 +57,13 @@ def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int =
4757
]
4858
step = X_full.shape[sample_axis] // subset_size
4959
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
50-
X = X_full[..., idxs]
60+
if transpose_a:
61+
X = X_full[..., idxs, :]
62+
else:
63+
X = X_full[..., idxs]
5164
else:
5265
X = X_full
5366

54-
# Compute max magnitude along the sample axis (last axis)
55-
# Result: [HiddenDim] or [No. of Experts, HiddenDim]
67+
# Compute max magnitude along the sample axis
5668
s = fns.max(fns.abs(X_full), axis=sample_axis)
5769
return s, X

src/nncf/quantization/algorithms/weight_compression/algorithm.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,11 +1152,6 @@ def apply_with_parameters(
11521152
)
11531153

11541154
if self._lora_correction:
1155-
for wc_params in all_weight_params:
1156-
if self._backend_entity.matmul_has_transposed_activations(wc_params.node_with_weight, graph):
1157-
msg = "Transposed activations are not supported yet for the LoRa correction algorithm"
1158-
raise nncf.UnsupportedModelError(msg)
1159-
11601155
lora_correction_params = self._advanced_parameters.lora_correction_params
11611156
lora_correction_algo = LoraCorrectionAlgorithm(statistics, lora_correction_params)
11621157
description += " with correction of low-rank adapters"
@@ -1370,7 +1365,7 @@ def _get_statistics_for_weights_compression(
13701365
# Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions,
13711366
# shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size).
13721367
statistics = {}
1373-
for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items():
1368+
for (act_node, output_port_id, _act_channel_axis), matmul_nodes in matmul_input_to_output_nodes_map.items():
13741369
tensor_collectors = list(
13751370
statistic_points.get_algo_statistics_for_node(
13761371
act_node.node_name,

src/nncf/quantization/algorithms/weight_compression/lora_correction.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111
from pathlib import Path
12-
from typing import Optional
13-
14-
import pandas as pd
1512

1613
import nncf
1714
from nncf.common.logging import nncf_logger
@@ -43,9 +40,10 @@ def __init__(self):
4340
def add_noises(self, layer_name: str, value: float):
4441
self._noise_per_layer[layer_name] = value
4542

46-
@skip_if_dependency_unavailable(dependencies=["matplotlib.pyplot"])
43+
@skip_if_dependency_unavailable(dependencies=["matplotlib.pyplot", "pandas"])
4744
def dump_data(self):
4845
import matplotlib.pyplot as plt
46+
import pandas as pd
4947

5048
if not self._noise_per_layer:
5149
return
@@ -108,27 +106,35 @@ def is_applicable(self, wc_params: WeightCompressionParameters):
108106
return wc_params.compression_config.num_bits == 4
109107

110108
def calculate_adapters(
111-
self, weight: Tensor, compressed_weight: CompressedWeight, wc_params: WeightCompressionParameters
109+
self,
110+
weight: Tensor,
111+
compressed_weight: CompressedWeight,
112+
wc_params: WeightCompressionParameters,
113+
act_ch_axis: int,
112114
) -> tuple[Tensor, Tensor, list[float]]:
113115
"""
114116
Calculates low rank matrices for a given original and compressed weights.
115117
116118
:param weight: original floating-point weight matrix.
117119
:param compressed_weight: compressed weight matrix.
118120
:param wc_params: parameters of weight compression.
121+
:param act_ch_axis: axis number of the activation tensor which correspond to it channel.
119122
:return: two low rank matrices in the order of execution of corresponding linear layers.
120123
"""
121124
layer_name = wc_params.node_with_weight.node_name
122125
layer_statistics = self._statistics[layer_name]
123126
is_debug = self._debug_interface is not None
127+
transpose_a_flag = getattr(wc_params.node_with_weight, "transpose_a", False)
124128
lora_A, lora_B, mean_noises = self.calculate_low_rank_matrices(
125129
weight,
126130
compressed_weight,
127131
wc_params.compression_config,
128132
wc_params.reduction_axes,
129133
self._lora_correction_params,
130134
layer_statistics,
135+
act_ch_axis,
131136
is_debug,
137+
transpose_a=transpose_a_flag,
132138
)
133139
if is_debug:
134140
self._debug_interface.add_noises(layer_name, mean_noises)
@@ -142,7 +148,9 @@ def calculate_low_rank_matrices(
142148
reduction_axes: tuple[int, ...],
143149
lora_correction_params: AdvancedLoraCorrectionParameters,
144150
layer_statistics: WCTensorStatistic,
145-
is_debug: Optional[bool] = False,
151+
act_ch_axis: int,
152+
is_debug: bool | None = False,
153+
transpose_a: bool = False,
146154
):
147155
"""
148156
Calculates low rank matrices for a given original and compressed weights.
@@ -157,6 +165,7 @@ def calculate_low_rank_matrices(
157165
:param reduction_axes: axes along which different statistics reduced.
158166
:param lora_correction_params: parameters to configure the algorithm.
159167
:param layer_statistics: an object containing statistics for the layer.
168+
:param act_ch_axis: axis number of the activation tensor which correspond to it channel.
160169
:param is_debug: whether to collect debug information, defaults to False.
161170
:return: two low rank matrices in the order of execution of corresponding linear layers and list of mean noises.
162171
Noises are collected from each step of the algorithm if debug was enabled.
@@ -170,7 +179,15 @@ def calculate_low_rank_matrices(
170179
)
171180
mode = compression_config.mode
172181
assert len(reduction_axes) == 1, "Assumed a single reduction axis"
173-
reduction_axis = reduction_axes[0] if compression_config.group_size != -1 else -1
182+
183+
if compression_config.group_size != -1:
184+
reduction_axis = reduction_axes[0]
185+
else:
186+
reduction_axis = -1
187+
188+
if transpose_a and reduction_axis != -1:
189+
reduction_axis = 1
190+
174191
if mode in (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM):
175192
fq_weights = do_integer_dequantization(
176193
compressed_weight.tensor,
@@ -194,16 +211,8 @@ def calculate_low_rank_matrices(
194211
svd_residual = fns.transpose(svd_residual)
195212
residual = svd_residual.clone() # [H, O]
196213

197-
# Get the activation channel axis
198-
act_ch_axis = getattr(layer_statistics, "act_ch_axis", -1) # default to last axis
199-
200-
# Pass it to process_stats
201-
s, X = process_stats(layer_statistics, subset_size, act_ch_axis)
202-
203-
# Conditionally transpose X so samples are rows and channels are columns
204-
if act_ch_axis != 0: # if channel is not already the first axis
205-
X = fns.transpose(X, axes=(1, 0)) # [SS, H]
206-
214+
# Pass it to process_stats with transpose_a=True to get [SS, H] layout
215+
s, X = process_stats(layer_statistics, subset_size, act_ch_axis, transpose_a=True)
207216
if compression_config.group_size > 0:
208217
# Multiply residual of weights by maximum channel magnitude of activations normalized per quantization
209218
# group. As a consequence, weights corresponding to a "noisy" activations has a higher error to correct.

src/nncf/quantization/algorithms/weight_compression/openvino_backend.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def insert_adapters(
206206
A_W = opset.constant(lora_A.data)
207207
B_W = opset.constant(lora_B.data)
208208

209-
A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
209+
transpose_a = wc_params.node_with_weight.layer_attributes.input_attributes["transpose"]
210+
A_MM = opset.matmul(input_node, A_W, transpose_a=transpose_a, transpose_b=True)
210211
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)
211212

212213
node_output_port = mm_node.output(0)
@@ -349,7 +350,15 @@ def transform_model(
349350
compressed_weight.tensor = compressed_weight.tensor.as_numpy_tensor()
350351
if compressed_weight.zero_point is not None:
351352
compressed_weight.zero_point = compressed_weight.zero_point.as_numpy_tensor()
352-
adapters = lora_correction_algo.calculate_adapters(weight, compressed_weight, wc_params)
353+
354+
activation_port_id = self.get_activation_port_id(wc_params.node_with_weight, graph)
355+
activation_edge = graph.get_input_edge_by_port_id(wc_params.node_with_weight, activation_port_id)
356+
activation_shape = activation_edge.tensor_shape
357+
act_ch_axis = self.get_activation_channel_axis(
358+
wc_params.node_with_weight, activation_port_id, activation_shape
359+
)
360+
361+
adapters = lora_correction_algo.calculate_adapters(weight, compressed_weight, wc_params, act_ch_axis)
353362
self.insert_adapters(wc_params, *adapters, int8_lora=lora_correction_algo.use_int8_adapters)
354363
self.name_to_node_mapping = None
355364

0 commit comments

Comments
 (0)