@@ -162,6 +162,8 @@ def __init__(
162162 mode : str = "train" ,
163163 pad_size : Optional [Tuple [int , ...]] = None ,
164164 pad_mode : str = "reflect" ,
165+ max_attempts : int = 10 ,
166+ foreground_threshold : float = 0.05 ,
165167 ):
166168 self .image_paths = image_paths
167169 self .label_paths = label_paths if label_paths else [None ] * len (image_paths )
@@ -194,6 +196,8 @@ def __init__(
194196 self ._d2_rejected_patches = 0
195197 self ._d2_foreground_fracs = []
196198 self ._d2_last_report_step = 0
199+ self .max_attempts = max_attempts
200+ self .foreground_threshold = foreground_threshold
197201
198202 # Load all volumes into memory
199203 print (f" Loading { len (image_paths )} volumes into memory..." )
@@ -271,13 +275,16 @@ def __init__(
271275 # Support both 2D and 3D: get last N dimensions matching patch_size
272276 ndim = len (self .patch_size )
273277 self .volume_sizes = [img .shape [- ndim :] for img in self .cached_images ] # (Z, Y, X) or (Y, X)
274-
278+
275279 # [D2 DIAGNOSTIC] Print foreground sampling configuration
276280 if self .mode == "train" :
277- print (f" [D2] Foreground sampling ENABLED:" )
278- print (f" - Minimum foreground threshold: 5.0%" )
279- print (f" - Max retry attempts: 10" )
280- print (f" - Will report statistics every 100 batches" )
281+ if self .foreground_threshold > 0 :
282+ print (" [D2] Foreground sampling ENABLED:" )
283+ print (f" - Minimum foreground threshold: { self .foreground_threshold * 100 :.1f} %" )
284+ print (f" - Max retry attempts: { self .max_attempts } " )
285+ print (" - Will report statistics every 100 batches" )
286+ else :
287+ print (" [D2] Foreground sampling DISABLED (threshold <= 0)" )
281288
282289 def _apply_padding (
283290 self , volume : np .ndarray , mode : Optional [str ] = None , constant_values : float = 0
@@ -366,15 +373,15 @@ def __len__(self) -> int:
366373 def set_epoch (self , epoch : int , base_seed : int = 0 ):
367374 """
368375 Set current epoch for epoch-based validation reseeding.
369-
376+
370377 This method enables validation to sample different patches each epoch
371378 while maintaining determinism. For training, this has no effect since
372379 training already uses random sampling.
373-
380+
374381 Args:
375382 epoch: Current training epoch
376383 base_seed: Base random seed (typically from cfg.system.seed)
377-
384+
378385 Usage:
379386 Called by ValidationReseedingCallback at the start of each validation epoch:
380387 if hasattr(dataset, 'set_epoch'):
@@ -387,32 +394,36 @@ def set_epoch(self, epoch: int, base_seed: int = 0):
387394 self .current_epoch = epoch
388395 effective_seed = self .base_seed + epoch
389396 random .seed (effective_seed )
390-
397+
391398 # IMPORTANT: Print to verify reseeding is happening
392399 # This should appear in logs at the start of EACH validation epoch
393- print (f"[Validation] Set epoch={ epoch } , base_seed={ base_seed } , effective_seed={ effective_seed } " )
394- print (f"[Validation] Dataset: { type (self ).__name__ } @{ id (self )} , mode={ self .mode } , iter_num={ self .iter_num } " )
395-
400+ print (
401+ f"[Validation] Set epoch={ epoch } , base_seed={ base_seed } , effective_seed={ effective_seed } "
402+ )
403+ print (
404+ f"[Validation] Dataset: { type (self ).__name__ } @{ id (self )} , mode={ self .mode } , iter_num={ self .iter_num } "
405+ )
406+
396407 def get_sampling_fingerprint (self , num_samples : int = 5 ) -> str :
397408 """
398409 Generate a deterministic fingerprint of validation sampling.
399-
410+
400411 This allows verification that validation patches change across epochs.
401412 The fingerprint is based on the first N random samples that would be
402413 generated with the current RNG state.
403-
414+
404415 Args:
405416 num_samples: Number of random samples to include in fingerprint
406-
417+
407418 Returns:
408419 String representing the sampling fingerprint
409420 """
410421 if self .mode != "val" :
411422 return "N/A (training mode)"
412-
423+
413424 # Save current RNG state
414425 state = random .getstate ()
415-
426+
416427 try :
417428 # Generate deterministic samples
418429 samples = []
@@ -422,11 +433,11 @@ def get_sampling_fingerprint(self, num_samples: int = 5) -> str:
422433 # Sample patch position
423434 pos = self ._get_random_crop_position (vol_idx )
424435 samples .append ((vol_idx , pos ))
425-
436+
426437 # Create fingerprint string
427438 fingerprint = ", " .join ([f"v{ v } @{ p } " for v , p in samples ])
428439 return fingerprint
429-
440+
430441 finally :
431442 # Restore RNG state
432443 random .setstate (state )
@@ -488,25 +499,46 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
488499 label = self .cached_labels [vol_idx ]
489500 mask = self .cached_masks [vol_idx ]
490501
491- # [D2] Foreground-aware patch sampling: ensure patches contain sufficient mitochondria
492- # This prevents SDT collapse by avoiding background-only patches
493- max_attempts = 10
494- foreground_threshold = 0.05 # Require at least 5% foreground (SDT > 0)
495-
496- # [D2 DIAGNOSTIC] Track sampling attempts
497- attempts_used = 0
502+ # [D2] Foreground-aware patch sampling: optional retry loop for training.
503+ # Disabled by default when foreground_threshold <= 0.
504+ max_attempts = self .max_attempts
505+ foreground_threshold = self .foreground_threshold
506+ use_foreground_sampling = (
507+ self .mode == "train" and label is not None and foreground_threshold > 0
508+ )
509+
510+ # [D2 DIAGNOSTIC] Track sampling attempts only when foreground sampling is active.
511+ attempts_used = 1
498512 final_foreground_frac = 0.0
499-
500- for attempt in range (max_attempts ):
501- attempts_used = attempt + 1
502-
503- # Get crop position
513+
514+ if use_foreground_sampling :
515+ for attempt in range (max_attempts ):
516+ attempts_used = attempt + 1
517+ pos = self ._get_random_crop_position (vol_idx )
518+
519+ # Crop using fast numpy slicing (like v1)
520+ image_crop = crop_volume (image , self .patch_size , pos )
521+ label_crop = crop_volume (label , self .patch_size , pos )
522+ if mask is not None :
523+ mask_crop = crop_volume (mask , self .patch_size , pos )
524+ else :
525+ mask_crop = np .zeros_like (image_crop )
526+
527+ foreground_frac = (label_crop > 0 ).sum () / label_crop .size
528+ final_foreground_frac = foreground_frac
529+
530+ if foreground_frac >= foreground_threshold :
531+ break
532+
533+ # [D2 DIAGNOSTIC] Patch rejected, increment counter
534+ self ._d2_rejected_patches += 1
535+ else :
536+ # Standard single-crop behavior (no foreground-based retry)
504537 if self .mode == "train" :
505538 pos = self ._get_random_crop_position (vol_idx )
506539 else :
507540 pos = self ._get_center_crop_position (vol_idx )
508541
509- # Crop using fast numpy slicing (like v1)
510542 image_crop = crop_volume (image , self .patch_size , pos )
511543 if label is not None :
512544 label_crop = crop_volume (label , self .patch_size , pos )
@@ -518,39 +550,30 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
518550 else :
519551 mask_crop = np .zeros_like (image_crop )
520552
521- # [D2] Check if patch has sufficient foreground (only during training)
522- if self .mode == "train" and label is not None :
523- foreground_frac = (label_crop > 0 ).sum () / label_crop .size
524- final_foreground_frac = foreground_frac
525-
526- if foreground_frac >= foreground_threshold :
527- # [D2 DIAGNOSTIC] Good patch found
528- break
529- else :
530- # [D2 DIAGNOSTIC] Patch rejected, increment counter
531- self ._d2_rejected_patches += 1
532- else :
533- # Val/test mode or no label: accept any patch
534- break
535-
536- # [D2 DIAGNOSTIC] Record statistics
537- self ._d2_total_samples += 1
538- self ._d2_total_attempts += attempts_used
539- self ._d2_foreground_fracs .append (final_foreground_frac * 100 ) # Convert to percentage
540-
553+ # [D2 DIAGNOSTIC] Record/report sampling stats only when enabled.
554+ if use_foreground_sampling :
555+ self ._d2_total_samples += 1
556+ self ._d2_total_attempts += attempts_used
557+ self ._d2_foreground_fracs .append (final_foreground_frac * 100 ) # percentage
558+
541559 # [D2 DIAGNOSTIC] Print report every 100 samples (not too verbose)
542- if self . mode == "train" and self ._d2_total_samples % 100 == 0 :
560+ if use_foreground_sampling and self ._d2_total_samples % 100 == 0 :
543561 avg_attempts = self ._d2_total_attempts / self ._d2_total_samples
544562 reject_rate = (self ._d2_rejected_patches / self ._d2_total_attempts ) * 100
545563 avg_fg = sum (self ._d2_foreground_fracs ) / len (self ._d2_foreground_fracs )
546564 min_fg = min (self ._d2_foreground_fracs )
547565 max_fg = max (self ._d2_foreground_fracs )
548-
566+
549567 print (f"[D2 Sampling Stats after { self ._d2_total_samples } batches]" )
550568 print (f" Avg attempts per patch: { avg_attempts :.2f} /{ max_attempts } " )
551- print (f" Patches rejected: { self ._d2_rejected_patches } /{ self ._d2_total_attempts } ({ reject_rate :.1f} %)" )
569+ print (
570+ f" Patches rejected: { self ._d2_rejected_patches } /{ self ._d2_total_attempts } ({ reject_rate :.1f} %)"
571+ )
552572 print (f" Final foreground %: avg={ avg_fg :.1f} %, min={ min_fg :.1f} %, max={ max_fg :.1f} %" )
553- print (f" Threshold: { foreground_threshold * 100 :.1f} % (5% minimum)" )
573+ print (
574+ f" Threshold: { foreground_threshold * 100 :.1f} % "
575+ f"({ self .foreground_threshold * 100 :.1f} % minimum)"
576+ )
554577
555578 # Create data dict
556579 data = {
0 commit comments