Skip to content

Commit 7748ef7

Browse files
committed
added docker
1 parent 19ba100 commit 7748ef7

3 files changed

Lines changed: 124 additions & 0 deletions

File tree

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
2+
3+
WORKDIR /app
4+
5+
# system deps
6+
RUN apt-get update && apt-get install -y \
7+
git \
8+
wget \
9+
&& rm -rf /var/lib/apt/lists/*
10+
11+
# python deps
12+
COPY requirements.txt .
13+
RUN pip install --no-cache-dir -r requirements.txt
14+
15+
# copy code
16+
COPY . .
17+
18+
# run inference
19+
ENTRYPOINT ["python", "inference.py"]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import argparse
2+
import torch
3+
import nibabel as nib
4+
from pathlib import Path
5+
6+
from monai.inferers import sliding_window_inference
7+
from monai.transforms import (
8+
Compose,
9+
LoadImaged,
10+
EnsureChannelFirstd,
11+
NormalizeIntensityd,
12+
ConcatItemsd,
13+
EnsureTyped,
14+
)
15+
16+
from models.unet import build_model
17+
18+
19+
# -----------------------------
20+
# ARGUMENTS
21+
# -----------------------------
22+
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--input", required=True, help="Path to nacpet.nii.gz")
25+
parser.add_argument("--output", required=True, help="Path to save pseudo CT")
26+
args = parser.parse_args()
27+
28+
INPUT_PATH = Path(args.input)
29+
OUTPUT_PATH = Path(args.output)
30+
31+
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
32+
33+
34+
# -----------------------------
35+
# CONFIG
36+
# -----------------------------
37+
38+
MODEL_PATH = "weights/best_model.pth"
39+
PATCH_SIZE = (192,192,192)
40+
SW_BATCH = 2
41+
OVERLAP = 0.75
42+
43+
44+
# -----------------------------
45+
# TRANSFORMS (PET only)
46+
# -----------------------------
47+
48+
transforms = Compose([
49+
LoadImaged(keys=["pet"]),
50+
EnsureChannelFirstd(keys=["pet"]),
51+
NormalizeIntensityd(keys=["pet"], nonzero=True, channel_wise=True),
52+
ConcatItemsd(keys=["pet"], name="input"),
53+
EnsureTyped(keys=["input"])
54+
])
55+
56+
57+
# -----------------------------
58+
# MODEL
59+
# -----------------------------
60+
61+
device = "cuda" if torch.cuda.is_available() else "cpu"
62+
print("Using device:", device)
63+
64+
model = build_model().to(device)
65+
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
66+
model.eval()
67+
68+
69+
# -----------------------------
70+
# INFERENCE
71+
# -----------------------------
72+
73+
data = {"pet": INPUT_PATH}
74+
75+
data = transforms(data)
76+
77+
x = data["input"].unsqueeze(0).to(device)
78+
79+
with torch.no_grad():
80+
81+
pred = sliding_window_inference(
82+
x,
83+
PATCH_SIZE,
84+
SW_BATCH,
85+
model,
86+
overlap=OVERLAP,
87+
mode="gaussian",
88+
)
89+
90+
pred = pred.cpu().numpy()[0,0]
91+
92+
# convert back to HU
93+
pred = pred * 3000 - 1000
94+
95+
ref = nib.load(str(INPUT_PATH))
96+
97+
nib.save(
98+
nib.Nifti1Image(pred, ref.affine, ref.header),
99+
str(OUTPUT_PATH)
100+
)
101+
102+
print("Saved:", OUTPUT_PATH)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
monai
2+
nibabel
3+
tqdm

0 commit comments

Comments
 (0)