Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions examples/propagate_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import mritk


def main(c_target, SNR=25):
# Physical Parameters
r1 = 3.2 # Longitudinal relaxivity (s^-1 L mmol^-1)
T10 = 4.5 # Native T1 time of CSF (seconds)
# SNR = 25.0 # Signal-to-Noise Ratio
sigma = 1.0 / SNR # Theoretical max signals approach 1.0 (based on M0=1.0)
M0 = 1.0
# # Sequence Parameters
TR = 9.6 # Taken from Gonzo paper
TI = 2.65 # Taken from Gonzo paper
t_LL = np.linspace(0.115, 2.754, 14) # Look-Locker: 14 data points over 2.75s same as Gonzo

N = 5000

T1_target = mritk.concentration.T1_from_concentration_expr(c=c_target, t1_0=T10, r1=r1)
np.random.seed(42) # For reproducibility
c_true_array = np.full(N, c_target)
T1_true = mritk.concentration.T1_from_concentration_expr(c=c_true_array, t1_0=T10, r1=r1)

# Generate Noisy T1 Estimates
T1_est_LL = mritk.looklocker.T1_to_noisy_T1_looklocker(
T1_true,
t_LL=t_LL,
M0=M0,
sigma=sigma,
)
T1_est_LL /= 1000.0 # Convert ms to seconds for consistency
T1_range = np.linspace(0.1, 10.0, 5000)
S_SE_range, S_IR_range = mritk.mixed.T1_to_mixed_signals(T1_range, TR=TR, TI=TI)
ratio_range = S_IR_range / S_SE_range
T1_est_mixed = mritk.mixed.T1_to_noisy_T1_mixed(
T1_true,
TR=TR,
TI=TI,
f_grid=ratio_range,
t_grid=T1_range,
sigma=sigma,
)

T1_est_hybrid = mritk.hybrid.compute_hybrid_t1_array(
ll_data=T1_est_LL,
mixed_data=T1_est_mixed,
mask=None,
threshold=1.5,
)

T1_mixed = T1_est_mixed[~np.isnan(T1_est_mixed)]
T1_LL = T1_est_LL[~np.isnan(T1_est_LL)]
T1_hybrid = T1_est_hybrid[~np.isnan(T1_est_hybrid)]

c_mixed = mritk.concentration.concentration_from_T1_expr(T1_mixed, t1_0=T10, r1=r1)
c_LL = mritk.concentration.concentration_from_T1_expr(T1_LL, t1_0=T10, r1=r1)
c_hybrid = mritk.concentration.concentration_from_T1_expr(T1_hybrid, t1_0=T10, r1=r1)

# Create 3 subplots sharing the x-axis
fig, axes = plt.subplots(3, 2, figsize=(10, 10)) # , sharex="col")

# 1. Mixed Sequence Subplot
sns.histplot(T1_mixed, bins=60, stat="density", color="orange", alpha=0.5, ax=axes[0, 0], label="Mixed Distribution")
sns.kdeplot(T1_mixed, color="darkorange", linestyle="-", ax=axes[0, 0])
axes[0, 0].axvline(T1_target, color="red", linestyle="solid", linewidth=2, label=f"True T1 = {T1_target:.2f} s")
axes[0, 0].set_title("Mixed Sequence T1 Estimation")
axes[0, 0].legend()

sns.histplot(c_mixed, bins=60, stat="density", color="orange", alpha=0.5, ax=axes[0, 1], label="Mixed Distribution")
sns.kdeplot(c_mixed, color="darkorange", linestyle="-", ax=axes[0, 1])
axes[0, 1].axvline(c_target, color="red", linestyle="solid", linewidth=2, label=f"True c = {c_target}")
axes[0, 1].set_title("Mixed Sequence Concentration Estimation")
axes[0, 1].legend()

# 2. Look-Locker Sequence Subplot
sns.histplot(T1_LL, bins=60, stat="density", color="blue", alpha=0.5, ax=axes[1, 0], label="Look-Locker Distribution")
sns.kdeplot(T1_LL, color="darkblue", linestyle="-", ax=axes[1, 0])
axes[1, 0].axvline(T1_target, color="red", linestyle="solid", linewidth=2, label=f"True T1 = {T1_target:.2f} s")
axes[1, 0].set_title("Look-Locker Sequence T1 Estimation")
axes[1, 0].legend()

sns.histplot(c_LL, bins=60, stat="density", color="blue", alpha=0.5, ax=axes[1, 1], label="Look-Locker Distribution")
sns.kdeplot(c_LL, color="darkblue", linestyle="-", ax=axes[1, 1])
axes[1, 1].axvline(c_target, color="red", linestyle="solid", linewidth=2, label=f"True c = {c_target}")
axes[1, 1].set_title("Look-Locker Sequence Concentration Estimation")
axes[1, 1].legend()

# 3. Hybrid Logic Subplot
sns.histplot(T1_hybrid, bins=60, stat="density", color="purple", alpha=0.5, ax=axes[2, 0], label="Hybrid Distribution")
sns.kdeplot(T1_hybrid, color="indigo", linestyle="-", ax=axes[2, 0])
axes[2, 0].axvline(T1_target, color="red", linestyle="solid", linewidth=2, label=f"True T1 = {T1_target:.2f} s")
axes[2, 0].set_title("Final Hybrid Pipeline T1 Estimation")
axes[2, 0].set_xlabel("T1 Relaxation Time (seconds)")
axes[2, 0].legend()

sns.histplot(c_hybrid, bins=60, stat="density", color="purple", alpha=0.5, ax=axes[2, 1], label="Hybrid Distribution")
sns.kdeplot(c_hybrid, color="indigo", linestyle="-", ax=axes[2, 1])
axes[2, 1].axvline(c_target, color="red", linestyle="solid", linewidth=2, label=f"True c = {c_target}")
axes[2, 1].set_title("Final Hybrid Pipeline Concentration Estimation")
axes[2, 1].set_xlabel("Concentration (mmol/L)")
axes[2, 1].legend()

for ax in axes.flatten():
ax.set_ylabel("Density")

fig.tight_layout()
fig.savefig(f"hybrid_pipeline_histograms_c{c_target}_{SNR}_{TR}_{TI}.png", dpi=300)


if __name__ == "__main__":
# for c_target in [0.0, 0.05, 0.1]:
# for SNR in [7.0, 25.0]:
# print(f"Running main() with c_target={c_target}, SNR={SNR}")
# main(c_target=c_target, SNR=SNR)
main(c_target=0.05, SNR=25)
2 changes: 2 additions & 0 deletions src/mritk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
statistics,
utils,
)
from .data import MRIData

meta = metadata("mritk")
__version__ = meta["Version"]
Expand All @@ -35,4 +36,5 @@
"hybrid",
"r1",
"statistics",
"MRIData",
]
43 changes: 41 additions & 2 deletions src/mritk/concentration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import logging
import typing
from collections.abc import Callable
from pathlib import Path

Expand All @@ -16,8 +17,10 @@

logger = logging.getLogger(__name__)

T = typing.TypeVar("T", np.ndarray, float)

def concentration_from_T1_expr(t1: np.ndarray, t1_0: np.ndarray, r1: float) -> np.ndarray:

def concentration_from_T1_expr(t1: T, t1_0: T, r1: float) -> T:
"""
Computes tracer concentration from T1 relaxation times.

Expand All @@ -34,7 +37,7 @@ def concentration_from_T1_expr(t1: np.ndarray, t1_0: np.ndarray, r1: float) -> n
return (1.0 / r1) * ((1.0 / t1) - (1.0 / t1_0))


def concentration_from_R1_expr(r1_map: np.ndarray, r1_0_map: np.ndarray, r1: float) -> np.ndarray:
def concentration_from_R1_expr(r1_map: T, r1_0_map: T, r1: float) -> T:
"""
Computes tracer concentration from R1 relaxation rates.

Expand All @@ -51,6 +54,42 @@ def concentration_from_R1_expr(r1_map: np.ndarray, r1_0_map: np.ndarray, r1: flo
return (1.0 / r1) * (r1_map - r1_0_map)


def R1_from_concentration_expr(c: T, r1_0: T, r1: float) -> T:
"""
Computes post-contrast R1 relaxation rates from tracer concentration.

Formula: R1 = C * r1 + R1_0

Args:
c (np.ndarray): Array of tracer concentrations.
r1_0 (np.ndarray): Array of pre-contrast (baseline) R1 relaxation rates.
r1 (float): Relaxivity of the contrast agent.

Returns:
np.ndarray: Computed post-contrast R1 array.
"""
return c * r1 + r1_0


def T1_from_concentration_expr(c: T, t1_0: T, r1: float) -> T:
"""
Computes post-contrast T1 relaxation times from tracer concentration.

Formula: T1 = 1 / (C * r1 + (1 / T1_0))

Args:
c (np.ndarray | float): Array of tracer concentrations.
t1_0 (np.ndarray | float): Array of pre-contrast (baseline) T1 relaxation times.
r1 (float): Relaxivity of the contrast agent.

Returns:
np.ndarray | float: Computed post-contrast T1 array.
"""
# Note: In a robust pipeline, you might want to mask or handle cases
# where t1_0 is 0 to avoid division by zero errors.
return 1.0 / (c * r1 + (1.0 / t1_0))


def compute_concentration_from_T1_array(
t1_data: np.ndarray, t10_data: np.ndarray, r1: float, mask: np.ndarray | None = None
) -> np.ndarray:
Expand Down
38 changes: 37 additions & 1 deletion src/mritk/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from pathlib import Path
from typing import Optional
from typing import Callable, Optional

import nibabel
import numpy as np
Expand Down Expand Up @@ -178,6 +178,42 @@ def find_nearest_valid_voxels(query_indices: np.ndarray, mask: np.ndarray, k: in
return dof_neighbours


def aggregate_nearest_valid_voxels(
data: np.ndarray, query_indices: np.ndarray, mask: np.ndarray | None = None, k: int = 10, agg_func: Callable = np.nanmedian
) -> np.ndarray:
"""Finds the `k` nearest valid voxels for each query index and aggregates their values.

Args:
data: (X, Y, Z) array of data values to sample from.
query_indices: (N, 3) array of voxel indices to find neighbors for.
mask: (X, Y, Z) boolean mask array indicating which voxels are valid neighbors.
k: Number of nearest neighbors to find.
agg_func: Aggregation function to apply over the `k` neighbors (default: np.median).
Must accept an `axis` keyword argument.

Returns:
(N,) array of aggregated values for each query point.
"""

if mask is None:
# If no mask is provided, consider all non-nan voxels as valid.
mask = np.isfinite(data)
# 1. Get the spatial indices of the nearest valid voxels.
# Returned shape is (3, k, N) where 3 corresponds to the x, y, z dimensions.
nearest_coords = find_nearest_valid_voxels(query_indices, mask, k)

# 2. Extract the data values using advanced numpy indexing.
# nearest_coords[0], nearest_coords[1], and nearest_coords[2] each have shape (k, N).
# The resulting extracted_values array will also have shape (k, N).
extracted_values = data[nearest_coords[0], nearest_coords[1], nearest_coords[2]]

# 3. Aggregate the extracted values across the neighbor dimension (axis=0).
# The output will be collapsed to shape (N,).
aggregated_values = agg_func(extracted_values, axis=0)

return aggregated_values


def apply_affine(T: np.ndarray, X: np.ndarray) -> np.ndarray:
"""Apply a homogeneous affine transformation matrix to a set of points.

Expand Down
2 changes: 1 addition & 1 deletion src/mritk/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_datasets() -> dict[str, Dataset]:
name="Test Data",
description="A small test dataset for testing functionality (based on the Gonzo dataset).",
license="CC-BY-4.0",
links={"mritk-test-data.zip": download_link_google_drive("1YVXoV1UhmpkMIeaNKeS9eqCsdMULwKBO")},
links={"mritk-test-data.zip": download_link_google_drive("1IYbomfJ38REUstbCdiqc3W39RaHrZfje")},
),
"gonzo": Dataset(
name="The Gonzo Dataset",
Expand Down
12 changes: 9 additions & 3 deletions src/mritk/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
logger = logging.getLogger(__name__)


def compute_hybrid_t1_array(ll_data: np.ndarray, mixed_data: np.ndarray, mask: np.ndarray, threshold: float) -> np.ndarray:
def compute_hybrid_t1_array(
ll_data: np.ndarray, mixed_data: np.ndarray, mask: np.ndarray | None = None, threshold: float = 1500
) -> np.ndarray:
"""
Creates a hybrid T1 array by selectively substituting Look-Locker voxels with Mixed voxels.

Expand All @@ -27,15 +29,19 @@ def compute_hybrid_t1_array(ll_data: np.ndarray, mixed_data: np.ndarray, mask: n
ll_data (np.ndarray): 3D numpy array of Look-Locker T1 values.
mixed_data (np.ndarray): 3D numpy array of Mixed T1 values.
mask (np.ndarray): 3D boolean mask (typically eroded CSF).
threshold (float): T1 threshold value (in ms).
threshold (float): T1 threshold value (in ms), default is 1500 ms.

Returns:
np.ndarray: Hybrid 3D T1 array.
"""
logger.debug("Computing hybrid T1 array with threshold %.2f ms.", threshold)
hybrid = ll_data.copy()
newmask = mask & (ll_data > threshold) & (mixed_data > threshold)
newmask = (ll_data > threshold) & (mixed_data > threshold)
if mask is not None:
newmask &= mask
logger.debug("Substituting %d voxels based on threshold and mask criteria.", np.sum(newmask))
hybrid[newmask] = mixed_data[newmask]

return hybrid


Expand Down
Loading
Loading