Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ wheels/
# Virtual environments
.venv

# backup
*.bak

.vscode/
# Figures
notebooks/figures/
tests/figures/

# deployment
site/

# ai agent docs
IFLOW.md

*.bak
.vscode/

14 changes: 12 additions & 2 deletions src/plotfig/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def plot_matrix_figure(
title_fontsize: Num = 15,
title_pad: Num = 20,
diag_border: bool = False,
xlabel: str | None = None,
ylabel: str | None = None,
**imshow_kwargs: Any,
) -> Axes:
"""
Expand Down Expand Up @@ -60,10 +62,12 @@ def plot_matrix_figure(
title_fontsize (Num): 标题的字体大小。
title_pad (Num): 标题上方的间距。
diag_border (bool): 是否绘制对角线单元格边框。
xlabel (str | None): X轴的整体标签名称。
ylabel (str | None): Y轴的整体标签名称。
**imshow_kwargs (Any): 传递给 `imshow()` 的其他关键字参数。

Returns:
AxesImage: 由 `imshow()` 创建的图像对象
Axes: 绘图的坐标轴对象
"""

ax = ax or plt.gca()
Expand All @@ -74,6 +78,13 @@ def plot_matrix_figure(
data, cmap=cmap, vmin=vmin, vmax=vmax, aspect=aspect, **imshow_kwargs
)
ax.set_title(title_name, fontsize=title_fontsize, pad=title_pad)

# 设置X轴和Y轴标签
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)

if diag_border:
for i in range(data.shape[0]):
ax.add_patch(
Expand Down Expand Up @@ -116,4 +127,3 @@ def plot_matrix_figure(
)

return ax