@@ -64,6 +64,48 @@ def _strict_binarize_mask(mask, threshold: float = 0.0):
6464 return (mask > threshold ).astype (mask .dtype , copy = False )
6565
6666
67+ def _target_context (cfg : Config ) -> tuple [int , ...]:
68+ context = getattr (cfg .data .dataloader , "target_context" , None ) or []
69+ if not context :
70+ return tuple (0 for _ in cfg .data .dataloader .patch_size )
71+ return tuple (int (v ) for v in context )
72+
73+
74+ def _effective_patch_size (cfg : Config ) -> tuple [int , ...] | None :
75+ patch_size = tuple (cfg .data .dataloader .patch_size ) if cfg .data .dataloader .patch_size else None
76+ if patch_size is None :
77+ return None
78+ context = _target_context (cfg )
79+ if len (context ) != len (patch_size ):
80+ raise ValueError (
81+ "data.dataloader.target_context must have the same length as patch_size: "
82+ f"{ context } vs { patch_size } "
83+ )
84+ return tuple (int (patch_size [i ]) + int (context [i ]) for i in range (len (patch_size )))
85+
86+
87+ def _append_banis_pre_target_transforms (transforms : list , label_cfg ) -> None :
88+ if bool (getattr (label_cfg , "relabel_connected_components" , False )):
89+ from ..processing .transforms import RelabelConnectedComponentsd
90+
91+ transforms .append (
92+ RelabelConnectedComponentsd (
93+ keys = ["label" ],
94+ connectivity = int (getattr (label_cfg , "relabel_connectivity" , 6 )),
95+ )
96+ )
97+
98+
99+ def _append_target_context_crop (transforms : list , cfg : Config ) -> None :
100+ context = _target_context (cfg )
101+ if not context or not any (v > 0 for v in context ):
102+ return
103+
104+ from ..processing .transforms import LeadingSpatialCropd
105+
106+ transforms .append (LeadingSpatialCropd (roi_size = tuple (cfg .data .dataloader .patch_size )))
107+
108+
67109def _build_nnunet_preprocess_transform (keys , nnunet_pre_cfg , source_spacing ):
68110 """Build NNUNetPreprocessd transform from config."""
69111 source_spacing = getattr (nnunet_pre_cfg , "source_spacing" , None ) or source_spacing
@@ -176,9 +218,7 @@ def build_train_transforms(
176218
177219 # Ensure target patch size is respected (unless using pre-cached dataset)
178220 if not skip_loading :
179- patch_size = (
180- tuple (cfg .data .dataloader .patch_size ) if cfg .data .dataloader .patch_size else None
181- )
221+ patch_size = _effective_patch_size (cfg )
182222 if patch_size and all (size > 0 for size in patch_size ):
183223 # Pad smaller volumes so random crops always succeed
184224 transforms .append (
@@ -208,6 +248,10 @@ def build_train_transforms(
208248 )
209249 )
210250
251+ label_cfg = getattr (cfg .data , "label_transform" , None )
252+ if "label" in keys and label_cfg is not None :
253+ _append_banis_pre_target_transforms (transforms , label_cfg )
254+
211255 # Add augmentations if enabled
212256 if cfg .data .augmentation is not None :
213257 # Pass do_2d flag to augmentation builder
@@ -218,12 +262,10 @@ def build_train_transforms(
218262 transforms .extend (_build_augmentations (cfg .data .augmentation , keys , do_2d = do_2d ))
219263
220264 # Label transformations (affinity, distance transform, etc.)
221- if hasattr ( cfg . data , "label_transform" ) :
265+ if label_cfg is not None :
222266 from ..processing .build import create_label_transform_pipeline
223267 from ..processing .transforms import SegErosionInstanced
224268
225- label_cfg = cfg .data .label_transform
226-
227269 # Apply instance erosion first if specified
228270 if hasattr (label_cfg , "erosion" ) and label_cfg .erosion > 0 :
229271 transforms .append (SegErosionInstanced (keys = ["label" ], tsz_h = label_cfg .erosion ))
@@ -235,6 +277,8 @@ def build_train_transforms(
235277 else :
236278 transforms .append (label_transform )
237279
280+ _append_target_context_crop (transforms , cfg )
281+
238282 # NOTE: Do NOT squeeze labels here!
239283 # - DiceLoss needs (B, 1, H, W) with to_onehot_y=True
240284 # - CrossEntropyLoss needs (B, H, W)
@@ -472,7 +516,13 @@ def _resolve_eval_split():
472516 )
473517 )
474518
475- patch_size = tuple (data_cfg .dataloader .patch_size ) if data_cfg .dataloader .patch_size else None
519+ patch_size = (
520+ _effective_patch_size (cfg )
521+ if mode == "val"
522+ else tuple (data_cfg .dataloader .patch_size )
523+ if data_cfg .dataloader .patch_size
524+ else None
525+ )
476526 if patch_size and all (size > 0 for size in patch_size ):
477527 transforms .append (
478528 SpatialPadd (
@@ -510,12 +560,22 @@ def _resolve_eval_split():
510560 # Test: Skip cropping to enable sliding window inference on full volumes
511561 if mode == "val" :
512562 if patch_size and all (size > 0 for size in patch_size ):
513- transforms .append (
514- CenterSpatialCropd (
515- keys = keys ,
516- roi_size = patch_size ,
563+ if bool (getattr (data_cfg .dataloader , "val_random_sampling" , False )):
564+ transforms .append (
565+ RandSpatialCropd (
566+ keys = keys ,
567+ roi_size = patch_size ,
568+ random_center = True ,
569+ random_size = False ,
570+ )
571+ )
572+ else :
573+ transforms .append (
574+ CenterSpatialCropd (
575+ keys = keys ,
576+ roi_size = patch_size ,
577+ )
517578 )
518- )
519579 # else: mode == "test" -> no cropping for sliding window inference
520580
521581 # Normalization - use smart normalization
@@ -530,6 +590,10 @@ def _resolve_eval_split():
530590 )
531591 )
532592
593+ label_cfg = getattr (data_cfg , "label_transform" , None )
594+ if mode == "val" and "label" in keys and label_cfg is not None :
595+ _append_banis_pre_target_transforms (transforms , label_cfg )
596+
533597 # Only process labels if 'label' is in keys
534598 if "label" in keys :
535599 # Label transformations (affinity, distance transform, etc.)
@@ -558,6 +622,9 @@ def _resolve_eval_split():
558622 else :
559623 transforms .append (label_transform )
560624
625+ if mode == "val" :
626+ _append_target_context_crop (transforms , cfg )
627+
561628 # NOTE: Do NOT squeeze labels here!
562629 # - DiceLoss needs (B, 1, H, W) with to_onehot_y=True
563630 # - CrossEntropyLoss needs (B, H, W)
@@ -733,16 +800,6 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo
733800
734801 # Intensity augmentations (only for images)
735802 if aug_cfg .intensity .enabled :
736- if aug_cfg .intensity .gaussian_noise_prob > 0 :
737- transforms .append (
738- RandGaussianNoised (
739- keys = ["image" ],
740- prob = aug_cfg .intensity .gaussian_noise_prob ,
741- std = aug_cfg .intensity .gaussian_noise_std ,
742- sample_std = True ,
743- )
744- )
745-
746803 if getattr (aug_cfg .intensity , "banis_style" , False ):
747804 transforms .append (
748805 RandMulAddIntensityd (
@@ -752,7 +809,26 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo
752809 add_range = aug_cfg .intensity .add_range ,
753810 )
754811 )
812+ if aug_cfg .intensity .gaussian_noise_prob > 0 :
813+ transforms .append (
814+ RandGaussianNoised (
815+ keys = ["image" ],
816+ prob = aug_cfg .intensity .gaussian_noise_prob ,
817+ std = aug_cfg .intensity .gaussian_noise_std ,
818+ sample_std = True ,
819+ )
820+ )
755821 else :
822+ if aug_cfg .intensity .gaussian_noise_prob > 0 :
823+ transforms .append (
824+ RandGaussianNoised (
825+ keys = ["image" ],
826+ prob = aug_cfg .intensity .gaussian_noise_prob ,
827+ std = aug_cfg .intensity .gaussian_noise_std ,
828+ sample_std = True ,
829+ )
830+ )
831+
756832 if aug_cfg .intensity .shift_intensity_prob > 0 :
757833 transforms .append (
758834 RandShiftIntensityd (
0 commit comments