Skip to content

Commit aa8634f

Browse files
committed
fix spacer task outputs and layered validation
1 parent c0ee791 commit aa8634f

File tree

2 files changed

+260
-130
lines changed

2 files changed

+260
-130
lines changed

evaluation_function/evaluation.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"long_missing": "The long spacer may be missing or not detected. Please check whether the long spacer is installed and retake the photo if needed.",
7070
"count_fail": "A spacer may be missing or not detected. Please check whether both spacers are installed and retake the photo if needed.",
7171
"type_confusion": "The spacer types could not be identified reliably. Please retake the photo from a clearer angle and make sure both spacers are fully visible.",
72+
"assignment_fail": "The spacers were detected, but their positions could not be determined reliably. Please retake the photo from a clearer top view and make sure both spacers are fully visible.",
7273
"position_mismatch": "The spacer positions appear to be incorrect. Please make sure the short spacer is on the short shaft and the long spacer is on the long shaft.",
7374
"distance_order": "The spacer order appears to be incorrect. Please check whether the short spacer is closer to gear 1 than the long spacer.",
7475
"fail": "Please check the spacer setup again.",
@@ -191,6 +192,16 @@ def _result_minimal(is_correct: bool, message: str, *, max_chars: int = _MAX_FEE
191192
return Result(is_correct=is_correct)
192193

193194

195+
def _safe_int(x: Any, default: int = 0) -> int:
196+
try:
197+
return int(x)
198+
except Exception:
199+
try:
200+
return int(float(x))
201+
except Exception:
202+
return default
203+
204+
194205
def _select_errors_by_task(errors: List[Dict[str, Any]], task: str) -> List[Dict[str, Any]]:
195206
task = (task or "full").strip().lower()
196207

@@ -218,14 +229,23 @@ def keep(e: Dict[str, Any]) -> bool:
218229
return code.startswith("E_SINGLE_STAGE") or code == "E_NO_GEARS"
219230

220231
if task == "shaft":
221-
return code.startswith("E_SHAFT") or code == "E_NO_SHAFTS" or code == "E_NO_GEAR11" or code == "E_NO_GEARS"
232+
return (
233+
code.startswith("E_SHAFT")
234+
or code == "E_NO_SHAFTS"
235+
or code == "E_NO_GEAR11"
236+
or code == "E_NO_GEARS"
237+
)
222238

223239
if task == "spacer":
224240
return (
225241
code.startswith("E_SPACER")
242+
or code.startswith("E_ASSIGN")
243+
or code.startswith("E_PARTS")
244+
or code == "E_BAD_PART_TYPE"
245+
or code == "E_NO_TARGET_PARTS"
226246
or code == "E_NO_SHAFTS"
227-
or code == "E_NO_GEARS"
228247
or code == "E_NO_GEAR11"
248+
or code == "E_NO_GEARS"
229249
)
230250

231251
if task == "gear_inventory":
@@ -284,16 +304,6 @@ def _format_rpm_value(out_rpm: Any) -> str:
284304
return str(out_rpm)
285305

286306

287-
def _safe_int(x: Any, default: int = 0) -> int:
288-
try:
289-
return int(x)
290-
except Exception:
291-
try:
292-
return int(float(x))
293-
except Exception:
294-
return default
295-
296-
297307
def _get_counts_dict(out: Dict[str, Any]) -> Dict[str, Any]:
298308
return out.get("counts", {}) if isinstance(out.get("counts"), dict) else {}
299309

@@ -310,6 +320,26 @@ def _get_gear_counts(out: Dict[str, Any]) -> Tuple[int, int, int]:
310320
return driving_gear, smallgear, biggear
311321

312322

323+
def _get_shaft_counts(out: Dict[str, Any]) -> Tuple[int, int, int]:
324+
counts = _get_counts_dict(out)
325+
326+
n_long = _safe_int(counts.get("shaft_long", 0))
327+
n_short = _safe_int(counts.get("shaft_short", 0))
328+
n_total = n_long + n_short
329+
330+
return n_long, n_short, n_total
331+
332+
333+
def _get_spacer_counts(out: Dict[str, Any]) -> Tuple[int, int, int]:
334+
counts = _get_counts_dict(out)
335+
336+
n_long = _safe_int(counts.get("spacer_long", 0))
337+
n_short = _safe_int(counts.get("spacer_short", 0))
338+
n_total = n_long + n_short
339+
340+
return n_long, n_short, n_total
341+
342+
313343
def _build_parts_inventory_message(out: Dict[str, Any], part_type: str) -> Tuple[bool, str]:
314344
counts = _get_counts_dict(out)
315345
errors = out.get("errors", []) if isinstance(out.get("errors"), list) else []
@@ -419,6 +449,16 @@ def _build_student_message(
419449
if isinstance(e, dict)
420450
}
421451

452+
n_long, n_short, n_total = _get_shaft_counts(out)
453+
454+
# ---- hard fallback by counts first ----
455+
if n_total > 0:
456+
if n_total != 2:
457+
return False, MESSAGE_POLICY["shaft"]["count_fail"]
458+
459+
if n_short != 1 or n_long != 1:
460+
return False, MESSAGE_POLICY["shaft"]["type_confusion"]
461+
422462
if (
423463
"E_SHAFT_COUNT_MISMATCH" in codes
424464
or "E_NO_SHAFTS" in codes
@@ -438,6 +478,7 @@ def _build_student_message(
438478
if task_has_error:
439479
return False, MESSAGE_POLICY["shaft"]["fail"]
440480

481+
# If no shaft counts are available, still allow pass only when pipeline reports no error.
441482
return True, MESSAGE_POLICY["shaft"]["pass"]
442483

443484
if task == "spacer":
@@ -447,6 +488,24 @@ def _build_student_message(
447488
if isinstance(e, dict)
448489
}
449490

491+
n_long, n_short, n_total = _get_spacer_counts(out)
492+
493+
# ---- Layer 1: hard fallback by counts first ----
494+
# Only activate when spacer counts are actually provided by the pipeline.
495+
if ("spacer_long" in _get_counts_dict(out)) or ("spacer_short" in _get_counts_dict(out)):
496+
if n_total == 0:
497+
return False, MESSAGE_POLICY["spacer"]["count_fail"]
498+
499+
if n_short == 0 and n_long >= 1:
500+
return False, MESSAGE_POLICY["spacer"]["short_missing"]
501+
502+
if n_long == 0 and n_short >= 1:
503+
return False, MESSAGE_POLICY["spacer"]["long_missing"]
504+
505+
if n_short != 1 or n_long != 1 or n_total != 2:
506+
return False, MESSAGE_POLICY["spacer"]["count_fail"]
507+
508+
# ---- Layer 2: explicit pipeline error codes ----
450509
if "E_SPACER_SHORT_MISSING" in codes:
451510
return False, MESSAGE_POLICY["spacer"]["short_missing"]
452511

@@ -459,6 +518,15 @@ def _build_student_message(
459518
if "E_SPACER_TYPE_CONFUSION" in codes:
460519
return False, MESSAGE_POLICY["spacer"]["type_confusion"]
461520

521+
# Assignment / visibility ambiguity should not be mapped to position mismatch.
522+
if (
523+
"E_SPACER_ASSIGNMENT_FAIL" in codes
524+
or "E_SPACER2_MISSING" in codes
525+
or "E_SPACER3_MISSING" in codes
526+
):
527+
return False, MESSAGE_POLICY["spacer"]["assignment_fail"]
528+
529+
# True geometric mismatch after assignment succeeded.
462530
if (
463531
"E_SPACER_POSITION_MISMATCH" in codes
464532
or "E_SPACER2_TYPE_MISMATCH" in codes

0 commit comments

Comments
 (0)