6060 ),
6161}
6262
63+ test_data_rank2_FP = {
64+ # test_name: (test_data, out_features, has_bias)
65+ "model_linear_rank2_zeros" : lambda : (
66+ torch .zeros (10 , 20 ),
67+ 15 ,
68+ True ,
69+ ),
70+ "model_linear_rank2_ones" : lambda : (
71+ torch .ones (2 , 240 ),
72+ 960 ,
73+ False ,
74+ ),
75+ "model_linear_rank2_negative_ones" : lambda : (
76+ torch .ones (10 , 20 ) * (- 1 ),
77+ 20 ,
78+ True ,
79+ ),
80+ "model_linear_rank2_rand" : lambda : (
81+ torch .rand (2 , 240 ),
82+ 960 ,
83+ True ,
84+ ),
85+ "model_linear_rank2_negative_large_rand" : lambda : (
86+ torch .rand (10 , 20 ) * (- 100 ),
87+ 30 ,
88+ False ,
89+ ),
90+ "model_linear_rank2_large_randn" : lambda : (
91+ torch .randn (15 , 20 ) * 100 ,
92+ 20 ,
93+ True ,
94+ ),
95+ }
96+
6397test_data_rank4_FP = {
6498 # test_name: (test_data, out_features, has_bias)
6599 "model_linear_rank4_zeros" : lambda : (
101135 for q in [True , False ]
102136}
103137
138+ # Generate a new test set paired with per_channel_quant=True/False.
139+ test_data_rank2_INT = {
140+ f"{ k } ,per_channel_quant={ q } " : (lambda v = v , q = q : (* v (), q ))
141+ for (k , v ) in test_data_rank2_FP .items ()
142+ for q in [True , False ]
143+ }
144+
104145# Generate a new test set paired with per_channel_quant=True/False.
105146test_data_rank4_INT = {
106147 f"{ k } ,per_channel_quant={ q } " : (lambda v = v , q = q : (* v (), q ))
@@ -192,7 +233,10 @@ def test_linear_tosa_INT_a8w4(test_data: torch.Tensor):
192233 pipeline .run ()
193234
194235
195- @common .parametrize ("test_data" , test_data_rank1_INT )
236+ @common .parametrize (
237+ "test_data" ,
238+ test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT ,
239+ )
196240@common .XfailIfNoCorstone300
197241def test_linear_u55_INT (test_data : torch .Tensor ):
198242 test_data , out_features , has_bias , per_channel_quantization = test_data ()
@@ -213,7 +257,7 @@ def test_linear_u55_INT(test_data: torch.Tensor):
213257
214258@common .parametrize (
215259 "test_data" ,
216- test_data_rank1_INT | test_data_rank4_INT ,
260+ test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT ,
217261)
218262@common .XfailIfNoCorstone320
219263def test_linear_u85_INT (test_data : torch .Tensor ):
@@ -281,7 +325,7 @@ def test_linear_vgf_quant_a8w4(test_data: torch.Tensor):
281325 pipeline .run ()
282326
283327
284- test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT
328+ test_data_all_16a8w = test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT
285329
286330
287331@common .parametrize ("test_data" , test_data_all_16a8w )
0 commit comments