Skip to content

Commit 56b030f

Browse files
committed
refactor pipeline to support three-model inference (modelABC)
1 parent aa8634f commit 56b030f

File tree

6 files changed

+251
-95
lines changed

6 files changed

+251
-95
lines changed

evaluation_function/evaluation.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ def _build_parts_inventory_message(out: Dict[str, Any], part_type: str) -> Tuple
351351
msg = (
352352
"Detected gears:\n"
353353
f"- biggear: {counts.get('biggear', 0)}\n"
354-
f"- smallgear: {counts.get('smallgear', 0)}"
354+
f"- smallgear: {counts.get('smallgear', 0)}\n"
355+
f"- driving gear: {counts.get('driving_gear', counts.get('drivinggear', 0))}"
355356
)
356357
if not is_correct:
357358
msg += "\nPlease spread the gears out clearly and retake the photo."
@@ -451,7 +452,6 @@ def _build_student_message(
451452

452453
n_long, n_short, n_total = _get_shaft_counts(out)
453454

454-
# ---- hard fallback by counts first ----
455455
if n_total > 0:
456456
if n_total != 2:
457457
return False, MESSAGE_POLICY["shaft"]["count_fail"]
@@ -478,7 +478,6 @@ def _build_student_message(
478478
if task_has_error:
479479
return False, MESSAGE_POLICY["shaft"]["fail"]
480480

481-
# If no shaft counts are available, still allow pass only when pipeline reports no error.
482481
return True, MESSAGE_POLICY["shaft"]["pass"]
483482

484483
if task == "spacer":
@@ -490,9 +489,9 @@ def _build_student_message(
490489

491490
n_long, n_short, n_total = _get_spacer_counts(out)
492491

493-
# ---- Layer 1: hard fallback by counts first ----
494-
# Only activate when spacer counts are actually provided by the pipeline.
495-
if ("spacer_long" in _get_counts_dict(out)) or ("spacer_short" in _get_counts_dict(out)):
492+
counts_dict = _get_counts_dict(out)
493+
494+
if ("spacer_long" in counts_dict) or ("spacer_short" in counts_dict):
496495
if n_total == 0:
497496
return False, MESSAGE_POLICY["spacer"]["count_fail"]
498497

@@ -505,7 +504,6 @@ def _build_student_message(
505504
if n_short != 1 or n_long != 1 or n_total != 2:
506505
return False, MESSAGE_POLICY["spacer"]["count_fail"]
507506

508-
# ---- Layer 2: explicit pipeline error codes ----
509507
if "E_SPACER_SHORT_MISSING" in codes:
510508
return False, MESSAGE_POLICY["spacer"]["short_missing"]
511509

@@ -518,15 +516,13 @@ def _build_student_message(
518516
if "E_SPACER_TYPE_CONFUSION" in codes:
519517
return False, MESSAGE_POLICY["spacer"]["type_confusion"]
520518

521-
# Assignment / visibility ambiguity should not be mapped to position mismatch.
522519
if (
523520
"E_SPACER_ASSIGNMENT_FAIL" in codes
524521
or "E_SPACER2_MISSING" in codes
525522
or "E_SPACER3_MISSING" in codes
526523
):
527524
return False, MESSAGE_POLICY["spacer"]["assignment_fail"]
528525

529-
# True geometric mismatch after assignment succeeded.
530526
if (
531527
"E_SPACER_POSITION_MISMATCH" in codes
532528
or "E_SPACER2_TYPE_MISMATCH" in codes
@@ -618,6 +614,97 @@ def _build_student_message(
618614
return False, "Unsupported task."
619615

620616

617+
def _call_pipeline_with_fallbacks(
618+
*,
619+
img_bgr: np.ndarray,
620+
model_a_rel: str,
621+
model_b_rel: str,
622+
model_c_rel: str,
623+
return_images: bool,
624+
task: str,
625+
part_type: str,
626+
expected_gears: Any,
627+
) -> Dict[str, Any]:
628+
"""
629+
Preferred new signature:
630+
run_yolo_pipeline(
631+
img_bgr=...,
632+
model_a_rel=...,
633+
model_b_rel=...,
634+
model_c_rel=...,
635+
return_images=...,
636+
task=...,
637+
part_type=...,
638+
expected_gears=...,
639+
)
640+
641+
Backward-compatibility fallback:
642+
run_yolo_pipeline(
643+
img_bgr=...,
644+
gear_model_rel=...,
645+
shaft_model_rel=...,
646+
return_images=...,
647+
task=...,
648+
part_type=...,
649+
expected_gears=...,
650+
)
651+
"""
652+
# New 3-model API
653+
try:
654+
return run_yolo_pipeline( # type: ignore[misc]
655+
img_bgr=img_bgr,
656+
model_a_rel=model_a_rel,
657+
model_b_rel=model_b_rel,
658+
model_c_rel=model_c_rel,
659+
return_images=return_images,
660+
task=task,
661+
part_type=part_type,
662+
expected_gears=expected_gears,
663+
)
664+
except TypeError:
665+
pass
666+
667+
# New 3-model API without part_type
668+
try:
669+
return run_yolo_pipeline( # type: ignore[misc]
670+
img_bgr=img_bgr,
671+
model_a_rel=model_a_rel,
672+
model_b_rel=model_b_rel,
673+
model_c_rel=model_c_rel,
674+
return_images=return_images,
675+
task=task,
676+
expected_gears=expected_gears,
677+
)
678+
except TypeError:
679+
pass
680+
681+
# Old 2-model fallback:
682+
# model_a_rel -> gear_model_rel
683+
# model_b_rel -> shaft_model_rel
684+
try:
685+
return run_yolo_pipeline( # type: ignore[misc]
686+
img_bgr=img_bgr,
687+
gear_model_rel=model_a_rel,
688+
shaft_model_rel=model_b_rel,
689+
return_images=return_images,
690+
task=task,
691+
part_type=part_type,
692+
expected_gears=expected_gears,
693+
)
694+
except TypeError:
695+
pass
696+
697+
# Old 2-model fallback without part_type
698+
return run_yolo_pipeline( # type: ignore[misc]
699+
img_bgr=img_bgr,
700+
gear_model_rel=model_a_rel,
701+
shaft_model_rel=model_b_rel,
702+
return_images=return_images,
703+
task=task,
704+
expected_gears=expected_gears,
705+
)
706+
707+
621708
def evaluation_function(response: Any, answer: Any, params: Params) -> Result:
622709
task = str(_pget(params, "task", "full") or "full").strip().lower()
623710
part_type = str(_pget(params, "part_type", "") or "").strip().lower()
@@ -632,8 +719,14 @@ def evaluation_function(response: Any, answer: Any, params: Params) -> Result:
632719
return_images: bool = bool(
633720
_pget(params, "return_images", pipeline_task not in ("precheck", "parts_inventory", "single_stage"))
634721
)
635-
gear_model_rel = str(_pget(params, "gear_model_rel", "gear_model.pt"))
636-
shaft_model_rel = str(_pget(params, "shaft_model_rel", "shaft_model.pt"))
722+
723+
# ----------------------------
724+
# New 3-model params
725+
# ----------------------------
726+
model_a_rel = str(_pget(params, "model_a_rel", _pget(params, "gear_model_rel", "modelA.pt")))
727+
model_b_rel = str(_pget(params, "model_b_rel", _pget(params, "shaft_model_rel", "modelB.pt")))
728+
model_c_rel = str(_pget(params, "model_c_rel", "modelC.pt"))
729+
637730
expected_gears = _pget(params, "expected_gears", None)
638731

639732
if not isinstance(response, list) or len(response) == 0:
@@ -655,24 +748,16 @@ def evaluation_function(response: Any, answer: Any, params: Params) -> Result:
655748
)
656749

657750
try:
658-
out: Dict[str, Any] = run_yolo_pipeline( # type: ignore[misc]
751+
out: Dict[str, Any] = _call_pipeline_with_fallbacks(
659752
img_bgr=img_bgr,
660-
gear_model_rel=gear_model_rel,
661-
shaft_model_rel=shaft_model_rel,
753+
model_a_rel=model_a_rel,
754+
model_b_rel=model_b_rel,
755+
model_c_rel=model_c_rel,
662756
return_images=return_images,
663757
task=pipeline_task,
664758
part_type=part_type,
665759
expected_gears=expected_gears,
666760
)
667-
except TypeError:
668-
out = run_yolo_pipeline( # type: ignore[misc]
669-
img_bgr=img_bgr,
670-
gear_model_rel=gear_model_rel,
671-
shaft_model_rel=shaft_model_rel,
672-
return_images=return_images,
673-
task=pipeline_task,
674-
expected_gears=expected_gears,
675-
)
676761
except Exception:
677762
return _result_minimal(
678763
False,

evaluation_function/local_run.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,32 @@
88
# ---- local test image ----
99
IMAGE_PATH = r"C:\Users\sheng\Desktop\Test.jpg"
1010

11-
# ---- task controls (edit these) ----
12-
TASK = "shaft" # "shaft" | "spacer" | "gear_inventory" | "mesh_ratio" | "full"
13-
EXPECTED_GEARS = None # e.g. 7 (only used if TASK=="gear_inventory")
14-
RUN_ALL_TASKS = False # True to run all tasks in sequence
11+
# ---- model files ----
12+
MODEL_A_REL = "modelA.pt"
13+
MODEL_B_REL = "modelB.pt"
14+
MODEL_C_REL = "modelC.pt"
1515

16-
RETURN_IMAGES = True # save det/label images if produced
16+
# ---- task controls ----
17+
TASK = "full" # "parts_inventory" | "shaft" | "spacer" | "gear_inventory" | "mesh_ratio" | "full"
18+
PART_TYPE = "gear" # "gear" | "shaft" | "spacer" (only used when TASK=="parts_inventory")
19+
EXPECTED_GEARS = None
20+
RUN_ALL_TASKS = False
21+
22+
RETURN_IMAGES = False
1723

1824
OUT_DIR = os.path.join(os.path.dirname(__file__), "_local_out")
1925
os.makedirs(OUT_DIR, exist_ok=True)
2026

2127

2228
def load_bgr_image_from_url(url: str):
23-
"""
24-
Correctly load image from:
25-
- file:// URL (Windows / Linux / macOS)
26-
"""
2729
if not isinstance(url, str) or not url:
2830
return None, "URL is empty or not a string."
2931

3032
parsed = urlparse(url)
3133

3234
if parsed.scheme == "file":
33-
# parsed.path on Windows looks like: /C:/Users/...
3435
path = unquote(parsed.path)
3536

36-
# Fix Windows leading slash: /C:/... -> C:/...
3737
if os.name == "nt" and path.startswith("/"):
3838
path = path[1:]
3939

@@ -73,6 +73,9 @@ def _print_result(task_name: str, result: dict):
7373
print("\n===== SUMMARY =====")
7474
print(result.get("summary", {}))
7575

76+
print("\n===== COUNTS =====")
77+
print(result.get("counts", {}))
78+
7679
print("\n===== RATIO =====")
7780
print(result.get("ratio", {}))
7881

@@ -114,28 +117,32 @@ def _save_images(task_name: str, result: dict):
114117
print("[WARN] label_img not found in result['images'].")
115118

116119

117-
def _run_one(task_name: str, img_bgr: np.ndarray):
118-
# NOTE: This assumes your yolo_pipeline.run_yolo_pipeline supports task=... and expected_gears=...
119-
# If not yet pushed/updated, you can temporarily remove these kwargs.
120+
def _run_one(task_name: str, img_bgr: np.ndarray, part_type: str | None = None):
120121
kwargs = {
122+
"model_a_rel": MODEL_A_REL,
123+
"model_b_rel": MODEL_B_REL,
124+
"model_c_rel": MODEL_C_REL,
121125
"return_images": RETURN_IMAGES,
122126
"task": task_name,
123127
}
128+
129+
if task_name == "parts_inventory" and part_type:
130+
kwargs["part_type"] = part_type
131+
124132
if task_name == "gear_inventory" and EXPECTED_GEARS is not None:
125133
kwargs["expected_gears"] = EXPECTED_GEARS
126134

127135
result = run_yolo_pipeline(img_bgr, **kwargs)
128136
_print_result(task_name, result)
137+
129138
if RETURN_IMAGES:
130139
_save_images(task_name, result)
131140

132141

133142
def main():
134-
# 1) Build a Lambda-like response payload
135143
abs_path = os.path.abspath(IMAGE_PATH).replace("\\", "/")
136144
response = [{"url": f"file:///{abs_path}"}]
137145

138-
# 2) Load image via URL (simulate platform URL flow)
139146
url = response[0].get("url")
140147
img, err = load_bgr_image_from_url(url)
141148
if err:
@@ -144,13 +151,20 @@ def main():
144151
print("\n===== RESPONSE (URL) =====")
145152
print(url)
146153

147-
# 3) Run pipeline by task(s)
148154
if RUN_ALL_TASKS:
149-
tasks = ["shaft", "spacer", "gear_inventory", "mesh_ratio", "full"]
150-
for t in tasks:
151-
_run_one(t, img)
155+
_run_one("parts_inventory", img, "gear")
156+
_run_one("parts_inventory", img, "shaft")
157+
_run_one("parts_inventory", img, "spacer")
158+
_run_one("shaft", img)
159+
_run_one("spacer", img)
160+
_run_one("gear_inventory", img)
161+
_run_one("mesh_ratio", img)
162+
_run_one("full", img)
152163
else:
153-
_run_one(TASK, img)
164+
if TASK == "parts_inventory":
165+
_run_one(TASK, img, PART_TYPE)
166+
else:
167+
_run_one(TASK, img)
154168

155169

156170
if __name__ == "__main__":

evaluation_function/modelA.pt

18.3 MB
Binary file not shown.

evaluation_function/modelB.pt

19.1 MB
Binary file not shown.

evaluation_function/modelC.pt

18.3 MB
Binary file not shown.

0 commit comments

Comments
 (0)