99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111from pathlib import Path
12- from typing import Optional
13-
14- import pandas as pd
1512
1613import nncf
1714from nncf .common .logging import nncf_logger
@@ -43,9 +40,10 @@ def __init__(self):
4340 def add_noises (self , layer_name : str , value : float ):
4441 self ._noise_per_layer [layer_name ] = value
4542
46- @skip_if_dependency_unavailable (dependencies = ["matplotlib.pyplot" ])
43+ @skip_if_dependency_unavailable (dependencies = ["matplotlib.pyplot" , "pandas" ])
4744 def dump_data (self ):
4845 import matplotlib .pyplot as plt
46+ import pandas as pd
4947
5048 if not self ._noise_per_layer :
5149 return
@@ -108,27 +106,35 @@ def is_applicable(self, wc_params: WeightCompressionParameters):
108106 return wc_params .compression_config .num_bits == 4
109107
110108 def calculate_adapters (
111- self , weight : Tensor , compressed_weight : CompressedWeight , wc_params : WeightCompressionParameters
109+ self ,
110+ weight : Tensor ,
111+ compressed_weight : CompressedWeight ,
112+ wc_params : WeightCompressionParameters ,
113+ act_ch_axis : int ,
112114 ) -> tuple [Tensor , Tensor , list [float ]]:
113115 """
114116 Calculates low rank matrices for a given original and compressed weights.
115117
116118 :param weight: original floating-point weight matrix.
117119 :param compressed_weight: compressed weight matrix.
118120 :param wc_params: parameters of weight compression.
121+ :param act_ch_axis: axis number of the activation tensor which correspond to it channel.
119122 :return: two low rank matrices in the order of execution of corresponding linear layers.
120123 """
121124 layer_name = wc_params .node_with_weight .node_name
122125 layer_statistics = self ._statistics [layer_name ]
123126 is_debug = self ._debug_interface is not None
127+ transpose_a_flag = getattr (wc_params .node_with_weight , "transpose_a" , False )
124128 lora_A , lora_B , mean_noises = self .calculate_low_rank_matrices (
125129 weight ,
126130 compressed_weight ,
127131 wc_params .compression_config ,
128132 wc_params .reduction_axes ,
129133 self ._lora_correction_params ,
130134 layer_statistics ,
135+ act_ch_axis ,
131136 is_debug ,
137+ transpose_a = transpose_a_flag ,
132138 )
133139 if is_debug :
134140 self ._debug_interface .add_noises (layer_name , mean_noises )
@@ -142,7 +148,9 @@ def calculate_low_rank_matrices(
142148 reduction_axes : tuple [int , ...],
143149 lora_correction_params : AdvancedLoraCorrectionParameters ,
144150 layer_statistics : WCTensorStatistic ,
145- is_debug : Optional [bool ] = False ,
151+ act_ch_axis : int ,
152+ is_debug : bool | None = False ,
153+ transpose_a : bool = False ,
146154 ):
147155 """
148156 Calculates low rank matrices for a given original and compressed weights.
@@ -157,6 +165,7 @@ def calculate_low_rank_matrices(
157165 :param reduction_axes: axes along which different statistics reduced.
158166 :param lora_correction_params: parameters to configure the algorithm.
159167 :param layer_statistics: an object containing statistics for the layer.
168+ :param act_ch_axis: axis number of the activation tensor which correspond to it channel.
160169 :param is_debug: whether to collect debug information, defaults to False.
161170 :return: two low rank matrices in the order of execution of corresponding linear layers and list of mean noises.
162171 Noises are collected from each step of the algorithm if debug was enabled.
@@ -170,7 +179,15 @@ def calculate_low_rank_matrices(
170179 )
171180 mode = compression_config .mode
172181 assert len (reduction_axes ) == 1 , "Assumed a single reduction axis"
173- reduction_axis = reduction_axes [0 ] if compression_config .group_size != - 1 else - 1
182+
183+ if compression_config .group_size != - 1 :
184+ reduction_axis = reduction_axes [0 ]
185+ else :
186+ reduction_axis = - 1
187+
188+ if transpose_a and reduction_axis != - 1 :
189+ reduction_axis = 1
190+
174191 if mode in (CompressWeightsMode .INT4_SYM , CompressWeightsMode .INT4_ASYM ):
175192 fq_weights = do_integer_dequantization (
176193 compressed_weight .tensor ,
@@ -194,16 +211,8 @@ def calculate_low_rank_matrices(
194211 svd_residual = fns .transpose (svd_residual )
195212 residual = svd_residual .clone () # [H, O]
196213
197- # Get the activation channel axis
198- act_ch_axis = getattr (layer_statistics , "act_ch_axis" , - 1 ) # default to last axis
199-
200- # Pass it to process_stats
201- s , X = process_stats (layer_statistics , subset_size , act_ch_axis )
202-
203- # Conditionally transpose X so samples are rows and channels are columns
204- if act_ch_axis != 0 : # if channel is not already the first axis
205- X = fns .transpose (X , axes = (1 , 0 )) # [SS, H]
206-
214+ # Pass it to process_stats with transpose_a=True to get [SS, H] layout
215+ s , X = process_stats (layer_statistics , subset_size , act_ch_axis , transpose_a = True )
207216 if compression_config .group_size > 0 :
208217 # Multiply residual of weights by maximum channel magnitude of activations normalized per quantization
209218 # group. As a consequence, weights corresponding to a "noisy" activations has a higher error to correct.
0 commit comments