Skip to content

Commit 3f16e5b

Browse files
authored
Arm backend: Add rank2 unit tests to test_linear (#17254)
Summary: Add rank2 unit tests to test_linear Differential Revision: D92111806
1 parent af22f84 commit 3f16e5b

1 file changed

Lines changed: 47 additions & 3 deletions

File tree

backends/arm/test/ops/test_linear.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,40 @@
6060
),
6161
}
6262

63+
test_data_rank2_FP = {
64+
# test_name: (test_data, out_features, has_bias)
65+
"model_linear_rank2_zeros": lambda: (
66+
torch.zeros(10, 20),
67+
15,
68+
True,
69+
),
70+
"model_linear_rank2_ones": lambda: (
71+
torch.ones(2, 240),
72+
960,
73+
False,
74+
),
75+
"model_linear_rank2_negative_ones": lambda: (
76+
torch.ones(10, 20) * (-1),
77+
20,
78+
True,
79+
),
80+
"model_linear_rank2_rand": lambda: (
81+
torch.rand(2, 240),
82+
960,
83+
True,
84+
),
85+
"model_linear_rank2_negative_large_rand": lambda: (
86+
torch.rand(10, 20) * (-100),
87+
30,
88+
False,
89+
),
90+
"model_linear_rank2_large_randn": lambda: (
91+
torch.randn(15, 20) * 100,
92+
20,
93+
True,
94+
),
95+
}
96+
6397
test_data_rank4_FP = {
6498
# test_name: (test_data, out_features, has_bias)
6599
"model_linear_rank4_zeros": lambda: (
@@ -101,6 +135,13 @@
101135
for q in [True, False]
102136
}
103137

138+
# Generate a new test set paired with per_channel_quant=True/False.
139+
test_data_rank2_INT = {
140+
f"{k},per_channel_quant={q}": (lambda v=v, q=q: (*v(), q))
141+
for (k, v) in test_data_rank2_FP.items()
142+
for q in [True, False]
143+
}
144+
104145
# Generate a new test set paired with per_channel_quant=True/False.
105146
test_data_rank4_INT = {
106147
f"{k},per_channel_quant={q}": (lambda v=v, q=q: (*v(), q))
@@ -192,7 +233,10 @@ def test_linear_tosa_INT_a8w4(test_data: torch.Tensor):
192233
pipeline.run()
193234

194235

195-
@common.parametrize("test_data", test_data_rank1_INT)
236+
@common.parametrize(
237+
"test_data",
238+
test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT,
239+
)
196240
@common.XfailIfNoCorstone300
197241
def test_linear_u55_INT(test_data: torch.Tensor):
198242
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -213,7 +257,7 @@ def test_linear_u55_INT(test_data: torch.Tensor):
213257

214258
@common.parametrize(
215259
"test_data",
216-
test_data_rank1_INT | test_data_rank4_INT,
260+
test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT,
217261
)
218262
@common.XfailIfNoCorstone320
219263
def test_linear_u85_INT(test_data: torch.Tensor):
@@ -281,7 +325,7 @@ def test_linear_vgf_quant_a8w4(test_data: torch.Tensor):
281325
pipeline.run()
282326

283327

284-
test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT
328+
test_data_all_16a8w = test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT
285329

286330

287331
@common.parametrize("test_data", test_data_all_16a8w)

0 commit comments

Comments
 (0)