Skip to content

Commit 19ba100

Browse files
committed
changed input to 1 channel(pet)
1 parent b20b38c commit 19ba100

3 files changed

Lines changed: 14 additions & 29 deletions

File tree

src/baseline/v2/datasets/transforms.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
NormalizeIntensityd,
66
ScaleIntensityRanged,
77
ConcatItemsd,
8-
RandCropByPosNegLabeld,
8+
RandSpatialCropd,
99
RandFlipd,
1010
EnsureTyped,
11-
Lambdad
1211
)
1312

1413

@@ -17,21 +16,16 @@ def get_train_transforms(patch_size, spacing):
1716
transforms = Compose(
1817
[
1918

20-
LoadImaged(keys=["pet", "topogram", "mri_in", "mri_out", "ct"]),
19+
LoadImaged(keys=["pet", "ct"]),
2120

22-
EnsureChannelFirstd(keys=["pet", "topogram", "mri_in", "mri_out", "ct"]),
21+
EnsureChannelFirstd(keys=["pet", "ct"]),
2322

24-
# expand topogram depth
25-
Lambdad(keys=["topogram"], func=lambda x: x.repeat(1,1,1,531)),
26-
27-
# normalize PET and MRI
2823
NormalizeIntensityd(
29-
keys=["pet","mri_in","mri_out"],
24+
keys=["pet"],
3025
nonzero=True,
3126
channel_wise=True
3227
),
3328

34-
# normalize CT
3529
ScaleIntensityRanged(
3630
keys=["ct"],
3731
a_min=-1000,
@@ -41,20 +35,16 @@ def get_train_transforms(patch_size, spacing):
4135
clip=True,
4236
),
4337

44-
# combine modalities
38+
# now input is just PET
4539
ConcatItemsd(
46-
keys=["pet","topogram","mri_in","mri_out"],
40+
keys=["pet"],
4741
name="input"
4842
),
4943

50-
# patch sampling
51-
RandCropByPosNegLabeld(
44+
RandSpatialCropd(
5245
keys=["input","ct"],
53-
label_key="ct",
54-
spatial_size=patch_size,
55-
pos=1,
56-
neg=1,
57-
num_samples=1
46+
roi_size=patch_size,
47+
random_size=False
5848
),
5949

6050
RandFlipd(

src/baseline/v2/inference/inference.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,18 @@
5050

5151
transforms = Compose(
5252
[
53-
LoadImaged(keys=["pet","topogram","mri_in","mri_out"]),
53+
LoadImaged(keys=["pet"]),
5454

55-
EnsureChannelFirstd(keys=["pet","topogram","mri_in","mri_out"]),
56-
57-
Lambdad(keys=["topogram"], func=lambda x: x.repeat(1,1,1,531)),
55+
EnsureChannelFirstd(keys=["pet"]),
5856

5957
NormalizeIntensityd(
60-
keys=["pet","mri_in","mri_out"],
58+
keys=["pet"],
6159
nonzero=True,
6260
channel_wise=True
6361
),
6462

6563
ConcatItemsd(
66-
keys=["pet","topogram","mri_in","mri_out"],
64+
keys=["pet"],
6765
name="input"
6866
),
6967

@@ -99,9 +97,6 @@
9997

10098
data = {
10199
"pet": sub/"features/nacpet.nii.gz",
102-
"topogram": sub/"features/topogram.nii.gz",
103-
"mri_in": sub/"features/mri_combined_in_phase.nii.gz",
104-
"mri_out": sub/"features/mri_combined_out_phase.nii.gz",
105100
}
106101

107102
data = transforms(data)

src/baseline/v2/models/unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,6 @@ def forward(self, x):
147147
def build_model():
148148

149149
return UNet3D(
150-
in_channels=4,
150+
in_channels=1, # only PET as input
151151
out_channels=1
152152
)

0 commit comments

Comments
 (0)