1+ #绘制直方图模块
2+ import numpy as np
3+ import matplotlib .pyplot as plt
4+ import matplotlib .colors as mcolors
5+ class Plotter :
6+ def __init__ (self ):
7+ pass
8+ def plot_histogram (self , data_list , num_bins = 10 , color = 'blue' , title = 'Histogram' ,xlabel = 'Values' ,ylabel = 'Frequency' ,net = 'None' ,
9+ text = 'None' ,need_log_scale = False ,range = [- 1 ,- 1 ],image_info = 'None' ,data_points = '0' ):
10+ # 计算范围
11+ data_min , data_max = np .min (data_list ), np .max (data_list )
12+ if range == [- 1 ,- 1 ]:
13+ data_range = (data_min , data_max )
14+ else :
15+ data_range = range
16+ hist , bin_edges = np .histogram (data_list ,bins = num_bins ,range = data_range )
17+ max_bin_index = np .argmax (hist )
18+ max_bin_range = (bin_edges [max_bin_index ], bin_edges [max_bin_index + 1 ])
19+ max_bin_freq = hist [max_bin_index ] # 该 bin 的频率
20+ print (f"频率最高的 bin 的范围是: { max_bin_range } ,频率为: { max_bin_freq } " )
21+ plt .hist (data_list ,bins = num_bins ,range = data_range ,histtype = 'stepfilled' ,align = 'mid' ,orientation = 'vertical' ,color = color )
22+ if need_log_scale :
23+ plt .yscale ('log' )
24+ plt .title (title ,fontsize = 17 )
25+ plt .xlabel (xlabel ,fontsize = 12 )
26+ plt .tick_params (labelsize = 12 )
27+ plt .ylabel (ylabel ,fontsize = 12 )
28+ #在图中添加文本
29+ plt .text (0.5 , - 0.2 , f"The most frequent data bin: { max_bin_range } \n The std of the data is { np .std (data_list )} " ,
30+ ha = 'center' , va = 'center' , transform = plt .gca ().transAxes ,fontsize = 12 )
31+ if text != 'None' :
32+ plt .text (0.5 , - 0.3 , text , ha = 'center' , va = 'center' , transform = plt .gca ().transAxes ,fontsize = 12 )
33+ if image_info != 'None' :
34+ plt .text (0.5 , - 0.4 , 'Generated from ' + image_info + '. The num of data points is ' + data_points , ha = 'center' , va = 'center' ,
35+ transform = plt .gca ().transAxes ,fontsize = 12 )
36+ if net != 'None' :
37+ plt .text (0.5 , - 0.5 , 'The network is ' + net , ha = 'center' , va = 'center' ,
38+ transform = plt .gca ().transAxes ,fontsize = 12 )
39+
40+ def plot_scatter (self ,list1 ,list2 ,title = 'scatter' ,color = 'r' ,s = 1 ,xlabel = 'xlabel' ,ylabel = 'ylabel' , text = 'None' ,net = 'None' ,
41+ image_info = 'None' ,data_points = '0' ):
42+ fig , ax1 = plt .subplots (figsize = (10 ,10 ))
43+ # 绘制对角线
44+ ax1 .plot ([min (list1 ), max (list1 )], [min (list1 ), max (list1 )], 'k--' , label = 'X=Y' )
45+ # 绘制散点图
46+ ax1 .scatter (list1 , list2 , s = s , color = color )
47+ ax1 .set_xlabel (xlabel )
48+ ax1 .set_ylabel (ylabel )
49+ plt .tick_params (labelsize = 12 )
50+ ax1 .set_title (title )
51+ # 添加文本
52+ plt .text (0.5 , - 0.1 , text , ha = 'center' , va = 'center' , transform = plt .gca ().transAxes )
53+ # 拟合散点图
54+ coefficients = np .polyfit (list1 , list2 , 1 )
55+ polynomial = np .poly1d (coefficients )
56+ x_fit = np .linspace (min (list1 ), max (list1 ), 100 )
57+ y_fit = polynomial (x_fit )
58+ if text != 'None' :
59+ plt .text (0.5 , - 0.3 , text , ha = 'center' , va = 'center' , transform = plt .gca ().transAxes ,fontsize = 12 )
60+ if image_info != 'None' :
61+ plt .text (0.5 , - 0.2 , 'Generated from ' + image_info + '. The num of data points is ' + data_points , ha = 'center' , va = 'center' ,
62+ transform = plt .gca ().transAxes ,fontsize = 12 )
63+ if net != 'None' :
64+ plt .text (0.5 , - 0.4 , 'The network is ' + net , ha = 'center' , va = 'center' , transform = plt .gca ().transAxes ,fontsize = 12 )
65+ # 绘制最小二乘拟合线
66+ ax1 .plot (x_fit , y_fit , 'r-' , label = 'Least Squares Fit' , linewidth = 2 , alpha = 0.7 ,color = 'b' )
67+
68+ def plot_error_pixel_map (self , true_list , predict_list , HIGTHT = 56 , WIDTH = 56 ,text = 'None' ,net = 'None' ,
69+ image_info = 'None' ,data_points = '0' ):
70+ # 计算误差矩阵
71+ Error_pixel_map = np .array (true_list ) - np .array (predict_list )
72+ Error_pixel_map = Error_pixel_map .astype (np .float32 ).reshape (HIGTHT , WIDTH )
73+ # 创建颜色映射
74+ plt .tick_params (labelsize = 12 )
75+ cmap = mcolors .LinearSegmentedColormap .from_list ('custom_cmap' , ['red' , 'white' , 'black' ])
76+ norm = mcolors .TwoSlopeNorm (vmin = Error_pixel_map .min (), vcenter = 0 , vmax = Error_pixel_map .max ())
77+
78+ # 绘制误差像素图
79+ plt .figure (figsize = (10 , 5 ))
80+ plt .imshow (Error_pixel_map , cmap = cmap , norm = norm )
81+ plt .colorbar (label = 'Error Value' )
82+ plt .title ('Error Pixel Map' )
83+ if text != 'None' :
84+ plt .text (0.5 , - 0.1 , text , ha = 'center' , va = 'center' , transform = plt .gca ().transAxes ,fontsize = 12 )
85+ if image_info != 'None' :
86+ plt .text (0.5 , - 0.2 , 'Generated from ' + image_info + '. The num of data points is ' + data_points , ha = 'center' , va = 'center' ,
87+ transform = plt .gca ().transAxes ,fontsize = 12 )
88+ if net != 'None' :
89+ plt .text (0.5 , - 0.3 , 'The network is ' + net , ha = 'center' , va = 'center' ,
90+ transform = plt .gca ().transAxes ,fontsize = 12 )
91+ plt .tight_layout ()
92+ plt .show ()
0 commit comments