diff --git a/src/plotfig/bar.py b/src/plotfig/bar.py index a2174f5..e116166 100644 --- a/src/plotfig/bar.py +++ b/src/plotfig/bar.py @@ -10,6 +10,9 @@ from matplotlib.patches import Polygon from scipy import stats +# 设置警告过滤器,显示所有警告 +warnings.simplefilter("always") + # 类型别名 Num = int | float # 可同时接受int和float的类型 NumArray = list[Num] | npt.NDArray[np.float64] # 数字数组类型 @@ -106,6 +109,7 @@ def set_yaxis( ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.0%}")) +# 统计相关 def perform_stat_test( data1: NumArray | None = None, data2: NumArray | None = None, @@ -130,6 +134,29 @@ def perform_stat_test( return stat, p +def determine_test_modle(data, method, p_list=None, popmean=0): + comparisons = [] + idx = 0 + if method != "ttest_1samp": + for i in range(len(data)): + for j in range(i + 1, len(data)): + if method == "external": + p = p_list[idx] + idx += 1 + else: + _, p = perform_stat_test( + data1=data[i], data2=data[j], method=method + ) + if p <= 0.05: + comparisons.append((i, j, p)) + else: + for i in range(len(data)): + _, p = perform_stat_test(data1=data[i], popmean=popmean, method=method) + if p <= 0.05: + comparisons.append((i, p)) + return comparisons + + def annotate_significance( ax: Axes, comparisons: list[tuple[int, int, float]], @@ -168,7 +195,7 @@ def _stars(pval, i, y, color, fontsize): _stars(pval, (i + j) / 2, y + star_offset, color, fontsize) elif len(comparisons[0]) == 2: for i, pval in comparisons: - y = y_base + interval + y = y_base _stars(pval, i, y + star_offset, color, fontsize) @@ -183,42 +210,63 @@ def statistics( asterisk_fontsize, asterisk_color, ): - comparisons = [] - idx = 0 - if test_method != "ttest_1samp": - for i in range(len(data)): - for j in range(i + 1, len(data)): - if test_method == "external": - p = p_list[idx] - idx += 1 - else: - _, p = perform_stat_test( - data1=data[i], data2=data[j], method=test_method - ) - if p <= 0.05: - comparisons.append((i, j, p)) + if isinstance(test_method, list): + if len(test_method) > 2 or ( + len(test_method) == 2 and "ttest_1samp" not in test_method + ): + raise ValueError( + "test_method 最多只能有2个元素。且当元素数量为2时,其中之一必须是 'ttest_1samp'。" + ) + + for method in test_method: + comparisons = determine_test_modle(data, method, p_list, popmean) + if not comparisons: + return + + y_max = ax.get_ylim()[1] + interval = (y_max - np.max(all_values)) / (len(comparisons) + 1) + + color = ( + "b" + if len(test_method) > 1 and method == "ttest_1samp" + else asterisk_color + ) + + annotate_significance( + ax, + comparisons, + np.max(all_values), + interval, + line_color=statistical_line_color, + star_offset=interval / 5, + fontsize=asterisk_fontsize, + color=color, + ) else: - for i in range(len(data)): - _, p = perform_stat_test(data1=data[i], popmean=popmean, method=test_method) - if p <= 0.05: - comparisons.append((i, p)) - if not comparisons: - return - - y_max = ax.get_ylim()[1] - interval = (y_max - np.max(all_values)) / (len(comparisons) + 1) - annotate_significance( - ax, - comparisons, - np.max(all_values), - interval, - line_color=statistical_line_color, - star_offset=interval / 5, - fontsize=asterisk_fontsize, - color=asterisk_color, - ) + warnings.warn( + "请使用列表形式传递 test_method 参数,例如 test_method=['ttest_ind']。字符串形式 test_method='ttest_ind' 将在后续版本中弃用。", + DeprecationWarning, + stacklevel=1, + ) + comparisons = determine_test_modle(data, test_method, p_list, popmean) + if not comparisons: + return + + y_max = ax.get_ylim()[1] + interval = (y_max - np.max(all_values)) / (len(comparisons) + 1) + annotate_significance( + ax, + comparisons, + np.max(all_values), + interval, + line_color=statistical_line_color, + star_offset=interval / 5, + fontsize=asterisk_fontsize, + color=asterisk_color, + ) +# 可调用接口函数 def plot_one_group_bar_figure( data: list[NumArray], ax: Axes | None = None, @@ -245,8 +293,6 @@ def plot_one_group_bar_figure( asterisk_color: str = "k", **kwargs: Any, ) -> None: - - """绘制单组柱状图,包含散点、误差条和统计显著性标记。 Args: @@ -319,7 +365,12 @@ def plot_one_group_bar_figure( ax.imshow(gradient, aspect="auto", cmap=cmap, extent=extent, zorder=0) else: ax.bar( - x_positions, means, width=width, color=colors, alpha=color_alpha, edgecolor=edgecolor + x_positions, + means, + width=width, + color=colors, + alpha=color_alpha, + edgecolor=edgecolor, ) ax.errorbar( @@ -442,7 +493,9 @@ def plot_one_group_violin_figure( labels_name = labels_name or [str(i) for i in range(len(data))] colors = colors or ["gray"] * len(data) - def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color_alpha=1): + def _draw_gradient_violin( + ax, data, pos, width=width, c1="red", c2="blue", color_alpha=1 + ): # KDE估计 kde = stats.gaussian_kde(data) buffer = (max(data) - min(data)) / 5 @@ -514,7 +567,9 @@ def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color c2 = colors_end[i] else: c1 = c2 = colors[i] - ymax, ymin = _draw_gradient_violin(ax, d, pos=i, c1=c1, c2=c2, color_alpha=color_alpha) + ymax, ymin = _draw_gradient_violin( + ax, d, pos=i, c1=c1, c2=c2, color_alpha=color_alpha + ) ymax_lst.append(ymax) ymin_lst.append(ymin) ymax = max(ymax_lst)