|
| 1 | +import argparse |
| 2 | +import ast |
| 3 | +import json |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +REPO_ROOT = Path(__file__).resolve().parents[1] |
| 7 | +E2E_ROOT = REPO_ROOT / "tests" / "e2e" |
| 8 | + |
| 9 | +# These weights keep the existing balance close to the observed runtime split, |
| 10 | +# while new files automatically fall back to their local test count. |
| 11 | +HISTORICAL_WEIGHTS = { |
| 12 | + "tests/e2e/test_batch_evaluation.py": 41, |
| 13 | + "tests/e2e/test_core_sdk.py": 53, |
| 14 | + "tests/e2e/test_datasets.py": 7, |
| 15 | + "tests/e2e/test_decorators.py": 32, |
| 16 | + "tests/e2e/test_experiments.py": 17, |
| 17 | + "tests/e2e/test_media.py": 1, |
| 18 | + "tests/e2e/test_prompt.py": 27, |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +def relative_test_path(path: Path) -> str: |
| 23 | + return path.relative_to(REPO_ROOT).as_posix() |
| 24 | + |
| 25 | + |
| 26 | +def discover_e2e_files() -> list[Path]: |
| 27 | + return sorted(E2E_ROOT.glob("test_*.py")) |
| 28 | + |
| 29 | + |
| 30 | +def count_test_functions(path: Path) -> int: |
| 31 | + module = ast.parse(path.read_text(encoding="utf-8")) |
| 32 | + return sum( |
| 33 | + 1 |
| 34 | + for node in module.body |
| 35 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) |
| 36 | + and node.name.startswith("test_") |
| 37 | + ) |
| 38 | + |
| 39 | + |
| 40 | +def estimate_weight(path: Path) -> int: |
| 41 | + try: |
| 42 | + relative_path = relative_test_path(path) |
| 43 | + except ValueError: |
| 44 | + relative_path = None |
| 45 | + if relative_path is not None and relative_path in HISTORICAL_WEIGHTS: |
| 46 | + return HISTORICAL_WEIGHTS[relative_path] |
| 47 | + |
| 48 | + return max(count_test_functions(path), 1) |
| 49 | + |
| 50 | + |
| 51 | +def assign_shards( |
| 52 | + paths: list[Path], shard_count: int |
| 53 | +) -> tuple[list[list[str]], list[int]]: |
| 54 | + shard_loads = [0] * shard_count |
| 55 | + shards: list[list[str]] = [[] for _ in range(shard_count)] |
| 56 | + |
| 57 | + weighted_paths = sorted( |
| 58 | + ((estimate_weight(path), relative_test_path(path)) for path in paths), |
| 59 | + key=lambda item: (-item[0], item[1]), |
| 60 | + ) |
| 61 | + |
| 62 | + for weight, relative_path in weighted_paths: |
| 63 | + shard_index = min( |
| 64 | + range(shard_count), key=lambda index: (shard_loads[index], index) |
| 65 | + ) |
| 66 | + shards[shard_index].append(relative_path) |
| 67 | + shard_loads[shard_index] += weight |
| 68 | + |
| 69 | + return [sorted(shard) for shard in shards], shard_loads |
| 70 | + |
| 71 | + |
| 72 | +def parse_args() -> argparse.Namespace: |
| 73 | + parser = argparse.ArgumentParser( |
| 74 | + description="Select the files for one e2e CI shard." |
| 75 | + ) |
| 76 | + parser.add_argument("--shard-index", required=True, type=int) |
| 77 | + parser.add_argument("--shard-count", default=2, type=int) |
| 78 | + parser.add_argument("--json", action="store_true") |
| 79 | + return parser.parse_args() |
| 80 | + |
| 81 | + |
| 82 | +def main() -> int: |
| 83 | + args = parse_args() |
| 84 | + |
| 85 | + if args.shard_count < 1: |
| 86 | + raise SystemExit("--shard-count must be at least 1") |
| 87 | + |
| 88 | + if args.shard_index < 0 or args.shard_index >= args.shard_count: |
| 89 | + raise SystemExit("--shard-index must be within the configured shard count") |
| 90 | + |
| 91 | + shards, shard_loads = assign_shards(discover_e2e_files(), args.shard_count) |
| 92 | + selected_files = shards[args.shard_index] |
| 93 | + |
| 94 | + if args.json: |
| 95 | + print( |
| 96 | + json.dumps( |
| 97 | + { |
| 98 | + "shard_count": args.shard_count, |
| 99 | + "shard_index": args.shard_index, |
| 100 | + "selected_files": selected_files, |
| 101 | + "shard_loads": shard_loads, |
| 102 | + } |
| 103 | + ) |
| 104 | + ) |
| 105 | + return 0 |
| 106 | + |
| 107 | + for path in selected_files: |
| 108 | + print(path) |
| 109 | + |
| 110 | + return 0 |
| 111 | + |
| 112 | + |
| 113 | +if __name__ == "__main__": |
| 114 | + raise SystemExit(main()) |
0 commit comments