Skip to content

Commit 83d46d7

Browse files
committed
fix(connec): fix issue with line_color display under color scale
When `scale_method` is set to `color` or `width_color` and `line_color` is specified, the color was not applied correctly. This fix ensures the color is displayed as intended. --- 修复(connec): 修复 color scale 下 line_color 显示异常问题 当 `scale_method` 设置为 `color` 或 `width_color` 且同时指定了 `line_color` 时, 颜色未能正确应用。该问题已修复,确保颜色显示符合预期。
1 parent 420e1b8 commit 83d46d7

1 file changed

Lines changed: 156 additions & 81 deletions

File tree

src/plotfig/brain_connection.py

Lines changed: 156 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
import os
2-
import os.path as op
31
import datetime
2+
from pathlib import Path
3+
from typing import Literal
4+
from collections.abc import Sequence
5+
46
import numpy as np
7+
import numpy.typing as npt
58
import nibabel as nib
6-
from scipy.ndimage import center_of_mass
79
import plotly.graph_objects as go
810
import plotly.io as pio
9-
import matplotlib.colors as mcolors
10-
import matplotlib.cm as cm
11+
from matplotlib.colors import LinearSegmentedColormap, to_hex
12+
from scipy.ndimage import center_of_mass
1113
from tqdm import tqdm
12-
from pathlib import Path
13-
import numpy.typing as npt
14+
15+
from loguru import logger
1416

1517
Num = int | float
1618

@@ -19,15 +21,30 @@
1921
"save_brain_connection_frames",
2022
]
2123

22-
def _load_surface(file: str | Path):
23-
'''加载 .surf.gii 文件,提取顶点和面'''
24-
gii = nib.load(file)
25-
vertices = gii.darrays[0].data
26-
faces = gii.darrays[1].data
27-
return vertices, faces
24+
25+
def _validate_connectome(connectome):
26+
"""检测数据是否为有效的对称且对角线为0的连接矩阵"""
27+
# 1. 判断是否二维方阵
28+
if connectome.ndim != 2 or connectome.shape[0] != connectome.shape[1]:
29+
raise ValueError("connectome 必须是二维方阵")
30+
# 2. 判断是否对称矩阵
31+
if not np.allclose(connectome, connectome.T, atol=1e-8):
32+
raise ValueError("connectome 必须是对称矩阵")
33+
# 3. 判断对角线是否全为0
34+
if not np.allclose(np.diag(connectome), 0, atol=1e-8):
35+
raise ValueError("connectome 对角线必须全部为0")
36+
# 4. 判断是否全0矩阵,警告但不抛异常
37+
if np.allclose(connectome, 0, atol=1e-8):
38+
logger.warning("connectome 矩阵所有元素均为0,可能没有有效连接数据")
39+
40+
41+
def _load_surface(file):
42+
"""加载 .surf.gii 文件,提取顶点和面"""
43+
return nib.load(file).darrays[0].data, nib.load(file).darrays[1].data
44+
2845

2946
def _create_mesh(vertices, faces, name):
30-
''' 创建 plotly 的 Mesh3d 图层'''
47+
"""创建 plotly 的 Mesh3d 图层"""
3148
return go.Mesh3d(
3249
x=vertices[:, 0],
3350
y=vertices[:, 1],
@@ -42,29 +59,33 @@ def _create_mesh(vertices, faces, name):
4259
name=name,
4360
)
4461

62+
4563
def _get_node_indices(connectome, show_all_nodes):
46-
''' 判断哪些节点需要显示'''
64+
"""判断是否显示无任何连接的节点"""
4765
if not show_all_nodes:
4866
row_is_zero = np.any(connectome != 0, axis=1)
4967
return np.where(row_is_zero)[0]
5068
else:
5169
return np.arange(connectome.shape[0])
5270

53-
def _get_centroids_real(niigz_file: str | Path):
54-
'''读取 NIfTI 图集并计算ROI质心'''
55-
img = nib.load(niigz_file)
56-
atlas_data = img.get_fdata()
57-
affine = img.affine
5871

72+
def _get_centroids_real(niigz_file):
73+
"""读取 NIfTI 图集并计算ROI质心"""
74+
atlas_data = nib.load(niigz_file).get_fdata()
75+
affine = nib.load(niigz_file).affine
5976
roi_labels = np.unique(atlas_data)
6077
roi_labels = roi_labels[roi_labels != 0]
61-
62-
centroids_voxel = [center_of_mass((atlas_data == label).astype(int)) for label in roi_labels]
78+
centroids_voxel = [
79+
center_of_mass((atlas_data == label).astype(int)) for label in roi_labels
80+
]
6381
centroids_real = [np.dot(affine, [*coord, 1])[:3] for coord in centroids_voxel]
6482
return np.array(centroids_real)
6583

66-
def _add_nodes_to_fig(fig, centroids_real, node_indices, nodes_name, nodes_size, nodes_color):
67-
'''将节点(球)添加到图中'''
84+
85+
def _add_nodes_to_fig(
86+
fig, centroids_real, node_indices, nodes_name, nodes_size, nodes_color
87+
):
88+
"""将节点(球)添加到图中"""
6889
for i in node_indices:
6990
fig.add_trace(
7091
go.Scatter3d(
@@ -74,7 +95,7 @@ def _add_nodes_to_fig(fig, centroids_real, node_indices, nodes_name, nodes_size,
7495
mode="markers+text",
7596
marker={
7697
"size": nodes_size[i],
77-
"color": nodes_color,
98+
"color": nodes_color[i],
7899
"colorscale": "Rainbow",
79100
"opacity": 0.8,
80101
"line": {"width": 2, "color": "black"},
@@ -85,8 +106,25 @@ def _add_nodes_to_fig(fig, centroids_real, node_indices, nodes_name, nodes_size,
85106
)
86107
)
87108

88-
def _add_edges_to_fig(fig, connectome, centroids_real, nodes_name, scale_method, line_width, line_color="#ff0000"):
89-
'''将连接线绘制到图中'''
109+
110+
def _add_edges_to_fig(
111+
fig,
112+
connectome,
113+
centroids_real,
114+
nodes_name,
115+
scale_method,
116+
line_width,
117+
line_color,
118+
):
119+
"""将连接线绘制到图中"""
120+
121+
def _get_gradient_color(value, color):
122+
"""获取渐变色"""
123+
assert 0 <= value <= 1, "value 必须在0和1之间"
124+
cmap = LinearSegmentedColormap.from_list("grad_cmap", ["white", color])
125+
rgba = cmap(value)
126+
return to_hex(rgba[:3])
127+
90128
nodes_num = connectome.shape[0]
91129
if np.all(connectome == 0):
92130
return
@@ -100,41 +138,50 @@ def _add_edges_to_fig(fig, connectome, centroids_real, nodes_name, scale_method,
100138
continue
101139

102140
match scale_method:
141+
case "":
142+
each_line_color = line_color if value > 0 else "#0000ff"
143+
each_line_width = line_width
103144
case "width":
104145
each_line_color = line_color if value > 0 else "#0000ff"
105146
each_line_width = abs(value / max_strength) * line_width
106147
case "color":
107148
norm_value = value / max_strength
108-
each_line_color = mcolors.to_hex(cm.bwr(mcolors.Normalize(vmin=-1, vmax=1)(norm_value)))
149+
each_line_color = _get_gradient_color(norm_value, line_color)
109150
each_line_width = line_width
110151
case "width_color" | "color_width":
111152
norm_value = value / max_strength
112153
each_line_width = abs(norm_value) * line_width
113-
each_line_color = mcolors.to_hex(cm.bwr(mcolors.Normalize(vmin=-1, vmax=1)(norm_value)))
114-
case "":
115-
each_line_color = "#ff0000" if value > 0 else "#0000ff"
116-
each_line_width = line_width
154+
each_line_color = _get_gradient_color(norm_value, line_color)
117155
case _:
118-
raise ValueError("scale_method must be '', 'width', 'color', 'width_color', or 'color_width'")
156+
raise ValueError(
157+
"scale_method 必须为 '', 'width', 'color', 'width_color', or 'color_width'中的一种"
158+
)
119159

120-
connection_line = np.array([centroids_real[i], centroids_real[j], [None] * 3])
160+
connection_line = np.array(
161+
[centroids_real[i], centroids_real[j], [None] * 3]
162+
)
121163
fig.add_trace(
122164
go.Scatter3d(
123165
x=connection_line[:, 0],
124166
y=connection_line[:, 1],
125167
z=connection_line[:, 2],
126168
mode="lines",
127169
line={"color": each_line_color, "width": each_line_width},
128-
hoverinfo="none",
170+
hoverinfo="name",
129171
name=f"{nodes_name[i]}-{nodes_name[j]}",
130172
)
131173
)
132174

175+
133176
def _finalize_figure(fig):
134-
'''调整图形布局与视觉样式'''
177+
"""调整图形布局与视觉样式"""
135178
fig.update_traces(
136179
selector={"mode": "markers"},
137-
marker={"size": 10, "colorscale": "Viridis", "line": {"width": 3, "color": "black"}},
180+
marker={
181+
"size": 10,
182+
"colorscale": "Viridis",
183+
"line": {"width": 3, "color": "black"},
184+
},
138185
)
139186
fig.update_layout(
140187
title="Connection",
@@ -147,80 +194,116 @@ def _finalize_figure(fig):
147194
margin={"l": 0, "r": 0, "b": 0, "t": 30},
148195
)
149196

197+
150198
def plot_brain_connection_figure(
151199
connectome: npt.NDArray,
152200
lh_surfgii_file: str | Path,
153201
rh_surfgii_file: str | Path,
154202
niigz_file: str | Path,
203+
output_file: str | Path | None = None,
204+
show_all_nodes: bool = False,
205+
nodes_size: Sequence[Num] | npt.NDArray | None = None,
155206
nodes_name: list[str] | None = None,
156-
nodes_size=None,
157207
nodes_color: list[str] | None = None,
158-
output_file: str | Path | None = None,
159-
scale_method: str = "",
208+
scale_method: Literal["", "width", "color", "width_color", "color_width"] = "",
160209
line_width: Num = 10,
161-
show_all_nodes: bool = False,
162-
line_color: str = "#ff0000",
163-
) -> None:
210+
line_color: str = "red",
211+
) -> go.Figure:
164212
"""绘制大脑连接图,保存在指定的html文件中
165213
166214
Args:
167-
connectome (npt.NDArray): 连接矩阵
168-
lh_surfgii_file (str | Path): 左脑surf.gii文件.
169-
rh_surfgii_file (str | Path): 右脑surf.gii文件.
170-
niigz_file (str | Path): 图集nii文件.
171-
nodes_name (List[str] | None, optional): 节点名称. Defaults to None.
172-
nodes_size (Num, optional): 节点大小. Defaults to 5.
173-
nodes_color (List[str] | None, optional): 节点颜色. Defaults to None.
174-
output_file (str | Path | None, optional): 保存的完整路径及文件名. Defaults to None.
175-
scale_method (str, optional): 连接scale的形式. Defaults to "".
176-
line_width (Num, optional): 连接粗细. Defaults to 10.
177-
178-
Raises:
179-
ValueError: 参数参数取值不合法时抛出.
215+
connectome (npt.NDArray):
216+
大脑连接矩阵,形状为 (n, n),其中 n 是脑区数量。
217+
矩阵中的值表示脑区之间的连接强度,正值表示正相关连接,负值表示负相关连接,0表示无连接。
218+
lh_surfgii_file (str | Path):
219+
左半脑表面几何文件路径 (.surf.gii 格式),用于绘制左半脑表面
220+
rh_surfgii_file (str | Path):
221+
右半脑表面几何文件路径 (.surf.gii 格式),用于绘制右半脑表面
222+
niigz_file (str | Path):
223+
NIfTI格式的脑区图谱文件路径 (.nii.gz 格式),用于定位脑区节点的三维坐标
224+
output_file (str | Path | None, optional):
225+
输出HTML文件路径。如果未指定,则使用当前时间戳生成文件名。默认为None
226+
show_all_nodes (bool, optional):
227+
是否显示所有脑区节点。如果为False,则只显示有连接的节点。默认为False
228+
nodes_size (Sequence[Num] | npt.NDArray | None, optional):
229+
每个节点的大小,长度应与脑区数量一致。默认为None,即所有节点大小为5
230+
nodes_name (list[str] | None, optional):
231+
每个节点的名称标签,长度应与脑区数量一致。默认为None,即不显示名称
232+
nodes_color (list[str] | None, optional):
233+
每个节点的颜色,长度应与脑区数量一致。默认为None,即所有节点为白色
234+
scale_method (Literal["", "width", "color", "width_color", "color_width"], optional):
235+
连接线的缩放方法:
236+
- "" : 所有连接线宽度和颜色固定
237+
- "width" : 根据连接强度调整线宽,正连接为红色,负连接为蓝色
238+
- "color" : 根据连接强度调整颜色(使用蓝白红颜色映射),线宽固定
239+
- "width_color" or "color_width" : 同时根据连接强度调整线宽和颜色
240+
默认为 ""
241+
line_width (Num, optional):
242+
连接线的基本宽度。当scale_method包含"width"时,此值作为最大宽度参考。默认为10
243+
line_color (str, optional):
244+
连接线的基本颜色。当scale_method不包含"color"时生效。默认为"#ff0000"(红色)
245+
246+
Returns:
247+
go.Figure: Plotly图形对象,包含绘制的大脑连接图
180248
"""
249+
_validate_connectome(connectome)
250+
251+
if np.any(connectome < 0):
252+
logger.warning(
253+
"由于 connectome 存在负值,连线颜色无法自定义,只能正值显示红色,负值显示蓝色"
254+
)
255+
line_color = "#ff0000"
256+
181257
nodes_num = connectome.shape[0]
182-
if nodes_name is None:
183-
nodes_name = [f"ROI-{i}" for i in range(nodes_num)]
184-
if nodes_color is None:
185-
nodes_color = ["white"] * nodes_num
186-
if nodes_size is None:
187-
nodes_size = [5] * nodes_num
258+
nodes_name = nodes_name or [""] * nodes_num
259+
nodes_color = nodes_color or ["white"] * nodes_num
260+
nodes_size = nodes_size or [5] * nodes_num
261+
188262
if output_file is None:
189263
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
190-
output_file = op.join(f"./{timestamp}.html")
191-
print(f"未指定保存路径,默认保存为:{output_file}")
264+
output_file = Path(f"{timestamp}.html")
265+
logger.info(f"未指定保存路径,默认保存在当前文件夹下的{output_file}中。")
192266

193267
node_indices = _get_node_indices(connectome, show_all_nodes)
194268
vertices_L, faces_L = _load_surface(lh_surfgii_file)
195269
vertices_R, faces_R = _load_surface(rh_surfgii_file)
196270

197271
mesh_L = _create_mesh(vertices_L, faces_L, "Left Hemisphere")
198272
mesh_R = _create_mesh(vertices_R, faces_R, "Right Hemisphere")
273+
199274
fig = go.Figure(data=[mesh_L, mesh_R])
200275

201276
centroids_real = _get_centroids_real(niigz_file)
202-
_add_nodes_to_fig(fig, centroids_real, node_indices, nodes_name, nodes_size, nodes_color)
203-
_add_edges_to_fig(fig, connectome, centroids_real, nodes_name, scale_method, line_width, line_color)
277+
_add_nodes_to_fig(
278+
fig, centroids_real, node_indices, nodes_name, nodes_size, nodes_color
279+
)
280+
_add_edges_to_fig(
281+
fig,
282+
connectome,
283+
centroids_real,
284+
nodes_name,
285+
scale_method,
286+
line_width,
287+
line_color,
288+
)
204289
_finalize_figure(fig)
205290

206291
fig.write_html(output_file)
207292
return fig
208293

209294

210295
def save_brain_connection_frames(
211-
fig: go.Figure,
212-
output_dir: str,
213-
n_frames: int = 36
296+
fig: go.Figure, output_dir: str | Path, n_frames: int = 36
214297
) -> None:
215298
"""
216-
生成不同角度的静态图片帧,用于制作旋转大脑连接图的 GIF 或视频
299+
生成不同角度的静态图片帧,可用于制作旋转大脑连接图的 GIF。
217300
218301
Args:
219302
fig (go.Figure): Plotly 的 Figure 对象,包含大脑表面和连接图。
220-
output_dir (str): 图片保存的文件夹路径,会自动创建文件夹
303+
output_dir (str): 图片保存的文件夹路径,若文件夹不存在则自动创建
221304
n_frames (int, optional): 旋转帧的数量。默认 36,即每 10 度一帧。
222305
"""
223-
os.makedirs(output_dir, exist_ok=True)
306+
Path(output_dir).mkdir(parents=True, exist_ok=True)
224307
angles = np.linspace(0, 360, n_frames, endpoint=False)
225308
for i, angle in tqdm(enumerate(angles), total=len(angles)):
226309
camera = dict(
@@ -230,12 +313,4 @@ def save_brain_connection_frames(
230313
)
231314
fig.update_layout(scene_camera=camera)
232315
pio.write_image(fig, f"{output_dir}/frame_{i:03d}.png", width=800, height=800)
233-
print(f"保存了 {n_frames} 张图片在 {output_dir}")
234-
235-
236-
def main():
237-
pass
238-
239-
240-
if __name__ == "__main__":
241-
main()
316+
logger.info(f"保存了 {n_frames} 张图片在 {output_dir}")

0 commit comments

Comments
 (0)