Skip to content

Commit 1aa7028

Browse files
final baseline
1 parent 06b9ad3 commit 1aa7028

4 files changed

Lines changed: 32 additions & 32 deletions

File tree

src/baseline/config.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
data_dir: /data/bic-mac-data/train
2-
cache_dir: /data/bic-mac-cache
3-
batch_size: 1
1+
data_dir: /data/bic-mac-data/train #CHANGE to your dataset path
42
num_workers: 2
5-
output_dir: outputs4
3+
output_dir: outputs
64

75
epochs: 250
86
learning_rate: 0.0003
97

108
patch_size: [192,192,192]
11-
train_num_samples: 2
12-
val_num_samples: 8
9+
train_num_samples: 2 #Number of patches to sample per train subject
10+
val_num_samples: 8 #Number of patches to sample per val subject

src/baseline/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
MODEL_PATH = Path(__file__).parent / "weights/best_model.pth"
21-
MODEL_PATH = Path("/sonne/hinge/Projects/challenge-codebase/src/baseline/outputs2/checkpoints/last_model.pth")
21+
MODEL_PATH = Path("/sonne/hinge/Projects/challenge-codebase/src/baseline/outputs4/checkpoints/best_model.pth")
2222
PATCH_SIZE = (192, 192, 192)
2323
SW_BATCH = 2
2424
OVERLAP = 0.5
@@ -29,7 +29,7 @@ def predict(features_dir, out_path):
2929
transforms = Compose([
3030
LoadImaged(keys=["nacpet"]),
3131
EnsureChannelFirstd(keys=["nacpet"]),
32-
NormalizeIntensityd(keys=["nacpet"], nonzero=False, subtrahend=0, channel_wise=True),
32+
NormalizeIntensityd(keys=["nacpet"], nonzero=False, subtrahend=[0], channel_wise=True),
3333
ConcatItemsd(keys=["nacpet"], name="input"),
3434
EnsureTyped(keys=["input"]),
3535
])

src/baseline/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def main():
3838
train_dataset = CacheDataset(
3939
data=train_data,
4040
transform=train_transforms,
41-
cache_rate=1.0,
42-
num_workers=8,
41+
cache_rate=1.0, # Change this to reduce memory footprint
42+
num_workers=cfg["num_workers"],
4343
)
4444
loader = DataLoader(
4545
train_dataset,
@@ -55,7 +55,7 @@ def main():
5555
data=val_data,
5656
transform=val_transforms,
5757
cache_rate=1.0,
58-
num_workers=4,
58+
num_workers=cfg["num_workers"],
5959
)
6060
val_loader = DataLoader(
6161
val_dataset,
@@ -106,15 +106,15 @@ def main():
106106

107107
x = batch["input"].to(device)
108108
y = batch["ct"].to(device)
109-
mask = batch["prediction_mask"].to(device)
110-
109+
mask = batch["prediction_mask"].bool().to(device)
110+
y[~mask] = 0 # don't bother trying to predict the bed
111111
optimizer.zero_grad()
112112

113113
with torch.amp.autocast("cuda"):
114114

115115
pred = model(x)
116116

117-
loss = l1_loss(pred[mask.bool()], y[mask.bool()])
117+
loss = l1_loss(pred, y)
118118

119119
scaler.scale(loss).backward()
120120
scaler.step(optimizer)
@@ -135,10 +135,12 @@ def main():
135135
for batch in val_loader:
136136
x = batch["input"].to(device)
137137
y = batch["ct"].to(device)
138-
mask = batch["prediction_mask"].to(device)
138+
mask = batch["prediction_mask"].bool().to(device)
139+
y[~mask] = 0 # don't bother trying to predict the bed
140+
139141
with torch.amp.autocast("cuda"):
140142
pred = model(x)
141-
loss = l1_loss(pred[mask.bool()], y[mask.bool()])
143+
loss = l1_loss(pred, y)
142144
val_loss += loss.item()
143145
avg_val_loss = val_loss / len(val_loader)
144146

src/baseline/transforms.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ def get_transforms(patch_size, num_samples=2):
4545
num_samples=num_samples
4646
),
4747

48-
RandGaussianNoised(keys=["input"], prob=0.5, mean=0.0, std=0.05),
49-
RandScaleIntensityd(keys=["input"], factors=0.1, prob=0.5),
50-
RandShiftIntensityd(keys=["input"], offsets=0.1, prob=0.5),
51-
RandGaussianSmoothd(
52-
keys=["input"],
53-
sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0),
54-
prob=0.3,
55-
),
48+
#RandGaussianNoised(keys=["input"], prob=0.5, mean=0.0, std=0.05),
49+
#RandScaleIntensityd(keys=["input"], factors=0.1, prob=0.5),
50+
#RandShiftIntensityd(keys=["input"], offsets=0.1, prob=0.5),
51+
#RandGaussianSmoothd(
52+
# keys=["input"],
53+
# sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0),
54+
# prob=0.3,
55+
#),
5656

57-
RandAffined(
58-
keys=["input", "ct", "prediction_mask"],
59-
prob=0.5,
60-
rotate_range=(0.087, 0.087, 0.087), # ±5°
61-
scale_range=(0.05, 0.05, 0.05), # ±5%
62-
mode=("bilinear", "bilinear", "nearest"),
63-
padding_mode="border",
64-
),
57+
# RandAffined(
58+
# keys=["input", "ct", "prediction_mask"],
59+
# prob=0.5,
60+
# rotate_range=(0.087, 0.087, 0.087), # ±5°
61+
# scale_range=(0.05, 0.05, 0.05), # ±5%
62+
# mode=("bilinear", "bilinear", "nearest"),
63+
# padding_mode="border",
64+
# ),
6565

6666
EnsureTyped(keys=["input", "ct", "prediction_mask"]),
6767

0 commit comments

Comments
 (0)