Skip to content

Commit a2d258f

Browse files
merge baseline
2 parents 3651e8f + 1aa7028 commit a2d258f

5 files changed

Lines changed: 107 additions & 76 deletions

File tree

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
conversion/
33
CLAUDE.md
44
outputs*/
5-
outputs2/
6-
75
*.nii.gz
8-
v3/
96
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
107
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
118

src/baseline/config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
data_dir: /data/bic-mac-data/train
2-
cache_dir: /data/bic-mac-cache
3-
batch_size: 1 #batch_size is actually 2, since monai samples two patches per volume
1+
data_dir: /data/bic-mac-data/train #CHANGE to your dataset path
42
num_workers: 2
3+
output_dir: outputs
54

65
epochs: 250
76
learning_rate: 0.0003
87

98
patch_size: [192,192,192]
9+
train_num_samples: 2 #Number of patches to sample per train subject
10+
val_num_samples: 8 #Number of patches to sample per val subject

src/baseline/predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
MODEL_PATH = Path(__file__).parent / "weights/best_model.pth"
21+
MODEL_PATH = Path("/sonne/hinge/Projects/challenge-codebase/src/baseline/outputs4/checkpoints/best_model.pth")
2122
PATCH_SIZE = (192, 192, 192)
2223
SW_BATCH = 2
2324
OVERLAP = 0.5
@@ -28,7 +29,7 @@ def predict(features_dir, out_path):
2829
transforms = Compose([
2930
LoadImaged(keys=["nacpet"]),
3031
EnsureChannelFirstd(keys=["nacpet"]),
31-
NormalizeIntensityd(keys=["nacpet"], nonzero=True, channel_wise=True),
32+
NormalizeIntensityd(keys=["nacpet"], nonzero=False, subtrahend=[0], channel_wise=True),
3233
ConcatItemsd(keys=["nacpet"], name="input"),
3334
EnsureTyped(keys=["input"]),
3435
])

src/baseline/train.py

Lines changed: 74 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,23 @@
22
import torch
33
import yaml
44
import matplotlib.pyplot as plt
5-
import torch.nn.functional as F
6-
7-
from monai.data import Dataset, DataLoader, CacheDataset, PersistentDataset
5+
from monai.data import DataLoader, CacheDataset
86
from tqdm import tqdm
97

108
from dataset import get_dataset
11-
from transforms import get_train_transforms
9+
from transforms import get_transforms
1210
from unet import build_model
1311

1412

1513
torch.backends.cudnn.benchmark = True
1614

17-
# -----------------------------
18-
# CONFIG
19-
# -----------------------------
15+
2016

2117
def load_config():
2218
with open("config.yaml") as f:
2319
return yaml.safe_load(f)
2420

2521

26-
# -----------------------------
27-
# TRAIN
28-
# -----------------------------
2922

3023
def main():
3124

@@ -35,30 +28,39 @@ def main():
3528

3629
print("Using device:", device)
3730

38-
data = get_dataset(cfg["data_dir"])
31+
all_data = get_dataset(cfg["data_dir"])
32+
val_data, train_data = all_data[:2], all_data[2:]
3933

40-
transforms = get_train_transforms(
41-
cfg["patch_size"],
42-
)
34+
train_transforms = get_transforms(cfg["patch_size"], cfg["train_num_samples"])
35+
val_transforms = get_transforms(cfg["patch_size"], cfg["val_num_samples"])
4336

44-
print("Preparing dataset ...")
45-
dataset = PersistentDataset(
46-
data=data,
47-
transform=transforms,
48-
cache_dir=cfg["cache_dir"]
37+
print("Caching train dataset...")
38+
train_dataset = CacheDataset(
39+
data=train_data,
40+
transform=train_transforms,
41+
cache_rate=1.0, # Change this to reduce memory footprint
42+
num_workers=cfg["num_workers"],
43+
)
44+
loader = DataLoader(
45+
train_dataset,
46+
batch_size=cfg["batch_size"],
47+
shuffle=True,
48+
num_workers=cfg["num_workers"],
49+
pin_memory=True,
50+
persistent_workers=True
4951
)
5052

51-
print("Caching dataset...")
52-
dataset = CacheDataset(
53-
data=dataset,
53+
print("Caching val dataset...")
54+
val_dataset = CacheDataset(
55+
data=val_data,
56+
transform=val_transforms,
5457
cache_rate=1.0,
55-
num_workers=8,
58+
num_workers=cfg["num_workers"],
5659
)
57-
58-
loader = DataLoader(
59-
dataset,
60+
val_loader = DataLoader(
61+
val_dataset,
6062
batch_size=cfg["batch_size"],
61-
shuffle=True,
63+
shuffle=False,
6264
num_workers=cfg["num_workers"],
6365
pin_memory=True,
6466
persistent_workers=True
@@ -77,17 +79,18 @@ def main():
7779
T_max=cfg["epochs"]
7880
)
7981

82+
scaler = torch.amp.GradScaler("cuda")
8083
l1_loss = torch.nn.L1Loss()
8184

82-
scaler = torch.amp.GradScaler("cuda")
83-
84-
os.makedirs("outputs/checkpoints", exist_ok=True)
85-
os.makedirs("outputs/logs", exist_ok=True)
86-
os.makedirs("outputs/plots", exist_ok=True)
85+
out = cfg["output_dir"]
86+
os.makedirs(f"{out}/checkpoints", exist_ok=True)
87+
os.makedirs(f"{out}/logs", exist_ok=True)
88+
os.makedirs(f"{out}/plots", exist_ok=True)
8789

88-
best_loss = float("inf")
90+
best_val_loss = float("inf")
8991

90-
loss_history = []
92+
train_loss_history = []
93+
val_loss_history = []
9194

9295
print("Starting training...")
9396

@@ -101,9 +104,10 @@ def main():
101104

102105
for batch in pbar:
103106

104-
x = batch["input"].to(device)
105-
y = batch["ct"].to(device)
106-
107+
x = batch["input"].to(device)
108+
y = batch["ct"].to(device)
109+
mask = batch["prediction_mask"].bool().to(device)
110+
y[~mask] = 0 # don't bother trying to predict the bed
107111
optimizer.zero_grad()
108112

109113
with torch.amp.autocast("cuda"):
@@ -120,41 +124,60 @@ def main():
120124

121125
pbar.set_description(f"loss {loss.item():.4f}")
122126

123-
avg_loss = epoch_loss / len(loader)
127+
avg_train_loss = epoch_loss / len(loader)
124128

125-
print("Epoch", epoch, "Loss", avg_loss)
129+
scheduler.step()
126130

127-
loss_history.append(avg_loss)
131+
# validation
132+
model.eval()
133+
val_loss = 0
134+
with torch.no_grad():
135+
for batch in val_loader:
136+
x = batch["input"].to(device)
137+
y = batch["ct"].to(device)
138+
mask = batch["prediction_mask"].bool().to(device)
139+
y[~mask] = 0 # don't bother trying to predict the bed
128140

129-
scheduler.step()
141+
with torch.amp.autocast("cuda"):
142+
pred = model(x)
143+
loss = l1_loss(pred, y)
144+
val_loss += loss.item()
145+
avg_val_loss = val_loss / len(val_loader)
146+
147+
print(f"Epoch {epoch} train={avg_train_loss:.4f} val={avg_val_loss:.4f}")
148+
149+
train_loss_history.append(avg_train_loss)
150+
val_loss_history.append(avg_val_loss)
130151

131-
# best checkpoint
132-
if avg_loss < best_loss:
152+
# best checkpoint (by val)
153+
if avg_val_loss < best_val_loss:
133154

134-
best_loss = avg_loss
155+
best_val_loss = avg_val_loss
135156

136157
torch.save(
137158
model.state_dict(),
138-
"outputs/checkpoints/best_model.pth"
159+
f"{out}/checkpoints/best_model.pth"
139160
)
140161

141162
# last checkpoint
142163
torch.save(
143164
model.state_dict(),
144-
"outputs/checkpoints/last_model.pth"
165+
f"{out}/checkpoints/last_model.pth"
145166
)
146167

147168
# log
148-
with open("outputs/logs/train_log.txt","a") as f:
149-
f.write(f"{epoch},{avg_loss}\n")
169+
with open(f"{out}/logs/train_log.txt", "a") as f:
170+
f.write(f"{epoch},{avg_train_loss},{avg_val_loss}\n")
150171

151172
# plot loss
152173
plt.figure()
153-
plt.plot(loss_history)
174+
plt.plot(train_loss_history, label="train")
175+
plt.plot(val_loss_history, label="val")
154176
plt.xlabel("Epoch")
155177
plt.ylabel("Loss")
156-
plt.title("Training Loss")
157-
plt.savefig("outputs/plots/loss_curve.png")
178+
plt.title("Train / Val Loss")
179+
plt.legend()
180+
plt.savefig(f"{out}/plots/loss_curve.png")
158181
plt.close()
159182

160183

src/baseline/transforms.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@
55
# however, your model may use all images and metadata available
66
# under the /features folder
77

8-
def get_train_transforms(patch_size):
8+
def get_transforms(patch_size, num_samples=2):
99

1010
transforms = Compose(
1111
[
1212

13-
LoadImaged(keys=["nacpet", "ct"]),
13+
LoadImaged(keys=["nacpet", "ct", "prediction_mask"]),
1414

15-
EnsureChannelFirstd(keys=["nacpet", "ct"]),
15+
EnsureChannelFirstd(keys=["nacpet", "ct", "prediction_mask"]),
1616

1717
NormalizeIntensityd(
1818
keys=["nacpet"],
19-
nonzero=True,
20-
channel_wise=True
19+
nonzero=False,
20+
channel_wise=True,
21+
subtrahend=[0]
22+
2123
),
2224

2325
ScaleIntensityRanged(
@@ -35,26 +37,33 @@ def get_train_transforms(patch_size):
3537
name="input"
3638
),
3739

40+
# Crop first so all random augmentations run on small patches
3841
RandSpatialCropSamplesd(
39-
keys=["input","ct"],
42+
keys=["input", "ct", "prediction_mask"],
4043
roi_size=patch_size,
4144
random_size=False,
42-
num_samples=2
45+
num_samples=num_samples
4346
),
4447

45-
RandFlipd(
46-
keys=["input","ct"],
47-
spatial_axis=0,
48-
prob=0.5
49-
),
48+
#RandGaussianNoised(keys=["input"], prob=0.5, mean=0.0, std=0.05),
49+
#RandScaleIntensityd(keys=["input"], factors=0.1, prob=0.5),
50+
#RandShiftIntensityd(keys=["input"], offsets=0.1, prob=0.5),
51+
#RandGaussianSmoothd(
52+
# keys=["input"],
53+
# sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0),
54+
# prob=0.3,
55+
#),
5056

51-
RandFlipd(
52-
keys=["input","ct"],
53-
spatial_axis=1,
54-
prob=0.5
55-
),
57+
# RandAffined(
58+
# keys=["input", "ct", "prediction_mask"],
59+
# prob=0.5,
60+
# rotate_range=(0.087, 0.087, 0.087), # ±5°
61+
# scale_range=(0.05, 0.05, 0.05), # ±5%
62+
# mode=("bilinear", "bilinear", "nearest"),
63+
# padding_mode="border",
64+
# ),
5665

57-
EnsureTyped(keys=["input","ct"]),
66+
EnsureTyped(keys=["input", "ct", "prediction_mask"]),
5867

5968
]
6069
)

0 commit comments

Comments
 (0)