Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/api/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# API

::: plotfig.bar
::: plotfig.single_bar

::: plotfig.multi_bars

::: plotfig.correlation

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
dev = [
"ipykernel>=6.30.1",
"mkdocstrings-python>=2.0.1",
"zensical>=0.0.11",
"zensical==0.0.13",
]

[build-system]
Expand Down
33 changes: 17 additions & 16 deletions src/plotfig/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
from importlib.metadata import version, PackageNotFoundError
from importlib.metadata import PackageNotFoundError, version

from .bar import (
plot_one_group_bar_figure,
plot_one_group_violin_figure,
plot_multi_group_bar_figure,
)
from .correlation import plot_correlation_figure
from .matrix import plot_matrix_figure
from .brain_surface import plot_brain_surface_figure
from .circos import plot_circos_figure
from .brain_connection import (
plot_brain_connection_figure,
save_brain_connection_frames,
batch_crop_images,
create_gif_from_images,
plot_brain_connection_figure,
save_brain_connection_frames,
)
from .brain_surface import plot_brain_surface_figure
from .circos import plot_circos_figure
from .correlation import plot_correlation_figure
from .matrix import plot_matrix_figure
from .multi_bars import (
plot_multi_group_bar_figure,
)
from .single_bar import (
plot_one_group_bar_figure,
plot_one_group_violin_figure,
)
from .utils import (
gen_hex_colors,
gen_symmetric_matrix,
gen_cmap,
value_to_hex,
gen_white_to_color_cmap,
is_symmetric_square,
value_to_hex,
)


__all__ = [
# bar
"plot_one_group_bar_figure",
Expand All @@ -45,7 +46,7 @@
# utils
"gen_hex_colors",
"gen_symmetric_matrix",
"gen_cmap",
"gen_white_to_color_cmap",
"value_to_hex",
"is_symmetric_square",
]
Expand Down
8 changes: 4 additions & 4 deletions src/plotfig/brain_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from typing import Literal

import imageio
from PIL import Image
import nibabel as nib
import numpy as np
import numpy.typing as npt
from scipy.ndimage import center_of_mass
from tqdm import tqdm
import plotly.graph_objects as go
import plotly.io as pio
from matplotlib.colors import LinearSegmentedColormap, to_hex
from loguru import logger
from matplotlib.colors import LinearSegmentedColormap, to_hex
from PIL import Image
from scipy.ndimage import center_of_mass
from tqdm import tqdm

warnings.filterwarnings("ignore", category=DeprecationWarning)

Expand Down
30 changes: 14 additions & 16 deletions src/plotfig/brain_surface.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
from collections.abc import Mapping
from pathlib import Path
from typing import TypeAlias
from collections.abc import Mapping

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from matplotlib.axes import Axes

from surfplot import Plot

# 类型别名定义
Num: TypeAlias = float | int

__all__ = [
"plot_brain_surface_figure",
]

# 路径常量
NEURODATA = Path(__file__).resolve().parent / "data" / "neurodata"


def _map_labels_to_values(data, gifti_file):
gifti = nib.load(gifti_file)
# 获取顶点标签编号数组,shape=(顶点数,)
labels = gifti.darrays[0].data
labels = gifti.darrays[0].data
# 构建标签编号到脑区名称的映射字典
key_to_label = {label.key: label.label for label in gifti.labeltable.labels}
# 检查数据中是否有在图集中找不到的脑区标签
Expand Down Expand Up @@ -64,7 +61,7 @@ def plot_brain_surface_figure(
title_name: str = "",
title_fontsize: int = 12,
as_outline: bool = False,
) -> Axes:
) -> Axes:
"""在大脑皮层表面绘制数值数据的函数。

Args:
Expand Down Expand Up @@ -134,7 +131,7 @@ def plot_brain_surface_figure(
"sulc": {
"lh": "surfaces/human_fsLR/100206.L.sulc.32k_fs_LR.shape.gii",
"rh": "surfaces/human_fsLR/100206.R.sulc.32k_fs_LR.shape.gii",
}
},
},
"chimpanzee": {
"surf": {
Expand All @@ -146,7 +143,7 @@ def plot_brain_surface_figure(
"lh": "atlases/chimpanzee_BNA/ChimpBNA.L.32k_fs_LR.label.gii",
"rh": "atlases/chimpanzee_BNA/ChimpBNA.R.32k_fs_LR.label.gii",
},
}
},
},
"macaque": {
"surf": {
Expand Down Expand Up @@ -174,14 +171,15 @@ def plot_brain_surface_figure(
"sulc": {
"lh": "surfaces/macaque_BNA/SC_06018.L.sulc.32k_fs_LR.shape.gii",
"rh": "surfaces/macaque_BNA/SC_06018.R.sulc.32k_fs_LR.shape.gii",
}

}
},
},
}

# 检查物种是否支持
if species not in atlas_info:
raise ValueError(f"Unsupported species: {species}. Supported species are: {list(atlas_info.keys())}")
raise ValueError(
f"Unsupported species: {species}. Supported species are: {list(atlas_info.keys())}"
)
else:
# 检查指定物种的图集是否支持
if atlas not in atlas_info[species]["atlas"]:
Expand All @@ -199,7 +197,7 @@ def plot_brain_surface_figure(
NEURODATA / atlas_info[species]["surf"]["lh"],
NEURODATA / atlas_info[species]["surf"]["rh"],
views="dorsal",
zoom = 1.2,
zoom=1.2,
)
lh_sulc_file = NEURODATA / atlas_info[species]["sulc"]["lh"]
rh_sulc_file = NEURODATA / atlas_info[species]["sulc"]["rh"]
Expand All @@ -209,7 +207,7 @@ def plot_brain_surface_figure(
"right": nib.load(rh_sulc_file).darrays[0].data,
},
cmap="Grays_r",
cbar=False
cbar=False,
)

# 分离左半球和右半球的数据
Expand Down
34 changes: 18 additions & 16 deletions src/plotfig/circos.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
# 标准库
from typing import Literal, Any
from typing import Any, Literal

# 第三方库
import numpy as np
from numpy.typing import NDArray
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.projections.polar import PolarAxes
from matplotlib.axes import Axes
from numpy.typing import NDArray
from pycirclize import Circos
from loguru import logger

# 项目内模块
from plotfig.utils.matrix import (
is_symmetric_square,
)
from plotfig.utils.color import (
gen_hex_colors,
gen_cmap,
gen_white_to_color_cmap,
value_to_hex,
)

from plotfig.utils.matrix import (
is_symmetric_square,
)

__all__ = ["plot_circos_figure"]

Expand Down Expand Up @@ -141,7 +137,9 @@ def plot_circos_figure(
logger.warning("connectome 矩阵所有元素均为0,可能没有有效连接数据")
vmax = float(0 if vmax is None else vmax)
vmin = float(0 if vmin is None else vmin)
colormap = gen_cmap(edge_color) if cmap is None else plt.get_cmap(cmap)
colormap = (
gen_white_to_color_cmap(edge_color) if cmap is None else plt.get_cmap(cmap)
)
elif np.any(connectome < 0):
logger.warning(
"由于 connectome 存在负值,连线颜色无法自定义,只能正值显示红色,负值显示蓝色"
Expand All @@ -153,7 +151,9 @@ def plot_circos_figure(
else:
vmin = float(connectome.min() if vmin is None else vmin)
vmax = float(connectome.max() if vmax is None else vmax)
colormap = gen_cmap(edge_color) if cmap is None else plt.get_cmap(cmap)
colormap = (
gen_white_to_color_cmap(edge_color) if cmap is None else plt.get_cmap(cmap)
)
if vmin > vmax:
raise ValueError(f"目前{vmin=},而{vmax=}。但是vmin不得大于vmax,请检查数据")
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
Expand Down Expand Up @@ -192,7 +192,9 @@ def plot_circos_figure(
for sector in circos.sectors:
if sector.name.startswith("_gap"):
continue
sector.text(sector.name, size=node_label_fontsize, orientation=node_label_orientation)
sector.text(
sector.name, size=node_label_fontsize, orientation=node_label_orientation
)
track = sector.add_track((95, 100))
track.axis(fc=name2color[sector.name])

Expand Down
46 changes: 29 additions & 17 deletions src/plotfig/correlation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import numpy as np
from typing import TypeAlias

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import (
ScalarFormatter,
FormatStrFormatter,
FuncFormatter,
MultipleLocator,
FormatStrFormatter,
ScalarFormatter,
)
from matplotlib.colors import LinearSegmentedColormap
from scipy import stats

from typing import TypeAlias

# 类型别名定义
Num: TypeAlias = float | int # 可同时接受int和float的类型
Num: TypeAlias = float | int

__all__ = ["plot_correlation_figure"]

Expand All @@ -24,26 +23,27 @@ def plot_correlation_figure(
ax: Axes | None = None,
stats_method: str = "spearman",
ci: bool = False,
dots_color: str = "steelblue",
ci_color: str = "gray",
dots_color: str | list[str] = "steelblue",
dots_size: int | float = 10,
line_color: str = "r",
title_name: str = "",
title_fontsize: int = 12,
title_pad: int = 10,
x_label_name: str = "",
x_label_fontsize: int = 10,
x_tick_fontsize: int = 10,
x_tick_fontsize: int = 8,
x_tick_rotation: int = 0,
x_major_locator: float | None = None,
x_max_tick_to_value: float | None = None,
x_format: str = "normal", # 支持 "normal", "sci", "1f", "percent"
y_label_name: str = "",
y_label_fontsize: int = 10,
y_tick_fontsize: int = 10,
y_tick_fontsize: int = 8,
y_tick_rotation: int = 0,
y_major_locator: float | None = None,
y_max_tick_to_value: float | None = None,
y_format: str = "normal", # 支持 "normal", "sci", "1f", "percent"
y_format: str = "sci", # 支持 "normal", "sci", "1f", "percent"
asterisk_fontsize: int = 10,
show_p_value: bool = False,
hexbin: bool = False,
Expand All @@ -61,6 +61,7 @@ def plot_correlation_figure(
ax (plt.Axes | None, optional): matplotlib 的 Axes 对象,用于绘图。默认为 None,使用当前 Axes。
stats_method (str, optional): 相关性统计方法,支持 "spearman" 和 "pearson"。默认为 "spearman"。
ci (bool, optional): 是否绘制置信区间带。默认为 False。
ci_color (str, optional): 置信区间带颜色。默认为 "salmon"。
dots_color (str, optional): 散点的颜色。默认为 "steelblue"。
dots_size (int | float, optional): 散点的大小。默认为 1。
line_color (str, optional): 回归线的颜色。默认为 "r"(红色)。
Expand Down Expand Up @@ -94,7 +95,16 @@ def plot_correlation_figure(
"""

def set_axis(
ax, axis, label, labelsize, ticksize, rotation, locator, max_tick_value, fmt, lim
ax,
axis,
label,
labelsize,
ticksize,
rotation,
locator,
max_tick_value,
fmt,
lim,
):
if axis == "x":
set_label = ax.set_xlabel
Expand Down Expand Up @@ -149,7 +159,7 @@ def set_axis(
)
hb = ax.hexbin(A, B, gridsize=hexbin_gridsize, cmap=hexbin_cmap)
else:
ax.scatter(A, B, c=dots_color, s=dots_size, alpha=0.8)
ax.scatter(A, B, c=dots_color, s=dots_size)
ax.plot(x_seq, y_pred, line_color, lw=1)

if ci:
Expand All @@ -165,7 +175,7 @@ def set_axis(
x_seq,
y_pred - conf_interval,
y_pred + conf_interval,
color="salmon",
color=ci_color,
alpha=0.3,
)

Expand Down Expand Up @@ -210,9 +220,11 @@ def set_axis(
label = r"$\rho$"

if show_p_value:
asterisk = f" p={p:.4f}"
asterisk = f" p={p:.3f}"
else:
asterisk = " ***" if p < 0.001 else " **" if p < 0.01 else " *" if p < 0.05 else ""
asterisk = (
" ***" if p < 0.001 else " **" if p < 0.01 else " *" if p < 0.05 else ""
)
x_start, x_end = ax.get_xlim()
y_start, y_end = ax.get_ylim()
ax.text(
Expand Down
8 changes: 4 additions & 4 deletions src/plotfig/matrix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Any, Sequence

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from typing import Sequence, Any
from mpl_toolkits.axes_grid1 import make_axes_locatable

Num = int | float

Expand Down
Loading