@@ -351,7 +351,8 @@ def _build_parts_inventory_message(out: Dict[str, Any], part_type: str) -> Tuple
351351 msg = (
352352 "Detected gears:\n "
353353 f"- biggear: { counts .get ('biggear' , 0 )} \n "
354- f"- smallgear: { counts .get ('smallgear' , 0 )} "
354+ f"- smallgear: { counts .get ('smallgear' , 0 )} \n "
355+ f"- driving gear: { counts .get ('driving_gear' , counts .get ('drivinggear' , 0 ))} "
355356 )
356357 if not is_correct :
357358 msg += "\n Please spread the gears out clearly and retake the photo."
@@ -451,7 +452,6 @@ def _build_student_message(
451452
452453 n_long , n_short , n_total = _get_shaft_counts (out )
453454
454- # ---- hard fallback by counts first ----
455455 if n_total > 0 :
456456 if n_total != 2 :
457457 return False , MESSAGE_POLICY ["shaft" ]["count_fail" ]
@@ -478,7 +478,6 @@ def _build_student_message(
478478 if task_has_error :
479479 return False , MESSAGE_POLICY ["shaft" ]["fail" ]
480480
481- # If no shaft counts are available, still allow pass only when pipeline reports no error.
482481 return True , MESSAGE_POLICY ["shaft" ]["pass" ]
483482
484483 if task == "spacer" :
@@ -490,9 +489,9 @@ def _build_student_message(
490489
491490 n_long , n_short , n_total = _get_spacer_counts (out )
492491
493- # ---- Layer 1: hard fallback by counts first ----
494- # Only activate when spacer counts are actually provided by the pipeline.
495- if ("spacer_long" in _get_counts_dict ( out )) or ("spacer_short" in _get_counts_dict ( out ) ):
492+ counts_dict = _get_counts_dict ( out )
493+
494+ if ("spacer_long" in counts_dict ) or ("spacer_short" in counts_dict ):
496495 if n_total == 0 :
497496 return False , MESSAGE_POLICY ["spacer" ]["count_fail" ]
498497
@@ -505,7 +504,6 @@ def _build_student_message(
505504 if n_short != 1 or n_long != 1 or n_total != 2 :
506505 return False , MESSAGE_POLICY ["spacer" ]["count_fail" ]
507506
508- # ---- Layer 2: explicit pipeline error codes ----
509507 if "E_SPACER_SHORT_MISSING" in codes :
510508 return False , MESSAGE_POLICY ["spacer" ]["short_missing" ]
511509
@@ -518,15 +516,13 @@ def _build_student_message(
518516 if "E_SPACER_TYPE_CONFUSION" in codes :
519517 return False , MESSAGE_POLICY ["spacer" ]["type_confusion" ]
520518
521- # Assignment / visibility ambiguity should not be mapped to position mismatch.
522519 if (
523520 "E_SPACER_ASSIGNMENT_FAIL" in codes
524521 or "E_SPACER2_MISSING" in codes
525522 or "E_SPACER3_MISSING" in codes
526523 ):
527524 return False , MESSAGE_POLICY ["spacer" ]["assignment_fail" ]
528525
529- # True geometric mismatch after assignment succeeded.
530526 if (
531527 "E_SPACER_POSITION_MISMATCH" in codes
532528 or "E_SPACER2_TYPE_MISMATCH" in codes
@@ -618,6 +614,97 @@ def _build_student_message(
618614 return False , "Unsupported task."
619615
620616
617+ def _call_pipeline_with_fallbacks (
618+ * ,
619+ img_bgr : np .ndarray ,
620+ model_a_rel : str ,
621+ model_b_rel : str ,
622+ model_c_rel : str ,
623+ return_images : bool ,
624+ task : str ,
625+ part_type : str ,
626+ expected_gears : Any ,
627+ ) -> Dict [str , Any ]:
628+ """
629+ Preferred new signature:
630+ run_yolo_pipeline(
631+ img_bgr=...,
632+ model_a_rel=...,
633+ model_b_rel=...,
634+ model_c_rel=...,
635+ return_images=...,
636+ task=...,
637+ part_type=...,
638+ expected_gears=...,
639+ )
640+
641+ Backward-compatibility fallback:
642+ run_yolo_pipeline(
643+ img_bgr=...,
644+ gear_model_rel=...,
645+ shaft_model_rel=...,
646+ return_images=...,
647+ task=...,
648+ part_type=...,
649+ expected_gears=...,
650+ )
651+ """
652+ # New 3-model API
653+ try :
654+ return run_yolo_pipeline ( # type: ignore[misc]
655+ img_bgr = img_bgr ,
656+ model_a_rel = model_a_rel ,
657+ model_b_rel = model_b_rel ,
658+ model_c_rel = model_c_rel ,
659+ return_images = return_images ,
660+ task = task ,
661+ part_type = part_type ,
662+ expected_gears = expected_gears ,
663+ )
664+ except TypeError :
665+ pass
666+
667+ # New 3-model API without part_type
668+ try :
669+ return run_yolo_pipeline ( # type: ignore[misc]
670+ img_bgr = img_bgr ,
671+ model_a_rel = model_a_rel ,
672+ model_b_rel = model_b_rel ,
673+ model_c_rel = model_c_rel ,
674+ return_images = return_images ,
675+ task = task ,
676+ expected_gears = expected_gears ,
677+ )
678+ except TypeError :
679+ pass
680+
681+ # Old 2-model fallback:
682+ # model_a_rel -> gear_model_rel
683+ # model_b_rel -> shaft_model_rel
684+ try :
685+ return run_yolo_pipeline ( # type: ignore[misc]
686+ img_bgr = img_bgr ,
687+ gear_model_rel = model_a_rel ,
688+ shaft_model_rel = model_b_rel ,
689+ return_images = return_images ,
690+ task = task ,
691+ part_type = part_type ,
692+ expected_gears = expected_gears ,
693+ )
694+ except TypeError :
695+ pass
696+
697+ # Old 2-model fallback without part_type
698+ return run_yolo_pipeline ( # type: ignore[misc]
699+ img_bgr = img_bgr ,
700+ gear_model_rel = model_a_rel ,
701+ shaft_model_rel = model_b_rel ,
702+ return_images = return_images ,
703+ task = task ,
704+ expected_gears = expected_gears ,
705+ )
706+
707+
621708def evaluation_function (response : Any , answer : Any , params : Params ) -> Result :
622709 task = str (_pget (params , "task" , "full" ) or "full" ).strip ().lower ()
623710 part_type = str (_pget (params , "part_type" , "" ) or "" ).strip ().lower ()
@@ -632,8 +719,14 @@ def evaluation_function(response: Any, answer: Any, params: Params) -> Result:
632719 return_images : bool = bool (
633720 _pget (params , "return_images" , pipeline_task not in ("precheck" , "parts_inventory" , "single_stage" ))
634721 )
635- gear_model_rel = str (_pget (params , "gear_model_rel" , "gear_model.pt" ))
636- shaft_model_rel = str (_pget (params , "shaft_model_rel" , "shaft_model.pt" ))
722+
723+ # ----------------------------
724+ # New 3-model params
725+ # ----------------------------
726+ model_a_rel = str (_pget (params , "model_a_rel" , _pget (params , "gear_model_rel" , "modelA.pt" )))
727+ model_b_rel = str (_pget (params , "model_b_rel" , _pget (params , "shaft_model_rel" , "modelB.pt" )))
728+ model_c_rel = str (_pget (params , "model_c_rel" , "modelC.pt" ))
729+
637730 expected_gears = _pget (params , "expected_gears" , None )
638731
639732 if not isinstance (response , list ) or len (response ) == 0 :
@@ -655,24 +748,16 @@ def evaluation_function(response: Any, answer: Any, params: Params) -> Result:
655748 )
656749
657750 try :
658- out : Dict [str , Any ] = run_yolo_pipeline ( # type: ignore[misc]
751+ out : Dict [str , Any ] = _call_pipeline_with_fallbacks (
659752 img_bgr = img_bgr ,
660- gear_model_rel = gear_model_rel ,
661- shaft_model_rel = shaft_model_rel ,
753+ model_a_rel = model_a_rel ,
754+ model_b_rel = model_b_rel ,
755+ model_c_rel = model_c_rel ,
662756 return_images = return_images ,
663757 task = pipeline_task ,
664758 part_type = part_type ,
665759 expected_gears = expected_gears ,
666760 )
667- except TypeError :
668- out = run_yolo_pipeline ( # type: ignore[misc]
669- img_bgr = img_bgr ,
670- gear_model_rel = gear_model_rel ,
671- shaft_model_rel = shaft_model_rel ,
672- return_images = return_images ,
673- task = pipeline_task ,
674- expected_gears = expected_gears ,
675- )
676761 except Exception :
677762 return _result_minimal (
678763 False ,
0 commit comments