Skip to content

Commit 6275b4b

Browse files
committed
refactor(bar): mark string input for test_method as deprecated
Passing `test_method="ttest_ind"` is now discouraged and will be deprecated in a future release. Users should use `test_method=["ttest_ind"]` instead. Refactored the parameter validation logic in the `statistics` function. Added `warnings.simplefilter('always')` to ensure deprecation warnings are always shown.
1 parent 736dc68 commit 6275b4b

1 file changed

Lines changed: 70 additions & 39 deletions

File tree

src/plotfig/bar.py

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from matplotlib.patches import Polygon
1111
from scipy import stats
1212

13+
# 设置警告过滤器,显示所有警告
14+
warnings.simplefilter("always")
15+
1316
# 类型别名
1417
Num = int | float # 可同时接受int和float的类型
1518
NumArray = list[Num] | npt.NDArray[np.float64] # 数字数组类型
@@ -106,6 +109,7 @@ def set_yaxis(
106109
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.0%}"))
107110

108111

112+
# 统计相关
109113
def perform_stat_test(
110114
data1: NumArray | None = None,
111115
data2: NumArray | None = None,
@@ -130,6 +134,29 @@ def perform_stat_test(
130134
return stat, p
131135

132136

137+
def determine_test_modle(data, method, p_list=None, popmean=0):
138+
comparisons = []
139+
idx = 0
140+
if method != "ttest_1samp":
141+
for i in range(len(data)):
142+
for j in range(i + 1, len(data)):
143+
if method == "external":
144+
p = p_list[idx]
145+
idx += 1
146+
else:
147+
_, p = perform_stat_test(
148+
data1=data[i], data2=data[j], method=method
149+
)
150+
if p <= 0.05:
151+
comparisons.append((i, j, p))
152+
else:
153+
for i in range(len(data)):
154+
_, p = perform_stat_test(data1=data[i], popmean=popmean, method=method)
155+
if p <= 0.05:
156+
comparisons.append((i, p))
157+
return comparisons
158+
159+
133160
def annotate_significance(
134161
ax: Axes,
135162
comparisons: list[tuple[int, int, float]],
@@ -168,7 +195,7 @@ def _stars(pval, i, y, color, fontsize):
168195
_stars(pval, (i + j) / 2, y + star_offset, color, fontsize)
169196
elif len(comparisons[0]) == 2:
170197
for i, pval in comparisons:
171-
y = y_base + interval
198+
y = y_base
172199
_stars(pval, i, y + star_offset, color, fontsize)
173200

174201

@@ -183,42 +210,39 @@ def statistics(
183210
asterisk_fontsize,
184211
asterisk_color,
185212
):
186-
comparisons = []
187-
idx = 0
188-
if test_method != "ttest_1samp":
189-
for i in range(len(data)):
190-
for j in range(i + 1, len(data)):
191-
if test_method == "external":
192-
p = p_list[idx]
193-
idx += 1
194-
else:
195-
_, p = perform_stat_test(
196-
data1=data[i], data2=data[j], method=test_method
197-
)
198-
if p <= 0.05:
199-
comparisons.append((i, j, p))
213+
if isinstance(test_method, list):
214+
if len(test_method) > 2 or (
215+
len(test_method) == 2 and "ttest_1samp" not in test_method
216+
):
217+
raise ValueError(
218+
"test_method 最多只能有2个元素。且当元素数量为2时,其中之一必须是 'ttest_1samp'。"
219+
)
220+
200221
else:
201-
for i in range(len(data)):
202-
_, p = perform_stat_test(data1=data[i], popmean=popmean, method=test_method)
203-
if p <= 0.05:
204-
comparisons.append((i, p))
205-
if not comparisons:
206-
return
207-
208-
y_max = ax.get_ylim()[1]
209-
interval = (y_max - np.max(all_values)) / (len(comparisons) + 1)
210-
annotate_significance(
211-
ax,
212-
comparisons,
213-
np.max(all_values),
214-
interval,
215-
line_color=statistical_line_color,
216-
star_offset=interval / 5,
217-
fontsize=asterisk_fontsize,
218-
color=asterisk_color,
219-
)
222+
warnings.warn(
223+
"请使用列表形式传递 test_method 参数,例如 test_method=['ttest_ind']。字符串形式 test_method='ttest_ind' 将在后续版本中弃用。",
224+
DeprecationWarning,
225+
stacklevel=1,
226+
)
227+
comparisons = determine_test_modle(data, test_method, p_list, popmean)
228+
if not comparisons:
229+
return
230+
231+
y_max = ax.get_ylim()[1]
232+
interval = (y_max - np.max(all_values)) / (len(comparisons) + 1)
233+
annotate_significance(
234+
ax,
235+
comparisons,
236+
np.max(all_values),
237+
interval,
238+
line_color=statistical_line_color,
239+
star_offset=interval / 5,
240+
fontsize=asterisk_fontsize,
241+
color=asterisk_color,
242+
)
220243

221244

245+
# 可调用接口函数
222246
def plot_one_group_bar_figure(
223247
data: list[NumArray],
224248
ax: Axes | None = None,
@@ -245,8 +269,6 @@ def plot_one_group_bar_figure(
245269
asterisk_color: str = "k",
246270
**kwargs: Any,
247271
) -> None:
248-
249-
250272
"""绘制单组柱状图,包含散点、误差条和统计显著性标记。
251273
252274
Args:
@@ -319,7 +341,12 @@ def plot_one_group_bar_figure(
319341
ax.imshow(gradient, aspect="auto", cmap=cmap, extent=extent, zorder=0)
320342
else:
321343
ax.bar(
322-
x_positions, means, width=width, color=colors, alpha=color_alpha, edgecolor=edgecolor
344+
x_positions,
345+
means,
346+
width=width,
347+
color=colors,
348+
alpha=color_alpha,
349+
edgecolor=edgecolor,
323350
)
324351

325352
ax.errorbar(
@@ -442,7 +469,9 @@ def plot_one_group_violin_figure(
442469
labels_name = labels_name or [str(i) for i in range(len(data))]
443470
colors = colors or ["gray"] * len(data)
444471

445-
def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color_alpha=1):
472+
def _draw_gradient_violin(
473+
ax, data, pos, width=width, c1="red", c2="blue", color_alpha=1
474+
):
446475
# KDE估计
447476
kde = stats.gaussian_kde(data)
448477
buffer = (max(data) - min(data)) / 5
@@ -514,7 +543,9 @@ def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color
514543
c2 = colors_end[i]
515544
else:
516545
c1 = c2 = colors[i]
517-
ymax, ymin = _draw_gradient_violin(ax, d, pos=i, c1=c1, c2=c2, color_alpha=color_alpha)
546+
ymax, ymin = _draw_gradient_violin(
547+
ax, d, pos=i, c1=c1, c2=c2, color_alpha=color_alpha
548+
)
518549
ymax_lst.append(ymax)
519550
ymin_lst.append(ymin)
520551
ymax = max(ymax_lst)

0 commit comments

Comments
 (0)