Skip to content

Commit 89d7a2b

Browse files
committed
更新训练脚本和推理工具,调整数据集结构和模型加载逻辑
1 parent e8518c5 commit 89d7a2b

3 files changed

Lines changed: 317 additions & 105 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ runs/
55
.vscode/tasks.json
66
*.pt
77
*.cache
8-
*.torchscript
8+
*.torchscript
9+
WellColumnClassification/

py/predict_gui.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
11
"""
2-
猫狗分类推理工具 (带 GUI)
2+
YOLO 分类推理工具 (带 GUI)
33
44
功能:
55
1. 选择 YOLO 分类模型 (.pt / .onnx / .torchscript)
66
2. 选择单张图片进行推理
77
3. 选择整个文件夹批量推理
88
4. 显示分类结果和置信度
9+
5. 优先读取模型同目录下的 labels.txt 作为类别名
910
"""
1011

12+
from __future__ import annotations
13+
1114
import os
1215
import tkinter as tk
13-
from tkinter import ttk, filedialog, messagebox
1416
from pathlib import Path
17+
from tkinter import filedialog, messagebox, ttk
18+
1519
from PIL import Image, ImageTk
1620
from ultralytics import YOLO
1721

1822

23+
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
24+
25+
1926
class InferenceApp:
2027
def __init__(self, root):
2128
self.root = root
@@ -25,13 +32,13 @@ def __init__(self, root):
2532

2633
self.model = None
2734
self.model_path = ""
35+
self.class_names: list[str] = []
2836
self.current_image = None
2937
self.photo = None
3038

3139
self._build_ui()
3240

3341
def _build_ui(self):
34-
# ========== 顶部: 模型选择 ==========
3542
model_frame = ttk.LabelFrame(self.root, text="模型", padding=10)
3643
model_frame.pack(fill="x", padx=10, pady=5)
3744

@@ -43,7 +50,6 @@ def _build_ui(self):
4350
side="left"
4451
)
4552

46-
# ========== 中间: 输入选择 ==========
4753
input_frame = ttk.LabelFrame(self.root, text="输入", padding=10)
4854
input_frame.pack(fill="x", padx=10, pady=5)
4955

@@ -65,11 +71,9 @@ def _build_ui(self):
6571
fill="x", pady=(5, 0)
6672
)
6773

68-
# ========== 结果区域 ==========
6974
result_frame = ttk.LabelFrame(self.root, text="推理结果", padding=10)
7075
result_frame.pack(fill="both", expand=True, padx=10, pady=5)
7176

72-
# 滚动区域
7377
canvas = tk.Canvas(result_frame)
7478
scrollbar = ttk.Scrollbar(result_frame, orient="vertical", command=canvas.yview)
7579
self.result_inner = ttk.Frame(canvas)
@@ -85,7 +89,6 @@ def _build_ui(self):
8589
canvas.pack(side="left", fill="both", expand=True)
8690
scrollbar.pack(side="right", fill="y")
8791

88-
# 状态栏
8992
self.status_var = tk.StringVar(value="就绪")
9093
ttk.Label(
9194
self.root, textvariable=self.status_var, relief="sunken", anchor="w"
@@ -106,16 +109,25 @@ def _select_model(self):
106109
self.root.update()
107110

108111
try:
109-
# Explicitly specify task for ONNX/classification models to avoid guessing issues
110112
if path.lower().endswith(".onnx"):
111113
self.model = YOLO(path, task="classify")
112114
else:
113115
self.model = YOLO(path)
116+
114117
self.model_path = path
115-
self.model_var.set(f"已加载: {os.path.basename(path)}")
118+
self.class_names = self._load_class_names(path)
119+
120+
model_name = os.path.basename(path)
121+
if self.class_names:
122+
self.model_var.set(f"已加载: {model_name} ({len(self.class_names)} 类)")
123+
else:
124+
self.model_var.set(f"已加载: {model_name}")
125+
116126
self.status_var.set("模型加载成功")
117127
except Exception as e:
118128
self.model = None
129+
self.model_path = ""
130+
self.class_names = []
119131
self.model_var.set("加载失败")
120132
self.status_var.set(f"模型加载失败: {e}")
121133
messagebox.showerror("错误", f"无法加载模型:\n{e}")
@@ -157,12 +169,10 @@ def _select_folder(self):
157169
if not folder:
158170
return
159171

160-
# 收集所有图片
161-
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
162172
images = [
163-
os.path.join(folder, f)
164-
for f in os.listdir(folder)
165-
if os.path.splitext(f)[1].lower() in exts
173+
str(path)
174+
for path in sorted(Path(folder).rglob("*"))
175+
if path.is_file() and path.suffix.lower() in IMAGE_SUFFIXES
166176
]
167177

168178
if not images:
@@ -175,7 +185,6 @@ def _select_folder(self):
175185
self.root.update()
176186

177187
try:
178-
# 尝试批量推理,若失败再逐图推理以定位问题
179188
try:
180189
results = self.model(images, imgsz=224)
181190
for i, (img_path, result) in enumerate(zip(images, results)):
@@ -200,11 +209,9 @@ def _select_folder(self):
200209
messagebox.showerror("错误", f"推理失败:\n{e}")
201210

202211
def _show_single_result(self, img_path, result):
203-
"""在结果区域显示单张图片的推理结果"""
204212
row = ttk.Frame(self.result_inner)
205213
row.pack(fill="x", pady=5)
206214

207-
# 左侧: 图片缩略图
208215
thumb_frame = ttk.Frame(row, width=150, height=150)
209216
thumb_frame.pack(side="left", padx=(0, 10))
210217
thumb_frame.pack_propagate(False)
@@ -214,12 +221,11 @@ def _show_single_result(self, img_path, result):
214221
img.thumbnail((150, 150))
215222
photo = ImageTk.PhotoImage(img)
216223
lbl = ttk.Label(thumb_frame, image=photo)
217-
lbl.image = photo # 保持引用
224+
lbl.image = photo
218225
lbl.pack()
219226
except Exception:
220227
ttk.Label(thumb_frame, text="无法加载").pack()
221228

222-
# 右侧: 信息
223229
info_frame = ttk.Frame(row)
224230
info_frame.pack(side="left", fill="x", expand=True)
225231

@@ -232,22 +238,14 @@ def _show_single_result(self, img_path, result):
232238
info_frame, text=img_path, foreground="gray", font=("Consolas", 8)
233239
).pack(anchor="w")
234240

235-
# 解析分类结果
236241
probs = result.probs
237242
top1_idx = probs.top1
238243
top1_conf = probs.top1conf
239244
top5 = probs.top5
240-
241245
names = result.names
242-
top1_name = (
243-
names.get(top1_idx, str(top1_idx))
244-
if isinstance(names, dict)
245-
else names[top1_idx]
246-
)
247246

248-
# 中文映射
249-
label_map = {"cat": "猫", "dog": "狗"}
250-
display_name = label_map.get(top1_name.lower(), top1_name)
247+
top1_name = self._resolve_class_name(top1_idx, names)
248+
display_name = self._format_label(top1_name)
251249

252250
result_text = f"预测: {display_name} (置信度: {top1_conf:.1%})"
253251
ttk.Label(
@@ -257,26 +255,53 @@ def _show_single_result(self, img_path, result):
257255
foreground="#1a73e8",
258256
).pack(anchor="w", pady=(5, 0))
259257

260-
# Top-5 详情
261258
if len(top5) > 1:
262259
detail_lines = []
263260
for idx, conf in zip(top5, probs.top5conf):
264-
name = (
265-
names.get(idx, str(idx)) if isinstance(names, dict) else names[idx]
266-
)
267-
display = label_map.get(name.lower(), name)
268-
detail_lines.append(f" {display}: {conf:.1%}")
261+
name = self._resolve_class_name(idx, names)
262+
detail_lines.append(f" {self._format_label(name)}: {conf:.1%}")
269263
ttk.Label(
270264
info_frame, text="\n".join(detail_lines), font=("Consolas", 9)
271265
).pack(anchor="w")
272266

273-
# 分隔线
274267
ttk.Separator(self.result_inner, orient="horizontal").pack(fill="x", pady=5)
275268

276269
def _clear_results(self):
277270
for widget in self.result_inner.winfo_children():
278271
widget.destroy()
279272

273+
def _resolve_class_name(self, index, result_names):
274+
if 0 <= index < len(self.class_names):
275+
return self.class_names[index]
276+
277+
if isinstance(result_names, dict):
278+
return result_names.get(index, str(index))
279+
280+
if isinstance(result_names, (list, tuple)) and 0 <= index < len(result_names):
281+
return result_names[index]
282+
283+
return str(index)
284+
285+
@staticmethod
286+
def _load_class_names(model_path: str) -> list[str]:
287+
model_dir = Path(model_path).resolve().parent
288+
for filename in ("labels.txt", "class_names.txt"):
289+
labels_path = model_dir / filename
290+
if not labels_path.is_file():
291+
continue
292+
lines = [
293+
line.strip()
294+
for line in labels_path.read_text(encoding="utf-8").splitlines()
295+
if line.strip()
296+
]
297+
if lines:
298+
return lines
299+
return []
300+
301+
@staticmethod
302+
def _format_label(name):
303+
return str(name).strip()
304+
280305

281306
def main():
282307
root = tk.Tk()

0 commit comments

Comments
 (0)