Skip to content

Commit f70d232

Browse files
committed
Align final check with step feedback flow
1 parent 7c829ef commit f70d232

File tree

3 files changed

+161
-33
lines changed

3 files changed

+161
-33
lines changed

evaluation_function/evaluation.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -817,27 +817,66 @@ def advice_text() -> str:
817817
for e in selected_errors
818818
if isinstance(e, dict)
819819
}
820+
driving_gear, smallgear, biggear = _get_gear_counts(out)
821+
shaft_long, shaft_short, shaft_total = _get_shaft_counts(out)
822+
spacer_long, spacer_short, spacer_total = _get_spacer_counts(out)
820823

821824
# Highest-priority prerequisite errors
822825
if "E_NO_GEARS" in codes:
823-
return False, MESSAGE_POLICY["mesh_ratio"]["fail"]
826+
return False, MESSAGE_POLICY["gear_inventory"]["no_gears"].format(
827+
driving_gear=driving_gear,
828+
smallgear=smallgear,
829+
biggear=biggear,
830+
)
831+
832+
if "E_NO_SHAFTS" in codes:
833+
return False, MESSAGE_POLICY["shaft"]["none_detected"]
824834

825-
if "E_NO_SHAFTS" in codes or "E_NO_GEAR11" in codes:
835+
if "E_NO_GEAR11" in codes:
826836
return False, MESSAGE_POLICY["mesh_ratio"]["fail"]
827837

838+
# Shaft checks
839+
if (
840+
"E_SHAFT_COUNT_MISMATCH" in codes
841+
or "E_SHAFT2_NOT_FOUND" in codes
842+
):
843+
if shaft_total == 0:
844+
return False, MESSAGE_POLICY["shaft"]["none_detected"]
845+
if shaft_total > 2:
846+
return False, MESSAGE_POLICY["shaft"]["too_many"]
847+
if shaft_total == 1 and shaft_long == 1 and shaft_short == 0:
848+
return False, MESSAGE_POLICY["shaft"]["short_missing"]
849+
if shaft_total == 1 and shaft_short == 1 and shaft_long == 0:
850+
return False, MESSAGE_POLICY["shaft"]["long_missing"]
851+
return False, MESSAGE_POLICY["shaft"]["count_fail"]
852+
853+
if "E_SHAFT_TYPE_CONFUSION" in codes:
854+
return False, MESSAGE_POLICY["shaft"]["type_confusion"]
855+
856+
if (
857+
"E_SHAFT_POSITION_SWAP" in codes
858+
or "E_SHAFT2_CLASS_MISMATCH" in codes
859+
or "E_SHAFT3_CLASS_MISMATCH" in codes
860+
):
861+
return False, MESSAGE_POLICY["shaft"]["position_swap"]
862+
828863
# Spacer checks
829864
if "E_SPACER_SHORT_MISSING" in codes:
830865
return False, MESSAGE_POLICY["spacer"]["short_missing"]
831866

832867
if "E_SPACER_LONG_MISSING" in codes:
833868
return False, MESSAGE_POLICY["spacer"]["long_missing"]
834869

835-
if "E_SPACER_COUNT_MISMATCH" in codes:
836-
return False, MESSAGE_POLICY["spacer"]["count_fail"]
837-
838870
if "E_SPACER_TYPE_CONFUSION" in codes:
839871
return False, MESSAGE_POLICY["spacer"]["type_confusion"]
840872

873+
if "E_SPACER_COUNT_MISMATCH" in codes:
874+
if spacer_total == 0:
875+
return False, MESSAGE_POLICY["spacer"]["none_detected"]
876+
if spacer_total > 2:
877+
return False, MESSAGE_POLICY["spacer"]["too_many"]
878+
return False, MESSAGE_POLICY["spacer"]["count_fail"]
879+
841880
if (
842881
"E_SPACER_ASSIGNMENT_FAIL" in codes
843882
or "E_SPACER2_MISSING" in codes
@@ -855,28 +894,33 @@ def advice_text() -> str:
855894
if "E_SPACER_DISTANCE_ORDER" in codes:
856895
return False, MESSAGE_POLICY["spacer"]["distance_order"]
857896

858-
# Shaft checks
859-
if (
860-
"E_SHAFT_POSITION_SWAP" in codes
861-
or "E_SHAFT2_CLASS_MISMATCH" in codes
862-
or "E_SHAFT3_CLASS_MISMATCH" in codes
863-
):
864-
return False, MESSAGE_POLICY["shaft"]["position_swap"]
865-
866-
if (
867-
"E_SHAFT_COUNT_MISMATCH" in codes
868-
or "E_SHAFT2_NOT_FOUND" in codes
869-
):
870-
return False, MESSAGE_POLICY["shaft"]["count_fail"]
897+
# Mesh and consistency checks
898+
if "E_GEAR_BIG_SMALL_INCONSISTENT" in codes:
899+
return False, MESSAGE_POLICY["gear_inventory"]["big_small_inconsistent"].format(
900+
driving_gear=driving_gear,
901+
smallgear=smallgear,
902+
biggear=biggear,
903+
)
871904

872-
if "E_SHAFT_TYPE_CONFUSION" in codes:
873-
return False, MESSAGE_POLICY["shaft"]["type_confusion"]
905+
if "E_GEAR_CONTACT_INCONSISTENT" in codes:
906+
return False, MESSAGE_POLICY["gear_inventory"]["contact_consistency_fail"].format(
907+
driving_gear=driving_gear,
908+
smallgear=smallgear,
909+
biggear=biggear,
910+
)
874911

875-
# Mesh and consistency checks
876912
if (
877913
"E_MISMESH_DETECTED" in codes
878914
or "E_MESH_MISMATCH" in codes
879-
or "E_CONTACT_COUNT_MISMATCH" in codes
915+
):
916+
return False, MESSAGE_POLICY["gear_inventory"]["mismatch_fail"].format(
917+
driving_gear=driving_gear,
918+
smallgear=smallgear,
919+
biggear=biggear,
920+
)
921+
922+
if (
923+
"E_CONTACT_COUNT_MISMATCH" in codes
880924
or "E_GEAR_COUNT_UNSUPPORTED" in codes
881925
):
882926
return False, MESSAGE_POLICY["mesh_ratio"]["fail"]

evaluation_function/evaluation_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,59 @@ def test_spacer_position_checked_before_distance_order(self):
369369

370370
self.assertEqual(errors[0]["code"], "E_SPACER_POSITION_MISMATCH")
371371

372+
def test_mesh_ratio_uses_specific_shaft_count_feedback(self):
373+
errors = [{"code": "E_SHAFT_COUNT_MISMATCH", "message": "Expected 2 shafts."}]
374+
375+
is_correct, message = _build_student_message(
376+
task="mesh_ratio",
377+
img_bgr=np.zeros((10, 10, 3), dtype=np.uint8),
378+
out={"shaft_counts": {"shaft_long": 1, "shaft_short": 0}, "errors": errors},
379+
errors=errors,
380+
selected_errors=errors,
381+
part_type="",
382+
)
383+
384+
self.assertFalse(is_correct)
385+
self.assertIn("short shaft", message.lower())
386+
387+
def test_mesh_ratio_uses_specific_spacer_count_feedback(self):
388+
errors = [{"code": "E_SPACER_COUNT_MISMATCH", "message": "Expected 2 spacers."}]
389+
390+
is_correct, message = _build_student_message(
391+
task="mesh_ratio",
392+
img_bgr=np.zeros((10, 10, 3), dtype=np.uint8),
393+
out={"spacer_counts": {"spacer_long": 2, "spacer_short": 1}, "errors": errors},
394+
errors=errors,
395+
selected_errors=errors,
396+
part_type="",
397+
)
398+
399+
self.assertFalse(is_correct)
400+
self.assertIn("too many spacers", message.lower())
401+
402+
def test_mesh_ratio_prioritizes_shaft_before_spacer(self):
403+
errors = [
404+
{"code": "E_SHAFT_COUNT_MISMATCH", "message": "Expected 2 shafts."},
405+
{"code": "E_SPACER_COUNT_MISMATCH", "message": "Expected 2 spacers."},
406+
]
407+
408+
is_correct, message = _build_student_message(
409+
task="mesh_ratio",
410+
img_bgr=np.zeros((10, 10, 3), dtype=np.uint8),
411+
out={
412+
"shaft_counts": {"shaft_long": 1, "shaft_short": 0},
413+
"spacer_counts": {"spacer_long": 2, "spacer_short": 1},
414+
"errors": errors,
415+
},
416+
errors=errors,
417+
selected_errors=errors,
418+
part_type="",
419+
)
420+
421+
self.assertFalse(is_correct)
422+
self.assertIn("short shaft", message.lower())
423+
self.assertNotIn("too many spacers", message.lower())
424+
372425

373426
if __name__ == "__main__":
374427
unittest.main()

evaluation_function/yolo_pipeline.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,18 +2663,47 @@ def _need_shafts() -> bool:
26632663
gears, spacers, shaft_obbs
26642664
)
26652665

2666+
shaft_counts = get_shaft_counts(shaft_obbs)
2667+
spacer_counts = get_spacer_counts(spacers)
2668+
26662669
errors: List[Dict[str, str]] = []
2667-
if ENABLE_ERROR_CHECKS and shaft_obbs:
2668-
errors = evaluate_assembly_errors(
2669-
gears=gears,
2670-
spacers=spacers,
2671-
shafts=shaft_obbs,
2672-
mesh_boxes=mesh_boxes,
2673-
mismesh_boxes=mismesh_boxes,
2674-
gear11_gid=int(gear11_gid),
2675-
gear_to_si=gear_to_si,
2676-
spacer_to_si=spacer_to_si,
2677-
)
2670+
if ENABLE_ERROR_CHECKS:
2671+
if not shaft_obbs:
2672+
errors = [{"code": "E_NO_SHAFTS", "message": "No shafts detected."}]
2673+
else:
2674+
errors = evaluate_shaft_step_errors(
2675+
gears=gears,
2676+
shafts=shaft_obbs,
2677+
gear11_gid=int(gear11_gid),
2678+
)
2679+
2680+
if not errors:
2681+
errors = evaluate_spacer_step_errors(
2682+
gears=gears,
2683+
spacers=spacers,
2684+
shafts=shaft_obbs,
2685+
gear11_gid=int(gear11_gid),
2686+
spacer_to_si=spacer_to_si,
2687+
)
2688+
2689+
if not errors:
2690+
errors, _gear_counts = evaluate_gear_inventory_step(
2691+
gears=gears,
2692+
mesh_boxes=mesh_boxes,
2693+
mismesh_boxes=mismesh_boxes,
2694+
)
2695+
2696+
if not errors:
2697+
errors = evaluate_assembly_errors(
2698+
gears=gears,
2699+
spacers=spacers,
2700+
shafts=shaft_obbs,
2701+
mesh_boxes=mesh_boxes,
2702+
mismesh_boxes=mismesh_boxes,
2703+
gear11_gid=int(gear11_gid),
2704+
gear_to_si=gear_to_si,
2705+
spacer_to_si=spacer_to_si,
2706+
)
26782707

26792708
gear_names, gear_stage, chain_pairs = stage_role_naming_chain(
26802709
gears=gears,
@@ -2703,6 +2732,8 @@ def _need_shafts() -> bool:
27032732
"stages": num_stages,
27042733
},
27052734
"counts": get_gear_counts(gears),
2735+
"shaft_counts": shaft_counts,
2736+
"spacer_counts": spacer_counts,
27062737
"detections": {
27072738
"gear_dets": gear_dets,
27082739
"aux_dets": aux_dets,

0 commit comments

Comments
 (0)