From 3c38a6b19a19c32f398826e2811a125143da5910 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Tue, 26 May 2026 14:56:55 +0200 Subject: [PATCH 01/17] add TOF variance reduction example --- .../03_algorithms/05_tof_vs_nontof.py | 269 ++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 docs/examples/03_algorithms/05_tof_vs_nontof.py diff --git a/docs/examples/03_algorithms/05_tof_vs_nontof.py b/docs/examples/03_algorithms/05_tof_vs_nontof.py new file mode 100644 index 00000000..376c74f0 --- /dev/null +++ b/docs/examples/03_algorithms/05_tof_vs_nontof.py @@ -0,0 +1,269 @@ +""" +TOF vs NON-TOF RECONSTRUCTIONS in a 2D uniform cylinder +======================================================= + +This example compares variance reduction due to the presence of TOF information. +""" + +# %% +from __future__ import annotations +import matplotlib.pyplot as plt +import numpy as np + +import parallelproj.operators +import parallelproj.tof +import parallelproj.pet_scanners +import parallelproj.pet_lors +import parallelproj.projectors +from parallelproj import to_numpy_array, Array +from parallelproj.functions import NegPoissonLogL, C2AffineObjective, C1Function + +from copy import copy + +# %% +from array_utils import suggest_array_backend_and_device + +# To use a specific backend and/or device, replace the None arguments, e.g.: +# xp, dev = suggest_array_backend_and_device(backend="numpy", dev="cpu") or by setting xp and dev manually +xp, dev = suggest_array_backend_and_device(None, None) + +# %% + +num_epochs = 700 +fwhm_tof_mm = 30.0 +sm_fwhm_mm = 9.0 +cylinder_radius = 140 +count_factor = 0.3 + +# %% +# Setup of the forward model :math:`\bar{y}(x) = A x + s` +# -------------------------------------------------------- +# +# We setup a linear forward operator :math:`A` consisting of an +# image-based resolution model, a non-TOF PET projector and an attenuation model +# +# .. note:: +# The OSEM implementation below works with all linear operators that +# subclass :class:`.LinearOperator` (e.g. the high-level projectors). + +num_rings = 1 +scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=300.0, + num_sides=28, + num_lor_endpoints_per_side=16, + lor_spacing=4.0, + ring_positions=xp.asarray([0], dtype=xp.float32, device=dev), + symmetry_axis=2, +) + +# %% +# setup the LOR descriptor that defines the sinogram + +img_shape = (151, 151, 1) +voxel_size = (2.0, 2.0, 2.0) + +lor_desc = parallelproj.pet_lors.RegularPolygonPETLORDescriptor( + scanner, + parallelproj.pet_lors.Michelogram(scanner.num_rings, max_ring_difference=2, span=1), + radial_trim=150, + sinogram_order=parallelproj.pet_lors.SinogramSpatialAxisOrder.RVP, +) + +proj_non_tof = parallelproj.projectors.RegularPolygonPETProjector( + lor_desc, img_shape=img_shape, voxel_size=voxel_size +) + +# setup a uniform circle +x_pos = voxel_size[0] * ( + xp.arange(img_shape[0], device=dev, dtype=xp.float32) - img_shape[0] / 2 + 0.5 +) +X, Y = xp.meshgrid(x_pos, x_pos, indexing="ij") +RHO = xp.sqrt(X**2 + Y**2) + +x_true = xp.ones(img_shape, device=dev, dtype=xp.float32) +x_true[..., 0] = count_factor * (RHO <= cylinder_radius) + +# %% +# Attenuation image and sinogram setup +# ------------------------------------ + +# setup an attenuation image +x_att = 0.01 * xp.astype(x_true > 0, xp.float32) +# calculate the attenuation sinogram +att_sino = xp.exp(-proj_non_tof(x_att)) + +# %% +# Complete PET forward model setup +# -------------------------------- +# +# We combine an image-based resolution model, +# a non-TOF or TOF PET projector and an attenuation model +# into a single linear operator. + +proj_tof = copy(proj_non_tof) + +proj_tof.tof_parameters = parallelproj.tof.TOFParameters( + num_tofbins=int(300 / (fwhm_tof_mm / 5.0)) + 1, + tofbin_width=fwhm_tof_mm / 4.0, + sigma_tof=fwhm_tof_mm / 2.35, +) + +# For TOF, att_sino has no TOF-bins dimension while the projector output does. +# broadcast_to adds a trailing singleton via expand_dims and broadcasts it over +# the TOF-bins axis without copying data (zero-stride view). +att_values_tof = xp.broadcast_to(xp.expand_dims(att_sino, axis=-1), proj_tof.out_shape) +att_op_tof = parallelproj.operators.ElementwiseMultiplicationOperator(att_values_tof) +att_op_non_tof = parallelproj.operators.ElementwiseMultiplicationOperator(att_sino) + +res_model = parallelproj.operators.GaussianFilterOperator( + img_shape, sigma=[4.0 / (2.35 * float(vs)) for vs in proj_tof.voxel_size] +) + +# compose all 3 operators into a single linear operator +pet_lin_op_tof = parallelproj.operators.CompositeLinearOperator( + (att_op_tof, proj_tof, res_model) +) + +# setup non-TOF fwd model +pet_lin_op_non_tof = parallelproj.operators.CompositeLinearOperator( + (att_op_non_tof, proj_non_tof, res_model) +) + +# %% +# Simulation of projection data +# ----------------------------- +# +# We setup an arbitrary ground truth :math:`x_{true}` and simulate +# noise-free and noisy data :math:`y` by adding Poisson noise. + +# simulated noise-free data +noise_free_data_tof = pet_lin_op_tof(x_true) + +# generate a contant contamination sinogram +contamination_tof = xp.full( + noise_free_data_tof.shape, + 0.5 * float(xp.mean(noise_free_data_tof)), + device=dev, + dtype=xp.float32, +) + +noise_free_data_tof += contamination_tof + +# add Poisson noise +np.random.seed(1) +y_tof = xp.asarray( + np.random.poisson(to_numpy_array(noise_free_data_tof)), + device=dev, + dtype=xp.float32, +) + +y_non_tof = xp.sum(y_tof, axis=-1) +contamination_non_tof = xp.sum(contamination_tof, axis=-1) + +# %% +# EM update +# --------- +# +# The EM update used in MLEM and OSEM is :cite:p:`Dempster1977` +# :cite:p:`Shepp1982` :cite:p:`Lange1984` :cite:p:`Hudson1994` +# +# .. math:: +# x^+ = \frac{x}{A^H 1} A^H \frac{y}{A x + s} +# +# which can be rewritten as a preconditioned gradient descent step with +# diagonal preconditioner :math:`D = \operatorname{diag}(x / (A^H 1))`: +# +# .. math:: +# x^+ = x - D \, \nabla_x f(x). +# +# We implement this as a single function used by both MLEM and OSEM. + + +def em_update( + x_cur: Array, + negpoissonlogl: C1Function, + adj_ones: Array, +) -> Array: + """EM update re-written as preconditioned GD step""" + em_diag_precond = x_cur / adj_ones + return x_cur - em_diag_precond * negpoissonlogl.gradient(x_cur) + + +# %% +# NON-TOF EM reconstruction +# ------------------------- + +sm_op = parallelproj.operators.GaussianFilterOperator( + in_shape=img_shape, sigma=sm_fwhm_mm / (2.35 * voxel_size[0]) +) + +full_data_fidelity_non_tof = C2AffineObjective( + NegPoissonLogL(y_non_tof), pet_lin_op_non_tof, contamination_non_tof +) + +adjoint_ones_non_tof = pet_lin_op_non_tof.adjoint( + xp.ones(pet_lin_op_non_tof.out_shape, dtype=xp.float32, device=dev) +) + +x_mlem_non_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +recons_non_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) + +for i in range(num_epochs): + print(f"NON-TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") + x_mlem_non_tof = em_update( + x_mlem_non_tof, full_data_fidelity_non_tof, adjoint_ones_non_tof + ) + recons_non_tof[i, ...] = sm_op(x_mlem_non_tof) +print() + + +# %% +# TOF EM reconstruction +# --------------------- + +full_data_fidelity_tof = C2AffineObjective( + NegPoissonLogL(y_tof), pet_lin_op_tof, contamination_tof +) + +adjoint_ones_tof = pet_lin_op_tof.adjoint( + xp.ones(pet_lin_op_tof.out_shape, dtype=xp.float32, device=dev) +) + +x_mlem_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +recons_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) + +for i in range(num_epochs): + print(f"TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") + x_mlem_tof = em_update(x_mlem_tof, full_data_fidelity_tof, adjoint_ones_tof) + recons_tof[i, ...] = sm_op(x_mlem_tof) + + +# %% +# Visualize (smoothed) reconstructions +# ------------------------------------ + +roi_std_non_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_non_tof]) +roi_std_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_tof]) +epochs = np.arange(1, 1 + num_epochs) + +ims = dict(vmin=0, vmax=xp.max(recons_non_tof), cmap="Greys") + +fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout="constrained", sharex="col") +ax[0, 0].plot(epochs, roi_std_non_tof, label="non-TOF") +ax[0, 0].plot(epochs, roi_std_tof, label="TOF") +ax[0, 0].legend() +ax[1, 0].plot(roi_std_non_tof / roi_std_tof) +ax[1, 0].set_xlabel("epoch") +ax[0, 0].set_ylabel("std.dev in central ROI") +ax[1, 0].set_ylabel("std.dev(non-TOF) / std.dev(TOF)") +ax[0, 0].grid(ls=":") +ax[1, 0].grid(ls=":") + +ax[0, 1].imshow(to_numpy_array(recons_non_tof[-1, :, :, 0]), **ims) +ax[1, 1].imshow(to_numpy_array(recons_tof[-1, :, :, 0]), **ims) +ax[0, 1].set_title(f"Non-TOF {num_epochs} epochs", fontsize="medium") +ax[1, 1].set_title(f"TOF ({fwhm_tof_mm}mm FWHM) {num_epochs} epochs", fontsize="medium") + +fig.show() From 5d1b618be847da4169607f0b3a96ce6deff8eb90 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Tue, 26 May 2026 16:32:20 +0200 Subject: [PATCH 02/17] wip TOF variance reduction --- .../03_algorithms/05_run_tof_variance.py | 454 ++++++++++++++++++ .../03_algorithms/05_tof_vs_nontof.py | 269 ----------- 2 files changed, 454 insertions(+), 269 deletions(-) create mode 100644 docs/examples/03_algorithms/05_run_tof_variance.py delete mode 100644 docs/examples/03_algorithms/05_tof_vs_nontof.py diff --git a/docs/examples/03_algorithms/05_run_tof_variance.py b/docs/examples/03_algorithms/05_run_tof_variance.py new file mode 100644 index 00000000..e1715aaa --- /dev/null +++ b/docs/examples/03_algorithms/05_run_tof_variance.py @@ -0,0 +1,454 @@ +""" +TOF vs non-TOF: variance reduction in a uniform cylinder +========================================================= + +Why TOF reduces image noise +---------------------------- + +In a standard (non-TOF) PET scan each detected coincidence event tells us +that an annihilation occurred *somewhere* along a line of response (LOR), +but gives no information about *where* along it. + +Time-of-flight (TOF) PET additionally measures the small difference in +arrival times of the two 511 keV photons and uses that difference to +localise the annihilation along the LOR to a Gaussian probability kernel: + +.. math:: + + h(\\ell) = \\frac{1}{\\sqrt{2\\pi}\\,\\sigma_\\text{TOF}} + \\exp\\!\\left(-\\frac{\\ell^2}{2\\sigma_\\text{TOF}^2}\\right), + \\qquad + \\sigma_\\text{TOF} = \\frac{c}{2} \\cdot \\frac{\\Delta t_\\text{FWHM}}{2.355} + +where :math:`\\ell` is the distance from the LOR midpoint and +:math:`\\Delta t_\\text{FWHM}` is the scanner's coincidence timing +resolution (CTR). A CTR of 200 ps corresponds to a spatial FWHM of +≈ 30 mm. + +It is know that TOF reduces the variance in the center of a 2D cylinder with +diameter :math:`D`, where the SNR gain is given by + +.. math:: + + G_\\text{TOF} \\approx \\sqrt{0.66 \\frac{D}{\\text{FWHM}_\\text{TOF}}}. + +For :math:`D = 240` mm and :math:`\\text{FWHM} = 30` mm this gives +:math:`G \\approx 2.0`, i.e. a roughly three-fold noise reduction at the +centre. + +The convergence-speed trap +--------------------------- + +TOF reconstruction also **converges faster** than non-TOF MLEM. +This creates a common pitfall: if both reconstructions are stopped at the +*same* (small) number of iterations, the TOF image may appear *noisier* +than the non-TOF image — not because TOF is worse, but because TOF has +already converged past its low-noise plateau while non-TOF is still +climbing. Conversely, at very early iterations non-TOF may look +smoother simply because it has not yet amplified the noise. + +To observe the *true* asymptotic advantage of TOF one must run **both** +algorithms long enough to reach (approximate) convergence. + +What this example shows +------------------------ + +* A single-ring 2-D scanner with a uniform circular phantom. +* **Full MLEM** run for 700 iterations, with a mild post-filter applied + after each iteration so that the smoothed image is stored. +* The standard deviation inside a small (25 mm-radius) central ROI is + tracked vs. iteration number — this is a fast single-realisation proxy + for the true noise level that avoids the need for Monte Carlo repeats. +* Four-panel figure: + + - **Top-left**: raw std.dev curves vs. iteration — shows faster TOF + convergence *and* the lower asymptotic noise level. + - **Bottom-left**: ratio non-TOF / TOF std.dev — values > 1 confirm + that non-TOF is noisier; the ratio rises as iterations increase and + both methods converge. + - **Top/bottom-right**: final smoothed images for visual comparison. + +.. note:: + + The standard deviation in a single noise realisation is used here as + a proxy for the true noise standard deviation. +""" + +# %% +from __future__ import annotations +import matplotlib.pyplot as plt +import numpy as np + +import parallelproj.operators +import parallelproj.tof +import parallelproj.pet_scanners +import parallelproj.pet_lors +import parallelproj.projectors +from parallelproj import to_numpy_array, Array +from parallelproj.functions import NegPoissonLogL, C2AffineObjective, C1Function + +from copy import copy + +# %% +from array_utils import suggest_array_backend_and_device + +# To use a specific backend and/or device, replace the None arguments, e.g.: +# xp, dev = suggest_array_backend_and_device(backend="numpy", dev="cpu") or by setting xp and dev manually +xp, dev = suggest_array_backend_and_device(None, None) + +# %% +# Key simulation parameters +# ------------------------- +# +# ``num_epochs`` controls how many MLEM iterations are stored. 700 is +# enough for both non-TOF and TOF to be well past their respective +# convergence knees, so the asymptotic noise levels are clearly visible. +# +# ``fwhm_tof_mm = 30 mm`` corresponds to a coincidence timing resolution +# of approximately 200 ps — representative of state-of-the-art clinical +# scanners as of 2025. +# +# ``sm_fwhm_mm`` is the FWHM of the Gaussian post-filter applied after +# every iteration. A mild 9 mm filter is applied to suppress +# high-frequency "salt-and-pepper" noise while preserving the +# convergence-related noise trend. +# +# ``count_factor`` scales the phantom activity to control the total number +# of detected events. Moderate counts (0.3) give a clearly visible noise +# difference between TOF and non-TOF. +# +# ``cylinder_radius`` (in voxels) defines the uniform phantom disk. + +num_epochs = 300 +fwhm_tof_mm = 30.0 +fwhm_res_model_mm = 4.0 +sm_fwhm_mm = 9.0 +cylinder_radius_mm = 120 +count_factor = 0.3 + +# %% +# Scanner and image geometry +# -------------------------- +# +# We use a **single-ring scanner** (``num_rings=1``) so that the +# reconstruction is effectively 2-D. This keeps computation fast and +# isolates the transaxial TOF effect without axial compression artefacts. +# +# The scanner radius of 300 mm and 28 × 16 = 448 detector elements give a +# realistic clinical-scale geometry. The single image plane has +# 151 × 151 × 1 voxels of 2 mm side length, yielding a 302 mm transaxial +# field of view. + +num_rings = 1 +scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=300.0, + num_sides=28, + num_lor_endpoints_per_side=16, + lor_spacing=4.0, + ring_positions=xp.asarray([0], dtype=xp.float32, device=dev), + symmetry_axis=2, +) + +# %% +# LOR descriptor and projectors +# ----------------------------- +# +# The :class:`.RegularPolygonPETLORDescriptor` maps detector pairs to +# sinogram bins. ``max_ring_difference=2`` is harmless here (single ring) +# and ``radial_trim=150`` discards the outermost radial bins that fall +# outside the cylinder FOV. + +img_shape = (151, 151, 1) +voxel_size = (2.0, 2.0, 2.0) + +lor_desc = parallelproj.pet_lors.RegularPolygonPETLORDescriptor( + scanner, + parallelproj.pet_lors.Michelogram(scanner.num_rings, max_ring_difference=2, span=1), + radial_trim=150, + sinogram_order=parallelproj.pet_lors.SinogramSpatialAxisOrder.RVP, +) + +proj_non_tof = parallelproj.projectors.RegularPolygonPETProjector( + lor_desc, img_shape=img_shape, voxel_size=voxel_size +) + +# %% +# Uniform cylinder phantom +# ------------------------ +# +# A disk of radius ``cylinder_radius`` voxels centred in the +# FOV with uniform activity ``count_factor``. + +x_pos = voxel_size[0] * ( + xp.arange(img_shape[0], device=dev, dtype=xp.float32) - img_shape[0] / 2 + 0.5 +) +X, Y = xp.meshgrid(x_pos, x_pos, indexing="ij") +RHO = xp.sqrt(X**2 + Y**2) + +x_true = xp.ones(img_shape, device=dev, dtype=xp.float32) +x_true[..., 0] = count_factor * (RHO <= cylinder_radius_mm) + +# %% +# Attenuation model +# ----------------- +# +# A uniform water-equivalent attenuation coefficient of +# :math:`\mu = 0.01\,\text{mm}^{-1}` is used inside the cylinder. +# Attenuation is the same for TOF and non-TOF; it is included here for +# realism and has no bearing on the variance comparison. + +x_att = 0.01 * xp.astype(x_true > 0, xp.float32) +att_sino = xp.exp(-proj_non_tof(x_att)) + +# %% +# TOF projector setup +# ------------------- +# +# The TOF projector is a copy of the non-TOF projector with +# :class:`.TOFParameters` attached. The bin width is set to +# :math:`\text{FWHM}/4` so that each TOF kernel spans approximately +# 4 bins (good Gaussian sampling). The total number of bins is chosen to +# cover an FOV of (300 mm) with some margin. +# +# Both forward operators also include the same image-based resolution model +# (:class:`.GaussianFilterOperator`, FWHM = 4 mm) to model finite detector +# resolution. + +proj_tof = copy(proj_non_tof) + +proj_tof.tof_parameters = parallelproj.tof.TOFParameters( + num_tofbins=int(300 / (fwhm_tof_mm / 4.0)) + 1, + tofbin_width=fwhm_tof_mm / 4.0, + sigma_tof=fwhm_tof_mm / 2.35, +) + +# For TOF, att_sino has no TOF-bins dimension while the projector output does. +# broadcast_to adds a trailing singleton via expand_dims and broadcasts it over +# the TOF-bins axis without copying data (zero-stride view). +att_values_tof = xp.broadcast_to(xp.expand_dims(att_sino, axis=-1), proj_tof.out_shape) +att_op_tof = parallelproj.operators.ElementwiseMultiplicationOperator(att_values_tof) +att_op_non_tof = parallelproj.operators.ElementwiseMultiplicationOperator(att_sino) + +res_model = parallelproj.operators.GaussianFilterOperator( + img_shape, + sigma=[fwhm_res_model_mm / (2.35 * float(vs)) for vs in proj_tof.voxel_size], +) + +# compose all 3 operators into a single linear operator +pet_lin_op_tof = parallelproj.operators.CompositeLinearOperator( + (att_op_tof, proj_tof, res_model) +) + +# setup non-TOF fwd model +pet_lin_op_non_tof = parallelproj.operators.CompositeLinearOperator( + (att_op_non_tof, proj_non_tof, res_model) +) + +# %% +# Data simulation +# --------------- +# +# The TOF sinogram is simulated once and a constant scatter/randoms +# contamination (50 % of the mean prompt rate) is added before Poisson +# sampling. +# +# The non-TOF sinogram is obtained by **summing the noisy TOF sinogram over +# its TOF-bin axis**. This marginalisation is mathematically equivalent to +# discarding the timing information in a real scanner, and it ensures that +# both reconstructions see exactly the same Poisson noise realisation — +# they differ only in how much of the timing information they exploit. + +noise_free_data_tof = pet_lin_op_tof(x_true) + +contamination_tof = xp.full( + noise_free_data_tof.shape, + 0.5 * float(xp.mean(noise_free_data_tof)), + device=dev, + dtype=xp.float32, +) + +noise_free_data_tof += contamination_tof + +np.random.seed(1) +y_tof = xp.asarray( + np.random.poisson(to_numpy_array(noise_free_data_tof)), + device=dev, + dtype=xp.float32, +) + +# marginalise: sum over TOF bins gives the non-TOF sinogram +y_non_tof = xp.sum(y_tof, axis=-1) +contamination_non_tof = xp.sum(contamination_tof, axis=-1) + +# %% +# EM update rule +# -------------- +# +# The standard MLEM update :cite:p:`Shepp1982` :cite:p:`Lange1984` can be +# written as a preconditioned gradient-descent step: +# +# .. math:: +# x^+ = x - D\,\nabla_x f(x), +# \qquad +# D = \operatorname{diag}\!\left(\frac{x}{A^H \mathbf{1}}\right) +# +# where :math:`f(x) = \sum_i [\bar{y}_i - y_i \log \bar{y}_i]` is the +# negative Poisson log-likelihood and :math:`A^H \mathbf{1}` is the +# sensitivity image. The same update is used for both non-TOF and TOF; +# the only difference is the forward operator :math:`A`. + + +def em_update( + x_cur: Array, + negpoissonlogl: C1Function, + adj_ones: Array, +) -> Array: + """EM update re-written as preconditioned GD step""" + em_diag_precond = x_cur / adj_ones + return x_cur - em_diag_precond * negpoissonlogl.gradient(x_cur) + + +# %% +# Non-TOF MLEM +# ------------ +# +# We run ``num_epochs`` full-data MLEM iterations and store the +# **post-filtered** image after every iteration. Applying the same +# post-filter (:class:`.GaussianFilterOperator`, FWHM = ``sm_fwhm_mm``) +# at each iteration mirrors the typical clinical workflow where a +# reconstruction is post-smoothed before evaluation. Storing all +# intermediate images lets us plot noise vs. iteration and observe both +# the convergence speed and the asymptotic noise level. + +sm_op = parallelproj.operators.GaussianFilterOperator( + in_shape=img_shape, sigma=sm_fwhm_mm / (2.35 * voxel_size[0]) +) + +full_data_fidelity_non_tof = C2AffineObjective( + NegPoissonLogL(y_non_tof), pet_lin_op_non_tof, contamination_non_tof +) + +adjoint_ones_non_tof = pet_lin_op_non_tof.adjoint( + xp.ones(pet_lin_op_non_tof.out_shape, dtype=xp.float32, device=dev) +) + +x_mlem_non_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +recons_non_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) + +for i in range(num_epochs): + print(f"NON-TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") + x_mlem_non_tof = em_update( + x_mlem_non_tof, full_data_fidelity_non_tof, adjoint_ones_non_tof + ) + recons_non_tof[i, ...] = sm_op(x_mlem_non_tof) +print() + + +# %% +# TOF MLEM +# -------- +# +# Identical loop but using the TOF forward operator and TOF sinogram. +# Because the TOF update is more informative (each LOR contributes noise +# over the kernel width ≈ 30 mm rather than the full chord ≈ 600 mm), the +# image converges to its maximum-likelihood solution in fewer iterations. + +full_data_fidelity_tof = C2AffineObjective( + NegPoissonLogL(y_tof), pet_lin_op_tof, contamination_tof +) + +adjoint_ones_tof = pet_lin_op_tof.adjoint( + xp.ones(pet_lin_op_tof.out_shape, dtype=xp.float32, device=dev) +) + +x_mlem_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +recons_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) + +for i in range(num_epochs): + print(f"TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") + x_mlem_tof = em_update(x_mlem_tof, full_data_fidelity_tof, adjoint_ones_tof) + recons_tof[i, ...] = sm_op(x_mlem_tof) + + +# %% +# Noise vs. iteration in the central ROI +# --------------------------------------- +# +# The standard deviation of voxel values inside a small 25 mm-radius +# central ROI is used as a single-realisation proxy for the true noise +# standard deviation. Because the phantom is uniform, every voxel inside +# the ROI has the same expected value, so spatial variability equals noise +# variability. +# +# Two effects are visible in the plot: +# +# 1. **Faster convergence of TOF**: the TOF std.dev curve rises steeply and +# then falls to its asymptote in far fewer iterations than non-TOF. +# At early iteration counts TOF can therefore appear *noisier* — not +# because TOF is worse, but because it has already amplified noise while +# non-TOF is still initialisation-smooth. +# +# 2. **Lower asymptotic noise for TOF**: once both curves have stabilised, +# the TOF std.dev is clearly below the non-TOF std.dev. The ratio plot +# (bottom-left) shows this: values > 1 confirm the TOF advantage, and +# the ratio continues to grow as both algorithms converge. + +roi_std_non_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_non_tof]) +roi_std_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_tof]) +epochs = np.arange(1, 1 + num_epochs) + +# %% +# Visualisation +# ------------- +# +# The four-panel figure summarises the comparison: +# +# * **Top-left**: std.dev in the central 25 mm ROI vs. MLEM iteration for +# non-TOF (orange) and TOF (blue). Note how TOF rises *and falls* faster; +# comparing at a fixed early iteration can give the wrong conclusion. +# * **Bottom-left**: ratio of std.devs (non-TOF / TOF). The ratio +# increases with iteration count and stabilises above 1, quantifying the +# asymptotic noise advantage of TOF. +# * **Top-right / bottom-right**: final smoothed images after ``num_epochs`` +# iterations. Visual noise in the uniform disk is lower for TOF. + +ims = dict(vmin=0, vmax=xp.max(recons_non_tof), cmap="Greys") + +fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout="constrained", sharex="col") +ax[0, 0].plot(epochs, roi_std_non_tof, label="non-TOF", color="tab:orange") +ax[0, 0].plot( + epochs, roi_std_tof, label=f"TOF ({fwhm_tof_mm:.0f} mm FWHM)", color="tab:blue" +) +ax[0, 0].legend(fontsize=8) +ax[1, 0].plot(epochs, roi_std_non_tof / roi_std_tof, color="tab:green") +ax[1, 0].axhline(1.0, color="gray", ls=":", lw=0.8) +ax[1, 0].set_xlabel("MLEM iteration") +ax[0, 0].set_ylabel("std.dev in central ROI") +ax[1, 0].set_ylabel("std.dev ratio (non-TOF / TOF)") +ax[0, 0].set_title( + f"(central 25 mm ROI, {sm_fwhm_mm:.0f} mm post-filter)", + fontsize=8, +) +ax[0, 0].grid(ls=":") +ax[1, 0].grid(ls=":") + +ax[0, 1].imshow(to_numpy_array(recons_non_tof[-1, :, :, 0]), **ims) +ax[1, 1].imshow(to_numpy_array(recons_tof[-1, :, :, 0]), **ims) +ax[0, 1].set_title( + f"non-TOF ({num_epochs} iter)\nstd.dev = {roi_std_non_tof[-1]:.4f}", + fontsize=8, +) +ax[1, 1].set_title( + f"TOF {fwhm_tof_mm:.0f} mm ({num_epochs} iter)\nstd.dev = {roi_std_tof[-1]:.4f}", + fontsize=8, +) +for a in ax[:, 1]: + a.set_axis_off() + +fig.suptitle( + f"TOF variance reduction — uniform cylinder Ø {2*cylinder_radius_mm:.0f} mm", + fontsize=9, +) +fig.show() diff --git a/docs/examples/03_algorithms/05_tof_vs_nontof.py b/docs/examples/03_algorithms/05_tof_vs_nontof.py deleted file mode 100644 index 376c74f0..00000000 --- a/docs/examples/03_algorithms/05_tof_vs_nontof.py +++ /dev/null @@ -1,269 +0,0 @@ -""" -TOF vs NON-TOF RECONSTRUCTIONS in a 2D uniform cylinder -======================================================= - -This example compares variance reduction due to the presence of TOF information. -""" - -# %% -from __future__ import annotations -import matplotlib.pyplot as plt -import numpy as np - -import parallelproj.operators -import parallelproj.tof -import parallelproj.pet_scanners -import parallelproj.pet_lors -import parallelproj.projectors -from parallelproj import to_numpy_array, Array -from parallelproj.functions import NegPoissonLogL, C2AffineObjective, C1Function - -from copy import copy - -# %% -from array_utils import suggest_array_backend_and_device - -# To use a specific backend and/or device, replace the None arguments, e.g.: -# xp, dev = suggest_array_backend_and_device(backend="numpy", dev="cpu") or by setting xp and dev manually -xp, dev = suggest_array_backend_and_device(None, None) - -# %% - -num_epochs = 700 -fwhm_tof_mm = 30.0 -sm_fwhm_mm = 9.0 -cylinder_radius = 140 -count_factor = 0.3 - -# %% -# Setup of the forward model :math:`\bar{y}(x) = A x + s` -# -------------------------------------------------------- -# -# We setup a linear forward operator :math:`A` consisting of an -# image-based resolution model, a non-TOF PET projector and an attenuation model -# -# .. note:: -# The OSEM implementation below works with all linear operators that -# subclass :class:`.LinearOperator` (e.g. the high-level projectors). - -num_rings = 1 -scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( - xp, - dev, - radius=300.0, - num_sides=28, - num_lor_endpoints_per_side=16, - lor_spacing=4.0, - ring_positions=xp.asarray([0], dtype=xp.float32, device=dev), - symmetry_axis=2, -) - -# %% -# setup the LOR descriptor that defines the sinogram - -img_shape = (151, 151, 1) -voxel_size = (2.0, 2.0, 2.0) - -lor_desc = parallelproj.pet_lors.RegularPolygonPETLORDescriptor( - scanner, - parallelproj.pet_lors.Michelogram(scanner.num_rings, max_ring_difference=2, span=1), - radial_trim=150, - sinogram_order=parallelproj.pet_lors.SinogramSpatialAxisOrder.RVP, -) - -proj_non_tof = parallelproj.projectors.RegularPolygonPETProjector( - lor_desc, img_shape=img_shape, voxel_size=voxel_size -) - -# setup a uniform circle -x_pos = voxel_size[0] * ( - xp.arange(img_shape[0], device=dev, dtype=xp.float32) - img_shape[0] / 2 + 0.5 -) -X, Y = xp.meshgrid(x_pos, x_pos, indexing="ij") -RHO = xp.sqrt(X**2 + Y**2) - -x_true = xp.ones(img_shape, device=dev, dtype=xp.float32) -x_true[..., 0] = count_factor * (RHO <= cylinder_radius) - -# %% -# Attenuation image and sinogram setup -# ------------------------------------ - -# setup an attenuation image -x_att = 0.01 * xp.astype(x_true > 0, xp.float32) -# calculate the attenuation sinogram -att_sino = xp.exp(-proj_non_tof(x_att)) - -# %% -# Complete PET forward model setup -# -------------------------------- -# -# We combine an image-based resolution model, -# a non-TOF or TOF PET projector and an attenuation model -# into a single linear operator. - -proj_tof = copy(proj_non_tof) - -proj_tof.tof_parameters = parallelproj.tof.TOFParameters( - num_tofbins=int(300 / (fwhm_tof_mm / 5.0)) + 1, - tofbin_width=fwhm_tof_mm / 4.0, - sigma_tof=fwhm_tof_mm / 2.35, -) - -# For TOF, att_sino has no TOF-bins dimension while the projector output does. -# broadcast_to adds a trailing singleton via expand_dims and broadcasts it over -# the TOF-bins axis without copying data (zero-stride view). -att_values_tof = xp.broadcast_to(xp.expand_dims(att_sino, axis=-1), proj_tof.out_shape) -att_op_tof = parallelproj.operators.ElementwiseMultiplicationOperator(att_values_tof) -att_op_non_tof = parallelproj.operators.ElementwiseMultiplicationOperator(att_sino) - -res_model = parallelproj.operators.GaussianFilterOperator( - img_shape, sigma=[4.0 / (2.35 * float(vs)) for vs in proj_tof.voxel_size] -) - -# compose all 3 operators into a single linear operator -pet_lin_op_tof = parallelproj.operators.CompositeLinearOperator( - (att_op_tof, proj_tof, res_model) -) - -# setup non-TOF fwd model -pet_lin_op_non_tof = parallelproj.operators.CompositeLinearOperator( - (att_op_non_tof, proj_non_tof, res_model) -) - -# %% -# Simulation of projection data -# ----------------------------- -# -# We setup an arbitrary ground truth :math:`x_{true}` and simulate -# noise-free and noisy data :math:`y` by adding Poisson noise. - -# simulated noise-free data -noise_free_data_tof = pet_lin_op_tof(x_true) - -# generate a contant contamination sinogram -contamination_tof = xp.full( - noise_free_data_tof.shape, - 0.5 * float(xp.mean(noise_free_data_tof)), - device=dev, - dtype=xp.float32, -) - -noise_free_data_tof += contamination_tof - -# add Poisson noise -np.random.seed(1) -y_tof = xp.asarray( - np.random.poisson(to_numpy_array(noise_free_data_tof)), - device=dev, - dtype=xp.float32, -) - -y_non_tof = xp.sum(y_tof, axis=-1) -contamination_non_tof = xp.sum(contamination_tof, axis=-1) - -# %% -# EM update -# --------- -# -# The EM update used in MLEM and OSEM is :cite:p:`Dempster1977` -# :cite:p:`Shepp1982` :cite:p:`Lange1984` :cite:p:`Hudson1994` -# -# .. math:: -# x^+ = \frac{x}{A^H 1} A^H \frac{y}{A x + s} -# -# which can be rewritten as a preconditioned gradient descent step with -# diagonal preconditioner :math:`D = \operatorname{diag}(x / (A^H 1))`: -# -# .. math:: -# x^+ = x - D \, \nabla_x f(x). -# -# We implement this as a single function used by both MLEM and OSEM. - - -def em_update( - x_cur: Array, - negpoissonlogl: C1Function, - adj_ones: Array, -) -> Array: - """EM update re-written as preconditioned GD step""" - em_diag_precond = x_cur / adj_ones - return x_cur - em_diag_precond * negpoissonlogl.gradient(x_cur) - - -# %% -# NON-TOF EM reconstruction -# ------------------------- - -sm_op = parallelproj.operators.GaussianFilterOperator( - in_shape=img_shape, sigma=sm_fwhm_mm / (2.35 * voxel_size[0]) -) - -full_data_fidelity_non_tof = C2AffineObjective( - NegPoissonLogL(y_non_tof), pet_lin_op_non_tof, contamination_non_tof -) - -adjoint_ones_non_tof = pet_lin_op_non_tof.adjoint( - xp.ones(pet_lin_op_non_tof.out_shape, dtype=xp.float32, device=dev) -) - -x_mlem_non_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) -recons_non_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) - -for i in range(num_epochs): - print(f"NON-TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") - x_mlem_non_tof = em_update( - x_mlem_non_tof, full_data_fidelity_non_tof, adjoint_ones_non_tof - ) - recons_non_tof[i, ...] = sm_op(x_mlem_non_tof) -print() - - -# %% -# TOF EM reconstruction -# --------------------- - -full_data_fidelity_tof = C2AffineObjective( - NegPoissonLogL(y_tof), pet_lin_op_tof, contamination_tof -) - -adjoint_ones_tof = pet_lin_op_tof.adjoint( - xp.ones(pet_lin_op_tof.out_shape, dtype=xp.float32, device=dev) -) - -x_mlem_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) -recons_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) - -for i in range(num_epochs): - print(f"TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") - x_mlem_tof = em_update(x_mlem_tof, full_data_fidelity_tof, adjoint_ones_tof) - recons_tof[i, ...] = sm_op(x_mlem_tof) - - -# %% -# Visualize (smoothed) reconstructions -# ------------------------------------ - -roi_std_non_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_non_tof]) -roi_std_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_tof]) -epochs = np.arange(1, 1 + num_epochs) - -ims = dict(vmin=0, vmax=xp.max(recons_non_tof), cmap="Greys") - -fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout="constrained", sharex="col") -ax[0, 0].plot(epochs, roi_std_non_tof, label="non-TOF") -ax[0, 0].plot(epochs, roi_std_tof, label="TOF") -ax[0, 0].legend() -ax[1, 0].plot(roi_std_non_tof / roi_std_tof) -ax[1, 0].set_xlabel("epoch") -ax[0, 0].set_ylabel("std.dev in central ROI") -ax[1, 0].set_ylabel("std.dev(non-TOF) / std.dev(TOF)") -ax[0, 0].grid(ls=":") -ax[1, 0].grid(ls=":") - -ax[0, 1].imshow(to_numpy_array(recons_non_tof[-1, :, :, 0]), **ims) -ax[1, 1].imshow(to_numpy_array(recons_tof[-1, :, :, 0]), **ims) -ax[0, 1].set_title(f"Non-TOF {num_epochs} epochs", fontsize="medium") -ax[1, 1].set_title(f"TOF ({fwhm_tof_mm}mm FWHM) {num_epochs} epochs", fontsize="medium") - -fig.show() From 3fe0fca6e086ca2de0dc2938fc3237f570eb752a Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Tue, 26 May 2026 17:09:59 +0200 Subject: [PATCH 03/17] wip TOF gain --- .../03_algorithms/05_run_tof_variance.py | 385 +++++++++++++----- 1 file changed, 291 insertions(+), 94 deletions(-) diff --git a/docs/examples/03_algorithms/05_run_tof_variance.py b/docs/examples/03_algorithms/05_run_tof_variance.py index e1715aaa..359a9e3e 100644 --- a/docs/examples/03_algorithms/05_run_tof_variance.py +++ b/docs/examples/03_algorithms/05_run_tof_variance.py @@ -25,26 +25,26 @@ resolution (CTR). A CTR of 200 ps corresponds to a spatial FWHM of ≈ 30 mm. -It is know that TOF reduces the variance in the center of a 2D cylinder with -diameter :math:`D`, where the SNR gain is given by +It is known that TOF reduces the variance in the center of a 2D cylinder +with diameter :math:`D`, where the SNR gain is approximately .. math:: G_\\text{TOF} \\approx \\sqrt{0.66 \\frac{D}{\\text{FWHM}_\\text{TOF}}}. For :math:`D = 240` mm and :math:`\\text{FWHM} = 30` mm this gives -:math:`G \\approx 2.0`, i.e. a roughly three-fold noise reduction at the +:math:`G \\approx 2.0`, i.e. a roughly two-fold noise reduction at the centre. The convergence-speed trap --------------------------- -TOF reconstruction also **converges faster** than non-TOF MLEM. +TOF reconstruction also **converges faster** than non-TOF reconstruction. This creates a common pitfall: if both reconstructions are stopped at the -*same* (small) number of iterations, the TOF image may appear *noisier* +*same* (small) number of epochs, the TOF image may appear *noisier* than the non-TOF image — not because TOF is worse, but because TOF has already converged past its low-noise plateau while non-TOF is still -climbing. Conversely, at very early iterations non-TOF may look +climbing. Conversely, at very early epochs non-TOF may look smoother simply because it has not yet amplified the noise. To observe the *true* asymptotic advantage of TOF one must run **both** @@ -54,28 +54,40 @@ ------------------------ * A single-ring 2-D scanner with a uniform circular phantom. -* **Full MLEM** run for 700 iterations, with a mild post-filter applied - after each iteration so that the smoothed image is stored. +* **SVRG** (:func:`00_run_mlem_osem_svrg`) with ``num_subsets=28`` + subsets run for ``num_epochs=10`` epochs (warm-started by a single + OSEM epoch), applied independently to the non-TOF and TOF forward + models. 10 SVRG epochs are sufficient for both to reach their + respective noise plateaux. * The standard deviation inside a small (25 mm-radius) central ROI is - tracked vs. iteration number — this is a fast single-realisation proxy + tracked after every epoch — this is a fast single-realisation proxy for the true noise level that avoids the need for Monte Carlo repeats. -* Four-panel figure: +* Eight-panel figure: - - **Top-left**: raw std.dev curves vs. iteration — shows faster TOF - convergence *and* the lower asymptotic noise level. + - **Top-left**: std.dev curves vs. SVRG epoch (epoch 0 = OSEM warm + start) — shows faster TOF convergence *and* the lower asymptotic + noise level. - **Bottom-left**: ratio non-TOF / TOF std.dev — values > 1 confirm - that non-TOF is noisier; the ratio rises as iterations increase and - both methods converge. - - **Top/bottom-right**: final smoothed images for visual comparison. + that non-TOF is noisier; the ratio stabilises above 1 once both + algorithms have converged. + - **2nd column**: smoothed images after the OSEM warm start (epoch 0). + - **3rd column**: smoothed images after 1 SVRG epoch, illustrating + that at very early iterations TOF may appear noisier (it has already + amplified noise while non-TOF is still initialisation-smooth). + - **Right column**: final smoothed images after ``num_epochs`` SVRG + epochs for visual comparison of the asymptotic noise levels. .. note:: The standard deviation in a single noise realisation is used here as - a proxy for the true noise standard deviation. + a proxy for the true noise standard deviation. For a uniform phantom + spatial variability inside the ROI equals the noise variability, so + the single realisation is sufficient. """ # %% from __future__ import annotations +from collections.abc import Sequence import matplotlib.pyplot as plt import numpy as np @@ -119,12 +131,14 @@ # # ``cylinder_radius`` (in voxels) defines the uniform phantom disk. -num_epochs = 300 +num_subsets = 14 +num_epochs = 10 fwhm_tof_mm = 30.0 fwhm_res_model_mm = 4.0 sm_fwhm_mm = 9.0 cylinder_radius_mm = 120 count_factor = 0.3 +step_size = 2.0 # %% # Scanner and image geometry @@ -283,93 +297,239 @@ contamination_non_tof = xp.sum(contamination_tof, axis=-1) # %% -# EM update rule -# -------------- +# Post-filter and subset splitting +# --------------------------------- # -# The standard MLEM update :cite:p:`Shepp1982` :cite:p:`Lange1984` can be -# written as a preconditioned gradient-descent step: +# A mild Gaussian post-filter is applied after each SVRG epoch so that the +# stored image matches the typical clinical workflow. # -# .. math:: -# x^+ = x - D\,\nabla_x f(x), -# \qquad -# D = \operatorname{diag}\!\left(\frac{x}{A^H \mathbf{1}}\right) +# The sinogram views are split into ``num_subsets`` disjoint groups. +# Non-TOF data and the attenuation sinogram are 3-D (R × V × P); TOF data +# adds a fourth TOF-bin axis. We therefore request 3-D slices for +# attenuation / non-TOF data indexing and 4-D slices for TOF data indexing. + +sm_op = parallelproj.operators.GaussianFilterOperator( + in_shape=img_shape, sigma=sm_fwhm_mm / (2.35 * voxel_size[0]) +) + +# 3-D slices: used for non-TOF data *and* to index att_sino +subset_views, subset_slices_nt = lor_desc.get_distributed_views_and_slices( + num_subsets, 3 +) +# 4-D slices: used to index TOF data and contamination +_, subset_slices_tof = lor_desc.get_distributed_views_and_slices(num_subsets, 4) + +# %% +# SVRG helper functions +# --------------------- # -# where :math:`f(x) = \sum_i [\bar{y}_i - y_i \log \bar{y}_i]` is the -# negative Poisson log-likelihood and :math:`A^H \mathbf{1}` is the -# sensitivity image. The same update is used for both non-TOF and TOF; -# the only difference is the forward operator :math:`A`. +# These two functions implement SVRG exactly as in +# :ref:`sphx_glr_examples_03_algorithms_00_run_mlem_osem_svrg.py`. +# ``svrg_calc_snapshot_gradients`` computes and stores all per-subset +# gradients at the current anchor point; ``svrg_update`` performs a single +# variance-reduced subset step. def em_update( x_cur: Array, - negpoissonlogl: C1Function, + data_fidelity: C1Function, adj_ones: Array, ) -> Array: - """EM update re-written as preconditioned GD step""" - em_diag_precond = x_cur / adj_ones - return x_cur - em_diag_precond * negpoissonlogl.gradient(x_cur) + """EM update (preconditioned gradient step) used for the warm start.""" + return x_cur - (x_cur / adj_ones) * data_fidelity.gradient(x_cur) -# %% -# Non-TOF MLEM -# ------------ -# -# We run ``num_epochs`` full-data MLEM iterations and store the -# **post-filtered** image after every iteration. Applying the same -# post-filter (:class:`.GaussianFilterOperator`, FWHM = ``sm_fwhm_mm``) -# at each iteration mirrors the typical clinical workflow where a -# reconstruction is post-smoothed before evaluation. Storing all -# intermediate images lets us plot noise vs. iteration and observe both -# the convergence speed and the asymptotic noise level. +def svrg_calc_snapshot_gradients( + x_cur: Array, + subset_obj_functions: Sequence[C1Function], +) -> tuple[Array, Array]: + """Compute and store per-subset gradients at the anchor point.""" + m = len(subset_obj_functions) + stored = xp.zeros((m,) + x_cur.shape, dtype=x_cur.dtype, device=dev) + for k, df in enumerate(subset_obj_functions): + stored[k] = df.gradient(x_cur) + return stored, xp.sum(stored, axis=0) -sm_op = parallelproj.operators.GaussianFilterOperator( - in_shape=img_shape, sigma=sm_fwhm_mm / (2.35 * voxel_size[0]) -) -full_data_fidelity_non_tof = C2AffineObjective( - NegPoissonLogL(y_non_tof), pet_lin_op_non_tof, contamination_non_tof -) +def svrg_update( + x_cur: Array, + subset_idx: int, + subset_obj_functions: Sequence[C1Function], + stored_grads: Array, + full_grad: Array, + precond: Array, + step_size: float = 1.0, +) -> Array: + """Single variance-reduced subset update.""" + m = len(subset_obj_functions) + grad_k = subset_obj_functions[subset_idx].gradient(x_cur) + approx_grad = m * (grad_k - stored_grads[subset_idx]) + full_grad + return xp.clip(x_cur - step_size * precond * approx_grad, 0, None) -adjoint_ones_non_tof = pet_lin_op_non_tof.adjoint( - xp.ones(pet_lin_op_non_tof.out_shape, dtype=xp.float32, device=dev) -) -x_mlem_non_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) -recons_non_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) +# %% +# Non-TOF: subset operators, warm start, and SVRG +# ------------------------------------------------ +# +# One :class:`.CompositeLinearOperator` is built per subset, combining +# the subset projector, the attenuation diagonal, and the resolution model. +# Sensitivity images :math:`(A^k)^H \mathbf{1}` are pre-computed once and +# summed to obtain the full :math:`A^H \mathbf{1}`. +# +# The warm start runs a single OSEM epoch, which moves the initial flat +# image close enough to the solution for the SVRG preconditioner to be +# meaningful from the very first epoch. + +proj_non_tof.clear_cached_lor_endpoints() +subset_linops_nt = [] +for i in range(num_subsets): + sp = copy(proj_non_tof) + sp.views = subset_views[i] + att_op_k = parallelproj.operators.ElementwiseMultiplicationOperator( + att_sino[subset_slices_nt[i]] + ) + subset_linops_nt.append( + parallelproj.operators.CompositeLinearOperator([att_op_k, sp, res_model]) + ) + +subset_adj_ones_nt = xp.zeros((num_subsets,) + img_shape, dtype=xp.float32, device=dev) +for k, op in enumerate(subset_linops_nt): + subset_adj_ones_nt[k] = op.adjoint( + xp.ones(op.out_shape, dtype=xp.float32, device=dev) + ) +adjoint_ones_nt = xp.sum(subset_adj_ones_nt, axis=0) -for i in range(num_epochs): - print(f"NON-TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") - x_mlem_non_tof = em_update( - x_mlem_non_tof, full_data_fidelity_non_tof, adjoint_ones_non_tof +subset_fidelities_nt = [ + C2AffineObjective( + NegPoissonLogL(y_non_tof[subset_slices_nt[k]]), + subset_linops_nt[k], + contamination_non_tof[subset_slices_nt[k]], ) - recons_non_tof[i, ...] = sm_op(x_mlem_non_tof) + for k in range(num_subsets) +] + +# --- warm start: 1 OSEM epoch --- +x_nt = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +for k in range(num_subsets): + print(f" non-TOF warm-start subset {k + 1:03}/{num_subsets:03}", end="\r") + x_nt = em_update(x_nt, subset_fidelities_nt[k], subset_adj_ones_nt[k]) print() +x_nt_warmstart_sm = sm_op(x_nt) +# --- SVRG loop --- +recons_non_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) +svrg_precond_nt = x_nt / adjoint_ones_nt +stored_grads_nt, full_grad_nt = None, None + +for epoch in range(num_epochs): + if epoch % 2 == 0: + if epoch <= 4: + svrg_precond_nt = x_nt / adjoint_ones_nt + stored_grads_nt, full_grad_nt = svrg_calc_snapshot_gradients( + x_nt, subset_fidelities_nt + ) + x_nt = xp.clip(x_nt - svrg_precond_nt * full_grad_nt, 0, None) + + for k in range(num_subsets): + print( + f" non-TOF SVRG epoch {epoch + 1:02}/{num_epochs:02}," + f" subset {k + 1:03}/{num_subsets:03}", + end="\r", + ) + x_nt = svrg_update( + x_nt, + k, + subset_fidelities_nt, + stored_grads_nt, + full_grad_nt, + svrg_precond_nt, + step_size=step_size, + ) + recons_non_tof[epoch, ...] = sm_op(x_nt) +print() # %% -# TOF MLEM -# -------- +# TOF: subset operators, warm start, and SVRG +# -------------------------------------------- +# +# Identical structure to the non-TOF case. The only differences are: +# +# * Each subset attenuation operator must broadcast ``att_sino`` over the +# TOF-bin axis (zero-copy via :func:`xp.broadcast_to`). +# * Data and contamination are sliced with 4-D ``subset_slices_tof``. # -# Identical loop but using the TOF forward operator and TOF sinogram. -# Because the TOF update is more informative (each LOR contributes noise -# over the kernel width ≈ 30 mm rather than the full chord ≈ 600 mm), the -# image converges to its maximum-likelihood solution in fewer iterations. +# Because the TOF forward model localises each event to ≈ 30 mm along the +# LOR rather than the full ≈ 600 mm chord, every gradient step is more +# informative and the algorithm reaches its noise floor in fewer epochs. + +proj_tof.clear_cached_lor_endpoints() +subset_linops_tof = [] +for i in range(num_subsets): + sp = copy(proj_tof) + sp.views = subset_views[i] + att_values_k = xp.broadcast_to( + xp.expand_dims(att_sino[subset_slices_nt[i]], axis=-1), sp.out_shape + ) + att_op_k = parallelproj.operators.ElementwiseMultiplicationOperator(att_values_k) + subset_linops_tof.append( + parallelproj.operators.CompositeLinearOperator([att_op_k, sp, res_model]) + ) -full_data_fidelity_tof = C2AffineObjective( - NegPoissonLogL(y_tof), pet_lin_op_tof, contamination_tof -) +subset_adj_ones_tof = xp.zeros((num_subsets,) + img_shape, dtype=xp.float32, device=dev) +for k, op in enumerate(subset_linops_tof): + subset_adj_ones_tof[k] = op.adjoint( + xp.ones(op.out_shape, dtype=xp.float32, device=dev) + ) +adjoint_ones_tof = xp.sum(subset_adj_ones_tof, axis=0) -adjoint_ones_tof = pet_lin_op_tof.adjoint( - xp.ones(pet_lin_op_tof.out_shape, dtype=xp.float32, device=dev) -) +subset_fidelities_tof = [ + C2AffineObjective( + NegPoissonLogL(y_tof[subset_slices_tof[k]]), + subset_linops_tof[k], + contamination_tof[subset_slices_tof[k]], + ) + for k in range(num_subsets) +] + +# --- warm start: 1 OSEM epoch --- +x_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +for k in range(num_subsets): + print(f" TOF warm-start subset {k + 1:03}/{num_subsets:03}", end="\r") + x_tof = em_update(x_tof, subset_fidelities_tof[k], subset_adj_ones_tof[k]) +print() +x_tof_warmstart_sm = sm_op(x_tof) -x_mlem_tof = count_factor * xp.ones(img_shape, device=dev, dtype=xp.float32) +# --- SVRG loop --- recons_tof = xp.ones((num_epochs,) + img_shape, device=dev, dtype=xp.float32) - -for i in range(num_epochs): - print(f"TOF MLEM epoch {(i + 1):04} / {num_epochs:04}", end="\r") - x_mlem_tof = em_update(x_mlem_tof, full_data_fidelity_tof, adjoint_ones_tof) - recons_tof[i, ...] = sm_op(x_mlem_tof) +svrg_precond_tof = x_tof / adjoint_ones_tof +stored_grads_tof, full_grad_tof = None, None + +for epoch in range(num_epochs): + if epoch % 2 == 0: + if epoch <= 4: + svrg_precond_tof = x_tof / adjoint_ones_tof + stored_grads_tof, full_grad_tof = svrg_calc_snapshot_gradients( + x_tof, subset_fidelities_tof + ) + x_tof = xp.clip(x_tof - svrg_precond_tof * full_grad_tof, 0, None) + + for k in range(num_subsets): + print( + f" TOF SVRG epoch {epoch + 1:02}/{num_epochs:02}," + f" subset {k + 1:03}/{num_subsets:03}", + end="\r", + ) + x_tof = svrg_update( + x_tof, + k, + subset_fidelities_tof, + stored_grads_tof, + full_grad_tof, + svrg_precond_tof, + step_size=step_size, + ) + recons_tof[epoch, ...] = sm_op(x_tof) +print() # %% @@ -397,26 +557,37 @@ def em_update( roi_std_non_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_non_tof]) roi_std_tof = np.array([float(x[:, :, 0][RHO < 25].std()) for x in recons_tof]) -epochs = np.arange(1, 1 + num_epochs) +# prepend warm-start (epoch 0) so the x-axis starts at 0 +roi_std_non_tof = np.concatenate( + [[float(x_nt_warmstart_sm[:, :, 0][RHO < 25].std())], roi_std_non_tof] +) +roi_std_tof = np.concatenate( + [[float(x_tof_warmstart_sm[:, :, 0][RHO < 25].std())], roi_std_tof] +) +epochs = np.arange(0, num_epochs + 1) # 0 = OSEM warm start # %% # Visualisation # ------------- # -# The four-panel figure summarises the comparison: +# The eight-panel figure summarises the comparison: # -# * **Top-left**: std.dev in the central 25 mm ROI vs. MLEM iteration for -# non-TOF (orange) and TOF (blue). Note how TOF rises *and falls* faster; -# comparing at a fixed early iteration can give the wrong conclusion. +# * **Top-left**: std.dev in the central 25 mm ROI vs. SVRG epoch (epoch 0 +# is the OSEM warm start) for non-TOF (orange) and TOF (blue). Note how +# TOF rises *and falls* faster; comparing at a fixed early epoch can give +# the wrong conclusion. # * **Bottom-left**: ratio of std.devs (non-TOF / TOF). The ratio -# increases with iteration count and stabilises above 1, quantifying the +# increases with epoch count and stabilises above 1, quantifying the # asymptotic noise advantage of TOF. -# * **Top-right / bottom-right**: final smoothed images after ``num_epochs`` -# iterations. Visual noise in the uniform disk is lower for TOF. +# * **2nd column**: smoothed warm-start images (epoch 0). +# * **3rd column**: smoothed images after 1 SVRG epoch. At this early +# stage TOF may look noisier than non-TOF because it has converged further. +# * **Right column**: final smoothed images after ``num_epochs`` SVRG +# epochs. Visual noise in the uniform disk is lower for TOF. ims = dict(vmin=0, vmax=xp.max(recons_non_tof), cmap="Greys") -fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout="constrained", sharex="col") +fig, ax = plt.subplots(2, 4, figsize=(12, 6), layout="constrained", sharex="col") ax[0, 0].plot(epochs, roi_std_non_tof, label="non-TOF", color="tab:orange") ax[0, 0].plot( epochs, roi_std_tof, label=f"TOF ({fwhm_tof_mm:.0f} mm FWHM)", color="tab:blue" @@ -424,7 +595,7 @@ def em_update( ax[0, 0].legend(fontsize=8) ax[1, 0].plot(epochs, roi_std_non_tof / roi_std_tof, color="tab:green") ax[1, 0].axhline(1.0, color="gray", ls=":", lw=0.8) -ax[1, 0].set_xlabel("MLEM iteration") +ax[1, 0].set_xlabel(f"SVRG epoch ({num_subsets} subsets)") ax[0, 0].set_ylabel("std.dev in central ROI") ax[1, 0].set_ylabel("std.dev ratio (non-TOF / TOF)") ax[0, 0].set_title( @@ -434,21 +605,47 @@ def em_update( ax[0, 0].grid(ls=":") ax[1, 0].grid(ls=":") -ax[0, 1].imshow(to_numpy_array(recons_non_tof[-1, :, :, 0]), **ims) -ax[1, 1].imshow(to_numpy_array(recons_tof[-1, :, :, 0]), **ims) +# warm-start images (epoch 0) +ax[0, 1].imshow(to_numpy_array(x_nt_warmstart_sm[:, :, 0]), **ims) +ax[1, 1].imshow(to_numpy_array(x_tof_warmstart_sm[:, :, 0]), **ims) ax[0, 1].set_title( - f"non-TOF ({num_epochs} iter)\nstd.dev = {roi_std_non_tof[-1]:.4f}", + f"non-TOF (epoch 0)\nstd.dev = {roi_std_non_tof[0]:.4f}", fontsize=8, ) ax[1, 1].set_title( - f"TOF {fwhm_tof_mm:.0f} mm ({num_epochs} iter)\nstd.dev = {roi_std_tof[-1]:.4f}", + f"TOF {fwhm_tof_mm:.0f} mm (epoch 0)\nstd.dev = {roi_std_tof[0]:.4f}", + fontsize=8, +) + +# epoch 1 images (index 0 in recons arrays) +ax[0, 2].imshow(to_numpy_array(recons_non_tof[0, :, :, 0]), **ims) +ax[1, 2].imshow(to_numpy_array(recons_tof[0, :, :, 0]), **ims) +ax[0, 2].set_title( + f"non-TOF (epoch 1)\nstd.dev = {roi_std_non_tof[1]:.4f}", + fontsize=8, +) +ax[1, 2].set_title( + f"TOF {fwhm_tof_mm:.0f} mm (epoch 1)\nstd.dev = {roi_std_tof[1]:.4f}", fontsize=8, ) -for a in ax[:, 1]: + +# final-epoch images +ax[0, 3].imshow(to_numpy_array(recons_non_tof[-1, :, :, 0]), **ims) +ax[1, 3].imshow(to_numpy_array(recons_tof[-1, :, :, 0]), **ims) +ax[0, 3].set_title( + f"non-TOF (epoch {num_epochs})\nstd.dev = {roi_std_non_tof[-1]:.4f}", + fontsize=8, +) +ax[1, 3].set_title( + f"TOF {fwhm_tof_mm:.0f} mm (epoch {num_epochs})\nstd.dev = {roi_std_tof[-1]:.4f}", + fontsize=8, +) + +for a in ax[:, 1:].flat: a.set_axis_off() fig.suptitle( - f"TOF variance reduction — uniform cylinder Ø {2*cylinder_radius_mm:.0f} mm", + f"TOF variance reduction - uniform cylinder Ø {2*cylinder_radius_mm:.0f} mm - 9mm post filter", fontsize=9, ) fig.show() From 710bb0a75f7cbae7eed7e9818992b668d4ba8213 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 09:14:25 +0200 Subject: [PATCH 04/17] fix cross ref --- docs/examples/03_algorithms/05_run_tof_variance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/03_algorithms/05_run_tof_variance.py b/docs/examples/03_algorithms/05_run_tof_variance.py index 359a9e3e..a56f1312 100644 --- a/docs/examples/03_algorithms/05_run_tof_variance.py +++ b/docs/examples/03_algorithms/05_run_tof_variance.py @@ -324,7 +324,7 @@ # --------------------- # # These two functions implement SVRG exactly as in -# :ref:`sphx_glr_examples_03_algorithms_00_run_mlem_osem_svrg.py`. +# :ref:`sphx_glr_03_algorithms_00_run_mlem_osem_svrg.py`. # ``svrg_calc_snapshot_gradients`` computes and stores all per-subset # gradients at the current anchor point; ``svrg_update`` performs a single # variance-reduced subset step. From 189e69006d1e8518cd3c84ea7b5923ba5c7f0d5b Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 09:27:04 +0200 Subject: [PATCH 05/17] wip cross ref --- docs/examples/03_algorithms/05_run_tof_variance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/03_algorithms/05_run_tof_variance.py b/docs/examples/03_algorithms/05_run_tof_variance.py index a56f1312..8deccfee 100644 --- a/docs/examples/03_algorithms/05_run_tof_variance.py +++ b/docs/examples/03_algorithms/05_run_tof_variance.py @@ -324,7 +324,7 @@ # --------------------- # # These two functions implement SVRG exactly as in -# :ref:`sphx_glr_03_algorithms_00_run_mlem_osem_svrg.py`. +# :ref:`sphx_glr_auto_examples_03_algorithms_00_run_mlem_osem_svrg.py`. # ``svrg_calc_snapshot_gradients`` computes and stores all per-subset # gradients at the current anchor point; ``svrg_update`` performs a single # variance-reduced subset step. From fd1252b5729a55186c0c6b54b18f29b6f9f43122 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 15:15:56 +0200 Subject: [PATCH 06/17] wip logcosh prior --- src/parallelproj/functions.py | 76 +++++++++++++++++++++++++++++++---- tests/test_functions.py | 69 +++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 7 deletions(-) diff --git a/src/parallelproj/functions.py b/src/parallelproj/functions.py index 338bc78e..550766b6 100644 --- a/src/parallelproj/functions.py +++ b/src/parallelproj/functions.py @@ -1,3 +1,4 @@ +import math from abc import ABC, abstractmethod from collections.abc import Sequence @@ -592,10 +593,7 @@ def _call(self, x: Array) -> float: safe_x = xp.where(self._mask, x, xp.ones_like(x)) return float( xp.sum( - x - - xp.where( - self._mask, self._data * xp.log(safe_x), xp.zeros_like(x) - ) + x - xp.where(self._mask, self._data * xp.log(safe_x), xp.zeros_like(x)) ) ) @@ -607,9 +605,7 @@ def _gradient(self, x: Array) -> Array: def _hessian_diag_vec_prod(self, x: Array, v: Array) -> Array: xp = get_namespace(x) safe_x = xp.where(self._mask, x, xp.ones_like(x)) - return ( - xp.where(self._mask, self._data / (safe_x**2), xp.zeros_like(x)) * v - ) + return xp.where(self._mask, self._data / (safe_x**2), xp.zeros_like(x)) * v def _prox_convex_conj(self, y: Array, sigma: float | Array) -> Array: """Proximal operator of the convex conjugate, safe for virtual bins. @@ -759,6 +755,72 @@ def _prox_convex_conj(self, y: Array, sigma: float | Array) -> Array: return self._weights * numerator / (self._weights + sigma) +class LogCosh(C2Function): + """Sum of log-cosh values, a smooth approximation to the L1 norm. + + Implements + + .. math:: + + f(x) = \\sum_i \\log(\\cosh(x_i)) + + which satisfies :math:`f(0) = 0` and behaves like + :math:`\\sum_i |x_i| - n \\log 2` for large :math:`|x_i|`. + + Gradient: + + .. math:: + + \\nabla f(x)_i = \\tanh(x_i) + + Diagonal Hessian-vector product: + + .. math:: + + \\operatorname{diag}(H_f(x))_i \\cdot v_i + = \\operatorname{sech}^2(x_i)\\, v_i + = (1 - \\tanh^2(x_i))\\, v_i + + The function value is computed via the numerically stable identity + + .. math:: + + \\log(\\cosh(x)) = |x| + \\log(1 + e^{-2|x|}) - \\log 2 + + which avoids the overflow that :math:`\\cosh(x) = (e^x + e^{-x})/2` + would cause for large :math:`|x|`. + + Parameters + ---------- + beta : float, optional + Multiplicative scale factor :math:`\\beta`. Defaults to ``1.0``. + """ + + def __init__(self, beta: float = 1.0): + self._log2 = math.log(2) + super().__init__(beta) + + def _call(self, x: Array) -> float: + xp = get_namespace(x) + ax = xp.abs(x) + return ( + float(xp.sum(ax + xp.log(1 + xp.exp(-2 * ax)))) + - math.prod(x.shape) * self._log2 + ) + + def _gradient(self, x: Array) -> Array: + xp = get_namespace(x) + return xp.tanh(x) + + def _call_and_gradient(self, x: Array) -> tuple[float, Array]: + return self._call(x), self._gradient(x) + + def _hessian_diag_vec_prod(self, x: Array, v: Array) -> Array: + xp = get_namespace(x) + t = xp.tanh(x) + return (1 - t**2) * v + + class SumC1Function(C1Function): """Sum of an arbitrary number of :class:`C1Function` objects. diff --git a/tests/test_functions.py b/tests/test_functions.py index bf532f2e..5b1c6711 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -326,6 +326,75 @@ def test_half_sq_l2_beta_scaling(xp: ModuleType, dev: str): assert allclose(f3.gradient(x), 3.0 * f1.gradient(x)) +# --------------------------------------------------------------------------- +# LogCosh +# --------------------------------------------------------------------------- + +_X_LC_NP = _np.asarray([1.0, -2.0, 0.5]) +_V_LC_NP = _np.asarray([1.0, -1.0, 2.0]) + + +def test_log_cosh_call_at_zero(xp: ModuleType, dev: str): + """f(0) must be exactly (up to float rounding) zero.""" + x = xp.asarray(_np.zeros(3), device=dev) + f = ppf.LogCosh() + assert abs(f(x)) < 1e-6 + + +def test_log_cosh_call(xp: ModuleType, dev: str): + x = xp.asarray(_X_LC_NP, device=dev) + f = ppf.LogCosh() + expected = float(_np.sum(_np.log(_np.cosh(_X_LC_NP)))) + assert abs(f(x) - expected) < 1e-5 + + +def test_log_cosh_gradient(xp: ModuleType, dev: str): + x = xp.asarray(_X_LC_NP, device=dev) + f = ppf.LogCosh() + grad = f.gradient(x) + fd_grad = finite_diff_gradient(f, _X_LC_NP, xp, dev) + assert allclose(grad, fd_grad, atol=1e-4, rtol=1e-4) + + +def test_log_cosh_call_and_gradient(xp: ModuleType, dev: str): + x = xp.asarray(_X_LC_NP, device=dev) + f = ppf.LogCosh() + val, grad = f.call_and_gradient(x) + assert abs(val - f(x)) < 1e-8 + assert allclose(grad, f.gradient(x)) + + +def test_log_cosh_hessian_diag_vec_prod(xp: ModuleType, dev: str): + """Hessian diagonal is sech^2(x) = 1 - tanh^2(x).""" + x = xp.asarray(_X_LC_NP, device=dev) + v = xp.asarray(_V_LC_NP, device=dev) + f = ppf.LogCosh() + hv = f.hessian_diag_vec_prod(x, v) + expected = xp.asarray( + (1 - _np.tanh(_X_LC_NP) ** 2) * _V_LC_NP, device=dev + ) + assert allclose(hv, expected, atol=1e-5) + + +def test_log_cosh_beta_scaling(xp: ModuleType, dev: str): + x = xp.asarray(_X_LC_NP, device=dev) + f1 = ppf.LogCosh(beta=1.0) + f3 = ppf.LogCosh(beta=3.0) + assert abs(f3(x) - 3.0 * f1(x)) < 1e-8 + assert allclose(f3.gradient(x), 3.0 * f1.gradient(x)) + + +def test_log_cosh_overflow_safe(xp: ModuleType, dev: str): + """Naive cosh(100) overflows float32; our stable form must not.""" + import math + + x = xp.asarray(_np.asarray([100.0, -100.0]), dtype=xp.float32, device=dev) + f = ppf.LogCosh() + expected = 2 * (100.0 - math.log(2.0)) + assert abs(f(x) - expected) < 1e-3 + assert all(xp.isfinite(f.gradient(x))) + + # --------------------------------------------------------------------------- # SumC1Function / SumC2Function (via __add__) # --------------------------------------------------------------------------- From 7e5f19cffb6ce5da25e2ea42b8d3521d74e4894a Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 15:23:30 +0200 Subject: [PATCH 07/17] add transition parameter delta --- src/parallelproj/functions.py | 52 +++++++++++++++++++++++----------- tests/test_functions.py | 53 ++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/parallelproj/functions.py b/src/parallelproj/functions.py index 550766b6..d6550fbd 100644 --- a/src/parallelproj/functions.py +++ b/src/parallelproj/functions.py @@ -756,69 +756,89 @@ def _prox_convex_conj(self, y: Array, sigma: float | Array) -> Array: class LogCosh(C2Function): - """Sum of log-cosh values, a smooth approximation to the L1 norm. + """Sum of scaled log-cosh values, a smooth approximation to the L1 norm. Implements .. math:: - f(x) = \\sum_i \\log(\\cosh(x_i)) + f(x) = \\sum_i \\log\\!\\left(\\cosh\\!\\left(\\frac{x_i}{\\delta}\\right)\\right) - which satisfies :math:`f(0) = 0` and behaves like - :math:`\\sum_i |x_i| - n \\log 2` for large :math:`|x_i|`. + where :math:`\\delta > 0` is a transition scale parameter (default 1). + The function satisfies :math:`f(0) = 0` and has two limiting regimes: + + * **Quadratic** for :math:`|x_i| \\ll \\delta`: + :math:`\\log(\\cosh(u)) \\approx u^2/2`, so + :math:`f(x) \\approx \\tfrac{1}{2\\delta^2}\\sum_i x_i^2`. + * **Linear** for :math:`|x_i| \\gg \\delta`: + :math:`f(x) \\approx \\tfrac{1}{\\delta}\\sum_i |x_i| - n\\log 2`. Gradient: .. math:: - \\nabla f(x)_i = \\tanh(x_i) + \\nabla f(x)_i = \\frac{1}{\\delta}\\,\\tanh\\!\\left(\\frac{x_i}{\\delta}\\right) Diagonal Hessian-vector product: .. math:: \\operatorname{diag}(H_f(x))_i \\cdot v_i - = \\operatorname{sech}^2(x_i)\\, v_i - = (1 - \\tanh^2(x_i))\\, v_i + = \\frac{1}{\\delta^2}\\,\\operatorname{sech}^2\\!\\left(\\frac{x_i}{\\delta}\\right) v_i + = \\frac{1 - \\tanh^2(x_i/\\delta)}{\\delta^2}\\, v_i The function value is computed via the numerically stable identity .. math:: - \\log(\\cosh(x)) = |x| + \\log(1 + e^{-2|x|}) - \\log 2 + \\log(\\cosh(z)) = |z| + \\log(1 + e^{-2|z|}) - \\log 2, \\quad z = x/\\delta - which avoids the overflow that :math:`\\cosh(x) = (e^x + e^{-x})/2` - would cause for large :math:`|x|`. + which avoids the overflow that :math:`\\cosh(z) = (e^z + e^{-z})/2` + would cause for large :math:`|z|`. Parameters ---------- + delta : float or None, optional + Transition scale :math:`\\delta > 0`. ``None`` (default) is + equivalent to :math:`\\delta = 1` but skips the division entirely. beta : float, optional Multiplicative scale factor :math:`\\beta`. Defaults to ``1.0``. """ - def __init__(self, beta: float = 1.0): + def __init__(self, delta: float | None = None, beta: float = 1.0): + self._delta = delta self._log2 = math.log(2) super().__init__(beta) + @property + def delta(self) -> float | None: + """Transition scale :math:`\\delta`.""" + return self._delta + def _call(self, x: Array) -> float: xp = get_namespace(x) - ax = xp.abs(x) + z = x if self._delta is None else x / self._delta + az = xp.abs(z) return ( - float(xp.sum(ax + xp.log(1 + xp.exp(-2 * ax)))) + float(xp.sum(az + xp.log(1 + xp.exp(-2 * az)))) - math.prod(x.shape) * self._log2 ) def _gradient(self, x: Array) -> Array: xp = get_namespace(x) - return xp.tanh(x) + z = x if self._delta is None else x / self._delta + t = xp.tanh(z) + return t if self._delta is None else t / self._delta def _call_and_gradient(self, x: Array) -> tuple[float, Array]: return self._call(x), self._gradient(x) def _hessian_diag_vec_prod(self, x: Array, v: Array) -> Array: xp = get_namespace(x) - t = xp.tanh(x) - return (1 - t**2) * v + z = x if self._delta is None else x / self._delta + t = xp.tanh(z) + h = 1 - t**2 + return h * v if self._delta is None else h * v / self._delta**2 class SumC1Function(C1Function): diff --git a/tests/test_functions.py b/tests/test_functions.py index 5b1c6711..6d7f1fab 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -332,54 +332,67 @@ def test_half_sq_l2_beta_scaling(xp: ModuleType, dev: str): _X_LC_NP = _np.asarray([1.0, -2.0, 0.5]) _V_LC_NP = _np.asarray([1.0, -1.0, 2.0]) +_DELTA = 2.0 def test_log_cosh_call_at_zero(xp: ModuleType, dev: str): - """f(0) must be exactly (up to float rounding) zero.""" + """f(0) must be exactly (up to float rounding) zero for any delta.""" x = xp.asarray(_np.zeros(3), device=dev) - f = ppf.LogCosh() - assert abs(f(x)) < 1e-6 + for f in (ppf.LogCosh(), ppf.LogCosh(delta=_DELTA)): + assert abs(f(x)) < 1e-6 def test_log_cosh_call(xp: ModuleType, dev: str): x = xp.asarray(_X_LC_NP, device=dev) + f = ppf.LogCosh() expected = float(_np.sum(_np.log(_np.cosh(_X_LC_NP)))) assert abs(f(x) - expected) < 1e-5 + f_d = ppf.LogCosh(delta=_DELTA) + expected_d = float(_np.sum(_np.log(_np.cosh(_X_LC_NP / _DELTA)))) + assert abs(f_d(x) - expected_d) < 1e-5 + def test_log_cosh_gradient(xp: ModuleType, dev: str): x = xp.asarray(_X_LC_NP, device=dev) + f = ppf.LogCosh() - grad = f.gradient(x) - fd_grad = finite_diff_gradient(f, _X_LC_NP, xp, dev) - assert allclose(grad, fd_grad, atol=1e-4, rtol=1e-4) + assert allclose(f.gradient(x), finite_diff_gradient(f, _X_LC_NP, xp, dev), atol=1e-4, rtol=1e-4) + + f_d = ppf.LogCosh(delta=_DELTA) + assert allclose(f_d.gradient(x), finite_diff_gradient(f_d, _X_LC_NP, xp, dev), atol=1e-4, rtol=1e-4) def test_log_cosh_call_and_gradient(xp: ModuleType, dev: str): x = xp.asarray(_X_LC_NP, device=dev) - f = ppf.LogCosh() - val, grad = f.call_and_gradient(x) - assert abs(val - f(x)) < 1e-8 - assert allclose(grad, f.gradient(x)) + + for f in (ppf.LogCosh(), ppf.LogCosh(delta=_DELTA)): + val, grad = f.call_and_gradient(x) + assert abs(val - f(x)) < 1e-8 + assert allclose(grad, f.gradient(x)) def test_log_cosh_hessian_diag_vec_prod(xp: ModuleType, dev: str): - """Hessian diagonal is sech^2(x) = 1 - tanh^2(x).""" + """Hessian diagonal is sech^2(x/delta)/delta^2 = (1 - tanh^2(x/delta))/delta^2.""" x = xp.asarray(_X_LC_NP, device=dev) v = xp.asarray(_V_LC_NP, device=dev) + f = ppf.LogCosh() - hv = f.hessian_diag_vec_prod(x, v) - expected = xp.asarray( - (1 - _np.tanh(_X_LC_NP) ** 2) * _V_LC_NP, device=dev + expected = xp.asarray((1 - _np.tanh(_X_LC_NP) ** 2) * _V_LC_NP, device=dev) + assert allclose(f.hessian_diag_vec_prod(x, v), expected, atol=1e-5) + + f_d = ppf.LogCosh(delta=_DELTA) + expected_d = xp.asarray( + (1 - _np.tanh(_X_LC_NP / _DELTA) ** 2) / _DELTA**2 * _V_LC_NP, device=dev ) - assert allclose(hv, expected, atol=1e-5) + assert allclose(f_d.hessian_diag_vec_prod(x, v), expected_d, atol=1e-5) def test_log_cosh_beta_scaling(xp: ModuleType, dev: str): x = xp.asarray(_X_LC_NP, device=dev) - f1 = ppf.LogCosh(beta=1.0) - f3 = ppf.LogCosh(beta=3.0) + f1 = ppf.LogCosh(delta=_DELTA, beta=1.0) + f3 = ppf.LogCosh(delta=_DELTA, beta=3.0) assert abs(f3(x) - 3.0 * f1(x)) < 1e-8 assert allclose(f3.gradient(x), 3.0 * f1.gradient(x)) @@ -389,11 +402,17 @@ def test_log_cosh_overflow_safe(xp: ModuleType, dev: str): import math x = xp.asarray(_np.asarray([100.0, -100.0]), dtype=xp.float32, device=dev) + f = ppf.LogCosh() expected = 2 * (100.0 - math.log(2.0)) assert abs(f(x) - expected) < 1e-3 assert all(xp.isfinite(f.gradient(x))) + f_d = ppf.LogCosh(delta=0.5) + expected_d = 2 * (100.0 / 0.5 - math.log(2.0)) + assert abs(f_d(x) - expected_d) < 1e-1 + assert all(xp.isfinite(f_d.gradient(x))) + # --------------------------------------------------------------------------- # SumC1Function / SumC2Function (via __add__) From 35c94a80991fba13f174e7384deae9e036093986 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 16:08:37 +0200 Subject: [PATCH 08/17] wip logcosh SGD / SVRG example --- .../examples/03_algorithms/01_run_sgd_svrg.py | 80 +++++++++++++------ 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/docs/examples/03_algorithms/01_run_sgd_svrg.py b/docs/examples/03_algorithms/01_run_sgd_svrg.py index d1c09657..fe1ca694 100644 --- a/docs/examples/03_algorithms/01_run_sgd_svrg.py +++ b/docs/examples/03_algorithms/01_run_sgd_svrg.py @@ -1,6 +1,6 @@ """ -Convergence comparison: SGD vs SVRG with quadratic regularization -================================================================= +Convergence comparison: SGD vs SVRG with logcosh regularization +================================================================ This example compares the convergence speed (per epoch) of two algorithms for minimising the regularised negative Poisson log-likelihood @@ -10,12 +10,17 @@ + \\beta \\, R(x), \\qquad \\bar{y}(x) = A x + s -where the quadratic penalty is +where the edge-preserving logcosh penalty is .. math:: - R(x) = \\frac{1}{2} \\| G x \\|_2^2 - -and :math:`G` is the finite forward-difference operator. + R(x) = \\sum_i \\log\\!\\cosh\\!\\left(\\frac{(Gx)_i}{\\delta}\\right) + +and :math:`G` is the finite forward-difference operator. The scale +:math:`\\delta` sets the transition between the quadratic regime +(:math:`|(Gx)_i| \\ll \\delta`) and the linear (L1-like) regime +(:math:`|(Gx)_i| \\gg \\delta`). Gradients at true edges are placed +in the linear regime by choosing :math:`\\delta` well below the typical +edge gradient in the ground truth image. The objective is decomposed into :math:`m` subset functions .. math:: @@ -56,7 +61,7 @@ from parallelproj import to_numpy_array, Array from parallelproj.functions import ( NegPoissonLogL, - HalfSquaredL2Deviation, + LogCosh, C2AffineObjective, C1Function, ) @@ -74,17 +79,22 @@ # %% # number of subsets for SGD and SVRG -num_subsets = 24 +num_subsets = 12 # if run on a CPU limit the number of epochs -num_epochs = (120 if dev == "cpu" else 1200) // num_subsets +num_epochs = (120 if dev == "cpu" else 240) // num_subsets # regularisation weight beta -beta = 3.0 +beta = 0.3 +# delta value relative to max of ground truth image for logcosh prior +delta_rel = 0.1 # step size for SGD and SVRG updates step_size = 1.0 +# factor that scales the ground truth image (also reconstruction) and the number of counts +count_factor = 1.0 + # %% # Setup of the forward model :math:`\bar{y}(x) = A x + s` # -------------------------------------------------------- @@ -122,7 +132,7 @@ ) # setup a simple test image containing a few "hot rods" -x_true = elliptic_cylinder_phantom( +x_true = count_factor * elliptic_cylinder_phantom( xp, dev, image_shape=img_shape, voxel_size=voxel_size ) @@ -240,20 +250,28 @@ # Regularisation and subset objective functions # --------------------------------------------- # -# The quadratic penalty :math:`R(x) = \frac{1}{2} \| G x \|_2^2` is built -# from the :class:`.FiniteForwardDifference` operator :math:`G`. +# The logcosh penalty :math:`R(x) = \sum_i \log\cosh((Gx)_i/\delta)` is +# built from the :class:`.FiniteForwardDifference` operator :math:`G` and +# :class:`.LogCosh`. # The full regulariser ``reg`` (weight :math:`\beta`) is used only for # the total objective evaluation. Each subset function # # .. math:: # f_k(x) = \sum_{i \in S_k} \left( \bar{y}_i(x) - y_i \log \bar{y}_i(x) \right) + \frac{\beta}{m} R(x) # -# is formed by adding a :class:`.HalfSquaredL2Deviation` scaled by -# :math:`\beta / m` to the subset data fidelity, so that -# :math:`\sum_k f_k(x) = F(x)`. +# is formed by adding a :class:`.LogCosh` scaled by :math:`\beta / m` to +# the subset data fidelity, so that :math:`\sum_k f_k(x) = F(x)`. +# +# ``delta`` is set to one third of the median nonzero finite-difference +# magnitude in the ground truth image. This places typical edge gradients +# (~3× delta) firmly in the linear regime of logcosh while smooth-region +# gradients (~0) remain quadratic. G = parallelproj.operators.FiniteForwardDifference(pet_lin_op.in_shape) -reg = C2AffineObjective(HalfSquaredL2Deviation(beta=beta), G) + +delta = float(xp.max(x_true)) * delta_rel + +reg = C2AffineObjective(LogCosh(delta=delta, beta=beta), G) # %% # Setup of objective functions and sensitivity image @@ -268,7 +286,7 @@ ) # reg/m term shared by all subset objectives -reg_per_subset = C2AffineObjective(HalfSquaredL2Deviation(beta=beta / num_subsets), G) +reg_per_subset = C2AffineObjective(LogCosh(delta=delta, beta=beta / num_subsets), G) # f_k = data_fidelity_k + (beta/m) * R(x) subset_objectives = [ @@ -318,8 +336,11 @@ # .. math:: # x^+ = \left(x - D\, m\,\nabla f_k(x)\right)_+. -df_sgd = xp.zeros(num_epochs, dtype=xp.float32, device=dev) +df_sgd = xp.zeros(num_epochs + 1, dtype=xp.float32, device=dev) x_sgd = xp.asarray(x_init, copy=True) +df_sgd[0] = total_objective(x_sgd) +sgd_recons = xp.zeros((num_epochs + 1,) + img_shape) +sgd_recons[0, ...] = x_sgd for i in range(num_epochs): if i % 2 == 0 and i <= 4: @@ -334,7 +355,8 @@ approx_grad = num_subsets * subset_objectives[k].gradient(x_sgd) x_sgd = xp.clip(x_sgd - step_size * sgd_precond * approx_grad, 0, None) - df_sgd[i] = total_objective(x_sgd) + df_sgd[i + 1] = total_objective(x_sgd) + sgd_recons[i + 1, ...] = x_sgd print() # %% @@ -389,8 +411,11 @@ def svrg_update( x_svrg = xp.asarray(x_init, copy=True) +svrg_recons = xp.zeros((num_epochs + 1,) + img_shape) +svrg_recons[0, ...] = x_svrg -df_svrg = xp.zeros(num_epochs, dtype=xp.float32, device=dev) +df_svrg = xp.zeros(num_epochs + 1, dtype=xp.float32, device=dev) +df_svrg[0] = total_objective(x_svrg) for epoch in range(num_epochs): if epoch % 2 == 0: @@ -419,7 +444,8 @@ def svrg_update( step_size=step_size, ) - df_svrg[epoch] = total_objective(x_svrg) + df_svrg[epoch + 1] = total_objective(x_svrg) + svrg_recons[epoch + 1, ...] = x_svrg # %% # Convergence comparison @@ -431,10 +457,12 @@ def svrg_update( # an anchor phase costs two full data passes (snapshot + subset updates), # and one full pass otherwise. -epochs = np.arange(1, num_epochs + 1) +epochs = np.arange(num_epochs + 1) osem_passes = epochs.copy() -svrg_passes_per_epoch = np.where(np.arange(num_epochs) % 2 == 0, 2, 1) +svrg_passes_per_epoch = np.concatenate( + [[0], np.where(np.arange(num_epochs) % 2 == 0, 2, 1)] +) svrg_cumulative_passes = np.cumsum(svrg_passes_per_epoch) df_min = min(float(xp.min(df_sgd)), float(xp.min(df_svrg))) @@ -471,12 +499,12 @@ def svrg_update( # %% fig, axs, widgets = show_vol_cuts( - to_numpy_array(x_sgd), voxel_size=voxel_size, fig_title="SGD result" + to_numpy_array(sgd_recons), voxel_size=voxel_size, fig_title="SGD result" ) fig.show() # %% fig2, axs2, widgets = show_vol_cuts( - to_numpy_array(x_svrg), voxel_size=voxel_size, fig_title="SVRG result" + to_numpy_array(svrg_recons), voxel_size=voxel_size, fig_title="SVRG result" ) fig2.show() From f5d2d0323b64ba2c7519f34a4b55ed35a7809f0d Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 16:41:41 +0200 Subject: [PATCH 09/17] wip log(cosh) --- .../examples/03_algorithms/01_run_sgd_svrg.py | 38 ++++++++++++------- src/parallelproj/functions.py | 38 +++++++++++-------- tests/test_functions.py | 11 ++++-- 3 files changed, 54 insertions(+), 33 deletions(-) diff --git a/docs/examples/03_algorithms/01_run_sgd_svrg.py b/docs/examples/03_algorithms/01_run_sgd_svrg.py index fe1ca694..09e19ff7 100644 --- a/docs/examples/03_algorithms/01_run_sgd_svrg.py +++ b/docs/examples/03_algorithms/01_run_sgd_svrg.py @@ -13,14 +13,22 @@ where the edge-preserving logcosh penalty is .. math:: - R(x) = \\sum_i \\log\\!\\cosh\\!\\left(\\frac{(Gx)_i}{\\delta}\\right) - -and :math:`G` is the finite forward-difference operator. The scale -:math:`\\delta` sets the transition between the quadratic regime -(:math:`|(Gx)_i| \\ll \\delta`) and the linear (L1-like) regime -(:math:`|(Gx)_i| \\gg \\delta`). Gradients at true edges are placed -in the linear regime by choosing :math:`\\delta` well below the typical -edge gradient in the ground truth image. + R(x) = \\delta \\sum_i \\log\\!\\cosh\\!\\left(\\frac{(Gx)_i}{\\delta}\\right) + +and :math:`G` is the finite forward-difference operator. The :math:`\\delta` +prefactor ensures the asymptotic gradient magnitude equals 1 regardless of +:math:`\\delta`, so the regularisation strength :math:`\\beta` retains the +same meaning across different choices of :math:`\\delta`. The scale +:math:`\\delta` itself controls the transition between two regimes: + +* **Quadratic** for :math:`|(Gx)_i| \\ll \\delta`: + :math:`R(x) \\approx \\tfrac{1}{2\\delta}\\|Gx\\|_2^2`. +* **Linear** for :math:`|(Gx)_i| \\gg \\delta`: + :math:`R(x) \\approx \\|Gx\\|_1 - n\\,\\delta\\log 2 \\approx \\|Gx\\|_1`. + +Setting :math:`\\delta` well below the typical edge gradient places true +edges in the linear regime (edge-preserving) while penalising smooth-region +deviations quadratically. The objective is decomposed into :math:`m` subset functions .. math:: @@ -85,7 +93,7 @@ num_epochs = (120 if dev == "cpu" else 240) // num_subsets # regularisation weight beta -beta = 0.3 +beta = 1.0 # delta value relative to max of ground truth image for logcosh prior delta_rel = 0.1 @@ -250,7 +258,8 @@ # Regularisation and subset objective functions # --------------------------------------------- # -# The logcosh penalty :math:`R(x) = \sum_i \log\cosh((Gx)_i/\delta)` is +# The logcosh penalty +# :math:`R(x) = \delta \sum_i \log\cosh\!\left((Gx)_i/\delta\right)` is # built from the :class:`.FiniteForwardDifference` operator :math:`G` and # :class:`.LogCosh`. # The full regulariser ``reg`` (weight :math:`\beta`) is used only for @@ -262,10 +271,11 @@ # is formed by adding a :class:`.LogCosh` scaled by :math:`\beta / m` to # the subset data fidelity, so that :math:`\sum_k f_k(x) = F(x)`. # -# ``delta`` is set to one third of the median nonzero finite-difference -# magnitude in the ground truth image. This places typical edge gradients -# (~3× delta) firmly in the linear regime of logcosh while smooth-region -# gradients (~0) remain quadratic. +# ``delta`` is set to ``delta_rel`` times the maximum of the ground truth +# image. With ``delta_rel = 0.1`` edges with gradient equal to the image +# maximum have :math:`|(Gx)|/\delta = 10`, placing them firmly in the +# linear regime (:math:`\tanh(10) \approx 1`), while smooth-region +# gradients near zero remain quadratic. G = parallelproj.operators.FiniteForwardDifference(pet_lin_op.in_shape) diff --git a/src/parallelproj/functions.py b/src/parallelproj/functions.py index d6550fbd..54e8ecac 100644 --- a/src/parallelproj/functions.py +++ b/src/parallelproj/functions.py @@ -762,36 +762,41 @@ class LogCosh(C2Function): .. math:: - f(x) = \\sum_i \\log\\!\\left(\\cosh\\!\\left(\\frac{x_i}{\\delta}\\right)\\right) + f(x) = \\delta \\sum_i \\log\\!\\left(\\cosh\\!\\left(\\frac{x_i}{\\delta}\\right)\\right) where :math:`\\delta > 0` is a transition scale parameter (default 1). The function satisfies :math:`f(0) = 0` and has two limiting regimes: * **Quadratic** for :math:`|x_i| \\ll \\delta`: - :math:`\\log(\\cosh(u)) \\approx u^2/2`, so - :math:`f(x) \\approx \\tfrac{1}{2\\delta^2}\\sum_i x_i^2`. + :math:`\\delta\\log(\\cosh(u)) \\approx u^2/2`, so + :math:`f(x) \\approx \\tfrac{1}{2\\delta}\\sum_i x_i^2`. * **Linear** for :math:`|x_i| \\gg \\delta`: - :math:`f(x) \\approx \\tfrac{1}{\\delta}\\sum_i |x_i| - n\\log 2`. + :math:`f(x) \\approx \\sum_i |x_i| - n\\,\\delta\\log 2 \\approx \\sum_i |x_i|`. + + The :math:`\\delta` prefactor ensures the asymptotic slope equals 1 + regardless of :math:`\\delta`, so the transition scale and the gradient + magnitude at saturation are decoupled. Gradient: .. math:: - \\nabla f(x)_i = \\frac{1}{\\delta}\\,\\tanh\\!\\left(\\frac{x_i}{\\delta}\\right) + \\nabla f(x)_i = \\tanh\\!\\left(\\frac{x_i}{\\delta}\\right) Diagonal Hessian-vector product: .. math:: \\operatorname{diag}(H_f(x))_i \\cdot v_i - = \\frac{1}{\\delta^2}\\,\\operatorname{sech}^2\\!\\left(\\frac{x_i}{\\delta}\\right) v_i - = \\frac{1 - \\tanh^2(x_i/\\delta)}{\\delta^2}\\, v_i + = \\frac{1}{\\delta}\\,\\operatorname{sech}^2\\!\\left(\\frac{x_i}{\\delta}\\right) v_i + = \\frac{1 - \\tanh^2(x_i/\\delta)}{\\delta}\\, v_i The function value is computed via the numerically stable identity .. math:: - \\log(\\cosh(z)) = |z| + \\log(1 + e^{-2|z|}) - \\log 2, \\quad z = x/\\delta + \\delta\\log(\\cosh(z)) = \\delta\\bigl(|z| + \\log(1 + e^{-2|z|}) - \\log 2\\bigr), + \\quad z = x/\\delta which avoids the overflow that :math:`\\cosh(z) = (e^z + e^{-z})/2` would cause for large :math:`|z|`. @@ -819,26 +824,27 @@ def _call(self, x: Array) -> float: xp = get_namespace(x) z = x if self._delta is None else x / self._delta az = xp.abs(z) - return ( - float(xp.sum(az + xp.log(1 + xp.exp(-2 * az)))) - - math.prod(x.shape) * self._log2 - ) + raw = float(xp.sum(az + xp.log(1 + xp.exp(-2 * az)))) - math.prod(x.shape) * self._log2 + return raw if self._delta is None else self._delta * raw def _gradient(self, x: Array) -> Array: xp = get_namespace(x) z = x if self._delta is None else x / self._delta - t = xp.tanh(z) - return t if self._delta is None else t / self._delta + return xp.tanh(z) def _call_and_gradient(self, x: Array) -> tuple[float, Array]: - return self._call(x), self._gradient(x) + xp = get_namespace(x) + z = x if self._delta is None else x / self._delta + az = xp.abs(z) + raw = float(xp.sum(az + xp.log(1 + xp.exp(-2 * az)))) - math.prod(x.shape) * self._log2 + return (raw if self._delta is None else self._delta * raw), xp.tanh(z) def _hessian_diag_vec_prod(self, x: Array, v: Array) -> Array: xp = get_namespace(x) z = x if self._delta is None else x / self._delta t = xp.tanh(z) h = 1 - t**2 - return h * v if self._delta is None else h * v / self._delta**2 + return h * v if self._delta is None else h * v / self._delta class SumC1Function(C1Function): diff --git a/tests/test_functions.py b/tests/test_functions.py index 6d7f1fab..ef7c2e9a 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -335,6 +335,11 @@ def test_half_sq_l2_beta_scaling(xp: ModuleType, dev: str): _DELTA = 2.0 +def test_log_cosh_delta_property(xp: ModuleType, dev: str): + assert ppf.LogCosh().delta is None + assert ppf.LogCosh(delta=_DELTA).delta == _DELTA + + def test_log_cosh_call_at_zero(xp: ModuleType, dev: str): """f(0) must be exactly (up to float rounding) zero for any delta.""" x = xp.asarray(_np.zeros(3), device=dev) @@ -350,7 +355,7 @@ def test_log_cosh_call(xp: ModuleType, dev: str): assert abs(f(x) - expected) < 1e-5 f_d = ppf.LogCosh(delta=_DELTA) - expected_d = float(_np.sum(_np.log(_np.cosh(_X_LC_NP / _DELTA)))) + expected_d = float(_DELTA * _np.sum(_np.log(_np.cosh(_X_LC_NP / _DELTA)))) assert abs(f_d(x) - expected_d) < 1e-5 @@ -384,7 +389,7 @@ def test_log_cosh_hessian_diag_vec_prod(xp: ModuleType, dev: str): f_d = ppf.LogCosh(delta=_DELTA) expected_d = xp.asarray( - (1 - _np.tanh(_X_LC_NP / _DELTA) ** 2) / _DELTA**2 * _V_LC_NP, device=dev + (1 - _np.tanh(_X_LC_NP / _DELTA) ** 2) / _DELTA * _V_LC_NP, device=dev ) assert allclose(f_d.hessian_diag_vec_prod(x, v), expected_d, atol=1e-5) @@ -409,7 +414,7 @@ def test_log_cosh_overflow_safe(xp: ModuleType, dev: str): assert all(xp.isfinite(f.gradient(x))) f_d = ppf.LogCosh(delta=0.5) - expected_d = 2 * (100.0 / 0.5 - math.log(2.0)) + expected_d = 2 * (100.0 - 0.5 * math.log(2.0)) assert abs(f_d(x) - expected_d) < 1e-1 assert all(xp.isfinite(f_d.gradient(x))) From d5e734225ecea35c37b01efd8ce24f557a7c87a8 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 21:46:30 +0200 Subject: [PATCH 10/17] add zigzag sampling --- .../05_run_zigzag_comparison.py | 147 ++++++++++++++++++ src/parallelproj/pet_lors.py | 65 ++++++-- tests/test_pet_lors.py | 57 +++++++ 3 files changed, 259 insertions(+), 10 deletions(-) create mode 100644 docs/examples/01_pet_geometry/05_run_zigzag_comparison.py diff --git a/docs/examples/01_pet_geometry/05_run_zigzag_comparison.py b/docs/examples/01_pet_geometry/05_run_zigzag_comparison.py new file mode 100644 index 00000000..4b7e33cc --- /dev/null +++ b/docs/examples/01_pet_geometry/05_run_zigzag_comparison.py @@ -0,0 +1,147 @@ +""" +Zig-zag LOR sampling: END_FIRST vs START_FIRST +=============================================== + +For a given sinogram view, a regular polygon PET scanner connects pairs of +in-ring detector endpoints in a zig-zag pattern as the radial bin index moves +from the central LOR toward the sinogram edges. Two conventions exist for the +ordering of those pairs: + +* **END_FIRST** (default): the *end* detector steps to the next position before + the start detector does. + Pairs at view 0 (n=8): (0,7), (0,6), (1,6), (1,5), (2,5), (2,4), (3,4), (3,3) + +* **START_FIRST**: the *start* detector steps first. + Pairs at view 0 (n=8): (0,7), (1,7), (1,6), (2,6), (2,5), (3,5), (3,4), (4,4) + +:class:`.SinogramZigZagOrder` selects the convention via the ``zig_zag_order`` +parameter of :class:`.RegularPolygonPETLORDescriptor`. + +This example visualises both conventions for a minimal scanner with 1 ring and +8 detector endpoints. +""" + +# %% +import numpy as np +import matplotlib.pyplot as plt +import parallelproj.pet_scanners +import parallelproj.pet_lors + +# %% +from array_utils import suggest_array_backend_and_device + +xp, dev = suggest_array_backend_and_device(None, None) + +# %% +# Scanner setup +# ------------- +# One ring with 8 detector endpoints, no radial trimming. + +n_endpoints = 8 +scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=100.0, + num_sides=n_endpoints, + num_lor_endpoints_per_side=1, + lor_spacing=1.0, + ring_positions=xp.asarray([0.0], device=dev), + symmetry_axis=2, +) + +# %% +# Build LOR descriptors for both zig-zag conventions +# --------------------------------------------------- + +lor_end_first = parallelproj.pet_lors.RegularPolygonPETLORDescriptor( + scanner, + radial_trim=0, + zig_zag_order=parallelproj.pet_lors.SinogramZigZagOrder.END_FIRST, +) + +lor_start_first = parallelproj.pet_lors.RegularPolygonPETLORDescriptor( + scanner, + radial_trim=0, + zig_zag_order=parallelproj.pet_lors.SinogramZigZagOrder.START_FIRST, +) + +# %% +# Print the (start, end) detector index pairs for each view +# --------------------------------------------------------- + +print("END_FIRST — (start, end) detector pairs per view:") +for view in range(lor_end_first.num_views): + s = lor_end_first.start_in_ring_index[view, :].tolist() + e = lor_end_first.end_in_ring_index[view, :].tolist() + pairs = list(zip(s, e)) + print(f" view {view}: {pairs}") + +print() +print("START_FIRST — (start, end) detector pairs per view:") +for view in range(lor_start_first.num_views): + s = lor_start_first.start_in_ring_index[view, :].tolist() + e = lor_start_first.end_in_ring_index[view, :].tolist() + pairs = list(zip(s, e)) + print(f" view {view}: {pairs}") + +# %% +# Visualisation: all LORs coloured by radial bin for view 0 +# ---------------------------------------------------------- +# Detector endpoint positions lie on a circle. + +angles = 2 * np.pi * np.arange(n_endpoints) / n_endpoints +xdet = np.cos(angles) +ydet = np.sin(angles) + +cmap = plt.colormaps["tab10"].resampled(lor_end_first.num_rad) + +fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + +for ax, lor_desc, title in zip( + axes, + [lor_end_first, lor_start_first], + ["END_FIRST (default)", "START_FIRST"], +): + # draw detector ring + circle = plt.Circle((0, 0), 1.0, fill=False, color="gray", lw=1, ls="--") + ax.add_patch(circle) + + # mark detector endpoints + ax.scatter(xdet, ydet, color="k", zorder=5, s=40) + for idx in range(n_endpoints): + ax.text( + 1.12 * xdet[idx], + 1.12 * ydet[idx], + str(idx), + ha="center", + va="center", + fontsize=9, + ) + + # draw LORs for view 0, coloured by radial bin + view = 0 + s_idx = lor_desc.start_in_ring_index[view, :].tolist() + e_idx = lor_desc.end_in_ring_index[view, :].tolist() + for rad_bin, (si, ei) in enumerate(zip(s_idx, e_idx)): + color = cmap(rad_bin) + ax.plot( + [xdet[si], xdet[ei]], + [ydet[si], ydet[ei]], + color=color, + lw=2, + label=f"rad {rad_bin}: ({si},{ei})", + ) + + ax.set_xlim(-1.35, 1.35) + ax.set_ylim(-1.35, 1.35) + ax.set_aspect("equal") + ax.legend(fontsize=7, loc="lower right") + ax.set_title(f"View 0 — {title}") + ax.axis("off") + +fig.suptitle( + f"Zig-zag LOR sampling for view 0 (n={n_endpoints} detectors, radial_trim=0)", + fontsize=11, +) +fig.tight_layout() +plt.show() diff --git a/src/parallelproj/pet_lors.py b/src/parallelproj/pet_lors.py index a3109392..c430b6ad 100644 --- a/src/parallelproj/pet_lors.py +++ b/src/parallelproj/pet_lors.py @@ -41,6 +41,28 @@ class SinogramSpatialAxisOrder(enum.Enum): """[plane,view,radial]""" +class SinogramZigZagOrder(enum.Enum): + """Zig-zag ordering of in-ring detector pairs for each sinogram view. + + For a scanner with :math:`n` detector endpoints per ring and view index 0, + the two variants differ in which detector (start or end) steps first as the + radial bin index increases from the central LOR outward. + + ``END_FIRST`` + The *end* detector steps first for each new radial pair. + Pairs (start, end) at view 0: (0,n-1), (0,n-2), (1,n-2), (1,n-3), … + + ``START_FIRST`` + The *start* detector steps first for each new radial pair. + Pairs (start, end) at view 0: (0,n-1), (1,n-1), (1,n-2), (2,n-2), … + """ + + END_FIRST = enum.auto() + """End crystal steps first (default, historically used convention).""" + START_FIRST = enum.auto() + """Start crystal steps first.""" + + class Michelogram: """Axial plane layout for a cylindrical PET scanner under odd span. @@ -1012,6 +1034,7 @@ def __init__( michelogram: Michelogram | None = None, radial_trim: int = 3, sinogram_order: SinogramSpatialAxisOrder = SinogramSpatialAxisOrder.RVP, + zig_zag_order: SinogramZigZagOrder = SinogramZigZagOrder.END_FIRST, ) -> None: """ @@ -1032,6 +1055,9 @@ def __init__( sinogram_order : SinogramSpatialAxisOrder, optional the order of the sinogram axes. Defaults to ``SinogramSpatialAxisOrder.RVP``. + zig_zag_order : SinogramZigZagOrder, optional + the zig-zag ordering convention for in-ring detector pairs. + Defaults to ``SinogramZigZagOrder.END_FIRST``. """ super().__init__(scanner) @@ -1058,6 +1084,7 @@ def __init__( self._num_views = scanner.num_lor_endpoints_per_ring // 2 self._sinogram_order = sinogram_order + self._zig_zag_order = zig_zag_order # declare all attributes set by the setup methods so they are # visible in __init__ @@ -1178,6 +1205,11 @@ def sinogram_order(self) -> SinogramSpatialAxisOrder: """the order of the sinogram axes""" return self._sinogram_order + @property + def zig_zag_order(self) -> SinogramZigZagOrder: + """the zig-zag ordering convention for in-ring detector pairs""" + return self._zig_zag_order + @property def plane_axis_num(self) -> int: """the axis number of the plane axis""" @@ -1273,21 +1305,34 @@ def _setup_view_indices(self) -> None: (self._num_views, self._num_rad), dtype=self.xp.int32, device=self.dev ) + # slice for radial trimming; -0 == 0 in Python so guard explicitly + trim = self._radial_trim + rad_slc = slice(trim, -trim if trim > 0 else None) + for view in np.arange(self._num_views): + if self._zig_zag_order is SinogramZigZagOrder.END_FIRST: + # end crystal steps first: (0,n-1),(0,n-2),(1,n-2),(1,n-3),... + start_seq = self.xp.concat( + (self.xp.arange(m) // 2, self.xp.asarray([n // 2])) + ) + end_seq = self.xp.concat( + (self.xp.asarray([-1]), -((self.xp.arange(m) + 4) // 2)) + ) + else: + # start crystal steps first: (0,n-1),(1,n-1),(1,n-2),(2,n-2),... + start_seq = self.xp.concat( + ((self.xp.arange(m) + 1) // 2, self.xp.asarray([n // 2])) + ) + end_seq = self.xp.concat( + (self.xp.asarray([-1]), -((self.xp.arange(m) + 3) // 2)) + ) + self._start_in_ring_index[view, :] = self.xp.astype( - ( - self.xp.concat((self.xp.arange(m) // 2, self.xp.asarray([n // 2]))) - - int(view) - )[self._radial_trim : -self._radial_trim], + (start_seq - int(view))[rad_slc], self.xp.int32, ) self._end_in_ring_index[view, :] = self.xp.astype( - ( - self.xp.concat( - (self.xp.asarray([-1]), -((self.xp.arange(m) + 4) // 2)) - ) - - int(view) - )[self._radial_trim : -self._radial_trim], + (end_seq - int(view))[rad_slc], self.xp.int32, ) diff --git a/tests/test_pet_lors.py b/tests/test_pet_lors.py index 5759111f..72e2d3cf 100644 --- a/tests/test_pet_lors.py +++ b/tests/test_pet_lors.py @@ -347,6 +347,63 @@ def _check_plane(idx, exp_sz, exp_ez, exp_mult, exp_seg): _check_plane(46, 9.5, 1.5, 4, -1) +def test_zig_zag_order(xp: ModuleType, dev: str) -> None: + """zig_zag_order property and START_FIRST branch are exercised.""" + import numpy as np + + num_rings = 1 + scanner = pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=100.0, + num_sides=8, + num_lor_endpoints_per_side=1, + lor_spacing=1.0, + ring_positions=xp.asarray([0.0], device=dev), + symmetry_axis=2, + ) + + lor_ef = ppl.RegularPolygonPETLORDescriptor( + scanner, + radial_trim=0, + zig_zag_order=ppl.SinogramZigZagOrder.END_FIRST, + ) + lor_sf = ppl.RegularPolygonPETLORDescriptor( + scanner, + radial_trim=0, + zig_zag_order=ppl.SinogramZigZagOrder.START_FIRST, + ) + + # property is readable and returns the correct enum value + assert lor_ef.zig_zag_order is ppl.SinogramZigZagOrder.END_FIRST + assert lor_sf.zig_zag_order is ppl.SinogramZigZagOrder.START_FIRST + + # both share the same sinogram shape + assert lor_ef.spatial_sinogram_shape == lor_sf.spatial_sinogram_shape + + # view 0: END_FIRST has end step first, START_FIRST has start step first + # n=8: END_FIRST pairs: (0,7),(0,6),(1,6),... START_FIRST: (0,7),(1,7),(1,6),... + ef_start = to_numpy_array(lor_ef.start_in_ring_index[0, :]) + ef_end = to_numpy_array(lor_ef.end_in_ring_index[0, :]) + sf_start = to_numpy_array(lor_sf.start_in_ring_index[0, :]) + sf_end = to_numpy_array(lor_sf.end_in_ring_index[0, :]) + + # first radial bin is the same for both (central LOR) + assert int(ef_start[0]) == int(sf_start[0]) + assert int(ef_end[0]) == int(sf_end[0]) + + # second radial bin differs: END_FIRST keeps start, advances end; + # START_FIRST advances start, keeps end + assert int(ef_start[1]) == int(ef_start[0]) # start unchanged + assert int(sf_end[1]) == int(sf_end[0]) # end unchanged + assert int(ef_end[1]) != int(ef_end[0]) # end stepped + assert int(sf_start[1]) != int(sf_start[0]) # start stepped + + # the two conventions produce different index arrays + assert not np.array_equal(ef_start, sf_start) + assert not np.array_equal(ef_end, sf_end) + + def test_show_michelogram(xp: ModuleType, dev: str) -> None: num_rings = 3 scanner = pps.DemoPETScannerGeometry(xp, dev, num_rings, symmetry_axis=2) From 1b7998dde53df07d0ec49f0e0ab46cf608faedb6 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 21:51:38 +0200 Subject: [PATCH 11/17] add missing unit tests --- tests/test_backend.py | 17 +++++++++++++++++ tests/test_operators.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index a387193e..93601332 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3,6 +3,7 @@ import pytest import array_api_compat.numpy as np from types import ModuleType +from unittest.mock import patch from parallelproj import to_numpy_array, count_event_multiplicity from parallelproj._backend import empty_cuda_cache @@ -10,6 +11,22 @@ from .config import pytestmark +def test_version_fallback(xp: ModuleType, dev: str) -> None: + """Lines 5-6 of __init__.py: PackageNotFoundError causes __version__ = 'unknown'.""" + import importlib + import importlib.metadata + import parallelproj + from importlib.metadata import PackageNotFoundError + + with patch.object(importlib.metadata, "version", side_effect=PackageNotFoundError): + importlib.reload(parallelproj) + + assert parallelproj.__version__ == "unknown" + + # Restore normal state + importlib.reload(parallelproj) + + def test_event_multiplicity(xp: ModuleType, dev: str) -> None: events = xp.asarray( diff --git a/tests/test_operators.py b/tests/test_operators.py index b43e5665..846c98d6 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -433,6 +433,20 @@ def test_norm_verbose(xp: ModuleType, dev: str): A.norm(xp, dev, verbose=True, num_iter=2) +def test_cupy_import_fallback(xp: ModuleType, dev: str) -> None: + """Lines 21-22 of operators.py: cp = None when cupy is unavailable.""" + import sys + import importlib + import parallelproj.operators as ppo + + with patch.dict(sys.modules, {"cupy": None}): + importlib.reload(ppo) + assert ppo.cp is None + + # Restore normal state + importlib.reload(ppo) + + def test_gaussian_array_api_strict(xp: ModuleType, dev: str): """Covers the array_api_strict branch in GaussianFilterOperator._apply (line 502).""" if importlib.util.find_spec("array_api_strict") is None: From 473009e2fb69c393df6e836b048ee743848f3b24 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 22:30:16 +0200 Subject: [PATCH 12/17] wip counter clockwise endpoints --- .../01_run_regular_polygon_pet_scanner.py | 51 +++++++++++++++++++ src/parallelproj/pet_scanners.py | 49 ++++++++++++++++++ tests/test_pet_scanners.py | 44 ++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py b/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py index a8b3623b..e9ee2325 100644 --- a/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py +++ b/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py @@ -137,3 +137,54 @@ ax2a = fig2.add_subplot(111, projection="3d") open_scanner.show_lor_endpoints(ax2a) fig2.show() + +# %% +# Endpoint ordering: clockwise vs counterclockwise +# ------------------------------------------------ +# +# By default, endpoint indices increase **clockwise** when the ring is viewed +# from the positive symmetry-axis direction. +# :class:`.RingEndpointOrdering` lets you switch to **counterclockwise** ordering. +# The two conventions produce the same physical detector positions, but with +# reversed index assignment — endpoint ``k`` in the CCW scanner occupies the +# same position as endpoint ``N-1-k`` in the CW scanner. + +cw_scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=65.0, + num_sides=8, + num_lor_endpoints_per_side=2, + lor_spacing=20.0, + ring_positions=xp.asarray([0.0], device=dev), + symmetry_axis=2, + ring_endpoint_ordering=parallelproj.pet_scanners.RingEndpointOrdering.CLOCKWISE, +) + +ccw_scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=65.0, + num_sides=8, + num_lor_endpoints_per_side=2, + lor_spacing=20.0, + ring_positions=xp.asarray([0.0], device=dev), + symmetry_axis=2, + ring_endpoint_ordering=parallelproj.pet_scanners.RingEndpointOrdering.COUNTERCLOCKWISE, +) + +fig3, axes = plt.subplots( + 1, 2, figsize=(10, 5), subplot_kw={"projection": "3d"}, layout="constrained" +) + +for ax, scanner, title in zip( + axes, + [cw_scanner, ccw_scanner], + ["CLOCKWISE (default)", "COUNTERCLOCKWISE"], +): + scanner.show_lor_endpoints(ax, show_linear_index=True, annotation_fontsize=10) + ax.view_init(elev=90, azim=-90) + ax.set_title(f"ring_endpoint_ordering = {title}\n(symmetry_axis=2, viewed from +z)") + +fig3.suptitle("Endpoint index ordering conventions", fontsize=12) +fig3.show() diff --git a/src/parallelproj/pet_scanners.py b/src/parallelproj/pet_scanners.py index 992a1471..a6dfed9d 100644 --- a/src/parallelproj/pet_scanners.py +++ b/src/parallelproj/pet_scanners.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import enum from collections.abc import Sequence from types import ModuleType from typing import TYPE_CHECKING @@ -15,6 +16,29 @@ from mpl_toolkits.mplot3d import Axes3D +class RingEndpointOrdering(enum.Enum): + """Direction in which endpoint indices increase around a ring. + + For any symmetry axis, when the ring is viewed from the **positive** + symmetry-axis direction using the natural right-handed camera frame + (right = next cyclic axis, up = next-next cyclic axis), the two + conventions are: + + ``CLOCKWISE`` + Index 0 is at the top (12 o'clock) and subsequent indices advance + clockwise. This is the default and matches the original behaviour. + + ``COUNTERCLOCKWISE`` + Index 0 is at the top (12 o'clock) and subsequent indices advance + counterclockwise. + """ + + CLOCKWISE = enum.auto() + """Indices increase clockwise (default).""" + COUNTERCLOCKWISE = enum.auto() + """Indices increase counterclockwise.""" + + class PETScannerModule(abc.ABC): """abstract base class for PET scanner module""" @@ -305,6 +329,7 @@ def __init__( ax1: int = 1, affine_transformation_matrix: Array | None = None, phis: None | Array = None, + ring_endpoint_ordering: RingEndpointOrdering = RingEndpointOrdering.CLOCKWISE, ) -> None: """ @@ -332,6 +357,9 @@ def __init__( phis : None | Array, optional angle of each side, by default None means that the sides are equally spaced around a circle + ring_endpoint_ordering : RingEndpointOrdering, optional + direction in which endpoint indices increase around the ring, by + default ``RingEndpointOrdering.CLOCKWISE``. """ self._radius = radius @@ -340,6 +368,7 @@ def __init__( self._ax0 = ax0 self._ax1 = ax1 self._lor_spacing = lor_spacing + self._ring_endpoint_ordering = ring_endpoint_ordering super().__init__( xp, dev, @@ -428,11 +457,20 @@ def phis(self) -> Array: """ return self._phis + @property + def ring_endpoint_ordering(self) -> RingEndpointOrdering: + """direction in which endpoint indices increase around the ring""" + return self._ring_endpoint_ordering + # abstract method from base class to be implemented def get_raw_lor_endpoints(self, inds: Array | None = None) -> Array: if inds is None: inds = self.lor_endpoint_numbers + if self._ring_endpoint_ordering is RingEndpointOrdering.COUNTERCLOCKWISE: + n_total = self._num_sides * self._num_lor_endpoints_per_side + inds = self.xp.astype(n_total - 1 - inds, self.xp.int32) + side = inds // self.num_lor_endpoints_per_side tmp = inds - side * self.num_lor_endpoints_per_side tmp = self.xp.astype(tmp, self.xp.float32) - ( @@ -649,6 +687,7 @@ def __init__( ring_positions: Array, symmetry_axis: int, phis: None | Array = None, + ring_endpoint_ordering: RingEndpointOrdering = RingEndpointOrdering.CLOCKWISE, ) -> None: """ Parameters @@ -672,6 +711,9 @@ def __init__( phis : None | Array, optional angle of each side, by default None means that the sides are equally spaced around a circle + ring_endpoint_ordering : RingEndpointOrdering, optional + direction in which endpoint indices increase around the ring, by + default ``RingEndpointOrdering.CLOCKWISE``. """ self._radius = radius @@ -680,6 +722,7 @@ def __init__( self._lor_spacing = lor_spacing self._symmetry_axis = symmetry_axis self._ring_positions = ring_positions + self._ring_endpoint_ordering = ring_endpoint_ordering if symmetry_axis == 0: self._ax0 = 2 @@ -711,6 +754,7 @@ def __init__( ax0=self._ax0, ax1=self._ax1, phis=phis, + ring_endpoint_ordering=ring_endpoint_ordering, ) ) @@ -751,6 +795,11 @@ def symmetry_axis(self) -> int: """The symmetry axis. Also called axial (or ring) direction.""" return self._symmetry_axis + @property + def ring_endpoint_ordering(self) -> RingEndpointOrdering: + """direction in which endpoint indices increase around the ring""" + return self._ring_endpoint_ordering + @property def all_lor_endpoints_ring_number(self) -> Array: """the ring (regular polygon) number of all LOR endpoints""" diff --git a/tests/test_pet_scanners.py b/tests/test_pet_scanners.py index a0a06ad2..e737d263 100644 --- a/tests/test_pet_scanners.py +++ b/tests/test_pet_scanners.py @@ -164,6 +164,50 @@ def test_regular_polygon_pet_scanner(xp: ModuleType, dev: str) -> None: assert xp.all(scanner2.modules[0].phis == phis) +def test_ring_endpoint_ordering(xp: ModuleType, dev: str) -> None: + """RingEndpointOrdering property, CCW branch, and reversal invariant.""" + import numpy as np + from parallelproj import to_numpy_array + + n_sides = 8 + n_per_side = 2 + + for symmetry_axis in [0, 1, 2]: + s_cw = pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=1.0, + num_sides=n_sides, + num_lor_endpoints_per_side=n_per_side, + lor_spacing=0.1, + ring_positions=xp.asarray([0.0], dtype=xp.float32, device=dev), + symmetry_axis=symmetry_axis, + ring_endpoint_ordering=pps.RingEndpointOrdering.CLOCKWISE, + ) + s_ccw = pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=1.0, + num_sides=n_sides, + num_lor_endpoints_per_side=n_per_side, + lor_spacing=0.1, + ring_positions=xp.asarray([0.0], dtype=xp.float32, device=dev), + symmetry_axis=symmetry_axis, + ring_endpoint_ordering=pps.RingEndpointOrdering.COUNTERCLOCKWISE, + ) + + # properties are readable and propagate to the module level + assert s_cw.ring_endpoint_ordering is pps.RingEndpointOrdering.CLOCKWISE + assert s_ccw.ring_endpoint_ordering is pps.RingEndpointOrdering.COUNTERCLOCKWISE + assert s_cw.modules[0].ring_endpoint_ordering is pps.RingEndpointOrdering.CLOCKWISE + assert s_ccw.modules[0].ring_endpoint_ordering is pps.RingEndpointOrdering.COUNTERCLOCKWISE + + # CCW is the mirror image: endpoint k in CCW == endpoint (N-1-k) in CW + cw_pts = to_numpy_array(s_cw.all_lor_endpoints) + ccw_pts = to_numpy_array(s_ccw.all_lor_endpoints) + assert np.allclose(cw_pts, ccw_pts[::-1], atol=1e-5) + + def test_regular_polygon_pet_scanner_invalid_symmetry_axis( xp: ModuleType, dev: str ) -> None: From ad48481f747528b4196620c0f9d43f3b4b9e64ad Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 22:42:54 +0200 Subject: [PATCH 13/17] wip counterclockwise orderering --- src/parallelproj/pet_scanners.py | 11 +++++++++-- tests/test_pet_scanners.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/parallelproj/pet_scanners.py b/src/parallelproj/pet_scanners.py index a6dfed9d..c54a2a7c 100644 --- a/src/parallelproj/pet_scanners.py +++ b/src/parallelproj/pet_scanners.py @@ -468,8 +468,15 @@ def get_raw_lor_endpoints(self, inds: Array | None = None) -> Array: inds = self.lor_endpoint_numbers if self._ring_endpoint_ordering is RingEndpointOrdering.COUNTERCLOCKWISE: - n_total = self._num_sides * self._num_lor_endpoints_per_side - inds = self.xp.astype(n_total - 1 - inds, self.xp.int32) + side = inds // self._num_lor_endpoints_per_side + within = inds - side * self._num_lor_endpoints_per_side + new_side = self.xp.astype( + (self._num_sides - side) % self._num_sides, self.xp.int32 + ) + new_within = self.xp.astype( + self._num_lor_endpoints_per_side - 1 - within, self.xp.int32 + ) + inds = new_side * self._num_lor_endpoints_per_side + new_within side = inds // self.num_lor_endpoints_per_side tmp = inds - side * self.num_lor_endpoints_per_side diff --git a/tests/test_pet_scanners.py b/tests/test_pet_scanners.py index e737d263..ab6fd39b 100644 --- a/tests/test_pet_scanners.py +++ b/tests/test_pet_scanners.py @@ -202,10 +202,19 @@ def test_ring_endpoint_ordering(xp: ModuleType, dev: str) -> None: assert s_cw.modules[0].ring_endpoint_ordering is pps.RingEndpointOrdering.CLOCKWISE assert s_ccw.modules[0].ring_endpoint_ordering is pps.RingEndpointOrdering.COUNTERCLOCKWISE - # CCW is the mirror image: endpoint k in CCW == endpoint (N-1-k) in CW + # CCW[i] maps to CW[j] where side order is reversed (keeping side 0 + # fixed) and within-side order is also reversed: + # j = ((n_sides - i//n_per_side) % n_sides) * n_per_side + # + (n_per_side - 1 - i % n_per_side) cw_pts = to_numpy_array(s_cw.all_lor_endpoints) ccw_pts = to_numpy_array(s_ccw.all_lor_endpoints) - assert np.allclose(cw_pts, ccw_pts[::-1], atol=1e-5) + n_total = n_sides * n_per_side + expected_indices = [ + ((n_sides - i // n_per_side) % n_sides) * n_per_side + + (n_per_side - 1 - i % n_per_side) + for i in range(n_total) + ] + assert np.allclose(ccw_pts, cw_pts[expected_indices], atol=1e-5) def test_regular_polygon_pet_scanner_invalid_symmetry_axis( From 787affac45d579e473c534762bb8ac7f4be5af72 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Wed, 27 May 2026 23:11:26 +0200 Subject: [PATCH 14/17] add phi0 in endpoint descriptor --- .../01_run_regular_polygon_pet_scanner.py | 69 +++++++++-------- src/parallelproj/pet_scanners.py | 26 ++++++- tests/test_pet_scanners.py | 74 +++++++++++++++++++ 3 files changed, 133 insertions(+), 36 deletions(-) diff --git a/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py b/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py index e9ee2325..61269f35 100644 --- a/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py +++ b/docs/examples/01_pet_geometry/01_run_regular_polygon_pet_scanner.py @@ -139,52 +139,51 @@ fig2.show() # %% -# Endpoint ordering: clockwise vs counterclockwise -# ------------------------------------------------ +# Endpoint ordering and phi0: all four combinations +# -------------------------------------------------- # # By default, endpoint indices increase **clockwise** when the ring is viewed # from the positive symmetry-axis direction. # :class:`.RingEndpointOrdering` lets you switch to **counterclockwise** ordering. -# The two conventions produce the same physical detector positions, but with -# reversed index assignment — endpoint ``k`` in the CCW scanner occupies the -# same position as endpoint ``N-1-k`` in the CW scanner. +# The ``phi0`` parameter rotates the starting angle of side 0 (in radians); it +# is ignored when ``phis`` is supplied explicitly. +# +# The 2×2 grid below shows all combinations of CW/CCW ordering with +# ``phi0=0`` and ``phi0=π/8`` (half a polygon step for an 8-sided scanner). -cw_scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( - xp, - dev, - radius=65.0, - num_sides=8, - num_lor_endpoints_per_side=2, - lor_spacing=20.0, - ring_positions=xp.asarray([0.0], device=dev), - symmetry_axis=2, - ring_endpoint_ordering=parallelproj.pet_scanners.RingEndpointOrdering.CLOCKWISE, -) +import math -ccw_scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( - xp, - dev, - radius=65.0, - num_sides=8, - num_lor_endpoints_per_side=2, - lor_spacing=20.0, - ring_positions=xp.asarray([0.0], device=dev), - symmetry_axis=2, - ring_endpoint_ordering=parallelproj.pet_scanners.RingEndpointOrdering.COUNTERCLOCKWISE, -) +_RO = parallelproj.pet_scanners.RingEndpointOrdering + +configs = [ + (_RO.CLOCKWISE, 0.0, "CW, phi0=0"), + (_RO.COUNTERCLOCKWISE, 0.0, "CCW, phi0=0"), + (_RO.CLOCKWISE, math.pi / 3, "CW, phi0=π/3"), + (_RO.COUNTERCLOCKWISE, math.pi / 3, "CCW, phi0=π/3"), +] fig3, axes = plt.subplots( - 1, 2, figsize=(10, 5), subplot_kw={"projection": "3d"}, layout="constrained" + 2, 2, figsize=(10, 10), subplot_kw={"projection": "3d"}, layout="constrained" ) -for ax, scanner, title in zip( - axes, - [cw_scanner, ccw_scanner], - ["CLOCKWISE (default)", "COUNTERCLOCKWISE"], -): +for ax, (ordering, phi0, title) in zip(axes.flat, configs): + scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=65.0, + num_sides=8, + num_lor_endpoints_per_side=2, + lor_spacing=20.0, + ring_positions=xp.asarray([0.0], device=dev), + symmetry_axis=2, + ring_endpoint_ordering=ordering, + phi0=phi0, + ) scanner.show_lor_endpoints(ax, show_linear_index=True, annotation_fontsize=10) ax.view_init(elev=90, azim=-90) - ax.set_title(f"ring_endpoint_ordering = {title}\n(symmetry_axis=2, viewed from +z)") + ax.set_title(title, fontsize="medium") -fig3.suptitle("Endpoint index ordering conventions", fontsize=12) +fig3.suptitle( + "Endpoint ordering × phi0 (symmetry_axis=2, viewed from +z)", fontsize=12 +) fig3.show() diff --git a/src/parallelproj/pet_scanners.py b/src/parallelproj/pet_scanners.py index c54a2a7c..8d98ee07 100644 --- a/src/parallelproj/pet_scanners.py +++ b/src/parallelproj/pet_scanners.py @@ -330,6 +330,7 @@ def __init__( affine_transformation_matrix: Array | None = None, phis: None | Array = None, ring_endpoint_ordering: RingEndpointOrdering = RingEndpointOrdering.CLOCKWISE, + phi0: float = 0.0, ) -> None: """ @@ -360,6 +361,10 @@ def __init__( ring_endpoint_ordering : RingEndpointOrdering, optional direction in which endpoint indices increase around the ring, by default ``RingEndpointOrdering.CLOCKWISE``. + phi0 : float, optional + azimuthal offset of side 0 in radians, by default 0. + Only applied when ``phis`` is ``None``; ignored when ``phis`` is + provided explicitly. """ self._radius = radius @@ -369,6 +374,7 @@ def __init__( self._ax1 = ax1 self._lor_spacing = lor_spacing self._ring_endpoint_ordering = ring_endpoint_ordering + self._phi0 = phi0 super().__init__( xp, dev, @@ -379,7 +385,8 @@ def __init__( # angle of each "side" if phis is None: self._phis = ( - 2 + phi0 + + 2 * self.xp.pi * self.xp.arange(self._num_sides, dtype=xp.float32, device=dev) / self.num_sides @@ -462,6 +469,11 @@ def ring_endpoint_ordering(self) -> RingEndpointOrdering: """direction in which endpoint indices increase around the ring""" return self._ring_endpoint_ordering + @property + def phi0(self) -> float: + """azimuthal offset of side 0 in radians (only applied when phis=None)""" + return self._phi0 + # abstract method from base class to be implemented def get_raw_lor_endpoints(self, inds: Array | None = None) -> Array: if inds is None: @@ -695,6 +707,7 @@ def __init__( symmetry_axis: int, phis: None | Array = None, ring_endpoint_ordering: RingEndpointOrdering = RingEndpointOrdering.CLOCKWISE, + phi0: float = 0.0, ) -> None: """ Parameters @@ -721,6 +734,10 @@ def __init__( ring_endpoint_ordering : RingEndpointOrdering, optional direction in which endpoint indices increase around the ring, by default ``RingEndpointOrdering.CLOCKWISE``. + phi0 : float, optional + azimuthal offset of side 0 in radians, by default 0. + Only applied when ``phis`` is ``None``; ignored when ``phis`` is + provided explicitly. """ self._radius = radius @@ -730,6 +747,7 @@ def __init__( self._symmetry_axis = symmetry_axis self._ring_positions = ring_positions self._ring_endpoint_ordering = ring_endpoint_ordering + self._phi0 = phi0 if symmetry_axis == 0: self._ax0 = 2 @@ -762,6 +780,7 @@ def __init__( ax1=self._ax1, phis=phis, ring_endpoint_ordering=ring_endpoint_ordering, + phi0=phi0, ) ) @@ -807,6 +826,11 @@ def ring_endpoint_ordering(self) -> RingEndpointOrdering: """direction in which endpoint indices increase around the ring""" return self._ring_endpoint_ordering + @property + def phi0(self) -> float: + """azimuthal offset of side 0 in radians (only applied when phis=None)""" + return self._phi0 + @property def all_lor_endpoints_ring_number(self) -> Array: """the ring (regular polygon) number of all LOR endpoints""" diff --git a/tests/test_pet_scanners.py b/tests/test_pet_scanners.py index ab6fd39b..e9cf230d 100644 --- a/tests/test_pet_scanners.py +++ b/tests/test_pet_scanners.py @@ -217,6 +217,80 @@ def test_ring_endpoint_ordering(xp: ModuleType, dev: str) -> None: assert np.allclose(ccw_pts, cw_pts[expected_indices], atol=1e-5) +def test_phi0(xp: ModuleType, dev: str) -> None: + """phi0 rotates side 0; ignored when phis is supplied explicitly.""" + import math + import numpy as np + from parallelproj import to_numpy_array + + n_sides = 8 + n_per_side = 2 + phi0_val = math.pi / n_sides # half a polygon step + + # --- scanner-level property propagates to module level --- + s = pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=1.0, + num_sides=n_sides, + num_lor_endpoints_per_side=n_per_side, + lor_spacing=0.1, + ring_positions=xp.asarray([0.0], dtype=xp.float32, device=dev), + symmetry_axis=2, + phi0=phi0_val, + ) + assert s.phi0 == phi0_val + assert s.modules[0].phi0 == phi0_val + + # --- phis are offset by phi0 relative to phi0=0 case --- + s0 = pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=1.0, + num_sides=n_sides, + num_lor_endpoints_per_side=n_per_side, + lor_spacing=0.1, + ring_positions=xp.asarray([0.0], dtype=xp.float32, device=dev), + symmetry_axis=2, + phi0=0.0, + ) + phis_with = to_numpy_array(s.modules[0].phis) + phis_base = to_numpy_array(s0.modules[0].phis) + assert np.allclose(phis_with, phis_base + phi0_val, atol=1e-5) + + # --- endpoint coordinates are consistent with the shifted phis --- + # phi0 rotates the ring in the (ax0, ax1) plane. + # For symmetry_axis=2: ax0=col1, ax1=col0. + # A shift phi → phi+phi0 acts as a rotation in (ax0, ax1) by phi0: + # new[:, ax0] = cos(phi0)*old[:, ax0] - sin(phi0)*old[:, ax1] + # new[:, ax1] = sin(phi0)*old[:, ax0] + cos(phi0)*old[:, ax1] + ax0, ax1 = 1, 0 # for symmetry_axis=2 + pts_new = to_numpy_array(s.all_lor_endpoints) + pts_old = to_numpy_array(s0.all_lor_endpoints) + cos_a, sin_a = math.cos(phi0_val), math.sin(phi0_val) + pts_expected = pts_old.copy() + pts_expected[:, ax0] = cos_a * pts_old[:, ax0] - sin_a * pts_old[:, ax1] + pts_expected[:, ax1] = sin_a * pts_old[:, ax0] + cos_a * pts_old[:, ax1] + assert np.allclose(pts_new, pts_expected, atol=1e-5) + + # --- phi0 is ignored when phis are supplied explicitly --- + phis_explicit = xp.asarray([0.0, math.pi / 2], dtype=xp.float32, device=dev) + s_explicit = pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=1.0, + num_sides=2, + num_lor_endpoints_per_side=n_per_side, + lor_spacing=0.1, + ring_positions=xp.asarray([0.0], dtype=xp.float32, device=dev), + symmetry_axis=2, + phis=phis_explicit, + phi0=99.0, # should be ignored + ) + # phis on the module should be the supplied array, unaffected by phi0 + assert xp.all(s_explicit.modules[0].phis == phis_explicit) + + def test_regular_polygon_pet_scanner_invalid_symmetry_axis( xp: ModuleType, dev: str ) -> None: From 600f4d3192e763207cb5f8cc9e5e44e8859baf5d Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Thu, 28 May 2026 14:31:53 +0200 Subject: [PATCH 15/17] wip unlister --- .../05_run_unlister.py | 272 +++++++++ src/parallelproj/__init__.py | 2 + src/parallelproj/projectors.py | 188 ++++--- src/parallelproj/unlist.py | 228 ++++++++ tests/test_unlist.py | 523 ++++++++++++++++++ 5 files changed, 1134 insertions(+), 79 deletions(-) create mode 100644 docs/examples/02_pet_sinogram_projections/05_run_unlister.py create mode 100644 src/parallelproj/unlist.py create mode 100644 tests/test_unlist.py diff --git a/docs/examples/02_pet_sinogram_projections/05_run_unlister.py b/docs/examples/02_pet_sinogram_projections/05_run_unlister.py new file mode 100644 index 00000000..070e51e9 --- /dev/null +++ b/docs/examples/02_pet_sinogram_projections/05_run_unlister.py @@ -0,0 +1,272 @@ +""" +Listmode to sinogram histogramming (unlister) +============================================= + +This example demonstrates :func:`.regular_polygon_events_to_sinogram`, +which histograms listmode crystal-pair events into a sinogram. + +We build a full simulation pipeline: + +1. **Forward-project** an ``elliptic_cylinder_phantom`` with a realistic PET + scanner into a span-1 sinogram. +2. **Add Poisson noise** to obtain an integer sinogram ``y_span1``. +3. **Convert** ``y_span1`` to crystal-index listmode events + ``(d1, r1, d2, r2)`` — one event row per detected photon-pair — using + :meth:`.RegularPolygonPETProjector.convert_sinogram_to_crystal_index_events`. +4. **Unlist into span-1** and verify that the recovered sinogram is + *identical* to ``y_span1`` (exact integer round-trip). +5. **Unlist directly into span-3** and verify that the result equals + :class:`.SinogramAxialCompressionOperator` applied to ``y_span1`` — no + re-binning step is needed; the unlister uses the span-3 ring-pair table + directly. +6. Repeat steps 3–4 for a **TOF sinogram** to confirm that TOF bins are + round-tripped correctly as well. + +The span-3 equality demonstrates a useful property: instead of first +histogramming into span-1 and then compressing, you can bin straight into +any odd-span sinogram in a single pass over the event list. +""" + +# %% +import numpy as np +import matplotlib.pyplot as plt + +import parallelproj.pet_scanners +import parallelproj.pet_lors as ppl +import parallelproj.projectors +import parallelproj.tof +from parallelproj import to_numpy_array +from parallelproj.unlist import regular_polygon_events_to_sinogram +from img import elliptic_cylinder_phantom + +# %% +from array_utils import suggest_array_backend_and_device + +xp, dev = suggest_array_backend_and_device(None, None) + +# %% +# Scanner and LOR descriptor +# -------------------------- +# +# We use the same scanner geometry as the listmode reconstruction example: +# 16 sides × 12 crystals per side = 192 crystals per ring, 5 rings. + +num_rings = 5 +scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=65.0, + num_sides=16, + num_lor_endpoints_per_side=12, + lor_spacing=2.3, + ring_positions=xp.linspace(-10, 10, num_rings, device=dev), + symmetry_axis=2, +) + +img_shape = (40, 40, 8) +voxel_size = (2.0, 2.0, 2.0) + +lor_desc = ppl.RegularPolygonPETLORDescriptor( + scanner, + ppl.Michelogram(scanner.num_rings, max_ring_difference=2, span=1), + radial_trim=10, + sinogram_order=ppl.SinogramSpatialAxisOrder.RVP, +) + +proj = parallelproj.projectors.RegularPolygonPETProjector( + lor_desc, img_shape=img_shape, voxel_size=voxel_size +) + +print(f"Scanner : {scanner.num_lor_endpoints_per_ring} crystals/ring × {scanner.num_rings} rings") +print(f"Sinogram : {lor_desc.spatial_sinogram_shape} (num_rad × num_views × num_planes)") + +# %% +# Simulate non-TOF PET data +# ------------------------- +# +# Forward-project the phantom and add Poisson noise. + +x_true = elliptic_cylinder_phantom(xp, dev) +noise_free = proj(x_true) + +np.random.seed(0) +y_span1 = xp.asarray( + np.random.poisson(to_numpy_array(noise_free)).astype(np.int32), + device=dev, + dtype=xp.int32, +) + +total_counts = int(xp.sum(y_span1)) +print(f"\nNon-TOF span-1 sinogram : shape={tuple(y_span1.shape)}, total counts={total_counts}") + +# %% +# Non-TOF round-trip (span-1) +# --------------------------- +# +# Convert the integer sinogram to crystal-index events and unlist back into +# span-1. With ``radial_trim=10`` the scanner has no self-pair bins, so +# every count is round-tripped exactly. + +events = proj.convert_sinogram_to_crystal_index_events(y_span1, shuffle=True) +print(f"\nNumber of events : {len(events)}") + +sino_unlisted = regular_polygon_events_to_sinogram(lor_desc, events) +y_span1_np = to_numpy_array(y_span1).astype(np.float32) + +print(f"Span-1 round-trip exact match : {np.array_equal(sino_unlisted, y_span1_np)}") +print(f"Max absolute difference : {float(np.max(np.abs(sino_unlisted - y_span1_np))):.0f}") + +# %% +# Span-3 comparison +# ----------------- +# +# :class:`.SinogramAxialCompressionOperator` sums span-1 planes that share +# the same span-3 segment and axial midpoint. Unlisting the *same events* +# directly into the span-3 descriptor must produce the identical result, +# because the span-3 ring-pair lookup table collapses those ring pairs in +# exactly the same way. + +op_compress = ppl.SinogramAxialCompressionOperator(lor_desc, target_span=3) +span3_desc = op_compress.out_lor_descriptor + +y_span3_np = to_numpy_array(op_compress(xp.astype(y_span1, xp.float32))) +sino_span3_unlisted = regular_polygon_events_to_sinogram(span3_desc, events) + +print(f"\nSpan-3 sinogram shape (operator) : {y_span3_np.shape}") +print(f"Span-3 sinogram shape (unlisted) : {sino_span3_unlisted.shape}") +print(f"Span-3 exact match : {np.array_equal(sino_span3_unlisted, y_span3_np)}") +print(f"Max absolute difference : {float(np.max(np.abs(sino_span3_unlisted - y_span3_np))):.0f}") + +# %% +# TOF simulation +# -------------- +# +# Enable TOF on the projector and re-simulate. + +proj.tof_parameters = parallelproj.tof.TOFParameters( + num_tofbins=13, tofbin_width=12.0, sigma_tof=12.0 +) +num_tof_bins = proj.tof_parameters.num_tofbins + +noise_free_tof = proj(x_true) + +np.random.seed(1) +y_tof = xp.asarray( + np.random.poisson(to_numpy_array(noise_free_tof)).astype(np.int32), + device=dev, + dtype=xp.int32, +) + +print(f"\nTOF sinogram shape : {tuple(y_tof.shape)}, total counts={int(xp.sum(y_tof))}") + +# %% +# TOF round-trip (span-1) +# ----------------------- +# +# ``convert_sinogram_to_crystal_index_events`` detects the trailing TOF +# dimension automatically (4-D input) and returns ``(d1, r1, d2, r2, tof_bin)`` +# rows, where bin 0 is the bin closest to d1 (the xstart crystal). + +events_tof = proj.convert_sinogram_to_crystal_index_events(y_tof, shuffle=True) +print(f"\nNumber of TOF events : {len(events_tof)}") + +sino_tof_unlisted = regular_polygon_events_to_sinogram( + lor_desc, events_tof, num_tof_bins=num_tof_bins +) +y_tof_np = to_numpy_array(y_tof).astype(np.float32) + +print(f"TOF round-trip exact match : {np.array_equal(sino_tof_unlisted, y_tof_np)}") +print(f"Max absolute difference : {float(np.max(np.abs(sino_tof_unlisted - y_tof_np))):.0f}") + +# %% +# Visualisation +# ------------- +# +# Non-TOF: ground-truth span-1 sinogram, unlisted span-1 sinogram, difference +# and the span-3 comparison. + +v_ax = lor_desc.view_axis_num + +fig, axes = plt.subplots(2, 3, figsize=(13, 8)) + +# --- row 0: span-1 round-trip --- +vmax1 = float(np.max(y_span1_np)) + +ax = axes[0, 0] +ax.imshow(y_span1_np.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax1) +ax.set_title("y_span1 (ground truth,\nsummed over views)") +ax.set_xlabel("planes") +ax.set_ylabel("radial") + +ax = axes[0, 1] +ax.imshow(sino_unlisted.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax1) +ax.set_title("Unlisted span-1\n(summed over views)") +ax.set_xlabel("planes") + +ax = axes[0, 2] +diff1 = sino_unlisted - y_span1_np +im = ax.imshow(diff1.sum(axis=v_ax), aspect="auto", cmap="bwr", vmin=-1, vmax=1) +ax.set_title("Difference\n(must be all zeros)") +ax.set_xlabel("planes") +fig.colorbar(im, ax=ax) + +# --- row 1: span-3 comparison --- +v_ax3 = span3_desc.view_axis_num +vmax3 = float(np.max(y_span3_np)) + +ax = axes[1, 0] +ax.imshow(y_span3_np.sum(axis=v_ax3), aspect="auto", vmin=0, vmax=vmax3) +ax.set_title("y_span3 via\nSinogramAxialCompressionOperator") +ax.set_xlabel("planes") +ax.set_ylabel("radial") + +ax = axes[1, 1] +ax.imshow(sino_span3_unlisted.sum(axis=v_ax3), aspect="auto", vmin=0, vmax=vmax3) +ax.set_title("Unlisted directly\ninto span-3") +ax.set_xlabel("planes") + +ax = axes[1, 2] +diff3 = sino_span3_unlisted - y_span3_np +im3 = ax.imshow(diff3.sum(axis=v_ax3), aspect="auto", cmap="bwr", vmin=-1, vmax=1) +ax.set_title("Span-3 difference\n(must be all zeros)") +ax.set_xlabel("planes") +fig.colorbar(im3, ax=ax) + +fig.suptitle("Non-TOF sinogram round-trips (radial × planes, summed over views)") +fig.tight_layout() +fig.show() + +# %% +# TOF comparison: sinogram summed over TOF bins. + +fig2, axes2 = plt.subplots(1, 3, figsize=(13, 4)) + +y_tof_spatial = y_tof_np.sum(axis=-1) +sino_tof_spatial = sino_tof_unlisted.sum(axis=-1) +vmax_tof = float(np.max(y_tof_spatial)) + +ax = axes2[0] +ax.imshow(y_tof_spatial.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax_tof) +ax.set_title("y_tof (ground truth,\nTOF-summed, view-summed)") +ax.set_xlabel("planes") +ax.set_ylabel("radial") + +ax = axes2[1] +ax.imshow(sino_tof_spatial.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax_tof) +ax.set_title("Unlisted TOF span-1\n(TOF-summed, view-summed)") +ax.set_xlabel("planes") + +ax = axes2[2] +diff_tof = sino_tof_unlisted - y_tof_np +im_tof = ax.imshow( + diff_tof.sum(axis=(-1, v_ax)), aspect="auto", cmap="bwr", vmin=-1, vmax=1 +) +ax.set_title("TOF difference\n(must be all zeros)") +ax.set_xlabel("planes") +fig2.colorbar(im_tof, ax=ax) + +fig2.suptitle( + "TOF sinogram round-trip (radial × planes, summed over views and TOF bins)" +) +fig2.tight_layout() +fig2.show() diff --git a/src/parallelproj/__init__.py b/src/parallelproj/__init__.py index 533f9bac..fdf2a532 100644 --- a/src/parallelproj/__init__.py +++ b/src/parallelproj/__init__.py @@ -7,6 +7,7 @@ from ._backend import Array, empty_cuda_cache, to_numpy_array, count_event_multiplicity +from .unlist import regular_polygon_events_to_sinogram __all__ = [ "__version__", @@ -14,4 +15,5 @@ "empty_cuda_cache", "to_numpy_array", "count_event_multiplicity", + "regular_polygon_events_to_sinogram", ] diff --git a/src/parallelproj/projectors.py b/src/parallelproj/projectors.py index 82243c24..d44f24a5 100644 --- a/src/parallelproj/projectors.py +++ b/src/parallelproj/projectors.py @@ -11,7 +11,7 @@ from matplotlib.figure import Figure from matplotlib.patches import Rectangle from mpl_toolkits.mplot3d import Axes3D -from array_api_compat import device, get_namespace, size +from array_api_compat import device, get_namespace import parallelproj_core @@ -842,28 +842,47 @@ def show_geometry( self.lor_descriptor.scanner.show_lor_endpoints(ax) - def convert_sinogram_to_listmode( + def convert_sinogram_to_crystal_index_events( self, sinogram: Array, shuffle: bool = False - ) -> tuple[Array, Array, Array | None]: - """convert a non-TOF or TOF emission sinogram to listmode events + ) -> np.ndarray: + """Convert a non-TOF or TOF span-1 sinogram to crystal-index events. + + Each count in the sinogram becomes one row in the output array. + Non-TOF rows are ``(d1, r1, d2, r2)``; TOF rows add a trailing + ``tof_bin`` column, where bin 0 is the bin closest to ``d1`` + (the xstart crystal). The output is ready for direct use with + :func:`.regular_polygon_events_to_sinogram`. Parameters ---------- sinogram : Array - an integer (TOF or non-TOF) emission sinogram + Integer span-1 sinogram. + Non-TOF shape: ``lor_descriptor.spatial_sinogram_shape``. + TOF shape: ``(*lor_descriptor.spatial_sinogram_shape, num_tof_bins)``. shuffle : bool, optional - if True, randomly shuffle the order of the output events, - by default False. Shuffling is implemented via - ``numpy.random.permutation(num_events)``, which draws from - numpy's global random state. Use ``numpy.random.seed()`` - before calling this method for reproducible results. + Randomly shuffle the output rows (default ``False``). + Uses numpy's global random state; call ``numpy.random.seed`` + before this method for reproducible results. Returns ------- - tuple[Array, Array, Array | None] - event_start_coordinates, event_end_coordinates, event_tofbin - in case of non-TOF, event_tofbin is None + events : np.ndarray, shape (N, 4) or (N, 5), dtype int32 + Crystal-index events. Columns are + ``(d1, r1, d2, r2)`` or ``(d1, r1, d2, r2, tof_bin)``. + + Raises + ------ + TypeError + If ``sinogram`` does not have an integer dtype. + ValueError + If the LOR descriptor has ``span > 1``; :attr:`start_plane_index` + is only defined for span-1 descriptors. """ + lor_desc = self.lor_descriptor + if lor_desc.michelogram.span != 1: + raise ValueError( + "convert_sinogram_to_crystal_index_events requires a span-1 LOR descriptor" + ) integer_dtypes = ( self.xp.int8, @@ -880,81 +899,92 @@ def convert_sinogram_to_listmode( f"sinogram must have an integer dtype, got {sinogram.dtype}" ) - num_events = int(self.xp.sum(sinogram)) - - event_start_coords = self.xp.empty( - (num_events, 3), device=self._dev, dtype=self.xp.float32 - ) - event_end_coords = self.xp.empty( - (num_events, 3), device=self._dev, dtype=self.xp.float32 - ) - - if self.tof and self._tof_parameters is not None: - num_tofbins = self._tof_parameters.num_tofbins - event_tofbins = self.xp.empty( - (num_events,), device=self._dev, dtype=self.xp.int16 - ) + sc = to_numpy_array(lor_desc.start_in_ring_index) # (num_views, num_rad) + ec = to_numpy_array(lor_desc.end_in_ring_index) + sr = to_numpy_array(lor_desc.start_plane_index) # (num_planes,) + er = to_numpy_array(lor_desc.end_plane_index) + + p_ax = lor_desc.plane_axis_num + v_ax = lor_desc.view_axis_num + r_ax = lor_desc.radial_axis_num + + sino_np = to_numpy_array(sinogram).astype(np.int32) + tof_mode = sino_np.ndim == 4 + + valid_vr = sc != ec # self-pair bins are unphysical; skip them + + # Reorder axes to (view, radial, plane[, tof]) so that np.where + # yields events in the same view-first, radial-second, plane-third + # order as the original view-by-view loop in convert_sinogram_to_listmode. + if tof_mode: + sino_vrpt = np.transpose(sino_np, (v_ax, r_ax, p_ax, 3)) + counts = sino_vrpt * valid_vr[:, :, None, None] + v_idx, r_idx, p_idx, t_idx = np.where(counts > 0) + cnt = counts[v_idx, r_idx, p_idx, t_idx] + rv = np.repeat(v_idx, cnt) + rr = np.repeat(r_idx, cnt) + rp = np.repeat(p_idx, cnt) + rt = np.repeat(t_idx, cnt) + events = np.column_stack( + [sc[rv, rr], sr[rp], ec[rv, rr], er[rp], rt] + ).astype(np.int32) else: - num_tofbins = 1 - event_tofbins = None - - event_offset = 0 + sino_vrp = np.transpose(sino_np, (v_ax, r_ax, p_ax)) + counts = sino_vrp * valid_vr[:, :, None] + v_idx, r_idx, p_idx = np.where(counts > 0) + cnt = counts[v_idx, r_idx, p_idx] + rv = np.repeat(v_idx, cnt) + rr = np.repeat(r_idx, cnt) + rp = np.repeat(p_idx, cnt) + events = np.column_stack( + [sc[rv, rr], sr[rp], ec[rv, rr], er[rp]] + ).astype(np.int32) - # we convert view by view to save memory - for view in range(self.lor_descriptor.num_views): - xstart, xend = self.lor_descriptor.get_lor_coordinates( - views=self.xp.asarray([view], device=self._dev) - ) - xstart = self.xp.reshape(xstart, (-1, 3)) - xend = self.xp.reshape(xend, (-1, 3)) - - sino_view = self.xp.take( - sinogram, - self.xp.asarray([view], device=self._dev), - axis=self.lor_descriptor.view_axis_num, - ) - sino_view = self.xp.squeeze( - sino_view, axis=self.lor_descriptor.view_axis_num - ) + if shuffle: + perm = np.random.permutation(len(events)) + events = events[perm] - # flatten all dims (lor dims + optional tofbin last dim) to 1D - # for TOF: flat_idx = lor_idx * num_tofbins + tofbin_idx - # for non-TOF: num_tofbins == 1, flat_idx == lor_idx - ss = self.xp.reshape(sino_view, (size(sino_view),)) + return events - # array_api_strict does not support repeat with array-valued counts; - # we convert back and forth to numpy cpu array as a workaround - flat_counts = to_numpy_array(ss).astype(int) - event_flat_inds = np.repeat(np.arange(ss.shape[0]), flat_counts) - num_events_view = int(event_flat_inds.shape[0]) + def convert_sinogram_to_listmode( + self, sinogram: Array, shuffle: bool = False + ) -> tuple[Array, Array, Array | None]: + """convert a non-TOF or TOF emission sinogram to listmode events - event_lor_inds = self.xp.asarray( - event_flat_inds // num_tofbins, device=self._dev - ) + Parameters + ---------- + sinogram : Array + an integer (TOF or non-TOF) emission sinogram + shuffle : bool, optional + if True, randomly shuffle the order of the output events, + by default False. Shuffling is implemented via + ``numpy.random.permutation(num_events)``, which draws from + numpy's global random state. Use ``numpy.random.seed()`` + before calling this method for reproducible results. - event_start_coords[event_offset : (event_offset + num_events_view), :] = ( - self.xp.take(xstart, event_lor_inds, axis=0) - ) - event_end_coords[event_offset : (event_offset + num_events_view), :] = ( - self.xp.take(xend, event_lor_inds, axis=0) - ) + Returns + ------- + tuple[Array, Array, Array | None] + event_start_coordinates, event_end_coordinates, event_tofbin + in case of non-TOF, event_tofbin is None + """ + events = self.convert_sinogram_to_crystal_index_events(sinogram, shuffle=shuffle) - if event_tofbins is not None: - event_tofbins[event_offset : (event_offset + num_events_view)] = ( - self.xp.asarray( - (event_flat_inds % num_tofbins).astype(np.int16), - device=self._dev, - ) - ) + scanner = self.lor_descriptor.scanner + d1 = self.xp.asarray(events[:, 0].astype(np.int64), device=self._dev) + r1 = self.xp.asarray(events[:, 1].astype(np.int64), device=self._dev) + d2 = self.xp.asarray(events[:, 2].astype(np.int64), device=self._dev) + r2 = self.xp.asarray(events[:, 3].astype(np.int64), device=self._dev) - event_offset += num_events_view + event_start_coords = scanner.get_lor_endpoints(r1, d1) + event_end_coords = scanner.get_lor_endpoints(r2, d2) - if shuffle: - perm = self.xp.asarray(np.random.permutation(num_events), device=self._dev) - event_start_coords = self.xp.take(event_start_coords, perm, axis=0) - event_end_coords = self.xp.take(event_end_coords, perm, axis=0) - if event_tofbins is not None: - event_tofbins = self.xp.take(event_tofbins, perm, axis=0) + if events.shape[1] == 5: + event_tofbins = self.xp.asarray( + events[:, 4].astype(np.int16), device=self._dev + ) + else: + event_tofbins = None return event_start_coords, event_end_coords, event_tofbins diff --git a/src/parallelproj/unlist.py b/src/parallelproj/unlist.py new file mode 100644 index 00000000..2d2cab8d --- /dev/null +++ b/src/parallelproj/unlist.py @@ -0,0 +1,228 @@ +"""Listmode-to-sinogram histogrammer for RegularPolygonPETScannerGeometry.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +from ._backend import to_numpy_array +from .pet_lors import RegularPolygonPETLORDescriptor + +if TYPE_CHECKING: + pass + + +def _build_inring_luts( + lor_descriptor: RegularPolygonPETLORDescriptor, +) -> tuple[np.ndarray, np.ndarray]: + """Build in-ring lookup tables from a LOR descriptor. + + Parameters + ---------- + lor_descriptor : RegularPolygonPETLORDescriptor + + Returns + ------- + inring_lut : np.ndarray, shape (n, n), dtype int32 + ``inring_lut[d1, d2]`` is the flat ``view * num_rad + rad`` sinogram + index for the crystal pair, or ``-1`` if the pair is not a valid LOR. + inring_tof_sign : np.ndarray, shape (n, n), dtype int8 + ``inring_tof_sign[d1, d2]`` is ``+1`` if d1 is the canonical *start* + crystal (the xstart endpoint used by the sinogram convention) and + ``-1`` if d1 is the end crystal. Zero on the diagonal. + """ + n = lor_descriptor.scanner.num_lor_endpoints_per_ring + num_rad = lor_descriptor.num_rad + num_views = lor_descriptor.num_views + + inring_lut = np.full((n, n), -1, dtype=np.int32) + inring_tof_sign = np.zeros((n, n), dtype=np.int8) + + # start/end in-ring crystal indices: shape (num_views, num_rad) + ds = to_numpy_array(lor_descriptor.start_in_ring_index).ravel().astype(np.intp) + de = to_numpy_array(lor_descriptor.end_in_ring_index).ravel().astype(np.intp) + + v_idx = np.repeat(np.arange(num_views, dtype=np.intp), num_rad) + r_idx = np.tile(np.arange(num_rad, dtype=np.intp), num_views) + flat_vr = (v_idx * num_rad + r_idx).astype(np.int32) + + inring_lut[ds, de] = flat_vr + inring_lut[de, ds] = flat_vr + inring_tof_sign[ds, de] = 1 + inring_tof_sign[de, ds] = -1 + + # Self-pairs (start == end crystal) are mathematical artifacts of the + # zig-zag sinogram parameterisation; no physical coincidence can produce + # them. Invalidate explicitly so they are dropped by the validity check. + np.fill_diagonal(inring_lut, -1) + np.fill_diagonal(inring_tof_sign, 0) + + return inring_lut, inring_tof_sign + + +def regular_polygon_events_to_sinogram( + lor_descriptor: RegularPolygonPETLORDescriptor, + events: Any, + num_tof_bins: int | None = None, + tof_bin_sign: int = 1, +) -> np.ndarray: + """Histogram listmode events into a sinogram. + + Parameters + ---------- + lor_descriptor : RegularPolygonPETLORDescriptor + LOR descriptor defining the sinogram geometry. + events : array-like, shape (N, 4) or (N, 5) + Listmode events. Each row is ``(d1, r1, d2, r2)`` for non-TOF or + ``(d1, r1, d2, r2, tof_bin)`` for TOF, where + + - ``d1``, ``d2`` are in-ring crystal indices + (0 … num_lor_endpoints_per_ring - 1) + - ``r1``, ``r2`` are ring indices (0 … num_rings - 1) + - ``tof_bin`` is an **unsigned** TOF bin index with the convention + set by ``tof_bin_sign`` + + Events outside the sinogram FOV (invalid crystal pair, ring pair + beyond ``max_ring_difference``, out-of-range indices) are silently + discarded. + num_tof_bins : int or None + Number of TOF bins. Required when ``events`` has 5 columns; must + be ``None`` when ``events`` has 4 columns. + tof_bin_sign : {+1, -1}, optional + TOF bin direction convention in the input events: + + * ``+1`` (default): bin 0 is closest to detector d1. This matches + the parallelproj sinogram convention (bin 0 = closest to xstart) + when d1 is the start crystal. + * ``-1``: bin 0 is closest to detector d2. + + A flip is applied automatically so that the output sinogram always + uses the parallelproj convention (bin 0 = closest to xstart). + + Returns + ------- + sinogram : np.ndarray + Histogram sinogram. Shape is ``spatial_sinogram_shape`` for non-TOF + or ``(*spatial_sinogram_shape, num_tof_bins)`` for TOF. + Dtype is ``float32``. + """ + if tof_bin_sign not in (1, -1): + raise ValueError("tof_bin_sign must be +1 or -1") + + events_np = np.asarray(to_numpy_array(events), dtype=np.int32) + + if events_np.ndim != 2: + raise ValueError("events must be a 2D array") + + n_events, n_cols = events_np.shape + + if n_cols == 4: + if num_tof_bins is not None: + raise ValueError( + "events has 4 columns (non-TOF) but num_tof_bins was specified" + ) + tof_mode = False + elif n_cols == 5: + if num_tof_bins is None: + raise ValueError( + "events has 5 columns (TOF) but num_tof_bins was not specified" + ) + tof_mode = True + else: + raise ValueError( + f"events must have 4 (non-TOF) or 5 (TOF) columns, got {n_cols}" + ) + + shape_spatial = lor_descriptor.spatial_sinogram_shape + + if n_events == 0: + if tof_mode: + return np.zeros((*shape_spatial, num_tof_bins), dtype=np.float32) + return np.zeros(shape_spatial, dtype=np.float32) + + d1 = events_np[:, 0] + r1 = events_np[:, 1] + d2 = events_np[:, 2] + r2 = events_np[:, 3] + if tof_mode: + tof_raw = events_np[:, 4] + + inring_lut, inring_tof_sign_lut = _build_inring_luts(lor_descriptor) + ring_pair_table = lor_descriptor.michelogram.plane_for_ring_pair_table + + n_crystals = lor_descriptor.scanner.num_lor_endpoints_per_ring + num_rings = lor_descriptor.scanner.num_rings + num_rad = lor_descriptor.num_rad + + # primary validity: all indices within bounds + valid = ( + (d1 >= 0) + & (d1 < n_crystals) + & (d2 >= 0) + & (d2 < n_crystals) + & (r1 >= 0) + & (r1 < num_rings) + & (r2 >= 0) + & (r2 < num_rings) + ) + + # safe fallback indices for out-of-bounds events (they will be masked out) + d1s = np.where(valid, d1, 0) + d2s = np.where(valid, d2, 0) + + inring_flat = inring_lut[d1s, d2s] + tof_sign_vals = inring_tof_sign_lut[d1s, d2s] + + # the crystal pair must form a valid LOR + valid &= inring_flat >= 0 + + # canonical ring ordering: +1 means d1 is xstart, so r1 is the start ring. + # Use safe fallback (0) for out-of-range ring indices so the table lookup + # never receives negative or OOB values (numpy wraps negative indices). + r1_safe = np.where(valid, r1, 0) + r2_safe = np.where(valid, r2, 0) + is_d1_start = tof_sign_vals == 1 + r_start = np.where(is_d1_start, r1_safe, r2_safe) + r_end = np.where(is_d1_start, r2_safe, r1_safe) + + # plane lookup + plane_idx = ring_pair_table[r_start, r_end] + valid &= plane_idx >= 0 + + if tof_mode: + valid &= (tof_raw >= 0) & (tof_raw < num_tof_bins) + tof_raw_safe = np.where(valid, tof_raw, 0) + # flip the bin when the canonical direction does not match tof_bin_sign + flip = (tof_bin_sign * tof_sign_vals) == -1 + sinogram_tof = np.where(flip, num_tof_bins - 1 - tof_raw_safe, tof_raw_safe) + + # decompose the in-ring flat index into view and radial indices + safe_flat = np.where(valid, inring_flat, 0) + view_idx = safe_flat // num_rad + rad_idx = safe_flat % num_rad + + # compute the flat sinogram index respecting sinogram_order + p_ax = lor_descriptor.plane_axis_num + v_ax = lor_descriptor.view_axis_num + r_ax = lor_descriptor.radial_axis_num + strides = [int(np.prod(shape_spatial[i + 1 :])) for i in range(3)] + + safe_plane = np.where(valid, plane_idx, 0) + flat_sino = ( + rad_idx.astype(np.int64) * strides[r_ax] + + view_idx.astype(np.int64) * strides[v_ax] + + safe_plane.astype(np.int64) * strides[p_ax] + ) + + if tof_mode: + flat_sino = flat_sino * num_tof_bins + sinogram_tof.astype(np.int64) + total_bins = int(np.prod(shape_spatial)) * num_tof_bins + output_shape = (*shape_spatial, num_tof_bins) + else: + total_bins = int(np.prod(shape_spatial)) + output_shape = shape_spatial + + flat_sino_valid = flat_sino[valid] + sino_flat = np.bincount(flat_sino_valid, minlength=total_bins).astype(np.float32) + return sino_flat.reshape(output_shape) diff --git a/tests/test_unlist.py b/tests/test_unlist.py new file mode 100644 index 00000000..0ce14210 --- /dev/null +++ b/tests/test_unlist.py @@ -0,0 +1,523 @@ +"""Tests for parallelproj.unlist — listmode-to-sinogram histogrammer.""" +from __future__ import annotations + +import numpy as np +import pytest +from types import ModuleType + +import parallelproj.pet_scanners as pps +import parallelproj.pet_lors as ppl +from parallelproj import to_numpy_array +from parallelproj.unlist import _build_inring_luts, regular_polygon_events_to_sinogram + +from .config import pytestmark # noqa: F401 (pytest picks this up as a module marker) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _small_scanner(xp: ModuleType, dev: str, num_rings: int = 3): + """4 sides × 4 endpoints/side = 16 crystals/ring.""" + return pps.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=100.0, + num_sides=4, + num_lor_endpoints_per_side=4, + lor_spacing=5.0, + ring_positions=xp.linspace(-10.0, 10.0, num_rings, device=dev), + symmetry_axis=2, + ) + + +def _lor_desc( + scanner: pps.RegularPolygonPETScannerGeometry, + span: int = 1, + max_ring_difference: int | None = None, + sinogram_order: ppl.SinogramSpatialAxisOrder = ppl.SinogramSpatialAxisOrder.RVP, +) -> ppl.RegularPolygonPETLORDescriptor: + """Build a LOR descriptor for a scanner.""" + if max_ring_difference is None: + max_ring_difference = scanner.num_rings - 1 + return ppl.RegularPolygonPETLORDescriptor( + scanner, + ppl.Michelogram( + scanner.num_rings, + max_ring_difference=max_ring_difference, + span=span, + ), + radial_trim=1, + sinogram_order=sinogram_order, + ) + + +def _first_valid_vr(sc: np.ndarray, ec: np.ndarray) -> tuple[int, int]: + """Return (v, r) of the first non-self-pair position (sc[v,r] != ec[v,r]).""" + idxs = np.argwhere(sc != ec) + return int(idxs[0, 0]), int(idxs[0, 1]) + + +# --------------------------------------------------------------------------- +# LUT structure tests +# --------------------------------------------------------------------------- + +def test_inring_lut_symmetry(xp: ModuleType, dev: str) -> None: + """inring_lut is symmetric; inring_tof_sign is antisymmetric off-diagonal; + diagonal is invalidated after self-pair fill.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + lut, sign = _build_inring_luts(desc) + assert np.array_equal(lut, lut.T), "inring_lut must be symmetric" + assert np.all(lut.diagonal() == -1), "self-pair LUT entries must be -1" + assert np.all(sign.diagonal() == 0), "self-pair sign entries must be 0" + n = scanner.num_lor_endpoints_per_ring + off = ~np.eye(n, dtype=bool) + assert np.array_equal(sign[off], -sign.T[off]), "sign must be antisymmetric off-diagonal" + + +def test_inring_lut_valid_pairs_covered(xp: ModuleType, dev: str) -> None: + """Every non-self-pair (view, rad) maps to the expected flat LUT index.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + lut, _ = _build_inring_luts(desc) + + ds = to_numpy_array(desc.start_in_ring_index).ravel().astype(int) + de = to_numpy_array(desc.end_in_ring_index).ravel().astype(int) + for k, (d1, d2) in enumerate(zip(ds, de)): + if d1 == d2: + assert lut[d1, d2] == -1, f"self-pair at k={k} must be invalidated" + else: + assert lut[d1, d2] == k, f"LUT entry mismatch at k={k}" + assert lut[d2, d1] == k, f"LUT reverse entry mismatch at k={k}" + + +# --------------------------------------------------------------------------- +# Non-TOF round-trip +# --------------------------------------------------------------------------- + +def test_non_tof_round_trip(xp: ModuleType, dev: str) -> None: + """Non-self-pair sinogram bins each hit exactly once; self-pair bins stay 0.""" + scanner = _small_scanner(xp, dev, num_rings=3) + desc = _lor_desc(scanner, span=1) + + sc = to_numpy_array(desc.start_in_ring_index) # (num_views, num_rad) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) # (num_planes,) + er = to_numpy_array(desc.end_plane_index) + + # only include events for non-self-pair (v, r) positions + valid_vr = sc != ec # (num_views, num_rad) bool mask + rows = [ + [int(sc[v, r]), int(sr[p]), int(ec[v, r]), int(er[p])] + for p in range(desc.num_planes) + for v in range(desc.num_views) + for r in range(desc.num_rad) + if valid_vr[v, r] + ] + events = xp.asarray(rows, dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events) + + assert sino.shape == desc.spatial_sinogram_shape + assert np.max(sino) == 1.0, "no bin should accumulate more than one count" + assert float(sino.sum()) == float(len(rows)) + + +def test_non_tof_accumulation(xp: ModuleType, dev: str) -> None: + """Duplicate events accumulate correctly.""" + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + row = [int(sc[v0, r0]), int(sr[0]), int(ec[v0, r0]), int(er[0])] + events = xp.asarray([row, row, row], dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events) + assert sino.sum() == 3.0 + + +# --------------------------------------------------------------------------- +# Crystal-swap invariance (non-TOF) +# --------------------------------------------------------------------------- + +def test_crystal_swap_non_tof(xp: ModuleType, dev: str) -> None: + """(d1,r1,d2,r2) and (d2,r2,d1,r1) must produce the same sinogram.""" + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + d1, d2 = int(sc[v0, r0]), int(ec[v0, r0]) + r1, r2 = int(sr[0]), int(er[0]) + + ev_fwd = xp.asarray([[d1, r1, d2, r2]], dtype=xp.int32, device=dev) + ev_bwd = xp.asarray([[d2, r2, d1, r1]], dtype=xp.int32, device=dev) + + assert np.allclose( + regular_polygon_events_to_sinogram(desc, ev_fwd), + regular_polygon_events_to_sinogram(desc, ev_bwd), + ), "Crystal swap must not change the sinogram bin" + + +# --------------------------------------------------------------------------- +# TOF round-trip +# --------------------------------------------------------------------------- + +def test_tof_single_event_lands_in_correct_bin(xp: ModuleType, dev: str) -> None: + """A single TOF event with d1=xstart lands in the expected TOF bin.""" + num_tof_bins = 7 + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + d1, d2 = int(sc[v0, r0]), int(ec[v0, r0]) # d1 is xstart → no flip + r1, r2 = int(sr[0]), int(er[0]) + tof_bin = 2 + + events = xp.asarray([[d1, r1, d2, r2, tof_bin]], dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events, num_tof_bins=num_tof_bins) + + assert sino.shape == (*desc.spatial_sinogram_shape, num_tof_bins) + assert sino.sum() == 1.0 + + nz = np.argwhere(sino > 0) + assert len(nz) == 1 + # TOF bin axis is trailing — no flip because d1 is xstart and tof_bin_sign=+1 + assert nz[0, -1] == tof_bin, "TOF bin must be stored unflipped when d1 is xstart" + + +def test_tof_round_trip(xp: ModuleType, dev: str) -> None: + """Non-self-pair TOF sinogram bins each hit exactly once.""" + num_tof_bins = 5 + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner, span=1) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + valid_vr = sc != ec + rows = [ + [int(sc[v, r]), int(sr[p]), int(ec[v, r]), int(er[p]), t] + for p in range(desc.num_planes) + for v in range(desc.num_views) + for r in range(desc.num_rad) + if valid_vr[v, r] + for t in range(num_tof_bins) + ] + events = xp.asarray(rows, dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events, num_tof_bins=num_tof_bins) + + assert sino.shape == (*desc.spatial_sinogram_shape, num_tof_bins) + assert np.max(sino) == 1.0 + assert float(sino.sum()) == float(len(rows)) + + +# --------------------------------------------------------------------------- +# TOF direction: crystal-swap with mirrored TOF bin +# --------------------------------------------------------------------------- + +def test_tof_crystal_swap_mirrored_bin(xp: ModuleType, dev: str) -> None: + """Swapping (d1,r1,d2,r2) and mirroring the TOF bin gives the same sinogram. + + Convention: bin 0 = closest to d1 (tof_bin_sign=+1, default). + + Event (d1_start, r1, d2_end, r2, bin=k) and + Event (d2_end, r2, d1_start, r1, bin=(n-1-k)) + must both land in the same sinogram bin at TOF index k. + """ + num_tof_bins = 7 + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + d1, d2 = int(sc[v0, r0]), int(ec[v0, r0]) # d1 = xstart + r1, r2 = int(sr[0]), int(er[0]) + tof_bin = 2 + tof_bin_mirror = num_tof_bins - 1 - tof_bin # = 4 + + ev_fwd = xp.asarray([[d1, r1, d2, r2, tof_bin]], dtype=xp.int32, device=dev) + ev_bwd = xp.asarray([[d2, r2, d1, r1, tof_bin_mirror]], dtype=xp.int32, device=dev) + + sino_fwd = regular_polygon_events_to_sinogram(desc, ev_fwd, num_tof_bins=num_tof_bins) + sino_bwd = regular_polygon_events_to_sinogram(desc, ev_bwd, num_tof_bins=num_tof_bins) + + assert np.allclose(sino_fwd, sino_bwd), ( + "Crystal swap with mirrored TOF bin must produce the same sinogram" + ) + # Both should land at TOF index tof_bin (not tof_bin_mirror) + nz_fwd = np.argwhere(sino_fwd > 0) + nz_bwd = np.argwhere(sino_bwd > 0) + assert nz_fwd[0, -1] == tof_bin + assert nz_bwd[0, -1] == tof_bin + + +# --------------------------------------------------------------------------- +# tof_bin_sign=-1 +# --------------------------------------------------------------------------- + +def test_tof_bin_sign_minus1(xp: ModuleType, dev: str) -> None: + """tof_bin_sign=-1 (bin 0 closest to d2): mirrored raw bin gives same sinogram. + + With sign=-1, physical offset k from d1 corresponds to raw bin (n-1-k) + because bin 0 counts from d2 instead of d1. + """ + num_tof_bins = 7 + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + d1, d2 = int(sc[v0, r0]), int(ec[v0, r0]) + r1, r2 = int(sr[0]), int(er[0]) + phys_bin = 2 # physical sinogram bin index + + # sign=+1: raw bin = phys_bin (bin 0 from d1) + ev_p1 = xp.asarray([[d1, r1, d2, r2, phys_bin]], dtype=xp.int32, device=dev) + # sign=-1: bin 0 is from d2, so phys_bin from d1 = (n-1-phys_bin) from d2 + raw_m1 = num_tof_bins - 1 - phys_bin + ev_m1 = xp.asarray([[d1, r1, d2, r2, raw_m1]], dtype=xp.int32, device=dev) + + sino_p1 = regular_polygon_events_to_sinogram( + desc, ev_p1, num_tof_bins=num_tof_bins, tof_bin_sign=1 + ) + sino_m1 = regular_polygon_events_to_sinogram( + desc, ev_m1, num_tof_bins=num_tof_bins, tof_bin_sign=-1 + ) + + assert np.allclose(sino_p1, sino_m1), ( + "tof_bin_sign=-1 with mirrored raw bin must give the same sinogram" + ) + + +# --------------------------------------------------------------------------- +# Out-of-FOV events silently dropped +# --------------------------------------------------------------------------- + +def test_out_of_fov_events_dropped(xp: ModuleType, dev: str) -> None: + """Events outside the FOV are silently discarded — sinogram stays zero.""" + scanner = _small_scanner(xp, dev, num_rings=3) + desc = _lor_desc(scanner, max_ring_difference=1) + + n_crystals = scanner.num_lor_endpoints_per_ring + num_rings = scanner.num_rings + + invalid_rows = [ + [n_crystals, 0, 0, 0], # d1 out of range + [-1, 0, 0, 0], # d1 negative + [0, 0, 1, num_rings], # r2 out of range + [0, 0, 1, -1], # r2 negative + [0, 0, 1, 2], # ring difference (=2) > max_ring_difference (=1) + [0, 0, 0, 0], # same crystal → invalidated by fill_diagonal + ] + events = xp.asarray(invalid_rows, dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events) + assert np.all(sino == 0), "All out-of-FOV events must be silently dropped" + + +def test_out_of_range_tof_bin_dropped(xp: ModuleType, dev: str) -> None: + """TOF events with out-of-range bins are dropped.""" + num_tof_bins = 5 + scanner = _small_scanner(xp, dev, num_rings=2) + desc = _lor_desc(scanner) + + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + d1, d2 = int(sc[v0, r0]), int(ec[v0, r0]) + r1, r2 = int(sr[0]), int(er[0]) + + events = xp.asarray( + [ + [d1, r1, d2, r2, num_tof_bins], # bin == num_tof_bins → out of range + [d1, r1, d2, r2, -1], # negative bin + ], + dtype=xp.int32, + device=dev, + ) + sino = regular_polygon_events_to_sinogram(desc, events, num_tof_bins=num_tof_bins) + assert np.all(sino == 0) + + +# --------------------------------------------------------------------------- +# Span > 1 +# --------------------------------------------------------------------------- + +def test_span3_round_trip(xp: ModuleType, dev: str) -> None: + """Span-3: total count equals number of contributing (ring-pair, view, rad) triples.""" + scanner = _small_scanner(xp, dev, num_rings=5) + desc = _lor_desc(scanner, span=3, max_ring_difference=3) + + m = desc.michelogram + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + valid_vr = sc != ec + + rows = [ + [int(sc[v, r]), int(m.plane_start_rings[p, k]), int(ec[v, r]), int(m.plane_end_rings[p, k])] + for p in range(desc.num_planes) + for k in range(int(m.plane_multiplicity[p])) + for v in range(desc.num_views) + for r in range(desc.num_rad) + if valid_vr[v, r] + ] + events = xp.asarray(rows, dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events) + + assert sino.shape == desc.spatial_sinogram_shape + assert float(sino.sum()) == float(len(rows)) + + +def test_span3_single_event_plane(xp: ModuleType, dev: str) -> None: + """A single span-3 event lands in the correct sinogram plane.""" + scanner = _small_scanner(xp, dev, num_rings=5) + desc = _lor_desc(scanner, span=3, max_ring_difference=3) + + m = desc.michelogram + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + + v0, r0 = _first_valid_vr(sc, ec) + target_plane = 3 + r1 = int(m.plane_start_rings[target_plane, 0]) + r2 = int(m.plane_end_rings[target_plane, 0]) + d1 = int(sc[v0, r0]) + d2 = int(ec[v0, r0]) + + events = xp.asarray([[d1, r1, d2, r2]], dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events) + + assert sino.sum() == 1.0 + nz = np.argwhere(sino > 0) + assert len(nz) == 1 + assert nz[0, desc.plane_axis_num] == target_plane + + +# --------------------------------------------------------------------------- +# Sinogram order variants +# --------------------------------------------------------------------------- + +def test_sinogram_order_variants(xp: ModuleType, dev: str) -> None: + """All sinogram axis orderings give the correct shape and total count.""" + scanner = _small_scanner(xp, dev, num_rings=2) + + for order in [ + ppl.SinogramSpatialAxisOrder.RVP, + ppl.SinogramSpatialAxisOrder.PVR, + ppl.SinogramSpatialAxisOrder.VRP, + ]: + desc = _lor_desc(scanner, sinogram_order=order) + sc = to_numpy_array(desc.start_in_ring_index) + ec = to_numpy_array(desc.end_in_ring_index) + sr = to_numpy_array(desc.start_plane_index) + er = to_numpy_array(desc.end_plane_index) + + v0, r0 = _first_valid_vr(sc, ec) + d1, d2 = int(sc[v0, r0]), int(ec[v0, r0]) + r1, r2 = int(sr[0]), int(er[0]) + + events = xp.asarray([[d1, r1, d2, r2]], dtype=xp.int32, device=dev) + sino = regular_polygon_events_to_sinogram(desc, events) + + assert sino.shape == desc.spatial_sinogram_shape, f"Wrong shape for {order}" + assert sino.sum() == 1.0, f"Wrong total count for {order}" + + +# --------------------------------------------------------------------------- +# Empty events +# --------------------------------------------------------------------------- + +def test_empty_events_non_tof(xp: ModuleType, dev: str) -> None: + """Empty non-TOF event array returns a zero sinogram of the right shape.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events_np = np.zeros((0, 4), dtype=np.int32) + sino = regular_polygon_events_to_sinogram(desc, events_np) + assert sino.shape == desc.spatial_sinogram_shape + assert np.all(sino == 0) + + +def test_empty_events_tof(xp: ModuleType, dev: str) -> None: + """Empty TOF event array returns a zero sinogram of the right shape.""" + num_tof_bins = 5 + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events_np = np.zeros((0, 5), dtype=np.int32) + sino = regular_polygon_events_to_sinogram(desc, events_np, num_tof_bins=num_tof_bins) + assert sino.shape == (*desc.spatial_sinogram_shape, num_tof_bins) + assert np.all(sino == 0) + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + +def test_error_bad_tof_bin_sign(xp: ModuleType, dev: str) -> None: + """tof_bin_sign values other than ±1 must raise ValueError.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events = xp.asarray([[0, 0, 1, 0, 0]], dtype=xp.int32, device=dev) + with pytest.raises(ValueError, match="tof_bin_sign"): + regular_polygon_events_to_sinogram(desc, events, num_tof_bins=5, tof_bin_sign=0) + + +def test_error_missing_num_tof_bins(xp: ModuleType, dev: str) -> None: + """5-column events without num_tof_bins must raise ValueError.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events = xp.asarray([[0, 0, 1, 0, 0]], dtype=xp.int32, device=dev) + with pytest.raises(ValueError, match="num_tof_bins"): + regular_polygon_events_to_sinogram(desc, events) + + +def test_error_spurious_num_tof_bins(xp: ModuleType, dev: str) -> None: + """4-column events with num_tof_bins specified must raise ValueError.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events = xp.asarray([[0, 0, 1, 0]], dtype=xp.int32, device=dev) + with pytest.raises(ValueError, match="num_tof_bins"): + regular_polygon_events_to_sinogram(desc, events, num_tof_bins=5) + + +def test_error_wrong_col_count(xp: ModuleType, dev: str) -> None: + """Event arrays with column count other than 4 or 5 must raise ValueError.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events = xp.asarray([[0, 0, 1]], dtype=xp.int32, device=dev) + with pytest.raises(ValueError, match="columns"): + regular_polygon_events_to_sinogram(desc, events) + + +def test_error_1d_events(xp: ModuleType, dev: str) -> None: + """1-D event array must raise ValueError.""" + scanner = _small_scanner(xp, dev) + desc = _lor_desc(scanner) + events = xp.asarray([0, 0, 1, 0], dtype=xp.int32, device=dev) + with pytest.raises(ValueError, match="2D"): + regular_polygon_events_to_sinogram(desc, events) From ff265eb405a41437ee13ced156748595439c1592 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Thu, 28 May 2026 14:42:30 +0200 Subject: [PATCH 16/17] wip unlister --- .../05_run_unlister.py | 125 +++++------------- src/parallelproj/unlist.py | 8 +- 2 files changed, 36 insertions(+), 97 deletions(-) diff --git a/docs/examples/02_pet_sinogram_projections/05_run_unlister.py b/docs/examples/02_pet_sinogram_projections/05_run_unlister.py index 070e51e9..fb2004cc 100644 --- a/docs/examples/02_pet_sinogram_projections/05_run_unlister.py +++ b/docs/examples/02_pet_sinogram_projections/05_run_unlister.py @@ -38,6 +38,7 @@ from parallelproj import to_numpy_array from parallelproj.unlist import regular_polygon_events_to_sinogram from img import elliptic_cylinder_phantom +from vis import show_vol_cuts # %% from array_utils import suggest_array_backend_and_device @@ -111,10 +112,10 @@ print(f"\nNumber of events : {len(events)}") sino_unlisted = regular_polygon_events_to_sinogram(lor_desc, events) -y_span1_np = to_numpy_array(y_span1).astype(np.float32) +y_span1_np = to_numpy_array(y_span1) -print(f"Span-1 round-trip exact match : {np.array_equal(sino_unlisted, y_span1_np)}") -print(f"Max absolute difference : {float(np.max(np.abs(sino_unlisted - y_span1_np))):.0f}") +assert np.array_equal(sino_unlisted, y_span1_np), "Span-1 round-trip failed" +print("Span-1 round-trip: OK") # %% # Span-3 comparison @@ -129,13 +130,11 @@ op_compress = ppl.SinogramAxialCompressionOperator(lor_desc, target_span=3) span3_desc = op_compress.out_lor_descriptor -y_span3_np = to_numpy_array(op_compress(xp.astype(y_span1, xp.float32))) +y_span3_np = to_numpy_array(op_compress(xp.astype(y_span1, xp.float32))).astype(np.int32) sino_span3_unlisted = regular_polygon_events_to_sinogram(span3_desc, events) -print(f"\nSpan-3 sinogram shape (operator) : {y_span3_np.shape}") -print(f"Span-3 sinogram shape (unlisted) : {sino_span3_unlisted.shape}") -print(f"Span-3 exact match : {np.array_equal(sino_span3_unlisted, y_span3_np)}") -print(f"Max absolute difference : {float(np.max(np.abs(sino_span3_unlisted - y_span3_np))):.0f}") +assert np.array_equal(sino_span3_unlisted, y_span3_np), "Span-3 round-trip failed" +print("Span-3 round-trip: OK") # %% # TOF simulation @@ -173,100 +172,40 @@ sino_tof_unlisted = regular_polygon_events_to_sinogram( lor_desc, events_tof, num_tof_bins=num_tof_bins ) -y_tof_np = to_numpy_array(y_tof).astype(np.float32) +y_tof_np = to_numpy_array(y_tof) -print(f"TOF round-trip exact match : {np.array_equal(sino_tof_unlisted, y_tof_np)}") -print(f"Max absolute difference : {float(np.max(np.abs(sino_tof_unlisted - y_tof_np))):.0f}") +assert np.array_equal(sino_tof_unlisted, y_tof_np), "TOF round-trip failed" +print("TOF round-trip: OK") # %% # Visualisation # ------------- # -# Non-TOF: ground-truth span-1 sinogram, unlisted span-1 sinogram, difference -# and the span-3 comparison. - -v_ax = lor_desc.view_axis_num - -fig, axes = plt.subplots(2, 3, figsize=(13, 8)) - -# --- row 0: span-1 round-trip --- -vmax1 = float(np.max(y_span1_np)) - -ax = axes[0, 0] -ax.imshow(y_span1_np.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax1) -ax.set_title("y_span1 (ground truth,\nsummed over views)") -ax.set_xlabel("planes") -ax.set_ylabel("radial") - -ax = axes[0, 1] -ax.imshow(sino_unlisted.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax1) -ax.set_title("Unlisted span-1\n(summed over views)") -ax.set_xlabel("planes") - -ax = axes[0, 2] -diff1 = sino_unlisted - y_span1_np -im = ax.imshow(diff1.sum(axis=v_ax), aspect="auto", cmap="bwr", vmin=-1, vmax=1) -ax.set_title("Difference\n(must be all zeros)") -ax.set_xlabel("planes") -fig.colorbar(im, ax=ax) - -# --- row 1: span-3 comparison --- -v_ax3 = span3_desc.view_axis_num -vmax3 = float(np.max(y_span3_np)) - -ax = axes[1, 0] -ax.imshow(y_span3_np.sum(axis=v_ax3), aspect="auto", vmin=0, vmax=vmax3) -ax.set_title("y_span3 via\nSinogramAxialCompressionOperator") -ax.set_xlabel("planes") -ax.set_ylabel("radial") - -ax = axes[1, 1] -ax.imshow(sino_span3_unlisted.sum(axis=v_ax3), aspect="auto", vmin=0, vmax=vmax3) -ax.set_title("Unlisted directly\ninto span-3") -ax.set_xlabel("planes") - -ax = axes[1, 2] -diff3 = sino_span3_unlisted - y_span3_np -im3 = ax.imshow(diff3.sum(axis=v_ax3), aspect="auto", cmap="bwr", vmin=-1, vmax=1) -ax.set_title("Span-3 difference\n(must be all zeros)") -ax.set_xlabel("planes") -fig.colorbar(im3, ax=ax) - -fig.suptitle("Non-TOF sinogram round-trips (radial × planes, summed over views)") -fig.tight_layout() -fig.show() +# :func:`.show_vol_cuts` handles both 3-D (non-TOF) and 4-D (TOF) sinograms +# with the same call. For the 4-D case the TOF axis is moved to the front so +# the full-width leading-axis slider browses individual TOF bins. # %% -# TOF comparison: sinogram summed over TOF bins. - -fig2, axes2 = plt.subplots(1, 3, figsize=(13, 4)) - -y_tof_spatial = y_tof_np.sum(axis=-1) -sino_tof_spatial = sino_tof_unlisted.sum(axis=-1) -vmax_tof = float(np.max(y_tof_spatial)) - -ax = axes2[0] -ax.imshow(y_tof_spatial.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax_tof) -ax.set_title("y_tof (ground truth,\nTOF-summed, view-summed)") -ax.set_xlabel("planes") -ax.set_ylabel("radial") - -ax = axes2[1] -ax.imshow(sino_tof_spatial.sum(axis=v_ax), aspect="auto", vmin=0, vmax=vmax_tof) -ax.set_title("Unlisted TOF span-1\n(TOF-summed, view-summed)") -ax.set_xlabel("planes") +_, _, _w1 = show_vol_cuts( + y_span1_np, + axis_labels=("rad", "view", "plane"), + fig_title="y_span1 (non-TOF span-1)", +) -ax = axes2[2] -diff_tof = sino_tof_unlisted - y_tof_np -im_tof = ax.imshow( - diff_tof.sum(axis=(-1, v_ax)), aspect="auto", cmap="bwr", vmin=-1, vmax=1 +# %% +_, _, _w2 = show_vol_cuts( + y_span3_np, + axis_labels=("rad", "view", "plane"), + fig_title="y_span3 (span-3 via SinogramAxialCompressionOperator)", ) -ax.set_title("TOF difference\n(must be all zeros)") -ax.set_xlabel("planes") -fig2.colorbar(im_tof, ax=ax) -fig2.suptitle( - "TOF sinogram round-trip (radial × planes, summed over views and TOF bins)" +# %% +# Transpose from (rad, view, plane, tof) → (tof, rad, view, plane) so the +# full-width slider browses TOF bins. +_, _, _w3 = show_vol_cuts( + y_tof_np.transpose(3, 0, 1, 2), + axis_labels=("tof", "rad", "view", "plane"), + fig_title="y_tof (TOF span-1)", ) -fig2.tight_layout() -fig2.show() + +plt.show() diff --git a/src/parallelproj/unlist.py b/src/parallelproj/unlist.py index 2d2cab8d..7fadb399 100644 --- a/src/parallelproj/unlist.py +++ b/src/parallelproj/unlist.py @@ -105,7 +105,7 @@ def regular_polygon_events_to_sinogram( sinogram : np.ndarray Histogram sinogram. Shape is ``spatial_sinogram_shape`` for non-TOF or ``(*spatial_sinogram_shape, num_tof_bins)`` for TOF. - Dtype is ``float32``. + Dtype is ``int32``. """ if tof_bin_sign not in (1, -1): raise ValueError("tof_bin_sign must be +1 or -1") @@ -138,8 +138,8 @@ def regular_polygon_events_to_sinogram( if n_events == 0: if tof_mode: - return np.zeros((*shape_spatial, num_tof_bins), dtype=np.float32) - return np.zeros(shape_spatial, dtype=np.float32) + return np.zeros((*shape_spatial, num_tof_bins), dtype=np.int32) + return np.zeros(shape_spatial, dtype=np.int32) d1 = events_np[:, 0] r1 = events_np[:, 1] @@ -224,5 +224,5 @@ def regular_polygon_events_to_sinogram( output_shape = shape_spatial flat_sino_valid = flat_sino[valid] - sino_flat = np.bincount(flat_sino_valid, minlength=total_bins).astype(np.float32) + sino_flat = np.bincount(flat_sino_valid, minlength=total_bins).astype(np.int32) return sino_flat.reshape(output_shape) From caada960ca900f68930f847eb1d23be977ed3a17 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Thu, 28 May 2026 17:43:19 +0200 Subject: [PATCH 17/17] use native bincount to support GPU unlisting --- .../05_run_unlister.py | 55 ++++++---- src/parallelproj/unlist.py | 103 ++++++++++++------ tests/test_unlist.py | 37 ++++++- 3 files changed, 137 insertions(+), 58 deletions(-) diff --git a/docs/examples/02_pet_sinogram_projections/05_run_unlister.py b/docs/examples/02_pet_sinogram_projections/05_run_unlister.py index fb2004cc..b9bb7e00 100644 --- a/docs/examples/02_pet_sinogram_projections/05_run_unlister.py +++ b/docs/examples/02_pet_sinogram_projections/05_run_unlister.py @@ -78,8 +78,12 @@ lor_desc, img_shape=img_shape, voxel_size=voxel_size ) -print(f"Scanner : {scanner.num_lor_endpoints_per_ring} crystals/ring × {scanner.num_rings} rings") -print(f"Sinogram : {lor_desc.spatial_sinogram_shape} (num_rad × num_views × num_planes)") +print( + f"Scanner : {scanner.num_lor_endpoints_per_ring} crystals/ring × {scanner.num_rings} rings" +) +print( + f"Sinogram : {lor_desc.spatial_sinogram_shape} (num_rad × num_views × num_planes)" +) # %% # Simulate non-TOF PET data @@ -98,23 +102,30 @@ ) total_counts = int(xp.sum(y_span1)) -print(f"\nNon-TOF span-1 sinogram : shape={tuple(y_span1.shape)}, total counts={total_counts}") +print( + f"\nNon-TOF span-1 sinogram : shape={tuple(y_span1.shape)}, total counts={total_counts}" +) # %% # Non-TOF round-trip (span-1) # --------------------------- # -# Convert the integer sinogram to crystal-index events and unlist back into -# span-1. With ``radial_trim=10`` the scanner has no self-pair bins, so -# every count is round-tripped exactly. +# Convert the integer sinogram to crystal-index events, move them to the +# active device, then unlist. The returned sinogram lives on the same +# device as the events — no device transfer needed for the comparison. +# +# With ``radial_trim=10`` the scanner has no self-pair bins, so every count +# is round-tripped exactly. -events = proj.convert_sinogram_to_crystal_index_events(y_span1, shuffle=True) -print(f"\nNumber of events : {len(events)}") +events = xp.asarray( + proj.convert_sinogram_to_crystal_index_events(y_span1, shuffle=True), + device=dev, +) +print(f"\nNumber of events : {events.shape[0]}") sino_unlisted = regular_polygon_events_to_sinogram(lor_desc, events) -y_span1_np = to_numpy_array(y_span1) -assert np.array_equal(sino_unlisted, y_span1_np), "Span-1 round-trip failed" +assert bool(xp.all(sino_unlisted == y_span1)), "Span-1 round-trip failed" print("Span-1 round-trip: OK") # %% @@ -130,10 +141,10 @@ op_compress = ppl.SinogramAxialCompressionOperator(lor_desc, target_span=3) span3_desc = op_compress.out_lor_descriptor -y_span3_np = to_numpy_array(op_compress(xp.astype(y_span1, xp.float32))).astype(np.int32) +y_span3 = xp.astype(op_compress(xp.astype(y_span1, xp.float32)), xp.int32) sino_span3_unlisted = regular_polygon_events_to_sinogram(span3_desc, events) -assert np.array_equal(sino_span3_unlisted, y_span3_np), "Span-3 round-trip failed" +assert bool(xp.all(sino_span3_unlisted == y_span3)), "Span-3 round-trip failed" print("Span-3 round-trip: OK") # %% @@ -156,7 +167,9 @@ dtype=xp.int32, ) -print(f"\nTOF sinogram shape : {tuple(y_tof.shape)}, total counts={int(xp.sum(y_tof))}") +print( + f"\nTOF sinogram shape : {tuple(y_tof.shape)}, total counts={int(xp.sum(y_tof))}" +) # %% # TOF round-trip (span-1) @@ -166,15 +179,17 @@ # dimension automatically (4-D input) and returns ``(d1, r1, d2, r2, tof_bin)`` # rows, where bin 0 is the bin closest to d1 (the xstart crystal). -events_tof = proj.convert_sinogram_to_crystal_index_events(y_tof, shuffle=True) -print(f"\nNumber of TOF events : {len(events_tof)}") +events_tof = xp.asarray( + proj.convert_sinogram_to_crystal_index_events(y_tof, shuffle=True), + device=dev, +) +print(f"\nNumber of TOF events : {events_tof.shape[0]}") sino_tof_unlisted = regular_polygon_events_to_sinogram( lor_desc, events_tof, num_tof_bins=num_tof_bins ) -y_tof_np = to_numpy_array(y_tof) -assert np.array_equal(sino_tof_unlisted, y_tof_np), "TOF round-trip failed" +assert bool(xp.all(sino_tof_unlisted == y_tof)), "TOF round-trip failed" print("TOF round-trip: OK") # %% @@ -187,14 +202,14 @@ # %% _, _, _w1 = show_vol_cuts( - y_span1_np, + to_numpy_array(y_span1), axis_labels=("rad", "view", "plane"), fig_title="y_span1 (non-TOF span-1)", ) # %% _, _, _w2 = show_vol_cuts( - y_span3_np, + to_numpy_array(y_span3), axis_labels=("rad", "view", "plane"), fig_title="y_span3 (span-3 via SinogramAxialCompressionOperator)", ) @@ -203,7 +218,7 @@ # Transpose from (rad, view, plane, tof) → (tof, rad, view, plane) so the # full-width slider browses TOF bins. _, _, _w3 = show_vol_cuts( - y_tof_np.transpose(3, 0, 1, 2), + to_numpy_array(y_tof).transpose(3, 0, 1, 2), axis_labels=("tof", "rad", "view", "plane"), fig_title="y_tof (TOF span-1)", ) diff --git a/src/parallelproj/unlist.py b/src/parallelproj/unlist.py index 7fadb399..67c5edc4 100644 --- a/src/parallelproj/unlist.py +++ b/src/parallelproj/unlist.py @@ -5,8 +5,9 @@ from typing import TYPE_CHECKING, Any import numpy as np +import array_api_compat -from ._backend import to_numpy_array +from ._backend import Array, to_numpy_array from .pet_lors import RegularPolygonPETLORDescriptor if TYPE_CHECKING: @@ -66,7 +67,7 @@ def regular_polygon_events_to_sinogram( events: Any, num_tof_bins: int | None = None, tof_bin_sign: int = 1, -) -> np.ndarray: +) -> Array: """Histogram listmode events into a sinogram. Parameters @@ -102,20 +103,38 @@ def regular_polygon_events_to_sinogram( Returns ------- - sinogram : np.ndarray - Histogram sinogram. Shape is ``spatial_sinogram_shape`` for non-TOF - or ``(*spatial_sinogram_shape, num_tof_bins)`` for TOF. + sinogram : Array + Histogram sinogram on the same device as ``events``. + Shape is ``spatial_sinogram_shape`` for non-TOF or + ``(*spatial_sinogram_shape, num_tof_bins)`` for TOF. Dtype is ``int32``. + + Raises + ------ + NotImplementedError + If the array backend of ``events`` does not provide ``bincount`` + (e.g. ``array_api_strict``). Supported backends are + **numpy**, **cupy**, and **torch**. """ + xp = array_api_compat.get_namespace(events) + dev = array_api_compat.device(events) + + if not hasattr(xp, "bincount"): + raise NotImplementedError( + "regular_polygon_events_to_sinogram requires a backend with " + "bincount support (numpy, cupy, torch); " + f"got {xp.__name__!r}" + ) + if tof_bin_sign not in (1, -1): raise ValueError("tof_bin_sign must be +1 or -1") - events_np = np.asarray(to_numpy_array(events), dtype=np.int32) + events_xp = xp.asarray(events, dtype=xp.int32, device=dev) - if events_np.ndim != 2: + if events_xp.ndim != 2: raise ValueError("events must be a 2D array") - n_events, n_cols = events_np.shape + n_events, n_cols = events_xp.shape if n_cols == 4: if num_tof_bins is not None: @@ -138,18 +157,22 @@ def regular_polygon_events_to_sinogram( if n_events == 0: if tof_mode: - return np.zeros((*shape_spatial, num_tof_bins), dtype=np.int32) - return np.zeros(shape_spatial, dtype=np.int32) + return xp.zeros((*shape_spatial, num_tof_bins), dtype=xp.int32, device=dev) + return xp.zeros(shape_spatial, dtype=xp.int32, device=dev) - d1 = events_np[:, 0] - r1 = events_np[:, 1] - d2 = events_np[:, 2] - r2 = events_np[:, 3] + d1 = events_xp[:, 0] + r1 = events_xp[:, 1] + d2 = events_xp[:, 2] + r2 = events_xp[:, 3] if tof_mode: - tof_raw = events_np[:, 4] + tof_raw = events_xp[:, 4] - inring_lut, inring_tof_sign_lut = _build_inring_luts(lor_descriptor) - ring_pair_table = lor_descriptor.michelogram.plane_for_ring_pair_table + inring_lut_np, inring_tof_sign_lut_np = _build_inring_luts(lor_descriptor) + inring_lut = xp.asarray(inring_lut_np, device=dev) + inring_tof_sign_lut = xp.asarray(inring_tof_sign_lut_np, device=dev) + ring_pair_table = xp.asarray( + lor_descriptor.michelogram.plane_for_ring_pair_table, device=dev + ) n_crystals = lor_descriptor.scanner.num_lor_endpoints_per_ring num_rings = lor_descriptor.scanner.num_rings @@ -168,37 +191,41 @@ def regular_polygon_events_to_sinogram( ) # safe fallback indices for out-of-bounds events (they will be masked out) - d1s = np.where(valid, d1, 0) - d2s = np.where(valid, d2, 0) + d1s = xp.where(valid, d1, xp.zeros(1, dtype=xp.int32, device=dev)) + d2s = xp.where(valid, d2, xp.zeros(1, dtype=xp.int32, device=dev)) inring_flat = inring_lut[d1s, d2s] tof_sign_vals = inring_tof_sign_lut[d1s, d2s] # the crystal pair must form a valid LOR - valid &= inring_flat >= 0 + valid = valid & (inring_flat >= 0) # canonical ring ordering: +1 means d1 is xstart, so r1 is the start ring. # Use safe fallback (0) for out-of-range ring indices so the table lookup # never receives negative or OOB values (numpy wraps negative indices). - r1_safe = np.where(valid, r1, 0) - r2_safe = np.where(valid, r2, 0) + r1_safe = xp.where(valid, r1, xp.zeros(1, dtype=xp.int32, device=dev)) + r2_safe = xp.where(valid, r2, xp.zeros(1, dtype=xp.int32, device=dev)) is_d1_start = tof_sign_vals == 1 - r_start = np.where(is_d1_start, r1_safe, r2_safe) - r_end = np.where(is_d1_start, r2_safe, r1_safe) + r_start = xp.where(is_d1_start, r1_safe, r2_safe) + r_end = xp.where(is_d1_start, r2_safe, r1_safe) # plane lookup plane_idx = ring_pair_table[r_start, r_end] - valid &= plane_idx >= 0 + valid = valid & (plane_idx >= 0) if tof_mode: - valid &= (tof_raw >= 0) & (tof_raw < num_tof_bins) - tof_raw_safe = np.where(valid, tof_raw, 0) + valid = valid & (tof_raw >= 0) & (tof_raw < num_tof_bins) + tof_raw_safe = xp.where(valid, tof_raw, xp.zeros(1, dtype=xp.int32, device=dev)) # flip the bin when the canonical direction does not match tof_bin_sign flip = (tof_bin_sign * tof_sign_vals) == -1 - sinogram_tof = np.where(flip, num_tof_bins - 1 - tof_raw_safe, tof_raw_safe) + sinogram_tof = xp.where( + flip, + xp.asarray(num_tof_bins - 1, dtype=xp.int32, device=dev) - tof_raw_safe, + tof_raw_safe, + ) # decompose the in-ring flat index into view and radial indices - safe_flat = np.where(valid, inring_flat, 0) + safe_flat = xp.where(valid, inring_flat, xp.zeros(1, dtype=xp.int32, device=dev)) view_idx = safe_flat // num_rad rad_idx = safe_flat % num_rad @@ -208,15 +235,15 @@ def regular_polygon_events_to_sinogram( r_ax = lor_descriptor.radial_axis_num strides = [int(np.prod(shape_spatial[i + 1 :])) for i in range(3)] - safe_plane = np.where(valid, plane_idx, 0) + safe_plane = xp.where(valid, plane_idx, xp.zeros(1, dtype=xp.int32, device=dev)) flat_sino = ( - rad_idx.astype(np.int64) * strides[r_ax] - + view_idx.astype(np.int64) * strides[v_ax] - + safe_plane.astype(np.int64) * strides[p_ax] + xp.astype(rad_idx, xp.int64) * strides[r_ax] + + xp.astype(view_idx, xp.int64) * strides[v_ax] + + xp.astype(safe_plane, xp.int64) * strides[p_ax] ) if tof_mode: - flat_sino = flat_sino * num_tof_bins + sinogram_tof.astype(np.int64) + flat_sino = flat_sino * num_tof_bins + xp.astype(sinogram_tof, xp.int64) total_bins = int(np.prod(shape_spatial)) * num_tof_bins output_shape = (*shape_spatial, num_tof_bins) else: @@ -224,5 +251,9 @@ def regular_polygon_events_to_sinogram( output_shape = shape_spatial flat_sino_valid = flat_sino[valid] - sino_flat = np.bincount(flat_sino_valid, minlength=total_bins).astype(np.int32) - return sino_flat.reshape(output_shape) + if flat_sino_valid.shape[0] == 0: + return xp.zeros(output_shape, dtype=xp.int32, device=dev) + sino_flat = xp.astype( + xp.bincount(flat_sino_valid, minlength=total_bins), xp.int32 + ) + return xp.reshape(sino_flat, output_shape) diff --git a/tests/test_unlist.py b/tests/test_unlist.py index 0ce14210..29722169 100644 --- a/tests/test_unlist.py +++ b/tests/test_unlist.py @@ -8,9 +8,25 @@ import parallelproj.pet_scanners as pps import parallelproj.pet_lors as ppl from parallelproj import to_numpy_array -from parallelproj.unlist import _build_inring_luts, regular_polygon_events_to_sinogram +from parallelproj.unlist import _build_inring_luts, regular_polygon_events_to_sinogram as _sinogram_native -from .config import pytestmark # noqa: F401 (pytest picks this up as a module marker) +def regular_polygon_events_to_sinogram(*args, **kwargs): + """Thin wrapper that always returns a numpy array for test assertions.""" + return to_numpy_array(_sinogram_native(*args, **kwargs)) + +from .config import xp_dev_list + +# Only run against backends that provide bincount (numpy, torch, cupy). +# array_api_strict is intentionally excluded; a dedicated test below checks +# that NotImplementedError is raised for backends without bincount. +pytestmark = pytest.mark.parametrize( + "xp,dev", [(xp, dev) for xp, dev in xp_dev_list if hasattr(xp, "bincount")] +) + +# Backends that do NOT have bincount — used by the NotImplementedError test. +_no_bincount_xp_dev = [ + (xp, dev) for xp, dev in xp_dev_list if not hasattr(xp, "bincount") +] # --------------------------------------------------------------------------- @@ -521,3 +537,20 @@ def test_error_1d_events(xp: ModuleType, dev: str) -> None: events = xp.asarray([0, 0, 1, 0], dtype=xp.int32, device=dev) with pytest.raises(ValueError, match="2D"): regular_polygon_events_to_sinogram(desc, events) + + +@pytest.mark.skipif(not _no_bincount_xp_dev, reason="no non-bincount backends available") +def test_error_no_bincount_backend(xp: ModuleType, dev: str) -> None: # noqa: ARG001 + """Backends without bincount must raise NotImplementedError. + + The module-level (xp, dev) are from a bincount backend and are ignored; + the test unconditionally exercises array_api_strict. + """ + import array_api_strict as xp_strict + + import numpy as np_plain + scanner = _small_scanner(np_plain, "cpu") + desc = _lor_desc(scanner) + events = xp_strict.asarray([[0, 0, 1, 0]]) + with pytest.raises(NotImplementedError, match="bincount"): + regular_polygon_events_to_sinogram(desc, events)