Skip to content

Commit 7b768a9

Browse files
committed
fix: harden workflow gates and quality signal enforcement
Close gate bypass paths in pack/generator/stress flow, enforce audit prerequisites, and standardize structured quality signals for verification and workflow state tracking. Made-with: Cursor
1 parent f856d36 commit 7b768a9

7 files changed

Lines changed: 276 additions & 5 deletions

File tree

scripts/workflow_guard.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import sys
5+
from datetime import datetime, timezone
56
from collections.abc import Callable
67
from pathlib import Path
78
from typing import Any
@@ -13,6 +14,9 @@
1314
"require_stress_passed": True,
1415
"require_validation_passed": True,
1516
"require_tests_verified": True,
17+
"require_limit_semantics": True,
18+
"require_wrong_solution_kill": True,
19+
"require_validator_check": True,
1620
"min_limit_case_ratio": 0.5,
1721
}
1822

@@ -93,6 +97,8 @@ def infer_state(problem_dir: str) -> dict[str, Any]:
9397
"sol_built": False,
9498
"brute_built": False,
9599
"solution_analyzed": False,
100+
"std_audited": False,
101+
"brute_audited": False,
96102
"validator_ready": False,
97103
"validator_accuracy": None,
98104
"generator_built": False,
@@ -108,8 +114,10 @@ def infer_state(problem_dir: str) -> dict[str, Any]:
108114
"tests_generated": False,
109115
"generated_test_count": 0,
110116
"tests_verified": False,
117+
"verify_signals": {},
111118
"packaged": (root / "problem.xml").exists(),
112119
"quality_gates": _extract_quality_gates(manifest),
120+
"history": [],
113121
}
114122

115123

@@ -217,6 +225,10 @@ def _tests_verified_required_gate_ok(state: dict[str, Any], _: dict[str, Any]) -
217225
)
218226

219227

228+
def _audit_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
229+
return bool(state.get("std_audited")) and bool(state.get("brute_audited"))
230+
231+
220232
def _min_limit_ratio_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
221233
gates = state.get("quality_gates", {})
222234
if not isinstance(gates, dict):
@@ -236,6 +248,53 @@ def _min_limit_ratio_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
236248
return ratio_val >= required
237249

238250

251+
def _required_verify_signal_ok(state: dict[str, Any], gate_key: str, signal_name: str) -> bool:
252+
if not _quality_gate_enabled(state, gate_key):
253+
return True
254+
verify_signals = state.get("verify_signals", {})
255+
if not isinstance(verify_signals, dict):
256+
return False
257+
signal = verify_signals.get(signal_name, {})
258+
if not isinstance(signal, dict):
259+
return False
260+
return bool(signal.get("executed")) and bool(signal.get("passed"))
261+
262+
263+
def _limit_semantics_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
264+
return _required_verify_signal_ok(state, "require_limit_semantics", "limit_semantics")
265+
266+
267+
def _wrong_solution_kill_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
268+
return _required_verify_signal_ok(state, "require_wrong_solution_kill", "wrong_solution_kill")
269+
270+
271+
def _validator_check_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
272+
return _required_verify_signal_ok(state, "require_validator_check", "validator_check")
273+
274+
275+
def _append_history(
276+
state: dict[str, Any],
277+
*,
278+
tool: str,
279+
success: bool,
280+
key_metrics: dict[str, Any] | None = None,
281+
gate_result: str = "n/a",
282+
) -> None:
283+
history = state.get("history")
284+
if not isinstance(history, list):
285+
history = []
286+
history.append(
287+
{
288+
"tool": tool,
289+
"success": success,
290+
"timestamp": datetime.now(timezone.utc).isoformat(),
291+
"gate_result": gate_result,
292+
"key_metrics": key_metrics or {},
293+
}
294+
)
295+
state["history"] = history[-200:]
296+
297+
239298
PRE_GATES: dict[str, list[Gate]] = {
240299
"solution_build": [
241300
(lambda s, i: bool(s.get("created")), "必须先运行 problem_create 创建题目目录。"),
@@ -247,21 +306,40 @@ def _min_limit_ratio_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
247306
"solution_analyze": [
248307
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol,再进行复杂度分析。"),
249308
],
309+
"solution_audit_std": [
310+
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
311+
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze。"),
312+
],
313+
"solution_audit_brute": [
314+
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
315+
(lambda s, i: bool(s.get("brute_built")), "必须先构建 brute。"),
316+
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze。"),
317+
(lambda s, i: bool(s.get("std_audited")), "必须先完成 solution_audit_std。"),
318+
],
250319
"validator_select": [
251320
(lambda s, i: bool(s.get("validator_ready")), "必须先完成 validator_build 才能选择校验器版本。"),
252321
],
253322
"validator_build": [
254323
(lambda s, i: bool(s.get("created")), "必须先运行 problem_create 创建题目目录。"),
255324
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
256325
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze,再构建 validator。"),
326+
(_audit_gate_ok, "必须先完成 solution_audit_std 与 solution_audit_brute。"),
257327
(_is_non_interactive, "交互题不应构建 validator,应改用 interactor_build。"),
258328
(lambda s, i: bool(s.get("brute_built")), "必须先构建 brute,再构建 validator。"),
259329
],
260330
"interactor_build": [
261331
(lambda s, i: bool(s.get("created")), "必须先运行 problem_create 创建题目目录。"),
262332
(_is_interactive, "只有交互题可运行 interactor_build。请在 problem_create 设 interactive=true。"),
333+
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
334+
(lambda s, i: bool(s.get("brute_built")), "必须先构建 brute。"),
335+
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze。"),
336+
(_audit_gate_ok, "必须先完成 solution_audit_std 与 solution_audit_brute。"),
263337
],
264338
"generator_build": [
339+
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
340+
(lambda s, i: bool(s.get("brute_built")), "必须先构建 brute。"),
341+
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze。"),
342+
(_audit_gate_ok, "必须先完成 solution_audit_std 与 solution_audit_brute。"),
265343
(
266344
_validator_gate_ok,
267345
"必须先完成 validator_build,并且 validator accuracy >= 0.9。",
@@ -274,6 +352,8 @@ def _min_limit_ratio_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
274352
"stress_test_run": [
275353
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
276354
(lambda s, i: bool(s.get("brute_built")), "必须先构建 brute。"),
355+
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze。"),
356+
(_audit_gate_ok, "必须先完成 solution_audit_std 与 solution_audit_brute。"),
277357
(
278358
_validator_gate_ok,
279359
"必须先完成 validator_build(accuracy >= 0.9),再进行 stress_test_run。",
@@ -308,6 +388,9 @@ def _min_limit_ratio_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
308388
),
309389
(_tests_verified_required_gate_ok, "必须先通过 problem_verify_tests(passed=true),再进行打包。"),
310390
(_min_limit_ratio_gate_ok, "最终测试中的极限样例占比未达到 quality_gates.min_limit_case_ratio。"),
391+
(_limit_semantics_gate_ok, "最终测试未通过 limit_semantics,不能打包。"),
392+
(_wrong_solution_kill_gate_ok, "最终测试未通过 wrong_solution_kill,不能打包。"),
393+
(_validator_check_gate_ok, "最终测试未通过 validator_check,不能打包。"),
311394
],
312395
}
313396

@@ -345,18 +428,39 @@ def post_tool(payload: dict[str, Any]) -> int:
345428
if not success:
346429
state["tests_generated"] = False
347430
state["generated_test_count"] = 0
431+
_append_history(
432+
state,
433+
tool=short_name,
434+
success=False,
435+
gate_result="post",
436+
key_metrics={"generated_test_count": 0},
437+
)
348438
save_state(problem_dir, state)
349439
return 0
350440
generated_tests = data.get("generated_tests", [])
351441
state["tests_generated"] = bool(generated_tests)
352442
state["generated_test_count"] = len(generated_tests)
443+
_append_history(
444+
state,
445+
tool=short_name,
446+
success=True,
447+
gate_result="post",
448+
key_metrics={"generated_test_count": len(generated_tests)},
449+
)
353450
save_state(problem_dir, state)
354451
return 0
355452

356453
if short_name == "problem_validate":
357454
state["statement_validated"] = data.get("statement_samples", {}).get("validated", False)
358455
state["sample_files_validated"] = data.get("sample_files", {}).get("validated", False)
359456
state["validation_passed"] = success
457+
_append_history(
458+
state,
459+
tool=short_name,
460+
success=success,
461+
gate_result="post",
462+
key_metrics={"validation_passed": success},
463+
)
360464
save_state(problem_dir, state)
361465
return 0
362466

@@ -380,15 +484,34 @@ def post_tool(payload: dict[str, Any]) -> int:
380484
ratio_ok = float(state.get("limit_case_ratio")) >= min_limit_ratio
381485
except (TypeError, ValueError):
382486
ratio_ok = False
487+
quality_signals = data.get("quality_signals", {})
488+
if isinstance(quality_signals, dict):
489+
state["verify_signals"] = quality_signals
490+
else:
491+
state["verify_signals"] = {}
383492
state["tests_verified"] = success and bool(data.get("passed", False))
384493
state["tests_verified"] = state["tests_verified"] and ratio_ok
494+
_append_history(
495+
state,
496+
tool=short_name,
497+
success=state["tests_verified"],
498+
gate_result="post",
499+
key_metrics={
500+
"tests_verified": state["tests_verified"],
501+
"limit_case_ratio": state.get("limit_case_ratio"),
502+
},
503+
)
385504
save_state(problem_dir, state)
386505
return 0
387506

388507
if not success:
389508
if short_name == "validator_build":
390509
state["validator_ready"] = False
391510
state["validator_accuracy"] = None
511+
elif short_name == "solution_audit_std":
512+
state["std_audited"] = False
513+
elif short_name == "solution_audit_brute":
514+
state["brute_audited"] = False
392515
elif short_name == "generator_build":
393516
state["generator_built"] = False
394517
elif short_name == "stress_test_run":
@@ -400,6 +523,7 @@ def post_tool(payload: dict[str, Any]) -> int:
400523
state["checker_ready"] = False
401524
elif short_name == "problem_pack_polygon":
402525
state["packaged"] = False
526+
_append_history(state, tool=short_name, success=False, gate_result="post")
403527
save_state(problem_dir, state)
404528
return 0
405529

@@ -410,10 +534,20 @@ def post_tool(payload: dict[str, Any]) -> int:
410534
solution_type = payload.get("tool_input", {}).get("solution_type")
411535
if solution_type == "sol":
412536
state["sol_built"] = True
537+
state["solution_analyzed"] = False
538+
state["std_audited"] = False
539+
state["brute_audited"] = False
413540
elif solution_type == "brute":
414541
state["brute_built"] = True
542+
state["brute_audited"] = False
415543
elif short_name == "solution_analyze":
416544
state["solution_analyzed"] = True
545+
state["std_audited"] = False
546+
state["brute_audited"] = False
547+
elif short_name == "solution_audit_std":
548+
state["std_audited"] = True
549+
elif short_name == "solution_audit_brute":
550+
state["brute_audited"] = True
417551
elif short_name == "validator_build":
418552
accuracy = data.get("accuracy")
419553
state["validator_accuracy"] = accuracy
@@ -437,6 +571,7 @@ def post_tool(payload: dict[str, Any]) -> int:
437571
elif short_name == "problem_pack_polygon":
438572
state["packaged"] = True
439573

574+
_append_history(state, tool=short_name, success=True, gate_result="post")
440575
save_state(problem_dir, state)
441576
return 0
442577

src/autocode_mcp/templates/autocode.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
"require_stress_passed": true,
1414
"require_validation_passed": true,
1515
"require_tests_verified": true,
16+
"require_limit_semantics": true,
17+
"require_wrong_solution_kill": true,
18+
"require_validator_check": true,
1619
"min_limit_case_ratio": 0.5
1720
},
1821
"solutions": [

src/autocode_mcp/tools/generator.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ async def execute(
301301

302302
generated_inputs = []
303303
signatures = set() # 用于去重
304+
generator_failures: list[dict[str, object]] = []
304305

305306
# 策略映射到 type 参数
306307
strategy_type_map = {
@@ -318,7 +319,7 @@ async def execute(
318319
attempts += 1
319320

320321
# 选择策略
321-
strategy = strategies[attempts % len(strategies)]
322+
strategy = strategies[(attempts - 1) % len(strategies)]
322323
type_param = strategy_type_map.get(strategy, "2")
323324

324325
# 运行 generator
@@ -331,9 +332,30 @@ async def execute(
331332
timeout=10,
332333
)
333334

334-
# 只要有输出就接受(某些 generator 可能返回非零退出码但仍产生有效输出)
335+
if not gen_result.success:
336+
generator_failures.append(
337+
{
338+
"seed": seed,
339+
"strategy": strategy,
340+
"return_code": gen_result.return_code,
341+
"stderr": (gen_result.stderr or "")[:200],
342+
}
343+
)
344+
seed += 1
345+
continue
346+
335347
input_data = gen_result.stdout
336348
if not input_data or not input_data.strip():
349+
generator_failures.append(
350+
{
351+
"seed": seed,
352+
"strategy": strategy,
353+
"return_code": gen_result.return_code,
354+
"stderr": (gen_result.stderr or "")[:200],
355+
"reason": "empty_stdout",
356+
}
357+
)
358+
seed += 1
337359
continue
338360

339361
# 计算 signature 用于去重
@@ -362,5 +384,6 @@ async def execute(
362384
test_count=test_count,
363385
inputs=generated_inputs[:test_count],
364386
strategies_used=strategies,
387+
generator_failures=generator_failures[-20:],
365388
message=f"Generated {len(generated_inputs)} test inputs",
366389
)

0 commit comments

Comments
 (0)