diff --git a/.gitignore b/.gitignore index 188005f..85f5e94 100644 --- a/.gitignore +++ b/.gitignore @@ -9,10 +9,16 @@ wheels/ # Virtual environments .venv -# backup -*.bak - -.vscode/ +# Figures notebooks/figures/ tests/figures/ + +# deployment site/ + +# ai agent docs +IFLOW.md + +*.bak +.vscode/ + diff --git a/src/plotfig/matrix.py b/src/plotfig/matrix.py index 70e9b70..ba33709 100644 --- a/src/plotfig/matrix.py +++ b/src/plotfig/matrix.py @@ -33,6 +33,8 @@ def plot_matrix_figure( title_fontsize: Num = 15, title_pad: Num = 20, diag_border: bool = False, + xlabel: str | None = None, + ylabel: str | None = None, **imshow_kwargs: Any, ) -> Axes: """ @@ -60,10 +62,12 @@ def plot_matrix_figure( title_fontsize (Num): 标题的字体大小。 title_pad (Num): 标题上方的间距。 diag_border (bool): 是否绘制对角线单元格边框。 + xlabel (str | None): X轴的整体标签名称。 + ylabel (str | None): Y轴的整体标签名称。 **imshow_kwargs (Any): 传递给 `imshow()` 的其他关键字参数。 Returns: - AxesImage: 由 `imshow()` 创建的图像对象。 + Axes: 绘图的坐标轴对象。 """ ax = ax or plt.gca() @@ -74,6 +78,13 @@ def plot_matrix_figure( data, cmap=cmap, vmin=vmin, vmax=vmax, aspect=aspect, **imshow_kwargs ) ax.set_title(title_name, fontsize=title_fontsize, pad=title_pad) + + # 设置X轴和Y轴标签 + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + if diag_border: for i in range(data.shape[0]): ax.add_patch( @@ -116,4 +127,3 @@ def plot_matrix_figure( ) return ax -