@@ -206,7 +206,12 @@ def should_skip_test(config_data: dict[str, Any], rel_path: str) -> bool:
206206 return bool (config .get ("skip" , False ))
207207
208208
209- def list_tests (ignored_test_files : str | list [str ], test_names : str , test_regex : str , tests_root : str | Path ) -> tuple [list [str ], list [str ], list [str ]]:
209+ def list_tests (
210+ ignored_test_files : str | list [str ],
211+ test_names : str ,
212+ test_regex : str ,
213+ tests_root : str | Path ,
214+ ) -> tuple [list [str ], list [str ], list [str ], list [str ]]:
210215 tests_root = Path (tests_root )
211216 input_tests = [strip_py_suffix (name ) for name in split_csv (test_names )]
212217 ignored_raw = ignored_test_files if isinstance (ignored_test_files , list ) else split_csv (ignored_test_files )
@@ -233,11 +238,24 @@ def list_tests(ignored_test_files: str | list[str], test_names: str, test_regex:
233238 and is_model_compat_test (rel , path )
234239 }
235240
241+ cpu_tests = {
242+ rel
243+ for rel , path in all_tests .items ()
244+ if (not input_tests or rel in input_tests )
245+ and rel not in model_tests
246+ and "mlx" not in rel
247+ and "ipex" not in rel
248+ and "xpu" not in rel
249+ and matches_test_regex (compiled_test_regex , rel )
250+ and has_no_gpu_marker (path )
251+ }
252+
236253 torch_tests = {
237254 rel
238255 for rel in all_tests
239256 if (not input_tests or rel in input_tests )
240257 and rel not in model_tests
258+ and rel not in cpu_tests
241259 and "mlx" not in rel
242260 and "ipex" not in rel
243261 and "xpu" not in rel
@@ -253,34 +271,61 @@ def list_tests(ignored_test_files: str | list[str], test_names: str, test_regex:
253271 }
254272
255273 return (
274+ sorted (cpu_tests , key = lambda rel : sort_key (rel , all_tests [rel ])),
256275 sorted (torch_tests , key = lambda rel : sort_key (rel , all_tests [rel ])),
257276 sorted (model_tests , key = lambda rel : sort_key (rel , all_tests [rel ])),
258277 sorted (mlx_tests , key = lambda rel : sort_key (rel , all_tests [rel ])),
259278 )
260279
261280
262- def build_test_matrix (torch_tests : list [str ], model_tests : list [str ]) -> list [dict [str , str ]]:
263- entries = [
264- TestMatrixEntry (
265- test_script = test_script ,
266- test_group = "torch" ,
267- alloc_gpu_count = "resolved" ,
268- require_single_gpu = "false" ,
269- include_model_test_mode = "false" ,
270- )
271- for test_script in torch_tests
272- ]
273- entries .extend (
274- TestMatrixEntry (
275- test_script = test_script ,
276- test_group = "model" ,
277- alloc_gpu_count = "1" ,
278- require_single_gpu = "true" ,
279- include_model_test_mode = "true" ,
280- )
281- for test_script in model_tests
282- )
283- return [entry .as_dict () for entry in entries ]
281+ def build_group_matrix (group : str , tests : list [str ]) -> list [dict [str , str ]]:
282+ if group == "cpu" :
283+ return [
284+ TestMatrixEntry (
285+ test_script = test_script ,
286+ test_group = "cpu" ,
287+ alloc_gpu_count = "0" ,
288+ require_single_gpu = "false" ,
289+ include_model_test_mode = "false" ,
290+ ).as_dict ()
291+ for test_script in tests
292+ ]
293+ if group == "torch" :
294+ return [
295+ TestMatrixEntry (
296+ test_script = test_script ,
297+ test_group = "torch" ,
298+ alloc_gpu_count = "resolved" ,
299+ require_single_gpu = "false" ,
300+ include_model_test_mode = "false" ,
301+ ).as_dict ()
302+ for test_script in tests
303+ ]
304+ if group == "model" :
305+ return [
306+ TestMatrixEntry (
307+ test_script = test_script ,
308+ test_group = "model" ,
309+ alloc_gpu_count = "1" ,
310+ require_single_gpu = "true" ,
311+ include_model_test_mode = "true" ,
312+ ).as_dict ()
313+ for test_script in tests
314+ ]
315+ raise ValueError (f"unsupported test group: { group } " )
316+
317+
318+ def build_test_matrices (
319+ * ,
320+ cpu_tests : list [str ],
321+ torch_tests : list [str ],
322+ model_tests : list [str ],
323+ ) -> dict [str , list [dict [str , str ]]]:
324+ return {
325+ "cpu_matrix" : build_group_matrix ("cpu" , cpu_tests ),
326+ "torch_matrix" : build_group_matrix ("torch" , torch_tests ),
327+ "model_matrix" : build_group_matrix ("model" , model_tests ),
328+ }
284329
285330
286331def build_test_plan (
@@ -290,17 +335,23 @@ def build_test_plan(
290335 test_regex : str ,
291336 tests_root : str | Path ,
292337) -> dict [str , list [dict [str , str ]] | list [str ]]:
293- torch_tests , model_tests , mlx_tests = list_tests (
338+ cpu_tests , torch_tests , model_tests , mlx_tests = list_tests (
294339 ignored_test_files = ignored_test_files ,
295340 test_names = test_names ,
296341 test_regex = test_regex ,
297342 tests_root = tests_root ,
298343 )
344+ matrices = build_test_matrices (
345+ cpu_tests = cpu_tests ,
346+ torch_tests = torch_tests ,
347+ model_tests = model_tests ,
348+ )
299349 return {
350+ "cpu_files" : cpu_tests ,
300351 "torch_files" : torch_tests ,
301352 "model_files" : model_tests ,
302353 "mlx_files" : mlx_tests ,
303- "test_matrix" : build_test_matrix ( torch_tests , model_tests ) ,
354+ ** matrices ,
304355 }
305356
306357
0 commit comments