Skip to content
This repository was archived by the owner on Feb 25, 2026. It is now read-only.

Commit 712de97

Browse files
Add files via upload
1 parent 6229ff8 commit 712de97

4 files changed

Lines changed: 476 additions & 0 deletions

File tree

Key Function/DataSet.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# 导入相关库
2+
import os # 与系统文件交互
3+
import tifffile as tiff #读取tiff文件格式
4+
from PIL import Image #图片处理
5+
#与torch 相关的库
6+
import torch
7+
from torchvision import transforms
8+
from torch.utils.data import Dataset, DataLoader
9+
10+
#from sklearn.preprocessing import MinMaxScaler
11+
import numpy as np
12+
13+
#创建数据集
14+
class myDataSet(Dataset):
15+
def __init__(self,img_dir,group_size=10000,transform=None):
16+
self.img_dir=img_dir
17+
self.images=os.listdir(img_dir)
18+
self.transform=transform
19+
self.all_imgs=[]
20+
self.emcal=[]
21+
self.hcal=[]
22+
self.trkn=[]
23+
self.trkp=[]
24+
self.truth=[]
25+
self.group_size=group_size
26+
self.load_images()
27+
28+
def load_images(self):
29+
all_imgs=[]
30+
for filename in self.images:
31+
if filename.endswith(".tiff"):
32+
img_path=os.path.join(self.img_dir, filename)
33+
img_array=tiff.imread(img_path)
34+
img=Image.fromarray(img_array)
35+
img_tensor=transform(img)
36+
all_imgs.append(img_tensor)
37+
self.emcal=all_imgs[:self.group_size]
38+
self.hcal=all_imgs[self.group_size:2*self.group_size]
39+
self.trkn=all_imgs[2*self.group_size:3*self.group_size]
40+
self.trkp=all_imgs[3*self.group_size:4*self.group_size]
41+
self.truth=all_imgs[4*self.group_size:5*self.group_size]
42+
43+
self.X=[]
44+
self.Y=[]
45+
46+
for emcal, hcal, trkn, trkp in zip(self.emcal,self.hcal,self.trkn, self.trkp):
47+
combined_features=torch.stack((emcal,hcal,trkn,trkp))
48+
self.X.append(combined_features)
49+
self.X=torch.stack(self.X).squeeze()
50+
self.Y=torch.stack(self.truth).squeeze()
51+
52+
def __len__(self):
53+
return len(self.X)
54+
def __getitem__(self,idx):
55+
return self.X[idx], self.Y[idx]
56+
transform=transforms.Compose([
57+
transforms.ToTensor(),
58+
# 数据预处理后期添加
59+
])
60+
61+
62+
63+

Key Function/Plotter.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#绘制直方图模块
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import matplotlib.colors as mcolors
5+
class Plotter:
6+
def __init__(self):
7+
pass
8+
def plot_histogram(self, data_list, num_bins=10, color='blue', title='Histogram',xlabel='Values',ylabel='Frequency',net = 'None',
9+
text = 'None',need_log_scale=False,range=[-1,-1],image_info = 'None',data_points='0'):
10+
# 计算范围
11+
data_min, data_max = np.min(data_list), np.max(data_list)
12+
if range == [-1,-1]:
13+
data_range = (data_min, data_max)
14+
else :
15+
data_range = range
16+
hist, bin_edges=np.histogram(data_list,bins=num_bins,range=data_range)
17+
max_bin_index = np.argmax(hist)
18+
max_bin_range = (bin_edges[max_bin_index], bin_edges[max_bin_index + 1])
19+
max_bin_freq = hist[max_bin_index] # 该 bin 的频率
20+
print(f"频率最高的 bin 的范围是: {max_bin_range},频率为: {max_bin_freq}")
21+
plt.hist(data_list,bins=num_bins,range=data_range,histtype='stepfilled',align='mid',orientation='vertical',color=color)
22+
if need_log_scale:
23+
plt.yscale('log')
24+
plt.title(title,fontsize=17)
25+
plt.xlabel(xlabel,fontsize=12)
26+
plt.tick_params(labelsize=12)
27+
plt.ylabel(ylabel,fontsize=12)
28+
#在图中添加文本
29+
plt.text(0.5, -0.2, f"The most frequent data bin: {max_bin_range}\nThe std of the data is {np.std(data_list)}",
30+
ha='center', va='center', transform=plt.gca().transAxes,fontsize=12)
31+
if text != 'None':
32+
plt.text(0.5, -0.3, text, ha='center', va='center', transform=plt.gca().transAxes,fontsize=12)
33+
if image_info != 'None':
34+
plt.text(0.5, -0.4, 'Generated from '+image_info+ '. The num of data points is '+ data_points, ha='center', va='center',
35+
transform=plt.gca().transAxes,fontsize=12)
36+
if net != 'None':
37+
plt.text(0.5, -0.5, 'The network is '+net, ha='center', va='center',
38+
transform=plt.gca().transAxes,fontsize=12)
39+
40+
def plot_scatter(self,list1,list2,title='scatter',color='r',s=1,xlabel='xlabel',ylabel='ylabel', text = 'None',net = 'None',
41+
image_info = 'None',data_points='0'):
42+
fig, ax1=plt.subplots(figsize=(10,10))
43+
# 绘制对角线
44+
ax1.plot([min(list1), max(list1)], [min(list1), max(list1)], 'k--', label='X=Y')
45+
# 绘制散点图
46+
ax1.scatter(list1, list2, s=s, color=color)
47+
ax1.set_xlabel(xlabel)
48+
ax1.set_ylabel(ylabel)
49+
plt.tick_params(labelsize=12)
50+
ax1.set_title(title)
51+
# 添加文本
52+
plt.text(0.5, -0.1, text, ha='center', va='center', transform=plt.gca().transAxes)
53+
# 拟合散点图
54+
coefficients = np.polyfit(list1, list2, 1)
55+
polynomial = np.poly1d(coefficients)
56+
x_fit = np.linspace(min(list1), max(list1), 100)
57+
y_fit = polynomial(x_fit)
58+
if text != 'None':
59+
plt.text(0.5, -0.3, text, ha='center', va='center', transform=plt.gca().transAxes,fontsize=12)
60+
if image_info != 'None':
61+
plt.text(0.5, -0.2, 'Generated from '+image_info+ '. The num of data points is '+ data_points, ha='center', va='center',
62+
transform=plt.gca().transAxes,fontsize=12)
63+
if net != 'None':
64+
plt.text(0.5, -0.4, 'The network is '+net, ha='center', va='center', transform=plt.gca().transAxes,fontsize=12)
65+
# 绘制最小二乘拟合线
66+
ax1.plot(x_fit, y_fit, 'r-', label='Least Squares Fit', linewidth=2, alpha=0.7,color='b')
67+
68+
def plot_error_pixel_map(self, true_list, predict_list, HIGTHT=56, WIDTH=56,text = 'None',net='None',
69+
image_info = 'None',data_points='0'):
70+
# 计算误差矩阵
71+
Error_pixel_map = np.array(true_list) - np.array(predict_list)
72+
Error_pixel_map = Error_pixel_map.astype(np.float32).reshape(HIGTHT, WIDTH)
73+
# 创建颜色映射
74+
plt.tick_params(labelsize=12)
75+
cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', ['red', 'white', 'black'])
76+
norm = mcolors.TwoSlopeNorm(vmin=Error_pixel_map.min(), vcenter=0, vmax=Error_pixel_map.max())
77+
78+
# 绘制误差像素图
79+
plt.figure(figsize=(10, 5))
80+
plt.imshow(Error_pixel_map, cmap=cmap, norm=norm)
81+
plt.colorbar(label='Error Value')
82+
plt.title('Error Pixel Map')
83+
if text != 'None':
84+
plt.text(0.5, -0.1, text, ha='center', va='center', transform=plt.gca().transAxes,fontsize=12)
85+
if image_info != 'None':
86+
plt.text(0.5, -0.2, 'Generated from '+image_info+ '. The num of data points is '+ data_points, ha='center', va='center',
87+
transform=plt.gca().transAxes,fontsize=12)
88+
if net != 'None':
89+
plt.text(0.5, -0.3, 'The network is '+net, ha='center', va='center',
90+
transform=plt.gca().transAxes,fontsize=12)
91+
plt.tight_layout()
92+
plt.show()

0 commit comments

Comments
 (0)