Skip to content

Commit 312fca1

Browse files
committed
run plugin validation concurrently
1 parent 096813c commit 312fca1

2 files changed

Lines changed: 105 additions & 19 deletions

File tree

scripts/validate_plugins/run.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import argparse
66
import asyncio
7+
import concurrent.futures
78
import hashlib
89
import json
910
import os
@@ -29,6 +30,7 @@
2930

3031
REQUIRED_METADATA_FIELDS = ("name", "desc", "version", "author")
3132
DEFAULT_CLONE_TIMEOUT = 120
33+
DEFAULT_MAX_WORKERS = 8
3234
CONFLICT_MARKERS = ("<<<<<<<", "=======", ">>>>>>>")
3335

3436

@@ -489,29 +491,60 @@ def validate_selected_plugins(
489491
work_dir: Path,
490492
clone_timeout: int,
491493
load_timeout: int,
494+
max_workers: int,
492495
) -> list[dict]:
493-
results = []
494496
total = len(selected)
495-
496-
for index, (plugin, plugin_data) in enumerate(selected, start=1):
497-
print(f"[{index}/{total}] Validating {plugin}", flush=True)
498-
result = validate_plugin(
499-
plugin=plugin,
500-
plugin_data=plugin_data,
501-
astrbot_path=astrbot_path,
502-
script_path=script_path,
503-
work_dir=work_dir,
504-
clone_timeout=clone_timeout,
505-
load_timeout=load_timeout,
497+
results: list[dict | None] = [None] * total
498+
499+
def task(index: int, plugin: str, plugin_data: dict) -> tuple[int, dict]:
500+
return (
501+
index,
502+
validate_plugin(
503+
plugin=plugin,
504+
plugin_data=plugin_data,
505+
astrbot_path=astrbot_path,
506+
script_path=script_path,
507+
work_dir=work_dir,
508+
clone_timeout=clone_timeout,
509+
load_timeout=load_timeout,
510+
),
506511
)
507-
results.append(result)
508512

509-
status = "PASS" if result.get("ok") else "FAIL"
510-
stage = result.get("stage", "unknown")
511-
message = result.get("message", "")
512-
print(f"[{index}/{total}] {status} {plugin} [{stage}] {message}", flush=True)
513+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
514+
future_to_context: dict[concurrent.futures.Future, tuple[int, str]] = {}
513515

514-
return results
516+
for index, (plugin, plugin_data) in enumerate(selected, start=1):
517+
print(f"[{index}/{total}] Queued {plugin}", flush=True)
518+
future = executor.submit(task, index, plugin, plugin_data)
519+
future_to_context[future] = (index, plugin)
520+
521+
for future in concurrent.futures.as_completed(future_to_context):
522+
index, plugin = future_to_context[future]
523+
try:
524+
original_index, result = future.result()
525+
except Exception as exc:
526+
original_index = index
527+
result = build_result(
528+
plugin=plugin,
529+
repo="",
530+
normalized_repo_url=None,
531+
ok=False,
532+
stage="threadpool",
533+
message=str(exc),
534+
details=traceback.format_exc(),
535+
)
536+
537+
results[original_index - 1] = result
538+
status = "PASS" if result.get("ok") else "FAIL"
539+
stage = result.get("stage", "unknown")
540+
message = result.get("message", "")
541+
print(f"[{original_index}/{total}] {status} {plugin} [{stage}] {message}", flush=True)
542+
543+
finalized = [result for result in results if result is not None]
544+
if len(finalized) != total:
545+
raise RuntimeError("parallel validation finished with missing results")
546+
547+
return finalized
515548

516549

517550
class NullStub:
@@ -684,6 +717,7 @@ def build_parser() -> argparse.ArgumentParser:
684717
parser.add_argument("--work-dir")
685718
parser.add_argument("--clone-timeout", type=int, default=DEFAULT_CLONE_TIMEOUT)
686719
parser.add_argument("--load-timeout", type=int, default=300)
720+
parser.add_argument("--max-workers", type=int, default=DEFAULT_MAX_WORKERS)
687721
parser.add_argument("--worker", action="store_true")
688722
parser.add_argument("--plugin-source-dir")
689723
parser.add_argument("--plugin-dir-name")
@@ -736,6 +770,7 @@ def main() -> int:
736770
work_dir=work_dir,
737771
clone_timeout=args.clone_timeout,
738772
load_timeout=args.load_timeout,
773+
max_workers=args.max_workers,
739774
)
740775
finally:
741776
if temp_dir is not None:

tests/test_validate_plugins.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ def test_load_plugins_index_rejects_non_dict_values(self):
386386

387387

388388
class ValidationProgressTests(unittest.TestCase):
389+
def test_build_parser_defaults_max_workers_to_eight(self):
390+
module = load_validator_module()
391+
392+
args = module.build_parser().parse_args(["--astrbot-path", "/tmp/AstrBot"])
393+
394+
self.assertEqual(args.max_workers, 8)
395+
389396
def test_validate_selected_plugins_emits_progress_and_result_lines(self):
390397
module = load_validator_module()
391398
selected = [
@@ -406,14 +413,58 @@ def test_validate_selected_plugins_emits_progress_and_result_lines(self):
406413
work_dir=Path("/tmp/work"),
407414
clone_timeout=60,
408415
load_timeout=300,
416+
max_workers=8,
409417
)
410418

411419
self.assertEqual(results, fake_results)
412420
self.assertEqual(validate_mock.call_count, 2)
413-
print_mock.assert_any_call("[1/2] Validating plugin-a", flush=True)
421+
print_mock.assert_any_call("[1/2] Queued plugin-a", flush=True)
414422
print_mock.assert_any_call("[1/2] PASS plugin-a [load] ok", flush=True)
415423
print_mock.assert_any_call("[2/2] FAIL plugin-b [metadata] invalid metadata.yaml", flush=True)
416424

425+
def test_validate_selected_plugins_preserves_result_order_with_out_of_order_completion(self):
426+
module = load_validator_module()
427+
selected = [
428+
("plugin-a", {"repo": "https://github.com/example/plugin-a"}),
429+
("plugin-b", {"repo": "https://github.com/example/plugin-b"}),
430+
("plugin-c", {"repo": "https://github.com/example/plugin-c"}),
431+
]
432+
futures = [mock.Mock(name="future-a"), mock.Mock(name="future-b"), mock.Mock(name="future-c")]
433+
future_to_result = {
434+
futures[0]: (1, {"plugin": "plugin-a", "ok": True, "stage": "load", "message": "a"}),
435+
futures[1]: (2, {"plugin": "plugin-b", "ok": False, "stage": "metadata", "message": "b"}),
436+
futures[2]: (3, {"plugin": "plugin-c", "ok": True, "stage": "load", "message": "c"}),
437+
}
438+
439+
executor = mock.MagicMock()
440+
executor.__enter__.return_value = executor
441+
executor.__exit__.return_value = False
442+
executor.submit.side_effect = futures
443+
444+
def future_result(future):
445+
return future_to_result[future]
446+
447+
for future in futures:
448+
future.result.side_effect = lambda _timeout=None, future=future: future_result(future)
449+
450+
with mock.patch.object(module.concurrent.futures, "ThreadPoolExecutor", return_value=executor) as pool_mock:
451+
with mock.patch.object(module.concurrent.futures, "as_completed", return_value=[futures[2], futures[0], futures[1]]):
452+
with mock.patch("builtins.print") as print_mock:
453+
results = module.validate_selected_plugins(
454+
selected=selected,
455+
astrbot_path=Path("/tmp/AstrBot"),
456+
script_path=Path("/tmp/run.py"),
457+
work_dir=Path("/tmp/work"),
458+
clone_timeout=60,
459+
load_timeout=300,
460+
max_workers=8,
461+
)
462+
463+
pool_mock.assert_called_once_with(max_workers=8)
464+
self.assertEqual([item["plugin"] for item in results], ["plugin-a", "plugin-b", "plugin-c"])
465+
print_mock.assert_any_call("[1/3] Queued plugin-a", flush=True)
466+
print_mock.assert_any_call("[3/3] PASS plugin-c [load] c", flush=True)
467+
417468

418469
class ValidatePluginTests(unittest.TestCase):
419470
def setUp(self):

0 commit comments

Comments
 (0)