Skip to content

Commit 7873271

Browse files
authored
Merge pull request #51 from scientificcomputing/segmentation-lut
Improve handling of segmentation LUTs
2 parents 491f1a0 + 594a94c commit 7873271

4 files changed

Lines changed: 248 additions & 102 deletions

File tree

src/mritk/segmentation.py

Lines changed: 186 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class Segmentation:
9898
mri (MRIData): The MRIData object containing the segmentation volume and affine.
9999
lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels
100100
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
101+
Assumes that entries are indexed by the "label" column. If there is no "label" column
102+
the current index is renamed to "label"
101103
"""
102104

103105
mri: MRIData
@@ -111,16 +113,27 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None):
111113
self.rois = np.unique(self.mri.data[self.mri.data > 0])
112114

113115
if lut is None:
114-
self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois)
115-
else:
116-
self.lut = lut
116+
lut = pd.DataFrame(
117+
{
118+
"label": self.rois.astype(int),
119+
"description": self.rois.astype(int).astype(str),
120+
}
121+
).set_index("label")
122+
123+
self.set_lut(lut, label_column="label" if "label" in lut.columns else None)
124+
self._preprocess_lut()
117125

118-
# Identify the primary label column dynamically
119-
self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0]
126+
def _preprocess_lut(self) -> pd.DataFrame:
127+
# dummy function for subclasses to override if they need to preprocess the LUT after loading
128+
pass
120129

121130
@classmethod
122131
def from_file(
123-
cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None
132+
cls,
133+
seg_path: Path,
134+
dtype: npt.DTypeLike | None = None,
135+
orient: bool = True,
136+
lut_path: Path | None = None,
124137
) -> "Segmentation":
125138
"""Loads a Segmentation from a NIfTI file.
126139
@@ -136,19 +149,29 @@ def from_file(
136149
logger.info(f"Loading segmentation from {seg_path}.")
137150
mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient)
138151

139-
if lut_path is None and seg_path.with_suffix(".json").exists():
140-
lut_path = seg_path.with_suffix(".json")
152+
if lut_path is None:
153+
if seg_path.with_suffix(".csv").exists():
154+
lut_path = seg_path.with_suffix(".csv")
155+
lut = pd.read_csv(lut_path)
156+
elif seg_path.with_suffix(".json").exists():
157+
lut_path = seg_path.with_suffix(".json")
158+
lut = pd.read_json(lut_path)
141159

142160
if lut_path is not None:
143161
logger.info(f"Loading LUT from {lut_path}.")
144-
lut = pd.read_json(lut_path)
145162
else:
146-
rois = np.unique(mri.data[mri.data > 0])
147-
lut = pd.DataFrame({"Label": rois}, index=rois)
163+
lut = None
148164

149165
return cls(mri=mri, lut=lut)
150166

151-
def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None):
167+
def save(
168+
self,
169+
output_path: Path,
170+
dtype: npt.DTypeLike | None = None,
171+
intent_code: int = 1006,
172+
lut_path: Path | None = None,
173+
lut_suffix=".csv",
174+
):
152175
"""Saves the Segmentation to a NIfTI file.
153176
154177
Args:
@@ -157,25 +180,35 @@ def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_cod
157180
intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL).
158181
"""
159182
self.mri.save(output_path, dtype=dtype, intent_code=intent_code)
183+
160184
if lut_path is not None:
161-
self.lut.to_json(lut_path, orient="index")
185+
write_lut(lut_path, self.lut)
162186
else:
163-
self.lut.to_json(output_path.with_suffix(".json"), orient="index")
187+
filename = output_path.name.removesuffix("".join(output_path.suffixes))
188+
write_lut(output_path.parent.joinpath(filename).with_suffix(lut_suffix), self.lut)
164189

165-
def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"):
190+
def set_lut(self, lut: pd.DataFrame, label_column: str | None = None):
166191
"""Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs.
167192
168193
Args:
169194
lut (pd.DataFrame): A pandas DataFrame mapping numerical labels
170195
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
171196
label_column (str, optional): The name of the column in the LUT that contains the label
172-
descriptions. Defaults to "Label".
197+
descriptions which will be used as the index. If None, use the current index. Defaults to None.
198+
If the index is not already named, it is renamed to "label".
173199
"""
174200

175201
self.lut = lut
176-
self.label_name = label_column
177-
if self.label_name not in self.lut.columns:
178-
raise ValueError(f"Specified label column '{self.label_name}' not found in LUT.")
202+
203+
if label_column is not None:
204+
self.lut = lut.set_index(label_column)
205+
self.label_name = label_column
206+
else:
207+
if lut.index.name is not None: # If lut index already is named, use it
208+
self.label_name = lut.index.name
209+
else: # Use label as default name for axis
210+
self.label_name = "label"
211+
self.lut = lut.rename_axis(self.label_name)
179212

180213
@property
181214
def num_rois(self) -> int:
@@ -209,8 +242,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr
209242

210243
if not np.isin(rois, self.rois).all():
211244
raise ValueError("Some of the provided ROIs are not present in the segmentation.")
212-
213-
return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index()
245+
return self.lut.loc[rois.astype(self.lut.index.dtype)]
214246

215247
def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation":
216248
"""
@@ -292,7 +324,11 @@ class FreeSurferSegmentation(Segmentation):
292324

293325
@classmethod
294326
def from_file(
295-
cls, filepath: Path | str, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None
327+
cls,
328+
filepath: Path | str,
329+
dtype: npt.DTypeLike | None = None,
330+
orient: bool = True,
331+
lut_path: Path | None = None,
296332
) -> "FreeSurferSegmentation":
297333
"""
298334
Load a FreeSurfer segmentation from a NIfTI file, automatically resolving the LUT.
@@ -309,13 +345,13 @@ def from_file(
309345
"""
310346
resolved_lut_path = resolve_freesurfer_lut_path(lut_path)
311347
lut = read_freesurfer_lut(resolved_lut_path)
312-
313-
# FreeSurfer LUTs index by the "label" column
314-
lut = lut.set_index("label") if "label" in lut.columns else lut
315-
316348
mri = MRIData.from_file(filepath, dtype=dtype, orient=orient)
317349
return cls(mri=mri, lut=lut)
318350

351+
def _preprocess_lut(self) -> pd.DataFrame:
352+
# FreeSurfer LUTs index by the "label" column
353+
self.lut = self.lut.query("label < 10000") # Most used FreeSurfer labels
354+
319355

320356
class ExtendedFreeSurferSegmentation(FreeSurferSegmentation):
321357
"""
@@ -326,6 +362,22 @@ class ExtendedFreeSurferSegmentation(FreeSurferSegmentation):
326362
the base FreeSurfer anatomical label (modulus 10000).
327363
"""
328364

365+
def _preprocess_lut(self) -> pd.DataFrame:
366+
super()._preprocess_lut()
367+
368+
# Add CSF and dura tags
369+
base_lut = self.lut.copy()
370+
for i, tissue_type in enumerate(["CSF", "Dura"]):
371+
tissue_lut = base_lut.copy()
372+
tissue_lut.index += 10000 if tissue_type == "CSF" else 20000
373+
tissue_lut["description"] = tissue_lut["description"] + f"-{tissue_type}"
374+
if np.all(np.isin(["R", "G", "B"], base_lut.columns)):
375+
for col in ["R", "G", "B"]:
376+
tissue_lut[col] = np.clip(
377+
tissue_lut[col] * (0.5 + 0.5 * i), 0, 1
378+
) # Shift colors towards blue for CSF and red for Dura
379+
self.lut = pd.concat([self.lut, tissue_lut])
380+
329381
def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame:
330382
"""
331383
Retrieves descriptive mappings including the augmented tissue type classifications.
@@ -338,21 +390,12 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr
338390
pd.DataFrame: A DataFrame mapping the requested ROIs to their base descriptions
339391
and their computed 'tissue_type'.
340392
"""
341-
rois = self.rois if rois is None else rois
342393

343-
# Use modulus 10000 to extract the base anatomical label from the superclass LUT
344-
freesurfer_labels = super().get_roi_labels(rois % 10000).rename(columns={"ROI": "FreeSurfer_ROI"})
394+
roi_labels = super().get_roi_labels(rois)
345395

346-
# Get the broad tissue categories based on the numerical offsets
396+
# Add column specifying tissue_type:
347397
tissue_type = self.get_tissue_type(rois)
348-
349-
# Merge the base anatomical names with the tissue types
350-
return freesurfer_labels.merge(
351-
tissue_type,
352-
left_on="FreeSurfer_ROI",
353-
right_on="FreeSurfer_ROI",
354-
how="outer",
355-
).drop(columns=["FreeSurfer_ROI"])[["ROI", self.label_name, "tissue_type"]]
398+
return pd.merge(roi_labels, tissue_type, on="label")
356399

357400
def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame:
358401
"""
@@ -372,15 +415,14 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF
372415
"""
373416
rois = self.rois if rois is None else rois
374417

375-
tissue_types = pd.Series(
376-
data=np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")),
377-
index=rois,
378-
name="tissue_type",
379-
)
418+
tissue_types = pd.DataFrame(
419+
{
420+
self.label_name: rois,
421+
"tissue_type": np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")),
422+
}
423+
).set_index(self.label_name)
380424

381-
ret = pd.DataFrame(tissue_types, columns=["tissue_type"]).rename_axis("ROI").reset_index()
382-
ret["FreeSurfer_ROI"] = ret["ROI"] % 10000
383-
return ret
425+
return tissue_types
384426

385427

386428
@dataclass
@@ -574,15 +616,63 @@ def write_lut(filename: Path, table: pd.DataFrame):
574616
"""
575617
newtable = table.copy()
576618

577-
# Re-scale RGB values to [0, 255]
578-
for col in ["R", "G", "B"]:
579-
newtable[col] = (newtable[col] * 255).astype(int)
619+
if np.all(np.isin(["R", "G", "B"], table.columns)):
620+
# Re-scale RGB values to [0, 255]
621+
for col in ["R", "G", "B"]:
622+
newtable[col] = (newtable[col] * 255).astype(int)
580623

581-
# Reverse Alpha inversion and scale to [0, 255]
582-
newtable["A"] = 255 - (newtable["A"] * 255).astype(int)
624+
# Reverse Alpha inversion and scale to [0, 255]
625+
newtable["A"] = 255 - (newtable["A"] * 255).astype(int)
583626

584627
# Save as tab-separated values without headers or indices
585-
newtable.to_csv(filename, sep="\t", index=False, header=False)
628+
if filename.suffix == ".csv":
629+
newtable.to_csv(filename, sep="\t", index=True, header=False)
630+
elif filename.suffix == ".json":
631+
newtable.to_json(filename, index=True, header=False)
632+
else:
633+
newtable.to_csv(filename, sep="\t", index=True, header=False)
634+
635+
636+
def procedural_freesurfer_lut(labels: list, descriptions: list, cmap: str | None = None) -> pd.DataFrame:
637+
"""
638+
Generates a FreeSurfer compatible lut with colors for each label in a procedural manner
639+
640+
Args:
641+
labels (list): list of labels to include in the lut
642+
descriptions (list): list of descriptions associated to each label
643+
cmap (str, optional): Colormap for label regions. Defaults to "hsv".
644+
645+
Returns:
646+
pd.DataFrame: DataFrame indexed by the label, with RGBA columns
647+
"""
648+
N = len(labels)
649+
if not N == len(descriptions):
650+
raise ValueError("Label and descriptions lists must have same length")
651+
652+
if cmap is not None: # If a colormap is specified, use cmap from matplotlib
653+
import matplotlib.pyplot as plt
654+
655+
# Get evenly spaced values between 0 and 1 based on the number of labels
656+
color_indices = np.linspace(0, 0.95, N)
657+
# Sample a colormap
658+
rgb_float = plt.get_cmap(cmap)(color_indices)
659+
else:
660+
rgb_float = []
661+
import colorsys
662+
663+
for i in range(N):
664+
h = i / N
665+
rgb = list(colorsys.hsv_to_rgb(h, 1.0, 1.0))
666+
rgb.append(1.0) # Add transparency
667+
rgb_float.append(rgb)
668+
rgb_float = np.array(rgb_float)
669+
670+
# Create the DataFrame
671+
df_colors = pd.DataFrame(rgb_float, columns=["R", "G", "B", "A"], index=labels)
672+
df_colors.index.name = "label"
673+
df_colors["description"] = descriptions
674+
lut = df_colors[["description", "R", "G", "B", "A"]]
675+
return lut
586676

587677

588678
def add_arguments(
@@ -592,7 +682,9 @@ def add_arguments(
592682
subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing")
593683

594684
resample_parser = subparser.add_parser(
595-
"resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class
685+
"resample",
686+
help="Resample a segmentation to match the space of a reference MRI",
687+
formatter_class=parser.formatter_class,
596688
)
597689
resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file")
598690
resample_parser.add_argument(
@@ -602,19 +694,43 @@ def add_arguments(
602694
help="Path to the reference MRI \
603695
- usually a registered T1 weighted anatomical scan",
604696
)
605-
resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation")
697+
resample_parser.add_argument(
698+
"-o",
699+
"--output",
700+
type=Path,
701+
help="Desired output path for the resampled segmentation",
702+
)
606703

607704
smooth_parser = subparser.add_parser(
608705
"smooth",
609706
help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map",
610707
formatter_class=parser.formatter_class,
611708
)
612-
smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file")
613-
smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing")
614709
smooth_parser.add_argument(
615-
"-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)"
710+
"-i",
711+
"--input",
712+
type=Path,
713+
help="Path to the input (refined) segmentation NIfTI file",
714+
)
715+
smooth_parser.add_argument(
716+
"-s",
717+
"--sigma",
718+
type=float,
719+
help="Standard deviation for the Gaussian kernel used in smoothing",
720+
)
721+
smooth_parser.add_argument(
722+
"-c",
723+
"--cutoff",
724+
type=float,
725+
default=0.5,
726+
help="Cutoff score to remove low-confidence voxels (default: 0.5)",
727+
)
728+
smooth_parser.add_argument(
729+
"-o",
730+
"--output",
731+
type=Path,
732+
help="Desired output path for the smoothed segmentation",
616733
)
617-
smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation")
618734

619735
refine_parser = subparser.add_parser(
620736
"refine",
@@ -629,8 +745,18 @@ def add_arguments(
629745
help="Path to the reference MRI \
630746
- usually a registered T1 weighted anatomical scan",
631747
)
632-
refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing")
633-
refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation")
748+
refine_parser.add_argument(
749+
"-s",
750+
"--smooth",
751+
type=float,
752+
help="Standard deviation for the Gaussian kernel used in smoothing",
753+
)
754+
refine_parser.add_argument(
755+
"-o",
756+
"--output",
757+
type=Path,
758+
help="Desired output path for the refined segmentation",
759+
)
634760

635761
if extra_args_cb is not None:
636762
extra_args_cb(resample_parser)

0 commit comments

Comments
 (0)