Skip to content

Commit 93d586e

Browse files
committed
Update local_run and "gears on shaft"logic
1 parent cd0bd06 commit 93d586e

File tree

2 files changed

+713
-127
lines changed

2 files changed

+713
-127
lines changed

evaluation_function/local_run.py

Lines changed: 184 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
11
from urllib.parse import urlparse, unquote
22
import os
33
import cv2
4+
import json
5+
import shutil
46
import numpy as np
57

68
from evaluation_function.yolo_pipeline import run_yolo_pipeline
79

8-
# ---- local test image ----
9-
IMAGE_PATH = r"C:\Users\sheng\Desktop\Test.jpg"
10+
# ---- local test image or folder ----
11+
IMAGE_PATH = r"C:\Users\sheng\Desktop\Eastern week 3"
1012

1113
# ---- model files ----
1214
MODEL_A_REL = "modelA.pt"
1315
MODEL_B_REL = "modelB.pt"
1416
MODEL_C_REL = "modelC.pt"
1517

18+
# ---- where you want to save outputs ----
19+
SAVE_ROOT = r"C:\Users\sheng\Desktop\Pipeline_saved_results"
20+
1621
# ---- task controls ----
1722
TASK = "full" # "parts_inventory" | "shaft" | "spacer" | "gear_inventory" | "mesh_ratio" | "full"
18-
PART_TYPE = "gear" # "gear" | "shaft" | "spacer" (only used when TASK=="parts_inventory")
23+
#PART_TYPE = "gear" # "gear" | "shaft" | "spacer" (only used when TASK=="parts_inventory")
1924
EXPECTED_GEARS = None
2025
RUN_ALL_TASKS = False
2126

22-
RETURN_IMAGES = False
27+
# must be True if you want det/label images from pipeline
28+
RETURN_IMAGES = True
2329

24-
OUT_DIR = os.path.join(os.path.dirname(__file__), "_local_out")
25-
os.makedirs(OUT_DIR, exist_ok=True)
30+
os.makedirs(SAVE_ROOT, exist_ok=True)
2631

2732

2833
def load_bgr_image_from_url(url: str):
@@ -46,14 +51,147 @@ def load_bgr_image_from_url(url: str):
4651
return None, f"Unsupported URL scheme: {parsed.scheme}"
4752

4853

49-
def _safe_task_dir(task_name: str) -> str:
50-
t = (task_name or "full").strip().lower()
51-
d = os.path.join(OUT_DIR, t)
52-
os.makedirs(d, exist_ok=True)
53-
return d
54+
def find_image_files(path):
55+
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
56+
if os.path.isfile(path):
57+
return [path]
58+
59+
if os.path.isdir(path):
60+
files = []
61+
for name in os.listdir(path):
62+
full = os.path.join(path, name)
63+
if os.path.isfile(full) and os.path.splitext(name)[1].lower() in exts:
64+
files.append(full)
65+
return sorted(files)
66+
67+
return []
68+
69+
70+
def safe_name(name: str) -> str:
71+
bad = '\\/:*?"<>|'
72+
for ch in bad:
73+
name = name.replace(ch, "_")
74+
return name
75+
76+
77+
def make_output_dir(image_path: str, task_name: str) -> str:
78+
image_stem = os.path.splitext(os.path.basename(image_path))[0]
79+
folder_name = f"{safe_name(image_stem)}__{safe_name(task_name)}"
80+
out_dir = os.path.join(SAVE_ROOT, folder_name)
81+
os.makedirs(out_dir, exist_ok=True)
82+
return out_dir
5483

5584

56-
def _print_result(task_name: str, result: dict):
85+
def to_jsonable(obj):
86+
if isinstance(obj, np.ndarray):
87+
return obj.tolist()
88+
if isinstance(obj, (np.float32, np.float64)):
89+
return float(obj)
90+
if isinstance(obj, (np.int32, np.int64)):
91+
return int(obj)
92+
if isinstance(obj, dict):
93+
return {k: to_jsonable(v) for k, v in obj.items() if k != "images"}
94+
if isinstance(obj, list):
95+
return [to_jsonable(v) for v in obj]
96+
return obj
97+
98+
99+
def format_result_text(task_name: str, result: dict) -> str:
100+
lines = []
101+
lines.append("==============================")
102+
lines.append(f"===== TASK: {task_name} =====")
103+
lines.append("==============================")
104+
lines.append("")
105+
106+
tr = result.get("task_result")
107+
if isinstance(tr, dict):
108+
lines.append("===== TASK_RESULT =====")
109+
for k in ["task", "status", "is_ready_for_next", "recommended_next_task", "focus"]:
110+
if k in tr:
111+
lines.append(f"{k}: {tr.get(k)}")
112+
msgs = tr.get("messages")
113+
if isinstance(msgs, list) and msgs:
114+
lines.append("messages:")
115+
for m in msgs:
116+
lines.append(f" - {m}")
117+
lines.append("")
118+
119+
lines.append("===== SUMMARY =====")
120+
lines.append(str(result.get("summary", {})))
121+
lines.append("")
122+
123+
lines.append("===== COUNTS =====")
124+
lines.append(str(result.get("counts", {})))
125+
lines.append("")
126+
127+
lines.append("===== RATIO =====")
128+
lines.append(str(result.get("ratio", {})))
129+
lines.append("")
130+
131+
lines.append("===== ERRORS =====")
132+
errs = result.get("errors", [])
133+
if not errs:
134+
lines.append("(none)")
135+
else:
136+
for e in errs:
137+
if isinstance(e, dict):
138+
lines.append(f"- {e.get('code')}: {e.get('message')}")
139+
else:
140+
lines.append(f"- {e}")
141+
lines.append("")
142+
143+
lines.append("===== TIMING =====")
144+
lines.append(str(result.get("timing", {})))
145+
lines.append("")
146+
147+
return "\n".join(lines)
148+
149+
150+
def save_outputs(image_path: str, task_name: str, img_bgr: np.ndarray, result: dict):
151+
out_dir = make_output_dir(image_path, task_name)
152+
153+
# 1) save original image
154+
original_path = os.path.join(out_dir, "original.jpg")
155+
cv2.imwrite(original_path, img_bgr)
156+
157+
# 2) save pipeline output images
158+
imgs = result.get("images", {})
159+
det_path = os.path.join(out_dir, "det.jpg")
160+
label_path = os.path.join(out_dir, "labels.jpg")
161+
162+
if isinstance(imgs, dict):
163+
if "det_img" in imgs and isinstance(imgs["det_img"], np.ndarray):
164+
cv2.imwrite(det_path, imgs["det_img"])
165+
if "label_img" in imgs and isinstance(imgs["label_img"], np.ndarray):
166+
cv2.imwrite(label_path, imgs["label_img"])
167+
168+
# 3) save staging / task result content as json
169+
json_result = to_jsonable(result)
170+
json_path = os.path.join(out_dir, "result.json")
171+
with open(json_path, "w", encoding="utf-8") as f:
172+
json.dump(json_result, f, ensure_ascii=False, indent=2)
173+
174+
# 4) save readable text report
175+
txt_path = os.path.join(out_dir, "report.txt")
176+
with open(txt_path, "w", encoding="utf-8") as f:
177+
f.write(format_result_text(task_name, result))
178+
179+
# 5) optional: save a copy of the source file with original filename
180+
src_copy_path = os.path.join(out_dir, os.path.basename(image_path))
181+
if os.path.abspath(src_copy_path) != os.path.abspath(image_path):
182+
shutil.copy2(image_path, src_copy_path)
183+
184+
print(f"\nSaved outputs to: {out_dir}")
185+
print(f" - original image: {original_path}")
186+
if os.path.exists(det_path):
187+
print(f" - det image: {det_path}")
188+
if os.path.exists(label_path):
189+
print(f" - label image: {label_path}")
190+
print(f" - report text: {txt_path}")
191+
print(f" - result json: {json_path}")
192+
193+
194+
def print_result(task_name: str, result: dict):
57195
print("\n==============================")
58196
print(f"===== TASK: {task_name} =====")
59197
print("==============================")
@@ -67,7 +205,7 @@ def _print_result(task_name: str, result: dict):
67205
msgs = tr.get("messages")
68206
if isinstance(msgs, list) and msgs:
69207
print("messages:")
70-
for m in msgs[:12]:
208+
for m in msgs:
71209
print(f" - {m}")
72210

73211
print("\n===== SUMMARY =====")
@@ -94,30 +232,7 @@ def _print_result(task_name: str, result: dict):
94232
print(result.get("timing", {}))
95233

96234

97-
def _save_images(task_name: str, result: dict):
98-
imgs = result.get("images", None)
99-
if not imgs:
100-
print("\n[INFO] No images returned (result['images'] missing).")
101-
return
102-
103-
task_dir = _safe_task_dir(task_name)
104-
det_path = os.path.join(task_dir, "det.jpg")
105-
lab_path = os.path.join(task_dir, "labels.jpg")
106-
107-
if "det_img" in imgs and isinstance(imgs["det_img"], np.ndarray):
108-
cv2.imwrite(det_path, imgs["det_img"])
109-
print(f"\nSaved: {det_path}")
110-
else:
111-
print("\n[WARN] det_img not found in result['images'].")
112-
113-
if "label_img" in imgs and isinstance(imgs["label_img"], np.ndarray):
114-
cv2.imwrite(lab_path, imgs["label_img"])
115-
print(f"Saved: {lab_path}")
116-
else:
117-
print("[WARN] label_img not found in result['images'].")
118-
119-
120-
def _run_one(task_name: str, img_bgr: np.ndarray, part_type: str | None = None):
235+
def run_one(task_name: str, image_path: str, img_bgr: np.ndarray, part_type: str | None = None):
121236
kwargs = {
122237
"model_a_rel": MODEL_A_REL,
123238
"model_b_rel": MODEL_B_REL,
@@ -133,38 +248,42 @@ def _run_one(task_name: str, img_bgr: np.ndarray, part_type: str | None = None):
133248
kwargs["expected_gears"] = EXPECTED_GEARS
134249

135250
result = run_yolo_pipeline(img_bgr, **kwargs)
136-
_print_result(task_name, result)
137-
138-
if RETURN_IMAGES:
139-
_save_images(task_name, result)
251+
print_result(task_name, result)
252+
save_outputs(image_path, task_name, img_bgr, result)
140253

141254

142255
def main():
143-
abs_path = os.path.abspath(IMAGE_PATH).replace("\\", "/")
144-
response = [{"url": f"file:///{abs_path}"}]
145-
146-
url = response[0].get("url")
147-
img, err = load_bgr_image_from_url(url)
148-
if err:
149-
raise SystemExit(f"[ERROR] {err}")
150-
151-
print("\n===== RESPONSE (URL) =====")
152-
print(url)
153-
154-
if RUN_ALL_TASKS:
155-
_run_one("parts_inventory", img, "gear")
156-
_run_one("parts_inventory", img, "shaft")
157-
_run_one("parts_inventory", img, "spacer")
158-
_run_one("shaft", img)
159-
_run_one("spacer", img)
160-
_run_one("gear_inventory", img)
161-
_run_one("mesh_ratio", img)
162-
_run_one("full", img)
163-
else:
164-
if TASK == "parts_inventory":
165-
_run_one(TASK, img, PART_TYPE)
256+
image_files = find_image_files(IMAGE_PATH)
257+
if not image_files:
258+
raise SystemExit(f"[ERROR] No image file found in: {IMAGE_PATH}")
259+
260+
for image_path in image_files:
261+
abs_path = os.path.abspath(image_path).replace("\\", "/")
262+
response = [{"url": f"file:///{abs_path}"}]
263+
264+
url = response[0].get("url")
265+
img, err = load_bgr_image_from_url(url)
266+
if err:
267+
print(f"[ERROR] {err}")
268+
continue
269+
270+
print("\n===== RESPONSE (URL) =====")
271+
print(url)
272+
273+
if RUN_ALL_TASKS:
274+
run_one("parts_inventory", image_path, img, "gear")
275+
run_one("parts_inventory", image_path, img, "shaft")
276+
run_one("parts_inventory", image_path, img, "spacer")
277+
run_one("shaft", image_path, img)
278+
run_one("spacer", image_path, img)
279+
run_one("gear_inventory", image_path, img)
280+
run_one("mesh_ratio", image_path, img)
281+
run_one("full", image_path, img)
166282
else:
167-
_run_one(TASK, img)
283+
if TASK == "parts_inventory":
284+
run_one(TASK, image_path, img, PART_TYPE)
285+
else:
286+
run_one(TASK, image_path, img)
168287

169288

170289
if __name__ == "__main__":

0 commit comments

Comments
 (0)