Skip to content

Commit c0b26aa

Browse files
Merge pull request #6 from bic-mac-challenge/dev/predict_vram
Reduce VRAM usage from 18gb to 5.5gb at inference by using half precision, batch size of 1, and patch aggregation on the CPU
2 parents 043091a + cefbcbe commit c0b26aa

2 files changed

Lines changed: 106 additions & 4 deletions

File tree

src/baseline/predict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
MODEL_PATH = Path(__file__).parent / "outputs/checkpoints/best_model.pth"
2121
PATCH_SIZE = (192, 192, 192)
22-
SW_BATCH = 2
22+
SW_BATCH = 1 # You can increase this if to speed up inference at the cost of VRAM
2323
OVERLAP = 0.5
2424

2525

@@ -44,10 +44,11 @@ def predict(features_dir, out_path):
4444

4545
x = data["input"].unsqueeze(0).to(device)
4646
print("Sliding window inference...")
47-
with torch.no_grad():
47+
with torch.no_grad(), torch.amp.autocast("cuda"):
4848
pred = sliding_window_inference(
4949
x, PATCH_SIZE, SW_BATCH, model,
5050
overlap=OVERLAP, mode="gaussian", progress=True,
51+
sw_device="cuda", device="cpu",
5152
)
5253

5354
# Apply inverse of normalization to get HU

0 commit comments

Comments
 (0)