Skip to content
Merged

Feat #13

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 94 additions & 39 deletions src/plotfig/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from matplotlib.patches import Polygon
from scipy import stats

# 设置警告过滤器,显示所有警告
warnings.simplefilter("always")

# 类型别名
Num = int | float # 可同时接受int和float的类型
NumArray = list[Num] | npt.NDArray[np.float64] # 数字数组类型
Expand Down Expand Up @@ -106,6 +109,7 @@ def set_yaxis(
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.0%}"))


# 统计相关
def perform_stat_test(
data1: NumArray | None = None,
data2: NumArray | None = None,
Expand All @@ -130,6 +134,29 @@ def perform_stat_test(
return stat, p


def determine_test_modle(data, method, p_list=None, popmean=0):
comparisons = []
idx = 0
if method != "ttest_1samp":
for i in range(len(data)):
for j in range(i + 1, len(data)):
if method == "external":
p = p_list[idx]
idx += 1
else:
_, p = perform_stat_test(
data1=data[i], data2=data[j], method=method
)
if p <= 0.05:
comparisons.append((i, j, p))
else:
for i in range(len(data)):
_, p = perform_stat_test(data1=data[i], popmean=popmean, method=method)
if p <= 0.05:
comparisons.append((i, p))
return comparisons


def annotate_significance(
ax: Axes,
comparisons: list[tuple[int, int, float]],
Expand Down Expand Up @@ -168,7 +195,7 @@ def _stars(pval, i, y, color, fontsize):
_stars(pval, (i + j) / 2, y + star_offset, color, fontsize)
elif len(comparisons[0]) == 2:
for i, pval in comparisons:
y = y_base + interval
y = y_base
_stars(pval, i, y + star_offset, color, fontsize)


Expand All @@ -183,42 +210,63 @@ def statistics(
asterisk_fontsize,
asterisk_color,
):
comparisons = []
idx = 0
if test_method != "ttest_1samp":
for i in range(len(data)):
for j in range(i + 1, len(data)):
if test_method == "external":
p = p_list[idx]
idx += 1
else:
_, p = perform_stat_test(
data1=data[i], data2=data[j], method=test_method
)
if p <= 0.05:
comparisons.append((i, j, p))
if isinstance(test_method, list):
if len(test_method) > 2 or (
len(test_method) == 2 and "ttest_1samp" not in test_method
):
raise ValueError(
"test_method 最多只能有2个元素。且当元素数量为2时,其中之一必须是 'ttest_1samp'。"
)

for method in test_method:
comparisons = determine_test_modle(data, method, p_list, popmean)
if not comparisons:
return

y_max = ax.get_ylim()[1]
interval = (y_max - np.max(all_values)) / (len(comparisons) + 1)

color = (
"b"
if len(test_method) > 1 and method == "ttest_1samp"
else asterisk_color
)

annotate_significance(
ax,
comparisons,
np.max(all_values),
interval,
line_color=statistical_line_color,
star_offset=interval / 5,
fontsize=asterisk_fontsize,
color=color,
)
else:
for i in range(len(data)):
_, p = perform_stat_test(data1=data[i], popmean=popmean, method=test_method)
if p <= 0.05:
comparisons.append((i, p))
if not comparisons:
return

y_max = ax.get_ylim()[1]
interval = (y_max - np.max(all_values)) / (len(comparisons) + 1)
annotate_significance(
ax,
comparisons,
np.max(all_values),
interval,
line_color=statistical_line_color,
star_offset=interval / 5,
fontsize=asterisk_fontsize,
color=asterisk_color,
)
warnings.warn(
"请使用列表形式传递 test_method 参数,例如 test_method=['ttest_ind']。字符串形式 test_method='ttest_ind' 将在后续版本中弃用。",
DeprecationWarning,
stacklevel=1,
)
comparisons = determine_test_modle(data, test_method, p_list, popmean)
if not comparisons:
return

y_max = ax.get_ylim()[1]
interval = (y_max - np.max(all_values)) / (len(comparisons) + 1)
annotate_significance(
ax,
comparisons,
np.max(all_values),
interval,
line_color=statistical_line_color,
star_offset=interval / 5,
fontsize=asterisk_fontsize,
color=asterisk_color,
)


# 可调用接口函数
def plot_one_group_bar_figure(
data: list[NumArray],
ax: Axes | None = None,
Expand All @@ -245,8 +293,6 @@ def plot_one_group_bar_figure(
asterisk_color: str = "k",
**kwargs: Any,
) -> None:


"""绘制单组柱状图,包含散点、误差条和统计显著性标记。

Args:
Expand Down Expand Up @@ -319,7 +365,12 @@ def plot_one_group_bar_figure(
ax.imshow(gradient, aspect="auto", cmap=cmap, extent=extent, zorder=0)
else:
ax.bar(
x_positions, means, width=width, color=colors, alpha=color_alpha, edgecolor=edgecolor
x_positions,
means,
width=width,
color=colors,
alpha=color_alpha,
edgecolor=edgecolor,
)

ax.errorbar(
Expand Down Expand Up @@ -442,7 +493,9 @@ def plot_one_group_violin_figure(
labels_name = labels_name or [str(i) for i in range(len(data))]
colors = colors or ["gray"] * len(data)

def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color_alpha=1):
def _draw_gradient_violin(
ax, data, pos, width=width, c1="red", c2="blue", color_alpha=1
):
# KDE估计
kde = stats.gaussian_kde(data)
buffer = (max(data) - min(data)) / 5
Expand Down Expand Up @@ -514,7 +567,9 @@ def _draw_gradient_violin(ax, data, pos, width=width, c1="red", c2="blue", color
c2 = colors_end[i]
else:
c1 = c2 = colors[i]
ymax, ymin = _draw_gradient_violin(ax, d, pos=i, c1=c1, c2=c2, color_alpha=color_alpha)
ymax, ymin = _draw_gradient_violin(
ax, d, pos=i, c1=c1, c2=c2, color_alpha=color_alpha
)
ymax_lst.append(ymax)
ymin_lst.append(ymin)
ymax = max(ymax_lst)
Expand Down