1010from matplotlib .patches import Polygon
1111from scipy import stats
1212
13+ # 设置警告过滤器,显示所有警告
14+ warnings .simplefilter ("always" )
15+
1316# 类型别名
1417Num = int | float # 可同时接受int和float的类型
1518NumArray = list [Num ] | npt .NDArray [np .float64 ] # 数字数组类型
@@ -106,6 +109,7 @@ def set_yaxis(
106109 ax .yaxis .set_major_formatter (FuncFormatter (lambda x , pos : f"{ x :.0%} " ))
107110
108111
112+ # 统计相关
109113def perform_stat_test (
110114 data1 : NumArray | None = None ,
111115 data2 : NumArray | None = None ,
@@ -130,6 +134,29 @@ def perform_stat_test(
130134 return stat , p
131135
132136
137+ def determine_test_modle (data , method , p_list = None , popmean = 0 ):
138+ comparisons = []
139+ idx = 0
140+ if method != "ttest_1samp" :
141+ for i in range (len (data )):
142+ for j in range (i + 1 , len (data )):
143+ if method == "external" :
144+ p = p_list [idx ]
145+ idx += 1
146+ else :
147+ _ , p = perform_stat_test (
148+ data1 = data [i ], data2 = data [j ], method = method
149+ )
150+ if p <= 0.05 :
151+ comparisons .append ((i , j , p ))
152+ else :
153+ for i in range (len (data )):
154+ _ , p = perform_stat_test (data1 = data [i ], popmean = popmean , method = method )
155+ if p <= 0.05 :
156+ comparisons .append ((i , p ))
157+ return comparisons
158+
159+
133160def annotate_significance (
134161 ax : Axes ,
135162 comparisons : list [tuple [int , int , float ]],
@@ -168,7 +195,7 @@ def _stars(pval, i, y, color, fontsize):
168195 _stars (pval , (i + j ) / 2 , y + star_offset , color , fontsize )
169196 elif len (comparisons [0 ]) == 2 :
170197 for i , pval in comparisons :
171- y = y_base + interval
198+ y = y_base
172199 _stars (pval , i , y + star_offset , color , fontsize )
173200
174201
@@ -183,42 +210,39 @@ def statistics(
183210 asterisk_fontsize ,
184211 asterisk_color ,
185212):
186- comparisons = []
187- idx = 0
188- if test_method != "ttest_1samp" :
189- for i in range (len (data )):
190- for j in range (i + 1 , len (data )):
191- if test_method == "external" :
192- p = p_list [idx ]
193- idx += 1
194- else :
195- _ , p = perform_stat_test (
196- data1 = data [i ], data2 = data [j ], method = test_method
197- )
198- if p <= 0.05 :
199- comparisons .append ((i , j , p ))
213+ if isinstance (test_method , list ):
214+ if len (test_method ) > 2 or (
215+ len (test_method ) == 2 and "ttest_1samp" not in test_method
216+ ):
217+ raise ValueError (
218+ "test_method 最多只能有2个元素。且当元素数量为2时,其中之一必须是 'ttest_1samp'。"
219+ )
220+
200221 else :
201- for i in range (len (data )):
202- _ , p = perform_stat_test (data1 = data [i ], popmean = popmean , method = test_method )
203- if p <= 0.05 :
204- comparisons .append ((i , p ))
205- if not comparisons :
206- return
207-
208- y_max = ax .get_ylim ()[1 ]
209- interval = (y_max - np .max (all_values )) / (len (comparisons ) + 1 )
210- annotate_significance (
211- ax ,
212- comparisons ,
213- np .max (all_values ),
214- interval ,
215- line_color = statistical_line_color ,
216- star_offset = interval / 5 ,
217- fontsize = asterisk_fontsize ,
218- color = asterisk_color ,
219- )
222+ warnings .warn (
223+ "请使用列表形式传递 test_method 参数,例如 test_method=['ttest_ind']。字符串形式 test_method='ttest_ind' 将在后续版本中弃用。" ,
224+ DeprecationWarning ,
225+ stacklevel = 1 ,
226+ )
227+ comparisons = determine_test_modle (data , test_method , p_list , popmean )
228+ if not comparisons :
229+ return
230+
231+ y_max = ax .get_ylim ()[1 ]
232+ interval = (y_max - np .max (all_values )) / (len (comparisons ) + 1 )
233+ annotate_significance (
234+ ax ,
235+ comparisons ,
236+ np .max (all_values ),
237+ interval ,
238+ line_color = statistical_line_color ,
239+ star_offset = interval / 5 ,
240+ fontsize = asterisk_fontsize ,
241+ color = asterisk_color ,
242+ )
220243
221244
245+ # 可调用接口函数
222246def plot_one_group_bar_figure (
223247 data : list [NumArray ],
224248 ax : Axes | None = None ,
@@ -245,8 +269,6 @@ def plot_one_group_bar_figure(
245269 asterisk_color : str = "k" ,
246270 ** kwargs : Any ,
247271) -> None :
248-
249-
250272 """绘制单组柱状图,包含散点、误差条和统计显著性标记。
251273
252274 Args:
@@ -319,7 +341,12 @@ def plot_one_group_bar_figure(
319341 ax .imshow (gradient , aspect = "auto" , cmap = cmap , extent = extent , zorder = 0 )
320342 else :
321343 ax .bar (
322- x_positions , means , width = width , color = colors , alpha = color_alpha , edgecolor = edgecolor
344+ x_positions ,
345+ means ,
346+ width = width ,
347+ color = colors ,
348+ alpha = color_alpha ,
349+ edgecolor = edgecolor ,
323350 )
324351
325352 ax .errorbar (
@@ -442,7 +469,9 @@ def plot_one_group_violin_figure(
442469 labels_name = labels_name or [str (i ) for i in range (len (data ))]
443470 colors = colors or ["gray" ] * len (data )
444471
445- def _draw_gradient_violin (ax , data , pos , width = width , c1 = "red" , c2 = "blue" , color_alpha = 1 ):
472+ def _draw_gradient_violin (
473+ ax , data , pos , width = width , c1 = "red" , c2 = "blue" , color_alpha = 1
474+ ):
446475 # KDE估计
447476 kde = stats .gaussian_kde (data )
448477 buffer = (max (data ) - min (data )) / 5
@@ -514,7 +543,9 @@ def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color
514543 c2 = colors_end [i ]
515544 else :
516545 c1 = c2 = colors [i ]
517- ymax , ymin = _draw_gradient_violin (ax , d , pos = i , c1 = c1 , c2 = c2 , color_alpha = color_alpha )
546+ ymax , ymin = _draw_gradient_violin (
547+ ax , d , pos = i , c1 = c1 , c2 = c2 , color_alpha = color_alpha
548+ )
518549 ymax_lst .append (ymax )
519550 ymin_lst .append (ymin )
520551 ymax = max (ymax_lst )
0 commit comments