-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspider_gui.py
More file actions
1467 lines (1165 loc) · 62.5 KB
/
spider_gui.py
File metadata and controls
1467 lines (1165 loc) · 62.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import torch.nn as nn
from matplotlib import ticker
from torchvision import transforms, models
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from PIL import Image, ImageTk
import os
import random
import shutil
from pathlib import Path
import matplotlib as mpl
from PIL import Image, ImageTk, ImageDraw, ImageFont # 确保添加了 ImageFont
from datetime import datetime
import time
import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import matplotlib.font_manager as fm # 确保导入了字体管理器
# --- Matplotlib 中文乱码处理 ---
# 尝试查找并使用支持中文的字体,增加兼容性
font_list = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'FangSong']
chosen_font = None
for font_name in font_list:
try:
# 尝试查找字体,如果找到则停止
if fm.fontManager.findfont(font_name, fallback_to_default=False):
chosen_font = font_name
break
except:
continue # 忽略查找失败
# 如果找到了中文字体,则进行设置
if chosen_font:
# 设置 sans-serif 字体系列的首选字体为找到的中文体
plt.rcParams['font.sans-serif'] = [chosen_font] + font_list
# 调试语句:查看实际使用的字体
print(f"DEBUG: Matplotlib 已配置使用字体: {chosen_font}")
else:
# 如果所有尝试都失败,至少保留默认字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
print("WARNING: 未找到常用中文字体,中文可能显示为方框。")
# 解决保存图像时负号 '-' 显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False
class MultiModelSpiderRecognition:
def __init__(self, root):
"""
初始化多模型蜘蛛识别系统。
Args:
root (tk.Tk): Tkinter根窗口对象。
"""
self.root = root
self.root.title("多模型蜘蛛识别系统")
# 初始窗口大小
window_width = 1200
window_height = 900
# 获取屏幕的宽度和高度
screen_width = self.root.winfo_screenwidth()
screen_height = self.root.winfo_screenheight()
# 计算窗口的起始位置
center_x = int(screen_width / 2 - window_width / 2)
center_y = int(screen_height / 2 - window_height / 2)
# 设置窗口大小和位置 (居中)
self.root.geometry(f"{window_width}x{window_height}+{center_x}+{center_y}")
# 调试语句
# print(f"DEBUG: 窗口大小: {window_width}x{window_height}, 居中位置: +{center_x}+{center_y}")
# 使用新的日志系统记录调试信息
# 注意:log_message 在 setup_gui 之前调用会打印到控制台
# self.log_message(f"窗口大小: {window_width}x{window_height}, 居中位置: +{center_x}+{center_y}", "DEBUG")
self.best_model_files = {} # 记录每个模型的最佳文件
self.available_models = {} # 记录所有可用的模型文件
# 新增:记录已加载模型的 Listbox 显示文本,用于高亮显示
self.loaded_model_display_texts = set()
self.models = {}
self.class_names = []
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.current_model = None
# 支持的模型列表
self.supported_models = {
'ResNet-50': 'resnet50',
'EfficientNet-B0': 'efficientnet_b0',
'DenseNet-121': 'densenet121'
}
# 为识别结果创建字典,用于存储每个模型的结果文本框
self.result_texts = {}
# 图像缩放和平移状态
self.zoom_level = 1.0 # 缩放级别,1.0 表示“基础缩放”
self.min_zoom = 0.005 # 最小缩放限制,允许图片缩得更小
self.max_zoom = 10.0 # 最大缩放限制
self.pan_x = 0 # 平移 X 坐标
self.pan_y = 0 # 平移 Y 坐标
self._last_x = 0 # 鼠标按下时的 X 坐标
self._last_y = 0 # 鼠标按下时的 Y 坐标
self.display_width = 400 # 图像预览区域固定宽度
self.display_height = 400 # 图像预览区域固定高度
# 新增:用于平滑缩放动画的状态
self.target_zoom = 1.0 # 缩放的目标级别
self.is_animating = False # 标记当前是否正在进行缩放动画
# 新增:日志记录 Text 组件引用
self.log_text_widget = None
self.setup_gui()
def setup_gui(self):
"""设置GUI界面,优化布局:图片预览固定且更大,图表区域更小,识别结果按模型分隔"""
# 主框架
main_frame = ttk.Frame(self.root, padding="5")
# 将 main_frame 向上扩展,但不能填满窗口底部,因为底部要留给日志区域
main_frame.pack(fill=tk.BOTH, expand=True, side=tk.TOP)
# ----------------------------------------------------------------------------------
# 日志输出区域 (位于窗口最底部,替换了原有的单行状态栏功能)
log_frame = ttk.LabelFrame(self.root, text="系统日志", padding="0")
# 将日志区域放在最底部,填充宽度
log_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=0, pady=0)
# 日志文本框和滚动条
scrollbar = ttk.Scrollbar(log_frame)
# 设置 Text 组件,高度固定为 5 行,使用 Consolas 字体
self.log_text_widget = tk.Text(log_frame, wrap="word", height=5,
yscrollcommand=scrollbar.set, state='disabled',
font=('Consolas', 8), bg='#f4f4f4')
scrollbar.config(command=self.log_text_widget.yview)
# 布局 Text 和 Scrollbar
scrollbar.pack(side="right", fill="y")
self.log_text_widget.pack(side="left", fill=tk.BOTH, expand=True)
# 新增:状态栏 (位于日志区域上方,占据底部的固定高度)
status_bar = ttk.Frame(self.root, relief=tk.SUNKEN)
status_bar.pack(side=tk.BOTTOM, fill=tk.X, padx=0, pady=0)
self.status_var = tk.StringVar(value=">>>准备就绪") # 移动到这里,并作为实例变量
# 状态栏 Label
status_label = ttk.Label(status_bar, textvariable=self.status_var,
anchor=tk.W, font=('Arial', 10), background='#dcdcdc')
status_label.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5, pady=2)
# ----------------------------------------------------------------------------------
# 标题
title_label = ttk.Label(main_frame, text="🕷️ 多模型蜘蛛识别系统",
font=("Arial", 20, "bold"))
title_label.pack(pady=10)
# 左侧控制面板 (保持不变)
left_frame = ttk.Frame(main_frame)
left_frame.pack(side=tk.LEFT, fill=tk.Y, padx=5)
# 数据准备区域
data_frame = ttk.LabelFrame(left_frame, text="1. 数据准备", padding="10")
data_frame.pack(fill=tk.X, pady=5)
ttk.Label(data_frame, text="数据文件夹:").pack(anchor=tk.W)
path_frame = ttk.Frame(data_frame)
path_frame.pack(fill=tk.X, pady=5)
self.data_path = tk.StringVar()
ttk.Entry(path_frame, textvariable=self.data_path, width=30).pack(side=tk.LEFT, padx=5)
ttk.Button(path_frame, text="浏览", command=self.browse_folder).pack(side=tk.LEFT, padx=5)
ttk.Button(data_frame, text="自动划分数据集",
command=self.split_dataset).pack(fill=tk.X, pady=5)
# 模型训练区域
train_frame = ttk.LabelFrame(left_frame, text="2. 模型训练", padding="10")
train_frame.pack(fill=tk.X, pady=5)
# 模型选择
ttk.Label(train_frame, text="选择要训练的模型:").pack(anchor=tk.W)
self.model_vars = {}
for model_name in self.supported_models.keys():
var = tk.BooleanVar(value=True)
self.model_vars[model_name] = var
ttk.Checkbutton(train_frame, text=model_name, variable=var).pack(anchor=tk.W)
# 训练参数
param_frame = ttk.Frame(train_frame)
param_frame.pack(fill=tk.X, pady=5)
ttk.Label(param_frame, text="训练轮数:").grid(row=0, column=0, sticky=tk.W)
self.epochs_var = tk.StringVar(value="20")
ttk.Entry(param_frame, textvariable=self.epochs_var, width=10).grid(row=0, column=1, padx=5)
ttk.Label(param_frame, text="批次大小:").grid(row=0, column=2, sticky=tk.W, padx=10)
self.batch_size_var = tk.StringVar(value="16")
ttk.Entry(param_frame, textvariable=self.batch_size_var, width=10).grid(row=0, column=3)
ttk.Button(train_frame, text="开始训练选中模型",
command=self.start_training).pack(fill=tk.X, pady=5)
self.progress_var = tk.DoubleVar()
self.progress = ttk.Progressbar(train_frame, mode='determinate',
maximum=100, variable=self.progress_var)
self.progress.pack(fill=tk.X, pady=5)
# 模型管理区域
model_frame = ttk.LabelFrame(left_frame, text="3. 模型管理", padding="10")
model_frame.pack(fill=tk.X, pady=5)
ttk.Label(model_frame, text="加载已训练模型:").pack(anchor=tk.W)
# 创建一个框架来容纳Listbox和Scrollbar
list_scroll_frame = ttk.Frame(model_frame)
list_scroll_frame.pack(fill=tk.X, pady=5)
# 创建Listbox - 移除 selectmode=tk.SINGLE,因为我们将用双击来控制状态和颜色
self.loaded_models_list = tk.Listbox(list_scroll_frame, height=4)
self.loaded_models_list.pack(side=tk.LEFT, fill=tk.X, expand=True)
# 绑定双击事件:实现选中/取消高亮和加载/取消加载模型
self.loaded_models_list.bind('<Double-1>', self.toggle_model_load_status)
# 创建垂直滚动条
scrollbar = ttk.Scrollbar(list_scroll_frame, orient=tk.VERTICAL,
command=self.loaded_models_list.yview)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
# 将Listbox和Scrollbar连接起来
self.loaded_models_list.config(yscrollcommand=scrollbar.set)
# Listbox 不支持 tag_configure,故移除,将在其他函数中直接使用 itemconfigure 设置颜色
load_frame = ttk.Frame(model_frame)
load_frame.pack(fill=tk.X)
ttk.Button(load_frame, text="加载/卸载模型",
command=lambda: self.toggle_model_load_status(None, is_button=True)).pack(side=tk.LEFT, padx=2)
# 删除了 "刷新列表" 按钮
ttk.Button(load_frame, text="删除模型", command=self.delete_selected_models).pack(side=tk.LEFT, padx=2)
# 图片识别区域
predict_frame = ttk.LabelFrame(left_frame, text="4. 图片识别", padding="10")
predict_frame.pack(fill=tk.X, pady=5)
ttk.Button(predict_frame, text="选择图片",
command=self.select_image).pack(fill=tk.X, pady=2)
ttk.Button(predict_frame, text="开始识别",
command=self.predict_image).pack(fill=tk.X, pady=2)
# ----------------------------------------------------------------------------------
# 右侧结果显示区域
right_frame = ttk.Frame(main_frame)
right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=5)
# 右侧上部:图片预览 (固定大小) 和 训练图表 (较小)
top_right_frame = ttk.Frame(right_frame)
top_right_frame.pack(fill=tk.X, pady=5, expand=False)
image_width, image_height = 400, 400 # <-- 宽度修改为 600
image_frame = ttk.LabelFrame(top_right_frame, text="图片预览", padding="10")
# 强制设置 frame 的像素尺寸,并禁止其尺寸被子组件内容改变
# 尺寸稍微增加,以预留 LabelFrame 的边框和内部 padding 的空间
image_frame.config(width=image_width + 20, height=image_height + 40)
image_frame.pack_propagate(False) # 阻止子组件(Label)改变父组件(Frame)的大小
image_frame.pack(side=tk.LEFT, padx=5, fill=tk.BOTH, expand=False)
# 强制设置Label的字符宽度,以预留空间
self.image_label = ttk.Label(image_frame, text="请选择图片",
background="white", anchor=tk.CENTER)
# 绑定鼠标事件用于拖动和平移
self.image_label.bind("<Button-1>", self._start_pan) # 鼠标左键按下:记录起始位置
self.image_label.bind("<B1-Motion>", self._pan_image) # 鼠标左键拖动:平移
# 绑定鼠标滚轮事件用于缩放 (Windows/Linux 是 <MouseWheel>,macOS 是 <Button-4>/<Button-5>)
# Tkinter 在 Windows/Linux 上统一将滚轮识别为 <MouseWheel>
self.image_label.bind("<MouseWheel>", self._zoom_image)
# 针对部分 Linux 系统或特定配置,也绑定 Button-4/Button-5
self.image_label.bind("<Button-4>", lambda event: self._zoom_image(event, delta=1)) # 向上
self.image_label.bind("<Button-5>", lambda event: self._zoom_image(event, delta=-1)) # 向下
# 将 Label 的 pack 属性改为不 expand,因为它现在由父 Frame 固定大小
# 同时增加 padx/pady 确保其在 LabelFrame 内部居中或填充
self.image_label.pack(fill=tk.BOTH, expand=False, pady=5, padx=5)
# 训练历史图表 - 占用剩余空间
chart_frame = ttk.LabelFrame(top_right_frame, text="训练历史图表", padding="5")
chart_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
# 设置图表大小 (宽高比约为 1.2:1)
self.fig, self.ax = plt.subplots(2, 1, figsize=(6, 4), dpi=100)
self.canvas = FigureCanvasTkAgg(self.fig, chart_frame)
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True, pady=5, padx=5)
# 初始化图表
self.update_training_chart({})
self.canvas.draw()
# 右侧下部:识别结果 (按模型划分)
result_frame = ttk.LabelFrame(right_frame, text="识别结果", padding="10")
result_frame.pack(fill=tk.BOTH, expand=True, pady=5)
# 使用三个独立的Frame来显示每个模型的结果
model_result_frame = ttk.Frame(result_frame)
model_result_frame.pack(fill=tk.BOTH, expand=True)
# 为每个支持的模型创建一个子区域
for model_name in self.supported_models.keys():
# 创建一个子框架
frame = ttk.LabelFrame(model_result_frame, text=model_name, padding="5")
frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
# 创建文本框显示结果
text_widget = tk.Text(frame, wrap=tk.WORD, height=10, width=1) # width=1 允许它在pack时扩展
scrollbar = ttk.Scrollbar(frame, orient=tk.VERTICAL, command=text_widget.yview)
text_widget.configure(yscrollcommand=scrollbar.set)
text_widget.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.result_texts[model_name] = text_widget
self.refresh_model_list()
def browse_folder(self):
"""选择数据文件夹"""
folder = filedialog.askdirectory(title="选择包含蜘蛛类别文件夹的目录")
if folder:
self.data_path.set(folder)
self.scan_classes(folder)
def scan_classes(self, folder_path):
"""扫描类别"""
classes = []
for item in Path(folder_path).iterdir():
if item.is_dir():
classes.append(item.name)
if classes:
self.status_var.set(f"发现 {len(classes)} 个类别: {', '.join(classes)}")
self.class_names = classes
else:
self.status_var.set("未找到类别文件夹")
def validate_data_structure(self):
"""验证数据目录结构是否正确"""
data_path = Path(self.data_path.get())
# 检查是否有不应该的文件夹名
invalid_folders = ['train_val', 'train', 'val', 'test']
found_classes = []
for item in data_path.iterdir():
if item.is_dir():
if item.name in invalid_folders:
messagebox.showerror(
"数据目录错误",
f"发现无效文件夹: {item.name}\n\n"
f"请确保{data_path}目录下直接是蜘蛛类别文件夹\n"
f"例如: 黑寡妇/, 跳蛛/, 狼蛛/\n"
f"而不是: train_val/, train/, val/"
)
return False
found_classes.append(item.name)
if len(found_classes) < 2:
messagebox.showerror(
"数据不足",
f"只找到 {len(found_classes)} 个类别: {found_classes}\n"
f"至少需要2个类别才能训练模型"
)
return False
# 调试语句已修改为使用 self.log_message
self.log_message(f"数据目录正确!找到类别: {found_classes}", "DEBUG")
return True
def split_dataset(self):
"""
自动划分数据集:将选定文件夹下的图片按 8:2 比例划分为 train 和 val 文件夹。
Returns:
bool: 划分成功返回 True,失败返回 False。
"""
data_path = self.data_path.get()
if not data_path:
self.root.after(0, lambda: messagebox.showerror("错误", "请先选择数据文件夹!"))
return False
base_path = Path(data_path)
# 目标路径
target_dir = base_path / 'train_val'
# 划分比例
TRAIN_RATIO = 0.8
try:
self.status_var.set(">>>开始划分数据集...")
self.log_message("开始划分数据集...", "INFO")
# 清理旧的划分结果
if target_dir.exists():
shutil.rmtree(target_dir)
# 创建新的 train 和 val 目录
train_dir = target_dir / 'train'
val_dir = target_dir / 'val'
train_dir.mkdir(parents=True, exist_ok=True)
val_dir.mkdir(parents=True, exist_ok=True)
# 获取所有子文件夹(类别)
all_classes = [d for d in base_path.iterdir() if d.is_dir() and d.name != 'train_val']
# 调试语句
# self.log_message(f"发现类别文件夹: {[c.name for c in all_classes]}", "DEBUG")
if not all_classes:
self.root.after(0,
lambda: messagebox.showerror("错误", f"在 {base_path} 中未找到任何类别子文件夹!"))
self.status_var.set(">>>划分失败")
return False
total_files_moved = 0
for class_dir in all_classes:
class_name = class_dir.name
# 确保目标类文件夹存在
(train_dir / class_name).mkdir(exist_ok=True)
(val_dir / class_name).mkdir(exist_ok=True)
# 获取所有图片文件
all_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.png'))
random.shuffle(all_files) # 随机打乱
if not all_files:
# print() 替换为 self.log_message()
self.log_message(f"类别 '{class_name}' 下未找到图片,跳过。", "WARNING")
continue
# 计算划分数量
train_count = int(len(all_files) * TRAIN_RATIO)
train_files = all_files[:train_count]
val_files = all_files[train_count:]
# 移动文件
for f in train_files:
shutil.copy(f, train_dir / class_name / f.name)
for f in val_files:
shutil.copy(f, val_dir / class_name / f.name)
total_files_moved += len(all_files)
if total_files_moved == 0:
self.root.after(0, lambda: messagebox.showerror("错误",
"未移动任何文件,请确保所选文件夹及其子文件夹中有JPG/PNG图片!"))
self.status_var.set(">>>划分失败")
return False
self.root.after(0, lambda: self.status_var.set(f">>>数据集划分完成!共处理 {total_files_moved} 个文件。"))
self.log_message(f"数据集划分完成!共处理 {total_files_moved} 个文件。", "SUCCESS")
return True # 划分成功
except Exception as e:
self.root.after(0, lambda: messagebox.showerror("划分失败", f"数据集划分过程中发生错误: {str(e)}"))
self.status_var.set(">>>划分失败")
self.log_message(f"数据集划分过程中发生错误: {str(e)}", "ERROR")
return False # 划分失败
def log_message(self, message: str, level: str = "INFO"):
"""
统一的日志记录方法,将日志写入 Text 组件并自动滚动到底部。
Args:
message (str): 要记录的日志信息。
level (str): 日志级别 (如 INFO, ERROR, WARNING, SUCCESS, DEBUG)。
"""
# 兼容性:检查日志组件是否已创建 (setup_gui 之前会是 None)
if self.log_text_widget is None:
print(f"[{datetime.now().strftime('%H:%M:%S')}] [{level}] {message}")
return
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
log_entry = f"[{timestamp}] [{level}] {message}\n"
# 1. 解锁并插入日志
self.log_text_widget.config(state='normal')
self.log_text_widget.insert(tk.END, log_entry)
# 2. 自动滚动到最新日志(底部)
self.log_text_widget.see(tk.END)
# 3. 重新锁定 Text 组件,防止用户编辑
self.log_text_widget.config(state='disabled')
# 4. 同时更新兼容性的状态栏
self.status_var.set(f"[{level}] {message}")
def create_model(self, model_name, num_classes):
"""创建指定模型"""
if model_name == 'ResNet-50':
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_name == 'EfficientNet-B0':
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
elif model_name == 'DenseNet-121':
model = models.densenet121(pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
else:
raise ValueError(f"不支持的模型: {model_name}")
return model.to(self.device)
def update_training_chart(self, new_data):
"""
更新训练过程图表 (损失和准确率)。
Args:
new_data (dict): 包含模型名和训练历史数据的字典。
例如: {'model_name_timestamp': {'train_losses': [...], 'val_accuracies': [...]}}
"""
# 确保图表对象已初始化
if self.ax is None:
self.log_message("图表对象未初始化,跳过更新。", "WARNING")
return
# 清除旧图
self.ax[0].clear() # 损失图
self.ax[1].clear() # 准确率图
# ------------------- 绘制损失图 -------------------
self.ax[0].set_title('训练损失 (Loss)', fontsize=10)
self.ax[0].set_xlabel('Epoch', fontsize=8)
self.ax[0].set_ylabel('Loss', fontsize=8)
self.ax[0].tick_params(axis='both', which='major', labelsize=7)
self.ax[0].grid(True, linestyle='--', alpha=0.6)
# ------------------- 绘制准确率图 -------------------
self.ax[1].set_title('验证准确率 (Accuracy)', fontsize=10)
self.ax[1].set_xlabel('Epoch', fontsize=8)
self.ax[1].set_ylabel('Accuracy (%)', fontsize=8)
self.ax[1].tick_params(axis='both', which='major', labelsize=7)
self.ax[1].grid(True, linestyle='--', alpha=0.6)
# 遍历所有模型数据并绘制
for model_label, data in new_data.items():
epochs = range(1, len(data['train_losses']) + 1)
# **修改点 1:为损失曲线添加 label**
self.ax[0].plot(epochs, data['train_losses'], label=f'Loss ({model_label})',
linestyle='-', marker='.', markersize=3)
# **修改点 2:为准确率曲线添加 label**
self.ax[1].plot(epochs, data['val_accuracies'], label=f'Acc ({model_label})',
linestyle='-', marker='.', markersize=3)
# **修改点 3:移动并调用 legend()**
# 只有在绘制完带 label 的曲线后,调用 legend() 才能显示内容。
# 损失图例
# 确保只在有曲线时调用 legend,如果图表清空后立即调用 legend 可能会有空警告。
# 这里使用 has_data 隐式检查,因为 new_data 是非空的
if new_data:
self.ax[0].legend(fontsize=8, loc='upper right') # 第 599 行
self.ax[1].legend(fontsize=8, loc='lower right') # 第 605 行
# 优化 x 轴刻度,避免刻度过多
for ax in self.ax:
if ax.get_lines():
max_epoch = max(len(ax.get_lines()[0].get_xdata()), 1)
if max_epoch > 10:
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True, nbins=10))
else:
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
# 刷新画布
self.canvas.draw_idle()
def train_single_model(self, model_name, data_path=None):
"""
训练单个模型。
该函数执行完整的模型训练流程,包括数据加载、模型初始化、
训练循环、验证、模型保存,并实时更新 GUI 上的训练损失、
准确率图表和进度条。
Args:
model_name (str): 要训练的模型名称。
data_path (str, optional): 训练数据的路径。默认为 None,将从 self.data_path 获取。
Returns:
tuple: (best_acc, train_losses, val_accuracies, timestamp)
"""
from torch.utils.data import DataLoader
from torchvision import datasets
# 生成时间戳
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 路径设置
if data_path is None:
data_path = Path(self.data_path.get()) / "train_val"
else:
data_path = Path(data_path)
# 数据预处理
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(0.3),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载数据
train_path = data_path / 'train'
val_path = data_path / 'val'
try:
train_dataset = datasets.ImageFolder(root=train_path, transform=train_transform)
val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform)
except Exception as e:
raise RuntimeError(f"数据加载失败,请检查路径和图片文件: {str(e)}")
# **修复点 1: 在这里获取准确的类别名称和数量**
# 确保模型创建和数据加载使用的类别数量一致 (解决 Target out of bounds 错误)
# 必须先检查数据集是否为空
if len(train_dataset) == 0:
raise RuntimeError(f"训练集为空!请检查 {train_path} 路径下是否有图片。")
if len(val_dataset) == 0:
raise RuntimeError(f"验证集为空!请检查 {val_path} 路径下是否有图片。")
# 获取准确的类别信息
class_names_local = train_dataset.classes
num_classes_local = len(class_names_local)
# **重要:将准确的类别名称更新到实例变量,供预测时使用**
self.class_names = class_names_local
# 调试语句:检查加载数据量
# print() 替换为 self.log_message()
self.log_message(
f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}, 类别数: {num_classes_local}", "DEBUG")
batch_size = int(self.batch_size_var.get())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# 创建模型
# **修复点 2: 使用本地准确的 num_classes_local 创建模型**
model = self.create_model(model_name, num_classes_local)
# 训练设置
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
epochs = int(self.epochs_var.get())
best_acc = 0
train_losses = []
val_accuracies = []
for epoch in range(epochs):
# 训练阶段
model.train()
epoch_train_loss = 0
if len(train_loader) == 0:
# print() 替换为 self.log_message()
self.log_message("train_loader 为空,跳过训练阶段。", "WARNING")
break
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self.device), targets.to(self.device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
epoch_train_loss += loss.item()
# --- 进度条更新代码 (保持不变) ---
current_batch_progress = (batch_idx + 1) / len(train_loader)
overall_progress_percent = ((epoch) + current_batch_progress) * 100 / epochs
self.root.after(0, lambda p=overall_progress_percent: self.progress_var.set(p))
# --- 进度条更新代码结束 ---
# 计算平均训练损失
if len(train_loader) > 0:
avg_train_loss = epoch_train_loss / len(train_loader)
else:
avg_train_loss = 0.0
train_losses.append(avg_train_loss)
# 验证阶段
model.eval()
correct = 0
total = 0
# 必须检查验证集是否为空,防止除以零
if len(val_loader) == 0:
acc = 0.0
else:
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100. * correct / total
val_accuracies.append(acc)
# 实时更新图表
current_data = {f"{model_name}_{timestamp}": {
'train_losses': train_losses,
'val_accuracies': val_accuracies
}}
self.root.after(0, lambda: self.update_training_chart(current_data))
self.root.after(0, lambda: self.status_var.set(
f"{model_name} - Epoch {epoch + 1}/{epochs} - 损失: {avg_train_loss:.4f} - 准确率: {acc:.1f}%"
))
if acc > best_acc:
best_acc = acc
# 保存模型 - 使用时间戳命名
model_filename = f'{model_name}_model_{timestamp}.pth'
model_info = {
'model_state_dict': model.state_dict(),
'class_names': class_names_local, # **修复点 3: 保存准确的类别名称**
'accuracy': acc,
'epoch': epoch,
'train_losses': train_losses.copy(),
'val_accuracies': val_accuracies.copy(),
'timestamp': timestamp,
'model_name': model_name,
'total_epochs': epochs
}
torch.save(model_info, model_filename)
self.log_message(f"模型 {model_filename} 保存成功,准确率: {acc:.1f}%", "INFO")
# 更新最佳模型记录
self.best_model_files[model_name] = model_filename
scheduler.step()
# 训练完成后,确保进度条显示 100%
self.root.after(0, lambda: self.progress_var.set(100))
return best_acc, train_losses, val_accuracies, timestamp
# 完整的子函数 (精简回核心功能)
def start_training(self):
"""
启动选中模型的训练过程。将训练过程放入单独线程,避免阻塞 GUI。
"""
data_path = self.data_path.get()
if not data_path:
messagebox.showerror("错误", "请先选择数据文件夹!")
return
# 检查是否有模型被选中
selected_models = [name for name, var in self.model_vars.items() if var.get()]
if not selected_models:
messagebox.showerror("错误", "请至少选择一个要训练的模型!")
return
self.status_var.set(">>>正在准备训练环境...")
import threading
# 训练逻辑函数
def training_thread_target():
# 1. 确保数据集划分是最新且有效的
# **修改点:直接检查 split_dataset 的返回值**
if not self.split_dataset():
# split_dataset 失败时会自行弹窗提示
self.root.after(0, lambda: self.status_var.set(">>>训练因数据集划分失败而停止。"))
return
try:
# 检查划分结果的文件夹路径
train_val_path = Path(data_path) / "train_val"
# 再次检查 train 文件夹是否确实存在,这能捕捉到 split_dataset 内部的失败
if not (train_val_path / 'train').exists():
self.root.after(0, lambda: messagebox.showerror("划分错误", "数据集划分失败,未找到训练目录。"))
return
# 清空训练历史记录
self.all_training_data = {}
self.root.after(0, lambda: self.update_training_chart({}))
self.root.after(0, lambda: self.progress_var.set(0))
for model_name in selected_models:
self.root.after(0, lambda m=model_name: self.status_var.set(f">>>开始训练模型: {m}..."))
# 调用实际训练函数
best_acc, _, _, _ = self.train_single_model(model_name, data_path=train_val_path)
self.root.after(0, lambda m=model_name, acc=best_acc: self.status_var.set(
f">>>模型 {m} 训练完成,最佳准确率: {acc:.1f}%"
))
except RuntimeError as e:
self.root.after(0, lambda err=str(e): messagebox.showerror("训练失败", f"训练失败: {err}"))
self.root.after(0, lambda: self.status_var.set(">>>训练因错误停止。"))
except Exception as e:
self.root.after(0, lambda err=str(e): messagebox.showerror("未知错误", f"程序遇到未知错误: {err}"))
self.root.after(0, lambda: self.status_var.set(">>>训练因未知错误停止。"))
finally:
# 所有模型训练完毕或发生错误后
self.root.after(0, lambda: self.refresh_model_list())
self.root.after(0, lambda: self.status_var.set(">>>训练任务结束。"))
# 创建并启动线程
training_thread = threading.Thread(target=training_thread_target)
training_thread.start()
def refresh_model_list(self):
"""刷新模型列表 - 增强版本,并保持已加载模型的颜色。
该函数会扫描当前目录下所有模型文件,并将其信息展示在Listbox中。
如果模型已被加载(在self.loaded_model_display_texts中),则高亮显示。
"""
self.loaded_models_list.delete(0, tk.END)
self.available_models.clear()
# 查找所有模型文件
model_files = glob.glob('*_model_*.pth')
self.log_message(f"发现 {len(model_files)} 个模型文件。", "DEBUG")
for model_file in model_files:
try:
# 从文件名解析信息
if 'ResNet-50_model_' in model_file:
model_name = 'ResNet-50'
elif 'EfficientNet-B0_model_' in model_file:
model_name = 'EfficientNet-B0'
elif 'DenseNet-121_model_' in model_file:
model_name = 'DenseNet-121'
else:
continue
# 加载模型信息获取准确率
checkpoint = torch.load(model_file, map_location='cpu')
accuracy = checkpoint.get('accuracy', 0)
timestamp = checkpoint.get('timestamp', '未知时间')
display_text = f"{model_name} | 准确率: {accuracy:.1f}% | 时间: {timestamp}"
# 插入 Listbox
self.loaded_models_list.insert(tk.END, display_text)
self.available_models[display_text] = model_file
# 检查是否是已加载的模型,并应用颜色
if display_text in self.loaded_model_display_texts:
# 获取刚刚插入的项的索引
index = self.loaded_models_list.size() - 1
# 直接设置背景和前景颜色
self.loaded_models_list.itemconfigure(index, {'background': '#3498db', 'foreground': 'white'})
self.log_message(f"加载模型信息成功: {display_text}", "DEBUG")
except Exception as e:
# print() 替换为 self.log_message()
self.log_message(f"加载模型文件 {model_file} 失败: {e}", "ERROR")
def clear_models(self):
"""清除所有加载的模型。
该函数会清空加载的模型字典,重置状态变量,并清除Listbox中的所有高亮显示。
"""
# 清除 Listbox 中的高亮
for i in range(self.loaded_models_list.size()):
# 恢复默认背景和前景
self.loaded_models_list.itemconfigure(i, {'background': '', 'foreground': ''})
self.models = {}
self.current_model = None
self.loaded_model_display_texts.clear() # 清除记录
self.status_var.set("已清除所有加载的模型")
def delete_selected_models(self):
"""
删除Listbox中选中的模型文件,并从磁盘上永久移除。
删除前会弹出确认窗口。
"""
# 获取选中的索引
selection = self.loaded_models_list.curselection()
if not selection:
messagebox.showwarning("警告", "请先在列表中选择一个要删除的模型!")
return
# 选中的是 Listbox 中的索引,通常只有一个,我们只处理第一个
index = selection[0]
display_text = self.loaded_models_list.get(index)
model_file_name = self.available_models.get(display_text)
model_name = self._get_model_name_from_display_text(display_text)
if not model_file_name:
self.log_message(f"模型文件路径丢失: {display_text}", "ERROR")
messagebox.showerror("错误", "无法找到对应的模型文件路径。")
return
# 弹出确认对话框
confirm = messagebox.askyesno(
"确认删除",
f"您确定要永久删除模型文件吗?\n\n模型: {model_name}\n文件: {model_file_name}\n\n此操作不可撤销!"
)
if confirm:
try:
# 1. 从磁盘删除文件
if os.path.exists(model_file_name):
os.remove(model_file_name)
self.log_message(f"已永久删除模型文件: {model_file_name}", "SUCCESS")
else:
self.log_message(f"文件不存在: {model_file_name},但已从列表中移除。", "WARNING")
# 2. 从内部记录和 Listbox 中移除
# 尝试卸载(如果已加载)
if display_text in self.loaded_model_display_texts:
self.loaded_model_display_texts.remove(display_text)
self.models.pop(model_name, None)
# 检查并清除 current_model 引用
if self.current_model == model_name:
self.current_model = None
# 从 Listbox 中删除项
self.loaded_models_list.delete(index)
# 从 available_models 中删除记录
self.available_models.pop(display_text, None)
self.status_var.set(f"已成功删除模型: {model_name}")
self.log_message(f"模型 {model_name} 已从系统和磁盘中移除。", "INFO")
except Exception as e:
messagebox.showerror("删除失败", f"删除文件 {model_file_name} 时发生错误: {str(e)}")
self.log_message(f"删除模型失败: {str(e)}", "ERROR")
def select_image(self):
"""选择图片,并重置视图状态"""
self.current_image_path = filedialog.askopenfilename(
title="选择蜘蛛图片",
filetypes=[("图片文件", "*.jpg *.jpeg *.png *.bmp")]
)
if self.current_image_path:
try:
# 关键:将原始 PIL 图像保存在实例变量中
self.original_pil_image = Image.open(self.current_image_path).convert('RGB')
# 重置视图状态
self.zoom_level = 1.0
self.pan_x = 0
self.pan_y = 0
# 调用统一的图像更新函数
self._update_displayed_image(self.original_pil_image)
self.status_var.set(f"已选择图片: {os.path.basename(self.current_image_path)}")
except Exception as e:
# 错误修正:将 selfbox.showerror 替换为 messagebox.showerror
messagebox.showerror("错误", f"加载图片失败: {str(e)}")
def _start_pan(self, event):
"""
处理鼠标左键按下事件,记录平移起始点。
"""
self._last_x = event.x
self._last_y = event.y