1- import os
2- import os .path as op
31import datetime
2+ from pathlib import Path
3+ from typing import Literal
4+ from collections .abc import Sequence
5+
46import numpy as np
7+ import numpy .typing as npt
58import nibabel as nib
6- from scipy .ndimage import center_of_mass
79import plotly .graph_objects as go
810import plotly .io as pio
9- import matplotlib .colors as mcolors
10- import matplotlib . cm as cm
11+ from matplotlib .colors import LinearSegmentedColormap , to_hex
12+ from scipy . ndimage import center_of_mass
1113from tqdm import tqdm
12- from pathlib import Path
13- import numpy . typing as npt
14+
15+ from loguru import logger
1416
1517Num = int | float
1618
1921 "save_brain_connection_frames" ,
2022]
2123
22- def _load_surface (file : str | Path ):
23- '''加载 .surf.gii 文件,提取顶点和面'''
24- gii = nib .load (file )
25- vertices = gii .darrays [0 ].data
26- faces = gii .darrays [1 ].data
27- return vertices , faces
24+
25+ def _validate_connectome (connectome ):
26+ """检测数据是否为有效的对称且对角线为0的连接矩阵"""
27+ # 1. 判断是否二维方阵
28+ if connectome .ndim != 2 or connectome .shape [0 ] != connectome .shape [1 ]:
29+ raise ValueError ("connectome 必须是二维方阵" )
30+ # 2. 判断是否对称矩阵
31+ if not np .allclose (connectome , connectome .T , atol = 1e-8 ):
32+ raise ValueError ("connectome 必须是对称矩阵" )
33+ # 3. 判断对角线是否全为0
34+ if not np .allclose (np .diag (connectome ), 0 , atol = 1e-8 ):
35+ raise ValueError ("connectome 对角线必须全部为0" )
36+ # 4. 判断是否全0矩阵,警告但不抛异常
37+ if np .allclose (connectome , 0 , atol = 1e-8 ):
38+ logger .warning ("connectome 矩阵所有元素均为0,可能没有有效连接数据" )
39+
40+
41+ def _load_surface (file ):
42+ """加载 .surf.gii 文件,提取顶点和面"""
43+ return nib .load (file ).darrays [0 ].data , nib .load (file ).darrays [1 ].data
44+
2845
2946def _create_mesh (vertices , faces , name ):
30- ''' 创建 plotly 的 Mesh3d 图层'''
47+ """ 创建 plotly 的 Mesh3d 图层"""
3148 return go .Mesh3d (
3249 x = vertices [:, 0 ],
3350 y = vertices [:, 1 ],
@@ -42,29 +59,33 @@ def _create_mesh(vertices, faces, name):
4259 name = name ,
4360 )
4461
62+
4563def _get_node_indices (connectome , show_all_nodes ):
46- ''' 判断哪些节点需要显示'''
64+ """判断是否显示无任何连接的节点"""
4765 if not show_all_nodes :
4866 row_is_zero = np .any (connectome != 0 , axis = 1 )
4967 return np .where (row_is_zero )[0 ]
5068 else :
5169 return np .arange (connectome .shape [0 ])
5270
53- def _get_centroids_real (niigz_file : str | Path ):
54- '''读取 NIfTI 图集并计算ROI质心'''
55- img = nib .load (niigz_file )
56- atlas_data = img .get_fdata ()
57- affine = img .affine
5871
72+ def _get_centroids_real (niigz_file ):
73+ """读取 NIfTI 图集并计算ROI质心"""
74+ atlas_data = nib .load (niigz_file ).get_fdata ()
75+ affine = nib .load (niigz_file ).affine
5976 roi_labels = np .unique (atlas_data )
6077 roi_labels = roi_labels [roi_labels != 0 ]
61-
62- centroids_voxel = [center_of_mass ((atlas_data == label ).astype (int )) for label in roi_labels ]
78+ centroids_voxel = [
79+ center_of_mass ((atlas_data == label ).astype (int )) for label in roi_labels
80+ ]
6381 centroids_real = [np .dot (affine , [* coord , 1 ])[:3 ] for coord in centroids_voxel ]
6482 return np .array (centroids_real )
6583
66- def _add_nodes_to_fig (fig , centroids_real , node_indices , nodes_name , nodes_size , nodes_color ):
67- '''将节点(球)添加到图中'''
84+
85+ def _add_nodes_to_fig (
86+ fig , centroids_real , node_indices , nodes_name , nodes_size , nodes_color
87+ ):
88+ """将节点(球)添加到图中"""
6889 for i in node_indices :
6990 fig .add_trace (
7091 go .Scatter3d (
@@ -74,7 +95,7 @@ def _add_nodes_to_fig(fig, centroids_real, node_indices, nodes_name, nodes_size,
7495 mode = "markers+text" ,
7596 marker = {
7697 "size" : nodes_size [i ],
77- "color" : nodes_color ,
98+ "color" : nodes_color [ i ] ,
7899 "colorscale" : "Rainbow" ,
79100 "opacity" : 0.8 ,
80101 "line" : {"width" : 2 , "color" : "black" },
@@ -85,8 +106,25 @@ def _add_nodes_to_fig(fig, centroids_real, node_indices, nodes_name, nodes_size,
85106 )
86107 )
87108
88- def _add_edges_to_fig (fig , connectome , centroids_real , nodes_name , scale_method , line_width , line_color = "#ff0000" ):
89- '''将连接线绘制到图中'''
109+
110+ def _add_edges_to_fig (
111+ fig ,
112+ connectome ,
113+ centroids_real ,
114+ nodes_name ,
115+ scale_method ,
116+ line_width ,
117+ line_color ,
118+ ):
119+ """将连接线绘制到图中"""
120+
121+ def _get_gradient_color (value , color ):
122+ """获取渐变色"""
123+ assert 0 <= value <= 1 , "value 必须在0和1之间"
124+ cmap = LinearSegmentedColormap .from_list ("grad_cmap" , ["white" , color ])
125+ rgba = cmap (value )
126+ return to_hex (rgba [:3 ])
127+
90128 nodes_num = connectome .shape [0 ]
91129 if np .all (connectome == 0 ):
92130 return
@@ -100,41 +138,50 @@ def _add_edges_to_fig(fig, connectome, centroids_real, nodes_name, scale_method,
100138 continue
101139
102140 match scale_method :
141+ case "" :
142+ each_line_color = line_color if value > 0 else "#0000ff"
143+ each_line_width = line_width
103144 case "width" :
104145 each_line_color = line_color if value > 0 else "#0000ff"
105146 each_line_width = abs (value / max_strength ) * line_width
106147 case "color" :
107148 norm_value = value / max_strength
108- each_line_color = mcolors . to_hex ( cm . bwr ( mcolors . Normalize ( vmin = - 1 , vmax = 1 )( norm_value )) )
149+ each_line_color = _get_gradient_color ( norm_value , line_color )
109150 each_line_width = line_width
110151 case "width_color" | "color_width" :
111152 norm_value = value / max_strength
112153 each_line_width = abs (norm_value ) * line_width
113- each_line_color = mcolors .to_hex (cm .bwr (mcolors .Normalize (vmin = - 1 , vmax = 1 )(norm_value )))
114- case "" :
115- each_line_color = "#ff0000" if value > 0 else "#0000ff"
116- each_line_width = line_width
154+ each_line_color = _get_gradient_color (norm_value , line_color )
117155 case _:
118- raise ValueError ("scale_method must be '', 'width', 'color', 'width_color', or 'color_width'" )
156+ raise ValueError (
157+ "scale_method 必须为 '', 'width', 'color', 'width_color', or 'color_width'中的一种"
158+ )
119159
120- connection_line = np .array ([centroids_real [i ], centroids_real [j ], [None ] * 3 ])
160+ connection_line = np .array (
161+ [centroids_real [i ], centroids_real [j ], [None ] * 3 ]
162+ )
121163 fig .add_trace (
122164 go .Scatter3d (
123165 x = connection_line [:, 0 ],
124166 y = connection_line [:, 1 ],
125167 z = connection_line [:, 2 ],
126168 mode = "lines" ,
127169 line = {"color" : each_line_color , "width" : each_line_width },
128- hoverinfo = "none " ,
170+ hoverinfo = "name " ,
129171 name = f"{ nodes_name [i ]} -{ nodes_name [j ]} " ,
130172 )
131173 )
132174
175+
133176def _finalize_figure (fig ):
134- ''' 调整图形布局与视觉样式'''
177+ """ 调整图形布局与视觉样式"""
135178 fig .update_traces (
136179 selector = {"mode" : "markers" },
137- marker = {"size" : 10 , "colorscale" : "Viridis" , "line" : {"width" : 3 , "color" : "black" }},
180+ marker = {
181+ "size" : 10 ,
182+ "colorscale" : "Viridis" ,
183+ "line" : {"width" : 3 , "color" : "black" },
184+ },
138185 )
139186 fig .update_layout (
140187 title = "Connection" ,
@@ -147,80 +194,116 @@ def _finalize_figure(fig):
147194 margin = {"l" : 0 , "r" : 0 , "b" : 0 , "t" : 30 },
148195 )
149196
197+
150198def plot_brain_connection_figure (
151199 connectome : npt .NDArray ,
152200 lh_surfgii_file : str | Path ,
153201 rh_surfgii_file : str | Path ,
154202 niigz_file : str | Path ,
203+ output_file : str | Path | None = None ,
204+ show_all_nodes : bool = False ,
205+ nodes_size : Sequence [Num ] | npt .NDArray | None = None ,
155206 nodes_name : list [str ] | None = None ,
156- nodes_size = None ,
157207 nodes_color : list [str ] | None = None ,
158- output_file : str | Path | None = None ,
159- scale_method : str = "" ,
208+ scale_method : Literal ["" , "width" , "color" , "width_color" , "color_width" ] = "" ,
160209 line_width : Num = 10 ,
161- show_all_nodes : bool = False ,
162- line_color : str = "#ff0000" ,
163- ) -> None :
210+ line_color : str = "red" ,
211+ ) -> go .Figure :
164212 """绘制大脑连接图,保存在指定的html文件中
165213
166214 Args:
167- connectome (npt.NDArray): 连接矩阵
168- lh_surfgii_file (str | Path): 左脑surf.gii文件.
169- rh_surfgii_file (str | Path): 右脑surf.gii文件.
170- niigz_file (str | Path): 图集nii文件.
171- nodes_name (List[str] | None, optional): 节点名称. Defaults to None.
172- nodes_size (Num, optional): 节点大小. Defaults to 5.
173- nodes_color (List[str] | None, optional): 节点颜色. Defaults to None.
174- output_file (str | Path | None, optional): 保存的完整路径及文件名. Defaults to None.
175- scale_method (str, optional): 连接scale的形式. Defaults to "".
176- line_width (Num, optional): 连接粗细. Defaults to 10.
177-
178- Raises:
179- ValueError: 参数参数取值不合法时抛出.
215+ connectome (npt.NDArray):
216+ 大脑连接矩阵,形状为 (n, n),其中 n 是脑区数量。
217+ 矩阵中的值表示脑区之间的连接强度,正值表示正相关连接,负值表示负相关连接,0表示无连接。
218+ lh_surfgii_file (str | Path):
219+ 左半脑表面几何文件路径 (.surf.gii 格式),用于绘制左半脑表面
220+ rh_surfgii_file (str | Path):
221+ 右半脑表面几何文件路径 (.surf.gii 格式),用于绘制右半脑表面
222+ niigz_file (str | Path):
223+ NIfTI格式的脑区图谱文件路径 (.nii.gz 格式),用于定位脑区节点的三维坐标
224+ output_file (str | Path | None, optional):
225+ 输出HTML文件路径。如果未指定,则使用当前时间戳生成文件名。默认为None
226+ show_all_nodes (bool, optional):
227+ 是否显示所有脑区节点。如果为False,则只显示有连接的节点。默认为False
228+ nodes_size (Sequence[Num] | npt.NDArray | None, optional):
229+ 每个节点的大小,长度应与脑区数量一致。默认为None,即所有节点大小为5
230+ nodes_name (list[str] | None, optional):
231+ 每个节点的名称标签,长度应与脑区数量一致。默认为None,即不显示名称
232+ nodes_color (list[str] | None, optional):
233+ 每个节点的颜色,长度应与脑区数量一致。默认为None,即所有节点为白色
234+ scale_method (Literal["", "width", "color", "width_color", "color_width"], optional):
235+ 连接线的缩放方法:
236+ - "" : 所有连接线宽度和颜色固定
237+ - "width" : 根据连接强度调整线宽,正连接为红色,负连接为蓝色
238+ - "color" : 根据连接强度调整颜色(使用蓝白红颜色映射),线宽固定
239+ - "width_color" or "color_width" : 同时根据连接强度调整线宽和颜色
240+ 默认为 ""
241+ line_width (Num, optional):
242+ 连接线的基本宽度。当scale_method包含"width"时,此值作为最大宽度参考。默认为10
243+ line_color (str, optional):
244+ 连接线的基本颜色。当scale_method不包含"color"时生效。默认为"#ff0000"(红色)
245+
246+ Returns:
247+ go.Figure: Plotly图形对象,包含绘制的大脑连接图
180248 """
249+ _validate_connectome (connectome )
250+
251+ if np .any (connectome < 0 ):
252+ logger .warning (
253+ "由于 connectome 存在负值,连线颜色无法自定义,只能正值显示红色,负值显示蓝色"
254+ )
255+ line_color = "#ff0000"
256+
181257 nodes_num = connectome .shape [0 ]
182- if nodes_name is None :
183- nodes_name = [f"ROI-{ i } " for i in range (nodes_num )]
184- if nodes_color is None :
185- nodes_color = ["white" ] * nodes_num
186- if nodes_size is None :
187- nodes_size = [5 ] * nodes_num
258+ nodes_name = nodes_name or ["" ] * nodes_num
259+ nodes_color = nodes_color or ["white" ] * nodes_num
260+ nodes_size = nodes_size or [5 ] * nodes_num
261+
188262 if output_file is None :
189263 timestamp = datetime .datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
190- output_file = op . join (f"./ { timestamp } .html" )
191- print (f"未指定保存路径,默认保存为: { output_file } " )
264+ output_file = Path (f"{ timestamp } .html" )
265+ logger . info (f"未指定保存路径,默认保存在当前文件夹下的 { output_file } 中。 " )
192266
193267 node_indices = _get_node_indices (connectome , show_all_nodes )
194268 vertices_L , faces_L = _load_surface (lh_surfgii_file )
195269 vertices_R , faces_R = _load_surface (rh_surfgii_file )
196270
197271 mesh_L = _create_mesh (vertices_L , faces_L , "Left Hemisphere" )
198272 mesh_R = _create_mesh (vertices_R , faces_R , "Right Hemisphere" )
273+
199274 fig = go .Figure (data = [mesh_L , mesh_R ])
200275
201276 centroids_real = _get_centroids_real (niigz_file )
202- _add_nodes_to_fig (fig , centroids_real , node_indices , nodes_name , nodes_size , nodes_color )
203- _add_edges_to_fig (fig , connectome , centroids_real , nodes_name , scale_method , line_width , line_color )
277+ _add_nodes_to_fig (
278+ fig , centroids_real , node_indices , nodes_name , nodes_size , nodes_color
279+ )
280+ _add_edges_to_fig (
281+ fig ,
282+ connectome ,
283+ centroids_real ,
284+ nodes_name ,
285+ scale_method ,
286+ line_width ,
287+ line_color ,
288+ )
204289 _finalize_figure (fig )
205290
206291 fig .write_html (output_file )
207292 return fig
208293
209294
210295def save_brain_connection_frames (
211- fig : go .Figure ,
212- output_dir : str ,
213- n_frames : int = 36
296+ fig : go .Figure , output_dir : str | Path , n_frames : int = 36
214297) -> None :
215298 """
216- 生成不同角度的静态图片帧,用于制作旋转大脑连接图的 GIF 或视频 。
299+ 生成不同角度的静态图片帧,可用于制作旋转大脑连接图的 GIF。
217300
218301 Args:
219302 fig (go.Figure): Plotly 的 Figure 对象,包含大脑表面和连接图。
220- output_dir (str): 图片保存的文件夹路径,会自动创建文件夹 。
303+ output_dir (str): 图片保存的文件夹路径,若文件夹不存在则自动创建 。
221304 n_frames (int, optional): 旋转帧的数量。默认 36,即每 10 度一帧。
222305 """
223- os . makedirs (output_dir , exist_ok = True )
306+ Path (output_dir ). mkdir ( parents = True , exist_ok = True )
224307 angles = np .linspace (0 , 360 , n_frames , endpoint = False )
225308 for i , angle in tqdm (enumerate (angles ), total = len (angles )):
226309 camera = dict (
@@ -230,12 +313,4 @@ def save_brain_connection_frames(
230313 )
231314 fig .update_layout (scene_camera = camera )
232315 pio .write_image (fig , f"{ output_dir } /frame_{ i :03d} .png" , width = 800 , height = 800 )
233- print (f"保存了 { n_frames } 张图片在 { output_dir } " )
234-
235-
236- def main ():
237- pass
238-
239-
240- if __name__ == "__main__" :
241- main ()
316+ logger .info (f"保存了 { n_frames } 张图片在 { output_dir } " )
0 commit comments