Skip to content

Commit 2c3d671

Browse files
committed
fix: align workflow quality gates and pack checks
Synchronize workflow guard and pack tool gate semantics by honoring quality_gates consistently and enforcing limit-ratio thresholds with verification state updates. Made-with: Cursor
1 parent ea188e5 commit 2c3d671

4 files changed

Lines changed: 568 additions & 24 deletions

File tree

scripts/workflow_guard.py

Lines changed: 142 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
STATE_DIR_NAME = ".autocode-workflow"
1010
STATE_FILE_NAME = "state.json"
1111
MANIFEST_FILE_NAME = "autocode.json"
12+
DEFAULT_QUALITY_GATES = {
13+
"require_stress_passed": True,
14+
"require_validation_passed": True,
15+
"require_tests_verified": True,
16+
"min_limit_case_ratio": 0.5,
17+
}
1218

1319

1420
def load_payload() -> dict[str, Any]:
@@ -50,16 +56,43 @@ def load_manifest(problem_dir: str) -> dict[str, Any]:
5056
return {}
5157

5258

59+
def _is_manifest_valid_for_created(manifest: dict[str, Any]) -> bool:
60+
problem_name = manifest.get("problem_name")
61+
return isinstance(problem_name, str) and bool(problem_name.strip())
62+
63+
64+
def _extract_quality_gates(manifest: dict[str, Any]) -> dict[str, Any]:
65+
configured = manifest.get("quality_gates")
66+
if not isinstance(configured, dict):
67+
configured = {}
68+
gates = dict(DEFAULT_QUALITY_GATES)
69+
for key in DEFAULT_QUALITY_GATES:
70+
if key in configured:
71+
gates[key] = configured[key]
72+
try:
73+
ratio = float(gates.get("min_limit_case_ratio", 0.5))
74+
except (TypeError, ValueError):
75+
ratio = 0.5
76+
gates["min_limit_case_ratio"] = min(1.0, max(0.0, ratio))
77+
return gates
78+
79+
5380
def infer_state(problem_dir: str) -> dict[str, Any]:
5481
root = Path(problem_dir)
5582
manifest = load_manifest(problem_dir)
5683
is_interactive = bool(manifest.get("interactive", False))
5784
return {
5885
"problem_dir": str(root),
59-
"created": root.exists() and (root / "files").exists() and (root / "solutions").exists(),
86+
"created": (
87+
root.exists()
88+
and (root / "files").exists()
89+
and (root / "solutions").exists()
90+
and _is_manifest_valid_for_created(manifest)
91+
),
6092
"interactive": is_interactive,
6193
"sol_built": False,
6294
"brute_built": False,
95+
"solution_analyzed": False,
6396
"validator_ready": False,
6497
"validator_accuracy": None,
6598
"generator_built": False,
@@ -76,6 +109,7 @@ def infer_state(problem_dir: str) -> dict[str, Any]:
76109
"generated_test_count": 0,
77110
"tests_verified": False,
78111
"packaged": (root / "problem.xml").exists(),
112+
"quality_gates": _extract_quality_gates(manifest),
79113
}
80114

81115

@@ -90,6 +124,7 @@ def load_state(problem_dir: str) -> dict[str, Any]:
90124
loaded = infer_state(problem_dir)
91125
manifest = load_manifest(problem_dir)
92126
loaded["interactive"] = bool(manifest.get("interactive", loaded.get("interactive", False)))
127+
loaded["quality_gates"] = _extract_quality_gates(manifest)
93128
return loaded
94129

95130

@@ -152,6 +187,55 @@ def _interactor_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
152187
return not bool(state.get("interactive", False)) or bool(state.get("interactor_ready"))
153188

154189

190+
def _quality_gate_enabled(state: dict[str, Any], key: str, default: bool = True) -> bool:
191+
gates = state.get("quality_gates", {})
192+
if not isinstance(gates, dict):
193+
return default
194+
value = gates.get(key, default)
195+
return bool(value)
196+
197+
198+
def _stress_required_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
199+
return (not _quality_gate_enabled(state, "require_stress_passed")) or bool(
200+
state.get("stress_passed")
201+
)
202+
203+
204+
def _validation_required_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
205+
if not _quality_gate_enabled(state, "require_validation_passed"):
206+
return True
207+
return (
208+
bool(state.get("validation_passed"))
209+
and bool(state.get("statement_validated"))
210+
and bool(state.get("sample_files_validated"))
211+
)
212+
213+
214+
def _tests_verified_required_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
215+
return (not _quality_gate_enabled(state, "require_tests_verified")) or bool(
216+
state.get("tests_verified")
217+
)
218+
219+
220+
def _min_limit_ratio_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
221+
gates = state.get("quality_gates", {})
222+
if not isinstance(gates, dict):
223+
return True
224+
try:
225+
required = float(gates.get("min_limit_case_ratio", 0.5))
226+
except (TypeError, ValueError):
227+
required = 0.5
228+
required = min(1.0, max(0.0, required))
229+
ratio = state.get("limit_case_ratio")
230+
if ratio is None:
231+
return True
232+
try:
233+
ratio_val = float(ratio)
234+
except (TypeError, ValueError):
235+
return False
236+
return ratio_val >= required
237+
238+
155239
PRE_GATES: dict[str, list[Gate]] = {
156240
"solution_build": [
157241
(lambda s, i: bool(s.get("created")), "必须先运行 problem_create 创建题目目录。"),
@@ -169,6 +253,7 @@ def _interactor_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
169253
"validator_build": [
170254
(lambda s, i: bool(s.get("created")), "必须先运行 problem_create 创建题目目录。"),
171255
(lambda s, i: bool(s.get("sol_built")), "必须先构建标准解 sol。"),
256+
(lambda s, i: bool(s.get("solution_analyzed")), "必须先运行 solution_analyze,再构建 validator。"),
172257
(_is_non_interactive, "交互题不应构建 validator,应改用 interactor_build。"),
173258
(lambda s, i: bool(s.get("brute_built")), "必须先构建 brute,再构建 validator。"),
174259
],
@@ -200,11 +285,11 @@ def _interactor_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
200285
(lambda s, i: bool(s.get("stress_passed")), "必须先通过 stress_test_run(completed_rounds == total_rounds)。"),
201286
],
202287
"problem_validate": [
203-
(lambda s, i: bool(s.get("stress_passed")), "必须先通过 stress_test_run,再进行题面与样例验证。"),
288+
(_stress_required_gate_ok, "必须先通过 stress_test_run,再进行题面与样例验证。"),
204289
],
205290
"problem_generate_tests": [
206-
(lambda s, i: bool(s.get("stress_passed")), "必须先通过 stress_test_run。"),
207-
(lambda s, i: bool(s.get("validation_passed")), "必须先通过 problem_validate。"),
291+
(_stress_required_gate_ok, "必须先通过 stress_test_run。"),
292+
(_validation_required_gate_ok, "必须先通过 problem_validate(题面与样例均通过)。"),
208293
(
209294
lambda s, i: not bool(s.get("interactive")) or bool(s.get("interactor_ready")),
210295
"交互题必须先完成 interactor_build 并可用。",
@@ -221,7 +306,8 @@ def _interactor_gate_ok(state: dict[str, Any], _: dict[str, Any]) -> bool:
221306
lambda s, i: bool(s.get("tests_generated")) and int(s.get("generated_test_count", 0)) > 0,
222307
"必须先生成最终测试数据。",
223308
),
224-
(lambda s, i: bool(s.get("tests_verified")), "必须先通过 problem_verify_tests(passed=true),再进行打包。"),
309+
(_tests_verified_required_gate_ok, "必须先通过 problem_verify_tests(passed=true),再进行打包。"),
310+
(_min_limit_ratio_gate_ok, "最终测试中的极限样例占比未达到 quality_gates.min_limit_case_ratio。"),
225311
],
226312
}
227313

@@ -253,6 +339,20 @@ def post_tool(payload: dict[str, Any]) -> int:
253339
success, data = parse_tool_result(payload)
254340
state = load_state(problem_dir)
255341

342+
if short_name == "problem_generate_tests":
343+
# 任何重新生成尝试都会让旧验证失效(无论成功还是失败)
344+
state["tests_verified"] = False
345+
if not success:
346+
state["tests_generated"] = False
347+
state["generated_test_count"] = 0
348+
save_state(problem_dir, state)
349+
return 0
350+
generated_tests = data.get("generated_tests", [])
351+
state["tests_generated"] = bool(generated_tests)
352+
state["generated_test_count"] = len(generated_tests)
353+
save_state(problem_dir, state)
354+
return 0
355+
256356
if short_name == "problem_validate":
257357
state["statement_validated"] = data.get("statement_samples", {}).get("validated", False)
258358
state["sample_files_validated"] = data.get("sample_files", {}).get("validated", False)
@@ -261,11 +361,46 @@ def post_tool(payload: dict[str, Any]) -> int:
261361
return 0
262362

263363
if short_name == "problem_verify_tests":
364+
limit_ratio = data.get("results", {}).get("limit_ratio", {})
365+
if isinstance(limit_ratio, dict):
366+
state["limit_case_ratio"] = limit_ratio.get("limit_case_ratio")
367+
else:
368+
state["limit_case_ratio"] = None
369+
gates = state.get("quality_gates", {})
370+
min_limit_ratio = 0.5
371+
if isinstance(gates, dict):
372+
try:
373+
min_limit_ratio = float(gates.get("min_limit_case_ratio", 0.5))
374+
except (TypeError, ValueError):
375+
min_limit_ratio = 0.5
376+
min_limit_ratio = min(1.0, max(0.0, min_limit_ratio))
377+
ratio_ok = True
378+
try:
379+
if state.get("limit_case_ratio") is not None:
380+
ratio_ok = float(state.get("limit_case_ratio")) >= min_limit_ratio
381+
except (TypeError, ValueError):
382+
ratio_ok = False
264383
state["tests_verified"] = success and bool(data.get("passed", False))
384+
state["tests_verified"] = state["tests_verified"] and ratio_ok
265385
save_state(problem_dir, state)
266386
return 0
267387

268388
if not success:
389+
if short_name == "validator_build":
390+
state["validator_ready"] = False
391+
state["validator_accuracy"] = None
392+
elif short_name == "generator_build":
393+
state["generator_built"] = False
394+
elif short_name == "stress_test_run":
395+
state["stress_completed_rounds"] = 0
396+
state["stress_total_rounds"] = 0
397+
state["stress_passed"] = False
398+
elif short_name == "checker_build":
399+
state["checker_accuracy"] = None
400+
state["checker_ready"] = False
401+
elif short_name == "problem_pack_polygon":
402+
state["packaged"] = False
403+
save_state(problem_dir, state)
269404
return 0
270405

271406
if short_name == "problem_create":
@@ -284,7 +419,7 @@ def post_tool(payload: dict[str, Any]) -> int:
284419
state["validator_accuracy"] = accuracy
285420
state["validator_ready"] = isinstance(accuracy, int | float) and accuracy >= 0.9
286421
elif short_name == "validator_select":
287-
state["validator_selected"] = True
422+
pass
288423
elif short_name == "generator_build":
289424
state["generator_built"] = True
290425
elif short_name == "stress_test_run":
@@ -294,16 +429,11 @@ def post_tool(payload: dict[str, Any]) -> int:
294429
elif short_name == "checker_build":
295430
accuracy = data.get("accuracy")
296431
state["checker_accuracy"] = accuracy
297-
state["checker_ready"] = accuracy is None or accuracy >= 0.9
432+
state["checker_ready"] = isinstance(accuracy, int | float) and accuracy >= 0.9
298433
elif short_name == "interactor_build":
299434
pass_rate = data.get("pass_rate", 0)
300435
fail_rate = data.get("fail_rate", 0)
301436
state["interactor_ready"] = pass_rate == 1.0 and fail_rate >= 0.8
302-
elif short_name == "problem_generate_tests":
303-
generated_tests = data.get("generated_tests", [])
304-
state["tests_generated"] = bool(generated_tests)
305-
state["generated_test_count"] = len(generated_tests)
306-
state["tests_verified"] = False
307437
elif short_name == "problem_pack_polygon":
308438
state["packaged"] = True
309439

src/autocode_mcp/tools/problem.py

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,95 @@ async def execute(
12431243
if not os.path.exists(problem_dir):
12441244
return ToolResult.fail(f"Problem directory not found: {problem_dir}")
12451245

1246+
tests_dir = os.path.join(problem_dir, "tests")
1247+
if not os.path.isdir(tests_dir):
1248+
return ToolResult.fail("tests directory not found, run problem_generate_tests first")
1249+
in_files = sorted(f for f in os.listdir(tests_dir) if f.endswith(".in"))
1250+
if not in_files:
1251+
return ToolResult.fail("no test input files found, run problem_generate_tests first")
1252+
1253+
manifest_path = os.path.join(tests_dir, _TEST_MANIFEST_FILENAME)
1254+
answer_ext = ".ans"
1255+
if os.path.exists(manifest_path):
1256+
try:
1257+
with open(manifest_path, encoding="utf-8") as mf:
1258+
manifest = json.load(mf)
1259+
answer_ext = _normalize_answer_ext_value(str(manifest.get("answer_ext", ".ans"))) or ".ans"
1260+
except (OSError, json.JSONDecodeError):
1261+
answer_ext = ".ans"
1262+
1263+
missing_answers = [
1264+
in_file for in_file in in_files if not os.path.exists(
1265+
os.path.join(tests_dir, f"{os.path.splitext(in_file)[0]}{answer_ext}")
1266+
)
1267+
]
1268+
if missing_answers:
1269+
return ToolResult.fail(
1270+
"missing answer files for some tests",
1271+
missing_answer_inputs=missing_answers,
1272+
)
1273+
1274+
statement_path = os.path.join(problem_dir, "statements", "README.md")
1275+
if not os.path.exists(statement_path):
1276+
return ToolResult.fail("statement file missing: statements/README.md")
1277+
1278+
sol_source = os.path.join(problem_dir, "solutions", "sol.cpp")
1279+
root_sol_source = os.path.join(problem_dir, "sol.cpp")
1280+
if not os.path.exists(sol_source) and not os.path.exists(root_sol_source):
1281+
return ToolResult.fail("main solution source missing: solutions/sol.cpp")
1282+
1283+
workflow_state_path = os.path.join(problem_dir, ".autocode-workflow", "state.json")
1284+
require_tests_verified = True
1285+
min_limit_case_ratio = 0.5
1286+
problem_manifest_path = os.path.join(problem_dir, "autocode.json")
1287+
if os.path.exists(problem_manifest_path):
1288+
try:
1289+
with open(problem_manifest_path, encoding="utf-8") as pmf:
1290+
problem_manifest = json.load(pmf)
1291+
quality_gates = problem_manifest.get("quality_gates", {})
1292+
if isinstance(quality_gates, dict):
1293+
require_tests_verified = bool(quality_gates.get("require_tests_verified", True))
1294+
min_limit_case_ratio = float(quality_gates.get("min_limit_case_ratio", 0.5))
1295+
except (OSError, json.JSONDecodeError, TypeError, ValueError):
1296+
require_tests_verified = True
1297+
min_limit_case_ratio = 0.5
1298+
min_limit_case_ratio = min(1.0, max(0.0, min_limit_case_ratio))
1299+
1300+
if os.path.exists(workflow_state_path):
1301+
try:
1302+
with open(workflow_state_path, encoding="utf-8") as sf:
1303+
workflow_state = json.load(sf)
1304+
if require_tests_verified and not bool(workflow_state.get("tests_verified", False)):
1305+
return ToolResult.fail(
1306+
"tests are not verified, run problem_verify_tests first",
1307+
tests_verified=False,
1308+
)
1309+
except (OSError, json.JSONDecodeError):
1310+
return ToolResult.fail("invalid workflow state file, rerun verification steps")
1311+
1312+
limit_ratio_from_manifest = None
1313+
if os.path.exists(manifest_path):
1314+
try:
1315+
with open(manifest_path, encoding="utf-8") as mf:
1316+
tests_manifest = json.load(mf)
1317+
tests = tests_manifest.get("tests", [])
1318+
if isinstance(tests, list) and tests:
1319+
total = len(tests)
1320+
limit_count = sum(
1321+
1
1322+
for item in tests
1323+
if isinstance(item, dict) and str(item.get("type_param")) in _LIMIT_STRATEGY_TYPES
1324+
)
1325+
limit_ratio_from_manifest = limit_count / total
1326+
except (OSError, json.JSONDecodeError):
1327+
limit_ratio_from_manifest = None
1328+
if limit_ratio_from_manifest is not None and limit_ratio_from_manifest < min_limit_case_ratio:
1329+
return ToolResult.fail(
1330+
"limit case ratio is below quality_gates.min_limit_case_ratio",
1331+
limit_case_ratio=limit_ratio_from_manifest,
1332+
min_limit_case_ratio=min_limit_case_ratio,
1333+
)
1334+
12461335
# 转换单位:秒 -> 毫秒,MB -> 字节
12471336
time_limit_ms = time_limit * 1000
12481337
memory_limit_bytes = memory_limit * 1024 * 1024
@@ -1291,22 +1380,11 @@ async def execute(
12911380
problem_xml = os.path.join(problem_dir, "problem.xml")
12921381
if not os.path.exists(problem_xml):
12931382
# 动态计算测试数量
1294-
tests_dir = os.path.join(problem_dir, "tests")
12951383
if os.path.exists(tests_dir):
12961384
test_files = [f for f in os.listdir(tests_dir) if f.endswith(".in")]
12971385
actual_test_count = len(test_files)
1298-
manifest_path = os.path.join(tests_dir, _TEST_MANIFEST_FILENAME)
1299-
answer_ext = ".ans"
1300-
if os.path.exists(manifest_path):
1301-
try:
1302-
with open(manifest_path, encoding="utf-8") as mf:
1303-
manifest = json.load(mf)
1304-
answer_ext = _normalize_answer_ext_value(str(manifest.get("answer_ext", ".ans"))) or ".ans"
1305-
except (OSError, json.JSONDecodeError):
1306-
answer_ext = ".ans"
13071386
else:
13081387
actual_test_count = 0
1309-
answer_ext = ".ans"
13101388

13111389
problem_name = os.path.basename(problem_dir)
13121390
xml_problem_name = escape(problem_name, {'"': "&quot;"})

0 commit comments

Comments
 (0)