1+ import argparse
2+ import torch
3+ import nibabel as nib
4+ from pathlib import Path
5+
6+ from monai .inferers import sliding_window_inference
7+ from monai .transforms import (
8+ Compose ,
9+ LoadImaged ,
10+ EnsureChannelFirstd ,
11+ NormalizeIntensityd ,
12+ ConcatItemsd ,
13+ EnsureTyped ,
14+ )
15+
16+ from models .unet import build_model
17+
18+
19+ # -----------------------------
20+ # ARGUMENTS
21+ # -----------------------------
22+
23+ parser = argparse .ArgumentParser ()
24+ parser .add_argument ("--input" , required = True , help = "Path to nacpet.nii.gz" )
25+ parser .add_argument ("--output" , required = True , help = "Path to save pseudo CT" )
26+ args = parser .parse_args ()
27+
28+ INPUT_PATH = Path (args .input )
29+ OUTPUT_PATH = Path (args .output )
30+
31+ OUTPUT_PATH .parent .mkdir (parents = True , exist_ok = True )
32+
33+
34+ # -----------------------------
35+ # CONFIG
36+ # -----------------------------
37+
38+ MODEL_PATH = "weights/best_model.pth"
39+ PATCH_SIZE = (192 ,192 ,192 )
40+ SW_BATCH = 2
41+ OVERLAP = 0.75
42+
43+
44+ # -----------------------------
45+ # TRANSFORMS (PET only)
46+ # -----------------------------
47+
48+ transforms = Compose ([
49+ LoadImaged (keys = ["pet" ]),
50+ EnsureChannelFirstd (keys = ["pet" ]),
51+ NormalizeIntensityd (keys = ["pet" ], nonzero = True , channel_wise = True ),
52+ ConcatItemsd (keys = ["pet" ], name = "input" ),
53+ EnsureTyped (keys = ["input" ])
54+ ])
55+
56+
57+ # -----------------------------
58+ # MODEL
59+ # -----------------------------
60+
61+ device = "cuda" if torch .cuda .is_available () else "cpu"
62+ print ("Using device:" , device )
63+
64+ model = build_model ().to (device )
65+ model .load_state_dict (torch .load (MODEL_PATH , map_location = device ))
66+ model .eval ()
67+
68+
69+ # -----------------------------
70+ # INFERENCE
71+ # -----------------------------
72+
73+ data = {"pet" : INPUT_PATH }
74+
75+ data = transforms (data )
76+
77+ x = data ["input" ].unsqueeze (0 ).to (device )
78+
79+ with torch .no_grad ():
80+
81+ pred = sliding_window_inference (
82+ x ,
83+ PATCH_SIZE ,
84+ SW_BATCH ,
85+ model ,
86+ overlap = OVERLAP ,
87+ mode = "gaussian" ,
88+ )
89+
90+ pred = pred .cpu ().numpy ()[0 ,0 ]
91+
92+ # convert back to HU
93+ pred = pred * 3000 - 1000
94+
95+ ref = nib .load (str (INPUT_PATH ))
96+
97+ nib .save (
98+ nib .Nifti1Image (pred , ref .affine , ref .header ),
99+ str (OUTPUT_PATH )
100+ )
101+
102+ print ("Saved:" , OUTPUT_PATH )
0 commit comments