11"""
2- 猫狗分类推理工具 (带 GUI)
2+ YOLO 分类推理工具 (带 GUI)
33
44功能:
551. 选择 YOLO 分类模型 (.pt / .onnx / .torchscript)
662. 选择单张图片进行推理
773. 选择整个文件夹批量推理
884. 显示分类结果和置信度
9+ 5. 优先读取模型同目录下的 labels.txt 作为类别名
910"""
1011
12+ from __future__ import annotations
13+
1114import os
1215import tkinter as tk
13- from tkinter import ttk , filedialog , messagebox
1416from pathlib import Path
17+ from tkinter import filedialog , messagebox , ttk
18+
1519from PIL import Image , ImageTk
1620from ultralytics import YOLO
1721
1822
23+ IMAGE_SUFFIXES = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".webp" }
24+
25+
1926class InferenceApp :
2027 def __init__ (self , root ):
2128 self .root = root
@@ -25,13 +32,13 @@ def __init__(self, root):
2532
2633 self .model = None
2734 self .model_path = ""
35+ self .class_names : list [str ] = []
2836 self .current_image = None
2937 self .photo = None
3038
3139 self ._build_ui ()
3240
3341 def _build_ui (self ):
34- # ========== 顶部: 模型选择 ==========
3542 model_frame = ttk .LabelFrame (self .root , text = "模型" , padding = 10 )
3643 model_frame .pack (fill = "x" , padx = 10 , pady = 5 )
3744
@@ -43,7 +50,6 @@ def _build_ui(self):
4350 side = "left"
4451 )
4552
46- # ========== 中间: 输入选择 ==========
4753 input_frame = ttk .LabelFrame (self .root , text = "输入" , padding = 10 )
4854 input_frame .pack (fill = "x" , padx = 10 , pady = 5 )
4955
@@ -65,11 +71,9 @@ def _build_ui(self):
6571 fill = "x" , pady = (5 , 0 )
6672 )
6773
68- # ========== 结果区域 ==========
6974 result_frame = ttk .LabelFrame (self .root , text = "推理结果" , padding = 10 )
7075 result_frame .pack (fill = "both" , expand = True , padx = 10 , pady = 5 )
7176
72- # 滚动区域
7377 canvas = tk .Canvas (result_frame )
7478 scrollbar = ttk .Scrollbar (result_frame , orient = "vertical" , command = canvas .yview )
7579 self .result_inner = ttk .Frame (canvas )
@@ -85,7 +89,6 @@ def _build_ui(self):
8589 canvas .pack (side = "left" , fill = "both" , expand = True )
8690 scrollbar .pack (side = "right" , fill = "y" )
8791
88- # 状态栏
8992 self .status_var = tk .StringVar (value = "就绪" )
9093 ttk .Label (
9194 self .root , textvariable = self .status_var , relief = "sunken" , anchor = "w"
@@ -106,16 +109,25 @@ def _select_model(self):
106109 self .root .update ()
107110
108111 try :
109- # Explicitly specify task for ONNX/classification models to avoid guessing issues
110112 if path .lower ().endswith (".onnx" ):
111113 self .model = YOLO (path , task = "classify" )
112114 else :
113115 self .model = YOLO (path )
116+
114117 self .model_path = path
115- self .model_var .set (f"已加载: { os .path .basename (path )} " )
118+ self .class_names = self ._load_class_names (path )
119+
120+ model_name = os .path .basename (path )
121+ if self .class_names :
122+ self .model_var .set (f"已加载: { model_name } ({ len (self .class_names )} 类)" )
123+ else :
124+ self .model_var .set (f"已加载: { model_name } " )
125+
116126 self .status_var .set ("模型加载成功" )
117127 except Exception as e :
118128 self .model = None
129+ self .model_path = ""
130+ self .class_names = []
119131 self .model_var .set ("加载失败" )
120132 self .status_var .set (f"模型加载失败: { e } " )
121133 messagebox .showerror ("错误" , f"无法加载模型:\n { e } " )
@@ -157,12 +169,10 @@ def _select_folder(self):
157169 if not folder :
158170 return
159171
160- # 收集所有图片
161- exts = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".webp" }
162172 images = [
163- os . path . join ( folder , f )
164- for f in os . listdir ( folder )
165- if os . path .splitext ( f )[ 1 ]. lower () in exts
173+ str ( path )
174+ for path in sorted ( Path ( folder ). rglob ( "*" ) )
175+ if path .is_file () and path . suffix . lower () in IMAGE_SUFFIXES
166176 ]
167177
168178 if not images :
@@ -175,7 +185,6 @@ def _select_folder(self):
175185 self .root .update ()
176186
177187 try :
178- # 尝试批量推理,若失败再逐图推理以定位问题
179188 try :
180189 results = self .model (images , imgsz = 224 )
181190 for i , (img_path , result ) in enumerate (zip (images , results )):
@@ -200,11 +209,9 @@ def _select_folder(self):
200209 messagebox .showerror ("错误" , f"推理失败:\n { e } " )
201210
202211 def _show_single_result (self , img_path , result ):
203- """在结果区域显示单张图片的推理结果"""
204212 row = ttk .Frame (self .result_inner )
205213 row .pack (fill = "x" , pady = 5 )
206214
207- # 左侧: 图片缩略图
208215 thumb_frame = ttk .Frame (row , width = 150 , height = 150 )
209216 thumb_frame .pack (side = "left" , padx = (0 , 10 ))
210217 thumb_frame .pack_propagate (False )
@@ -214,12 +221,11 @@ def _show_single_result(self, img_path, result):
214221 img .thumbnail ((150 , 150 ))
215222 photo = ImageTk .PhotoImage (img )
216223 lbl = ttk .Label (thumb_frame , image = photo )
217- lbl .image = photo # 保持引用
224+ lbl .image = photo
218225 lbl .pack ()
219226 except Exception :
220227 ttk .Label (thumb_frame , text = "无法加载" ).pack ()
221228
222- # 右侧: 信息
223229 info_frame = ttk .Frame (row )
224230 info_frame .pack (side = "left" , fill = "x" , expand = True )
225231
@@ -232,22 +238,14 @@ def _show_single_result(self, img_path, result):
232238 info_frame , text = img_path , foreground = "gray" , font = ("Consolas" , 8 )
233239 ).pack (anchor = "w" )
234240
235- # 解析分类结果
236241 probs = result .probs
237242 top1_idx = probs .top1
238243 top1_conf = probs .top1conf
239244 top5 = probs .top5
240-
241245 names = result .names
242- top1_name = (
243- names .get (top1_idx , str (top1_idx ))
244- if isinstance (names , dict )
245- else names [top1_idx ]
246- )
247246
248- # 中文映射
249- label_map = {"cat" : "猫" , "dog" : "狗" }
250- display_name = label_map .get (top1_name .lower (), top1_name )
247+ top1_name = self ._resolve_class_name (top1_idx , names )
248+ display_name = self ._format_label (top1_name )
251249
252250 result_text = f"预测: { display_name } (置信度: { top1_conf :.1%} )"
253251 ttk .Label (
@@ -257,26 +255,53 @@ def _show_single_result(self, img_path, result):
257255 foreground = "#1a73e8" ,
258256 ).pack (anchor = "w" , pady = (5 , 0 ))
259257
260- # Top-5 详情
261258 if len (top5 ) > 1 :
262259 detail_lines = []
263260 for idx , conf in zip (top5 , probs .top5conf ):
264- name = (
265- names .get (idx , str (idx )) if isinstance (names , dict ) else names [idx ]
266- )
267- display = label_map .get (name .lower (), name )
268- detail_lines .append (f" { display } : { conf :.1%} " )
261+ name = self ._resolve_class_name (idx , names )
262+ detail_lines .append (f" { self ._format_label (name )} : { conf :.1%} " )
269263 ttk .Label (
270264 info_frame , text = "\n " .join (detail_lines ), font = ("Consolas" , 9 )
271265 ).pack (anchor = "w" )
272266
273- # 分隔线
274267 ttk .Separator (self .result_inner , orient = "horizontal" ).pack (fill = "x" , pady = 5 )
275268
276269 def _clear_results (self ):
277270 for widget in self .result_inner .winfo_children ():
278271 widget .destroy ()
279272
273+ def _resolve_class_name (self , index , result_names ):
274+ if 0 <= index < len (self .class_names ):
275+ return self .class_names [index ]
276+
277+ if isinstance (result_names , dict ):
278+ return result_names .get (index , str (index ))
279+
280+ if isinstance (result_names , (list , tuple )) and 0 <= index < len (result_names ):
281+ return result_names [index ]
282+
283+ return str (index )
284+
285+ @staticmethod
286+ def _load_class_names (model_path : str ) -> list [str ]:
287+ model_dir = Path (model_path ).resolve ().parent
288+ for filename in ("labels.txt" , "class_names.txt" ):
289+ labels_path = model_dir / filename
290+ if not labels_path .is_file ():
291+ continue
292+ lines = [
293+ line .strip ()
294+ for line in labels_path .read_text (encoding = "utf-8" ).splitlines ()
295+ if line .strip ()
296+ ]
297+ if lines :
298+ return lines
299+ return []
300+
301+ @staticmethod
302+ def _format_label (name ):
303+ return str (name ).strip ()
304+
280305
281306def main ():
282307 root = tk .Tk ()
0 commit comments