Skip to content

Commit fdc13fe

Browse files
[CI] fix matrix over 255 & make torch tests reusable & separate cpu only tests to a new group. (#2792)
* [CI] make test reusable * [CI] add cpu tests, and wrap matrix with meta * [CI] add cpu tests, and wrap matrix with meta * [CI] print uv envs * [CI] use bash -l as default * [CI] move to top * Revert "[CI] move to top" This reverts commit d04a72a. * [CI] no need this test * ignore Agents.md
1 parent cd6e982 commit fdc13fe

5 files changed

Lines changed: 433 additions & 202 deletions

File tree

.github/scripts/ci_workflow.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,12 @@ def should_skip_test(config_data: dict[str, Any], rel_path: str) -> bool:
206206
return bool(config.get("skip", False))
207207

208208

209-
def list_tests(ignored_test_files: str | list[str], test_names: str, test_regex: str, tests_root: str | Path) -> tuple[list[str], list[str], list[str]]:
209+
def list_tests(
210+
ignored_test_files: str | list[str],
211+
test_names: str,
212+
test_regex: str,
213+
tests_root: str | Path,
214+
) -> tuple[list[str], list[str], list[str], list[str]]:
210215
tests_root = Path(tests_root)
211216
input_tests = [strip_py_suffix(name) for name in split_csv(test_names)]
212217
ignored_raw = ignored_test_files if isinstance(ignored_test_files, list) else split_csv(ignored_test_files)
@@ -233,11 +238,24 @@ def list_tests(ignored_test_files: str | list[str], test_names: str, test_regex:
233238
and is_model_compat_test(rel, path)
234239
}
235240

241+
cpu_tests = {
242+
rel
243+
for rel, path in all_tests.items()
244+
if (not input_tests or rel in input_tests)
245+
and rel not in model_tests
246+
and "mlx" not in rel
247+
and "ipex" not in rel
248+
and "xpu" not in rel
249+
and matches_test_regex(compiled_test_regex, rel)
250+
and has_no_gpu_marker(path)
251+
}
252+
236253
torch_tests = {
237254
rel
238255
for rel in all_tests
239256
if (not input_tests or rel in input_tests)
240257
and rel not in model_tests
258+
and rel not in cpu_tests
241259
and "mlx" not in rel
242260
and "ipex" not in rel
243261
and "xpu" not in rel
@@ -253,34 +271,61 @@ def list_tests(ignored_test_files: str | list[str], test_names: str, test_regex:
253271
}
254272

255273
return (
274+
sorted(cpu_tests, key=lambda rel: sort_key(rel, all_tests[rel])),
256275
sorted(torch_tests, key=lambda rel: sort_key(rel, all_tests[rel])),
257276
sorted(model_tests, key=lambda rel: sort_key(rel, all_tests[rel])),
258277
sorted(mlx_tests, key=lambda rel: sort_key(rel, all_tests[rel])),
259278
)
260279

261280

262-
def build_test_matrix(torch_tests: list[str], model_tests: list[str]) -> list[dict[str, str]]:
263-
entries = [
264-
TestMatrixEntry(
265-
test_script=test_script,
266-
test_group="torch",
267-
alloc_gpu_count="resolved",
268-
require_single_gpu="false",
269-
include_model_test_mode="false",
270-
)
271-
for test_script in torch_tests
272-
]
273-
entries.extend(
274-
TestMatrixEntry(
275-
test_script=test_script,
276-
test_group="model",
277-
alloc_gpu_count="1",
278-
require_single_gpu="true",
279-
include_model_test_mode="true",
280-
)
281-
for test_script in model_tests
282-
)
283-
return [entry.as_dict() for entry in entries]
281+
def build_group_matrix(group: str, tests: list[str]) -> list[dict[str, str]]:
282+
if group == "cpu":
283+
return [
284+
TestMatrixEntry(
285+
test_script=test_script,
286+
test_group="cpu",
287+
alloc_gpu_count="0",
288+
require_single_gpu="false",
289+
include_model_test_mode="false",
290+
).as_dict()
291+
for test_script in tests
292+
]
293+
if group == "torch":
294+
return [
295+
TestMatrixEntry(
296+
test_script=test_script,
297+
test_group="torch",
298+
alloc_gpu_count="resolved",
299+
require_single_gpu="false",
300+
include_model_test_mode="false",
301+
).as_dict()
302+
for test_script in tests
303+
]
304+
if group == "model":
305+
return [
306+
TestMatrixEntry(
307+
test_script=test_script,
308+
test_group="model",
309+
alloc_gpu_count="1",
310+
require_single_gpu="true",
311+
include_model_test_mode="true",
312+
).as_dict()
313+
for test_script in tests
314+
]
315+
raise ValueError(f"unsupported test group: {group}")
316+
317+
318+
def build_test_matrices(
319+
*,
320+
cpu_tests: list[str],
321+
torch_tests: list[str],
322+
model_tests: list[str],
323+
) -> dict[str, list[dict[str, str]]]:
324+
return {
325+
"cpu_matrix": build_group_matrix("cpu", cpu_tests),
326+
"torch_matrix": build_group_matrix("torch", torch_tests),
327+
"model_matrix": build_group_matrix("model", model_tests),
328+
}
284329

285330

286331
def build_test_plan(
@@ -290,17 +335,23 @@ def build_test_plan(
290335
test_regex: str,
291336
tests_root: str | Path,
292337
) -> dict[str, list[dict[str, str]] | list[str]]:
293-
torch_tests, model_tests, mlx_tests = list_tests(
338+
cpu_tests, torch_tests, model_tests, mlx_tests = list_tests(
294339
ignored_test_files=ignored_test_files,
295340
test_names=test_names,
296341
test_regex=test_regex,
297342
tests_root=tests_root,
298343
)
344+
matrices = build_test_matrices(
345+
cpu_tests=cpu_tests,
346+
torch_tests=torch_tests,
347+
model_tests=model_tests,
348+
)
299349
return {
350+
"cpu_files": cpu_tests,
300351
"torch_files": torch_tests,
301352
"model_files": model_tests,
302353
"mlx_files": mlx_tests,
303-
"test_matrix": build_test_matrix(torch_tests, model_tests),
354+
**matrices,
304355
}
305356

306357

0 commit comments

Comments
 (0)