DiffCAS: Inference-time CT-free Diffusion Model for Physics-aware Multi-slice Attenuation Correction in Cardiac SPECT
DiffCAS generates attenuation-corrected (AC) cardiac SPECT images directly from non-attenuation-corrected (NAC) inputs without requiring CT at inference time. It is built on a Brownian Bridge Diffusion Model (BBDM) and trained with a two-stage Teacher–Student framework:
- DiffCas-Teacher — a CT-conditioned diffusion model that learns physics-informed attenuation priors from paired NAC/CT/AC training data.
- DiffCas-Student — a CT-free student distilled from the Teacher, combining a reconstruction loss with a cosine-similarity feature-alignment loss so the student replicates the Teacher's internal representations at inference time.
DiffCAS/
├── DiffCas-Teacher/ # CT-conditioned teacher model
│ ├── main.py # Entry point (train / test)
│ ├── configs/
│ │ └── DiffCas-Teacher.yaml
│ ├── datasets/ # Dataset loaders
│ ├── runners/
│ │ └── BaseRunner.py # Training/validation/test loop
│ ├── evaluation/ # FID, LPIPS, diversity metrics
│ ├── calculate_metrics.py # SSIM, PSNR, RMSE
│ ├── preprocess_and_evaluation.py
│ ├── Register.py
│ ├── utils.py
│ └── environment.yml
│
└── DiffCas-Student/ # CT-free student model
├── main.py # Entry point (train / test)
├── configs/
│ └── Template-DiffCas-S.yaml
├── datasets/ # Dataset loaders
├── runners/
│ └── BaseRunner.py # Training loop with distillation losses
├── evaluation/ # FID, LPIPS, diversity metrics
├── model/ # UNet and registration (Reg / Transformer_2D)
├── shell/
│ └── Template-shell.sh # Example training / test commands
├── calculate_metrics.py # SSIM, PSNR, RMSE
├── preprocess_and_evaluation.py
├── Register.py
├── utils.py
└── environment.yml
- Brownian Bridge Diffusion Model (BBDM): models the NAC→AC translation as a stochastic bridge between two image distributions rather than pure noise.
- Physics-aware Teacher: conditions on CT images via a
CT_transformerUNet conditioning key, encoding patient-specific attenuation maps. - Multi-slice contextual learning: each sample spans the target slice ±8 neighboring slices (18 input/output channels), capturing volumetric context.
- Knowledge distillation: the Student is trained with a combined loss — pixel-level L1 reconstruction and cosine-similarity alignment of Teacher feature maps at three UNet stages (input block, middle block, output block).
- Spatial registration: a learnable deformable registration network (
Reg/Transformer_2D) is co-trained with the Student to handle slice-alignment. - Exponential Moving Average (EMA): EMA weights (decay 0.995) are used during validation and inference for stable outputs.
- Multi-GPU training: native PyTorch DDP with NCCL backend.
Both modules share the same conda environment (BBDM):
conda env create -f DiffCas-Teacher/environment.yml
conda activate BBDMKey dependencies:
| Package | Version |
|---|---|
| Python | 3.9 |
| PyTorch | 1.12.1+cu113 |
| torchvision | 0.13.1+cu113 |
| pytorch-lightning | 1.9.3 |
| lpips | 0.1.4 |
| einops | 0.6.0 |
| transformers | 4.26.1 |
| scikit-image | 0.20.0 |
Both Teacher and Student expect a dataset in custom_aligned format.
Set dataset_path in the config YAML to point to your data directory.
The default image resolution is 64×64 with 3 channels, normalized
to [-1, 1] (to_normal: True).
cd DiffCas-Teacher
python main.py \
--config configs/DiffCas-Teacher.yaml \
--train \
--save_top \
--gpu_ids 0Multi-GPU (e.g., 4 GPUs):
python main.py \
--config configs/DiffCas-Teacher.yaml \
--train \
--save_top \
--gpu_ids 0,1,2,3Resume from a checkpoint:
python main.py \
--config configs/DiffCas-Teacher.yaml \
--train \
--resume_model path/to/last_model.pth \
--resume_optim path/to/last_optim_sche.pth \
--gpu_ids 0Edit configs/Template-DiffCas-S.yaml and set teacher_ckpt to the
Teacher checkpoint path, then:
cd DiffCas-Student
python main.py \
--config configs/Template-DiffCas-S.yaml \
--train \
--sample_at_start \
--save_top \
--gpu_ids 0 \
--resume_model path/to/model_ckpt \
--resume_optim path/to/optim_ckpt| Hyperparameter | Teacher | Student |
|---|---|---|
| Epochs | 200 | 200 |
| Max steps | 400 000 | 400 000 |
| Batch size (train) | 4 | 8 |
| Optimizer | Adam | Adam |
| Learning rate | 1e-4 | 1e-4 |
| LR scheduler | ReduceLROnPlateau | ReduceLROnPlateau |
| Diffusion timesteps | 500 | 500 |
| Sampling steps | 100 | 100 |
| EMA decay | 0.995 | 0.995 |
| Grad accumulation | 2 | 2 |
| Neighbor slices | 8 | 8 |
# Teacher (requires CT at test time)
cd DiffCas-Teacher
python main.py \
--config configs/DiffCas-Teacher.yaml \
--sample_to_eval \
--gpu_ids 0 \
--resume_model path/to/model_ckpt
# Student (CT-free inference)
cd DiffCas-Student
python main.py \
--config configs/Template-DiffCas-S.yaml \
--sample_to_eval \
--gpu_ids 0 \
--resume_model path/to/model_ckptResults are written to the results/ directory under the path derived from
dataset_name and model_name in the config.
Compute image quality metrics on saved predictions:
# SSIM / PSNR / RMSE (computed inline during sampling via calculate_metrics.py)
# LPIPS
python preprocess_and_evaluation.py \
--func_name LPIPS \
--source_dir path/to/predictions \
--target_dir path/to/ground_truth \
--num_samples 5
# Diversity across stochastic samples
python preprocess_and_evaluation.py \
--func_name diversity \
--source_dir path/to/predictions \
--num_samples 5Supported metrics: SSIM, PSNR, RMSE, LPIPS, Diversity.
If you find this work useful, please cite:
@article{vu2026diffcas,
title={DiffCAS: Inference-time CT-free Diffusion Model for Physics-aware
Multi-slice Attenuation Correction in Cardiac SPECT},
author={Vu, Hoang Minh and Pham, Trung Kien and Nguyen, Thi Ha Chi and
Nguyen, Hai Dang and Nguyen, Dac Thai and Son, Mai Hong and
Nguyen, Thanh Trung and Nguyen, Trung Thanh and Nguyen, Phi Le},
year={2026}
}For questions or collaborations, please open an issue or contact the corresponding authors.