@@ -98,6 +98,8 @@ class Segmentation:
9898 mri (MRIData): The MRIData object containing the segmentation volume and affine.
9999 lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels
100100 to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
101+ Assumes that entries are indexed by the "label" column. If there is no "label" column
102+ the current index is renamed to "label"
101103 """
102104
103105 mri : MRIData
@@ -111,16 +113,27 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None):
111113 self .rois = np .unique (self .mri .data [self .mri .data > 0 ])
112114
113115 if lut is None :
114- self .lut = pd .DataFrame ({"Label" : self .rois }, index = self .rois )
115- else :
116- self .lut = lut
116+ lut = pd .DataFrame (
117+ {
118+ "label" : self .rois .astype (int ),
119+ "description" : self .rois .astype (int ).astype (str ),
120+ }
121+ ).set_index ("label" )
122+
123+ self .set_lut (lut , label_column = "label" if "label" in lut .columns else None )
124+ self ._preprocess_lut ()
117125
118- # Identify the primary label column dynamically
119- self .label_name = "Label" if "Label" in self .lut .columns else self .lut .columns [0 ]
126+ def _preprocess_lut (self ) -> pd .DataFrame :
127+ # dummy function for subclasses to override if they need to preprocess the LUT after loading
128+ pass
120129
121130 @classmethod
122131 def from_file (
123- cls , seg_path : Path , dtype : npt .DTypeLike | None = None , orient : bool = True , lut_path : Path | None = None
132+ cls ,
133+ seg_path : Path ,
134+ dtype : npt .DTypeLike | None = None ,
135+ orient : bool = True ,
136+ lut_path : Path | None = None ,
124137 ) -> "Segmentation" :
125138 """Loads a Segmentation from a NIfTI file.
126139
@@ -136,19 +149,29 @@ def from_file(
136149 logger .info (f"Loading segmentation from { seg_path } ." )
137150 mri = MRIData .from_file (seg_path , dtype = dtype , orient = orient )
138151
139- if lut_path is None and seg_path .with_suffix (".json" ).exists ():
140- lut_path = seg_path .with_suffix (".json" )
152+ if lut_path is None :
153+ if seg_path .with_suffix (".csv" ).exists ():
154+ lut_path = seg_path .with_suffix (".csv" )
155+ lut = pd .read_csv (lut_path )
156+ elif seg_path .with_suffix (".json" ).exists ():
157+ lut_path = seg_path .with_suffix (".json" )
158+ lut = pd .read_json (lut_path )
141159
142160 if lut_path is not None :
143161 logger .info (f"Loading LUT from { lut_path } ." )
144- lut = pd .read_json (lut_path )
145162 else :
146- rois = np .unique (mri .data [mri .data > 0 ])
147- lut = pd .DataFrame ({"Label" : rois }, index = rois )
163+ lut = None
148164
149165 return cls (mri = mri , lut = lut )
150166
151- def save (self , output_path : Path , dtype : npt .DTypeLike | None = None , intent_code : int = 1006 , lut_path : Path | None = None ):
167+ def save (
168+ self ,
169+ output_path : Path ,
170+ dtype : npt .DTypeLike | None = None ,
171+ intent_code : int = 1006 ,
172+ lut_path : Path | None = None ,
173+ lut_suffix = ".csv" ,
174+ ):
152175 """Saves the Segmentation to a NIfTI file.
153176
154177 Args:
@@ -157,25 +180,35 @@ def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_cod
157180 intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL).
158181 """
159182 self .mri .save (output_path , dtype = dtype , intent_code = intent_code )
183+
160184 if lut_path is not None :
161- self . lut . to_json (lut_path , orient = "index" )
185+ write_lut (lut_path , self . lut )
162186 else :
163- self .lut .to_json (output_path .with_suffix (".json" ), orient = "index" )
187+ filename = output_path .name .removesuffix ("" .join (output_path .suffixes ))
188+ write_lut (output_path .parent .joinpath (filename ).with_suffix (lut_suffix ), self .lut )
164189
165- def set_lut (self , lut : pd .DataFrame , label_column : str = "Label" ):
190+ def set_lut (self , lut : pd .DataFrame , label_column : str | None = None ):
166191 """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs.
167192
168193 Args:
169194 lut (pd.DataFrame): A pandas DataFrame mapping numerical labels
170195 to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
171196 label_column (str, optional): The name of the column in the LUT that contains the label
172- descriptions. Defaults to "Label".
197+ descriptions which will be used as the index. If None, use the current index. Defaults to None.
198+ If the index is not already named, it is renamed to "label".
173199 """
174200
175201 self .lut = lut
176- self .label_name = label_column
177- if self .label_name not in self .lut .columns :
178- raise ValueError (f"Specified label column '{ self .label_name } ' not found in LUT." )
202+
203+ if label_column is not None :
204+ self .lut = lut .set_index (label_column )
205+ self .label_name = label_column
206+ else :
207+ if lut .index .name is not None : # If lut index already is named, use it
208+ self .label_name = lut .index .name
209+ else : # Use label as default name for axis
210+ self .label_name = "label"
211+ self .lut = lut .rename_axis (self .label_name )
179212
180213 @property
181214 def num_rois (self ) -> int :
@@ -209,8 +242,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr
209242
210243 if not np .isin (rois , self .rois ).all ():
211244 raise ValueError ("Some of the provided ROIs are not present in the segmentation." )
212-
213- return self .lut .loc [self .lut .index .isin (rois ), [self .label_name ]].rename_axis ("ROI" ).reset_index ()
245+ return self .lut .loc [rois .astype (self .lut .index .dtype )]
214246
215247 def resample_to_reference (self , reference_mri : MRIData ) -> "Segmentation" :
216248 """
@@ -292,7 +324,11 @@ class FreeSurferSegmentation(Segmentation):
292324
293325 @classmethod
294326 def from_file (
295- cls , filepath : Path | str , dtype : npt .DTypeLike | None = None , orient : bool = True , lut_path : Path | None = None
327+ cls ,
328+ filepath : Path | str ,
329+ dtype : npt .DTypeLike | None = None ,
330+ orient : bool = True ,
331+ lut_path : Path | None = None ,
296332 ) -> "FreeSurferSegmentation" :
297333 """
298334 Load a FreeSurfer segmentation from a NIfTI file, automatically resolving the LUT.
@@ -309,13 +345,13 @@ def from_file(
309345 """
310346 resolved_lut_path = resolve_freesurfer_lut_path (lut_path )
311347 lut = read_freesurfer_lut (resolved_lut_path )
312-
313- # FreeSurfer LUTs index by the "label" column
314- lut = lut .set_index ("label" ) if "label" in lut .columns else lut
315-
316348 mri = MRIData .from_file (filepath , dtype = dtype , orient = orient )
317349 return cls (mri = mri , lut = lut )
318350
351+ def _preprocess_lut (self ) -> pd .DataFrame :
352+ # FreeSurfer LUTs index by the "label" column
353+ self .lut = self .lut .query ("label < 10000" ) # Most used FreeSurfer labels
354+
319355
320356class ExtendedFreeSurferSegmentation (FreeSurferSegmentation ):
321357 """
@@ -326,6 +362,22 @@ class ExtendedFreeSurferSegmentation(FreeSurferSegmentation):
326362 the base FreeSurfer anatomical label (modulus 10000).
327363 """
328364
365+ def _preprocess_lut (self ) -> pd .DataFrame :
366+ super ()._preprocess_lut ()
367+
368+ # Add CSF and dura tags
369+ base_lut = self .lut .copy ()
370+ for i , tissue_type in enumerate (["CSF" , "Dura" ]):
371+ tissue_lut = base_lut .copy ()
372+ tissue_lut .index += 10000 if tissue_type == "CSF" else 20000
373+ tissue_lut ["description" ] = tissue_lut ["description" ] + f"-{ tissue_type } "
374+ if np .all (np .isin (["R" , "G" , "B" ], base_lut .columns )):
375+ for col in ["R" , "G" , "B" ]:
376+ tissue_lut [col ] = np .clip (
377+ tissue_lut [col ] * (0.5 + 0.5 * i ), 0 , 1
378+ ) # Shift colors towards blue for CSF and red for Dura
379+ self .lut = pd .concat ([self .lut , tissue_lut ])
380+
329381 def get_roi_labels (self , rois : npt .NDArray [np .int32 ] | None = None ) -> pd .DataFrame :
330382 """
331383 Retrieves descriptive mappings including the augmented tissue type classifications.
@@ -338,21 +390,12 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr
338390 pd.DataFrame: A DataFrame mapping the requested ROIs to their base descriptions
339391 and their computed 'tissue_type'.
340392 """
341- rois = self .rois if rois is None else rois
342393
343- # Use modulus 10000 to extract the base anatomical label from the superclass LUT
344- freesurfer_labels = super ().get_roi_labels (rois % 10000 ).rename (columns = {"ROI" : "FreeSurfer_ROI" })
394+ roi_labels = super ().get_roi_labels (rois )
345395
346- # Get the broad tissue categories based on the numerical offsets
396+ # Add column specifying tissue_type:
347397 tissue_type = self .get_tissue_type (rois )
348-
349- # Merge the base anatomical names with the tissue types
350- return freesurfer_labels .merge (
351- tissue_type ,
352- left_on = "FreeSurfer_ROI" ,
353- right_on = "FreeSurfer_ROI" ,
354- how = "outer" ,
355- ).drop (columns = ["FreeSurfer_ROI" ])[["ROI" , self .label_name , "tissue_type" ]]
398+ return pd .merge (roi_labels , tissue_type , on = "label" )
356399
357400 def get_tissue_type (self , rois : npt .NDArray [np .int32 ] | None = None ) -> pd .DataFrame :
358401 """
@@ -372,15 +415,14 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF
372415 """
373416 rois = self .rois if rois is None else rois
374417
375- tissue_types = pd .Series (
376- data = np .where (rois < 10000 , "Parenchyma" , np .where (rois < 20000 , "CSF" , "Dura" )),
377- index = rois ,
378- name = "tissue_type" ,
379- )
418+ tissue_types = pd .DataFrame (
419+ {
420+ self .label_name : rois ,
421+ "tissue_type" : np .where (rois < 10000 , "Parenchyma" , np .where (rois < 20000 , "CSF" , "Dura" )),
422+ }
423+ ).set_index (self .label_name )
380424
381- ret = pd .DataFrame (tissue_types , columns = ["tissue_type" ]).rename_axis ("ROI" ).reset_index ()
382- ret ["FreeSurfer_ROI" ] = ret ["ROI" ] % 10000
383- return ret
425+ return tissue_types
384426
385427
386428@dataclass
@@ -574,15 +616,63 @@ def write_lut(filename: Path, table: pd.DataFrame):
574616 """
575617 newtable = table .copy ()
576618
577- # Re-scale RGB values to [0, 255]
578- for col in ["R" , "G" , "B" ]:
579- newtable [col ] = (newtable [col ] * 255 ).astype (int )
619+ if np .all (np .isin (["R" , "G" , "B" ], table .columns )):
620+ # Re-scale RGB values to [0, 255]
621+ for col in ["R" , "G" , "B" ]:
622+ newtable [col ] = (newtable [col ] * 255 ).astype (int )
580623
581- # Reverse Alpha inversion and scale to [0, 255]
582- newtable ["A" ] = 255 - (newtable ["A" ] * 255 ).astype (int )
624+ # Reverse Alpha inversion and scale to [0, 255]
625+ newtable ["A" ] = 255 - (newtable ["A" ] * 255 ).astype (int )
583626
584627 # Save as tab-separated values without headers or indices
585- newtable .to_csv (filename , sep = "\t " , index = False , header = False )
628+ if filename .suffix == ".csv" :
629+ newtable .to_csv (filename , sep = "\t " , index = True , header = False )
630+ elif filename .suffix == ".json" :
631+ newtable .to_json (filename , index = True , header = False )
632+ else :
633+ newtable .to_csv (filename , sep = "\t " , index = True , header = False )
634+
635+
636+ def procedural_freesurfer_lut (labels : list , descriptions : list , cmap : str | None = None ) -> pd .DataFrame :
637+ """
638+ Generates a FreeSurfer compatible lut with colors for each label in a procedural manner
639+
640+ Args:
641+ labels (list): list of labels to include in the lut
642+ descriptions (list): list of descriptions associated to each label
643+ cmap (str, optional): Colormap for label regions. Defaults to "hsv".
644+
645+ Returns:
646+ pd.DataFrame: DataFrame indexed by the label, with RGBA columns
647+ """
648+ N = len (labels )
649+ if not N == len (descriptions ):
650+ raise ValueError ("Label and descriptions lists must have same length" )
651+
652+ if cmap is not None : # If a colormap is specified, use cmap from matplotlib
653+ import matplotlib .pyplot as plt
654+
655+ # Get evenly spaced values between 0 and 1 based on the number of labels
656+ color_indices = np .linspace (0 , 0.95 , N )
657+ # Sample a colormap
658+ rgb_float = plt .get_cmap (cmap )(color_indices )
659+ else :
660+ rgb_float = []
661+ import colorsys
662+
663+ for i in range (N ):
664+ h = i / N
665+ rgb = list (colorsys .hsv_to_rgb (h , 1.0 , 1.0 ))
666+ rgb .append (1.0 ) # Add transparency
667+ rgb_float .append (rgb )
668+ rgb_float = np .array (rgb_float )
669+
670+ # Create the DataFrame
671+ df_colors = pd .DataFrame (rgb_float , columns = ["R" , "G" , "B" , "A" ], index = labels )
672+ df_colors .index .name = "label"
673+ df_colors ["description" ] = descriptions
674+ lut = df_colors [["description" , "R" , "G" , "B" , "A" ]]
675+ return lut
586676
587677
588678def add_arguments (
@@ -592,7 +682,9 @@ def add_arguments(
592682 subparser = parser .add_subparsers (dest = "seg-command" , help = "Commands for segmentation processing" )
593683
594684 resample_parser = subparser .add_parser (
595- "resample" , help = "Resample a segmentation to match the space of a reference MRI" , formatter_class = parser .formatter_class
685+ "resample" ,
686+ help = "Resample a segmentation to match the space of a reference MRI" ,
687+ formatter_class = parser .formatter_class ,
596688 )
597689 resample_parser .add_argument ("-i" , "--input" , type = Path , help = "Path to the input segmentation NIfTI file" )
598690 resample_parser .add_argument (
@@ -602,19 +694,43 @@ def add_arguments(
602694 help = "Path to the reference MRI \
603695 - usually a registered T1 weighted anatomical scan" ,
604696 )
605- resample_parser .add_argument ("-o" , "--output" , type = Path , help = "Desired output path for the resampled segmentation" )
697+ resample_parser .add_argument (
698+ "-o" ,
699+ "--output" ,
700+ type = Path ,
701+ help = "Desired output path for the resampled segmentation" ,
702+ )
606703
607704 smooth_parser = subparser .add_parser (
608705 "smooth" ,
609706 help = "Apply Gaussian smoothing to a segmentation to create a soft probabilistic map" ,
610707 formatter_class = parser .formatter_class ,
611708 )
612- smooth_parser .add_argument ("-i" , "--input" , type = Path , help = "Path to the input (refined) segmentation NIfTI file" )
613- smooth_parser .add_argument ("-s" , "--sigma" , type = float , help = "Standard deviation for the Gaussian kernel used in smoothing" )
614709 smooth_parser .add_argument (
615- "-c" , "--cutoff" , type = float , default = 0.5 , help = "Cutoff score to remove low-confidence voxels (default: 0.5)"
710+ "-i" ,
711+ "--input" ,
712+ type = Path ,
713+ help = "Path to the input (refined) segmentation NIfTI file" ,
714+ )
715+ smooth_parser .add_argument (
716+ "-s" ,
717+ "--sigma" ,
718+ type = float ,
719+ help = "Standard deviation for the Gaussian kernel used in smoothing" ,
720+ )
721+ smooth_parser .add_argument (
722+ "-c" ,
723+ "--cutoff" ,
724+ type = float ,
725+ default = 0.5 ,
726+ help = "Cutoff score to remove low-confidence voxels (default: 0.5)" ,
727+ )
728+ smooth_parser .add_argument (
729+ "-o" ,
730+ "--output" ,
731+ type = Path ,
732+ help = "Desired output path for the smoothed segmentation" ,
616733 )
617- smooth_parser .add_argument ("-o" , "--output" , type = Path , help = "Desired output path for the smoothed segmentation" )
618734
619735 refine_parser = subparser .add_parser (
620736 "refine" ,
@@ -629,8 +745,18 @@ def add_arguments(
629745 help = "Path to the reference MRI \
630746 - usually a registered T1 weighted anatomical scan" ,
631747 )
632- refine_parser .add_argument ("-s" , "--smooth" , type = float , help = "Standard deviation for the Gaussian kernel used in smoothing" )
633- refine_parser .add_argument ("-o" , "--output" , type = Path , help = "Desired output path for the refined segmentation" )
748+ refine_parser .add_argument (
749+ "-s" ,
750+ "--smooth" ,
751+ type = float ,
752+ help = "Standard deviation for the Gaussian kernel used in smoothing" ,
753+ )
754+ refine_parser .add_argument (
755+ "-o" ,
756+ "--output" ,
757+ type = Path ,
758+ help = "Desired output path for the refined segmentation" ,
759+ )
634760
635761 if extra_args_cb is not None :
636762 extra_args_cb (resample_parser )
0 commit comments