Skip to content

Commit deaf74f

Browse files
author
Donglai Wei
committed
Add faithful BANIS base reproduction
1 parent 6b5d55c commit deaf74f

10 files changed

Lines changed: 537 additions & 43 deletions

File tree

connectomics/config/pipeline/config_io.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,15 @@ def validate_config(cfg: Config) -> None:
370370
"data.dataloader.patch_size must be 2D or 3D "
371371
f"(got length {len(cfg.data.dataloader.patch_size)})"
372372
)
373+
target_context = getattr(cfg.data.dataloader, "target_context", None) or []
374+
if target_context:
375+
if len(target_context) != len(cfg.data.dataloader.patch_size):
376+
raise ValueError(
377+
"data.dataloader.target_context must match patch_size dimensionality "
378+
f"({len(target_context)} vs {len(cfg.data.dataloader.patch_size)})"
379+
)
380+
if any(int(v) < 0 for v in target_context):
381+
raise ValueError("data.dataloader.target_context values must be non-negative")
373382
if cfg.data.dataloader.batch_size <= 0:
374383
raise ValueError("data.dataloader.batch_size must be positive")
375384

@@ -417,6 +426,9 @@ def validate_config(cfg: Config) -> None:
417426

418427
if cfg.optimization.gradient_clip_val < 0:
419428
raise ValueError("optimization.gradient_clip_val must be non-negative")
429+
val_check_unit = str(getattr(cfg.optimization, "val_check_interval_unit", "epoch")).lower()
430+
if val_check_unit not in {"epoch", "step"}:
431+
raise ValueError("optimization.val_check_interval_unit must be 'epoch' or 'step'")
420432
if cfg.optimization.accumulate_grad_batches <= 0:
421433
raise ValueError("optimization.accumulate_grad_batches must be positive")
422434
if hasattr(cfg.optimization, "ema") and getattr(cfg.optimization.ema, "enabled", False):

connectomics/config/schema/data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class LabelTransformConfig:
7676
output_dtype: Optional[str] = "float32"
7777
output_key_format: str = "{key}_{task}"
7878
allow_missing_keys: bool = False
79+
relabel_connected_components: bool = False # Relabel disconnected same-ID crop components.
80+
relabel_connectivity: int = 6 # Connectivity for relabel_connected_components in 3D.
7981
segment_id: Optional[List[int]] = None
8082
boundary_thickness: int = 1
8183
resolution: Optional[List[float]] = None # Forwarded into compatible label targets.
@@ -175,6 +177,9 @@ class DataloaderConfig:
175177

176178
batch_size: int = 4
177179
patch_size: List[int] = field(default_factory=lambda: [128, 128, 128])
180+
target_context: List[int] = field(
181+
default_factory=lambda: [0, 0, 0]
182+
) # Extra positive-side crop context used for target generation, then cropped away.
178183
pin_memory: bool = True
179184
use_preloaded_cache_train: bool = True # Preload training volumes into memory
180185
use_preloaded_cache_val: bool = True # Preload validation volumes into memory
@@ -196,6 +201,7 @@ class DataloaderConfig:
196201
False # Voxel approach: center crops on random nonzero mask voxels (stronger guarantee)
197202
)
198203
reject_sampling: Optional[Dict[str, Any]] = None # Dict with 'size_thres' and 'p' keys
204+
val_random_sampling: bool = False # If true, validation samples random patches, not center crops.
199205

200206

201207
@dataclass

connectomics/config/schema/monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CheckpointConfig:
1818
save_every_n_epochs: int = 1
1919
save_every_n_steps: Optional[int] = None
2020
step_checkpoint_filename: str = "step-{step:08d}"
21+
save_on_train_epoch_end: bool = True
2122
use_timestamp: bool = True
2223

2324

connectomics/config/schema/optimization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class OptimizationConfig:
9494
precision: str = "16-mixed" # "32", "16-mixed", "bf16-mixed"
9595

9696
# Validation scheduling
97-
val_check_interval: Union[int, float] = 1.0 # Validate every N epochs
97+
val_check_interval: Union[int, float] = 1.0 # Validate every N epochs or steps.
98+
val_check_interval_unit: str = "epoch" # "epoch" or "step"
9899

99100
# Logging
100101
log_every_n_steps: int = 100

connectomics/data/augmentation/build.py

Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,48 @@ def _strict_binarize_mask(mask, threshold: float = 0.0):
6464
return (mask > threshold).astype(mask.dtype, copy=False)
6565

6666

67+
def _target_context(cfg: Config) -> tuple[int, ...]:
68+
context = getattr(cfg.data.dataloader, "target_context", None) or []
69+
if not context:
70+
return tuple(0 for _ in cfg.data.dataloader.patch_size)
71+
return tuple(int(v) for v in context)
72+
73+
74+
def _effective_patch_size(cfg: Config) -> tuple[int, ...] | None:
75+
patch_size = tuple(cfg.data.dataloader.patch_size) if cfg.data.dataloader.patch_size else None
76+
if patch_size is None:
77+
return None
78+
context = _target_context(cfg)
79+
if len(context) != len(patch_size):
80+
raise ValueError(
81+
"data.dataloader.target_context must have the same length as patch_size: "
82+
f"{context} vs {patch_size}"
83+
)
84+
return tuple(int(patch_size[i]) + int(context[i]) for i in range(len(patch_size)))
85+
86+
87+
def _append_banis_pre_target_transforms(transforms: list, label_cfg) -> None:
88+
if bool(getattr(label_cfg, "relabel_connected_components", False)):
89+
from ..processing.transforms import RelabelConnectedComponentsd
90+
91+
transforms.append(
92+
RelabelConnectedComponentsd(
93+
keys=["label"],
94+
connectivity=int(getattr(label_cfg, "relabel_connectivity", 6)),
95+
)
96+
)
97+
98+
99+
def _append_target_context_crop(transforms: list, cfg: Config) -> None:
100+
context = _target_context(cfg)
101+
if not context or not any(v > 0 for v in context):
102+
return
103+
104+
from ..processing.transforms import LeadingSpatialCropd
105+
106+
transforms.append(LeadingSpatialCropd(roi_size=tuple(cfg.data.dataloader.patch_size)))
107+
108+
67109
def _build_nnunet_preprocess_transform(keys, nnunet_pre_cfg, source_spacing):
68110
"""Build NNUNetPreprocessd transform from config."""
69111
source_spacing = getattr(nnunet_pre_cfg, "source_spacing", None) or source_spacing
@@ -176,9 +218,7 @@ def build_train_transforms(
176218

177219
# Ensure target patch size is respected (unless using pre-cached dataset)
178220
if not skip_loading:
179-
patch_size = (
180-
tuple(cfg.data.dataloader.patch_size) if cfg.data.dataloader.patch_size else None
181-
)
221+
patch_size = _effective_patch_size(cfg)
182222
if patch_size and all(size > 0 for size in patch_size):
183223
# Pad smaller volumes so random crops always succeed
184224
transforms.append(
@@ -208,6 +248,10 @@ def build_train_transforms(
208248
)
209249
)
210250

251+
label_cfg = getattr(cfg.data, "label_transform", None)
252+
if "label" in keys and label_cfg is not None:
253+
_append_banis_pre_target_transforms(transforms, label_cfg)
254+
211255
# Add augmentations if enabled
212256
if cfg.data.augmentation is not None:
213257
# Pass do_2d flag to augmentation builder
@@ -218,12 +262,10 @@ def build_train_transforms(
218262
transforms.extend(_build_augmentations(cfg.data.augmentation, keys, do_2d=do_2d))
219263

220264
# Label transformations (affinity, distance transform, etc.)
221-
if hasattr(cfg.data, "label_transform"):
265+
if label_cfg is not None:
222266
from ..processing.build import create_label_transform_pipeline
223267
from ..processing.transforms import SegErosionInstanced
224268

225-
label_cfg = cfg.data.label_transform
226-
227269
# Apply instance erosion first if specified
228270
if hasattr(label_cfg, "erosion") and label_cfg.erosion > 0:
229271
transforms.append(SegErosionInstanced(keys=["label"], tsz_h=label_cfg.erosion))
@@ -235,6 +277,8 @@ def build_train_transforms(
235277
else:
236278
transforms.append(label_transform)
237279

280+
_append_target_context_crop(transforms, cfg)
281+
238282
# NOTE: Do NOT squeeze labels here!
239283
# - DiceLoss needs (B, 1, H, W) with to_onehot_y=True
240284
# - CrossEntropyLoss needs (B, H, W)
@@ -472,7 +516,13 @@ def _resolve_eval_split():
472516
)
473517
)
474518

475-
patch_size = tuple(data_cfg.dataloader.patch_size) if data_cfg.dataloader.patch_size else None
519+
patch_size = (
520+
_effective_patch_size(cfg)
521+
if mode == "val"
522+
else tuple(data_cfg.dataloader.patch_size)
523+
if data_cfg.dataloader.patch_size
524+
else None
525+
)
476526
if patch_size and all(size > 0 for size in patch_size):
477527
transforms.append(
478528
SpatialPadd(
@@ -510,12 +560,22 @@ def _resolve_eval_split():
510560
# Test: Skip cropping to enable sliding window inference on full volumes
511561
if mode == "val":
512562
if patch_size and all(size > 0 for size in patch_size):
513-
transforms.append(
514-
CenterSpatialCropd(
515-
keys=keys,
516-
roi_size=patch_size,
563+
if bool(getattr(data_cfg.dataloader, "val_random_sampling", False)):
564+
transforms.append(
565+
RandSpatialCropd(
566+
keys=keys,
567+
roi_size=patch_size,
568+
random_center=True,
569+
random_size=False,
570+
)
571+
)
572+
else:
573+
transforms.append(
574+
CenterSpatialCropd(
575+
keys=keys,
576+
roi_size=patch_size,
577+
)
517578
)
518-
)
519579
# else: mode == "test" -> no cropping for sliding window inference
520580

521581
# Normalization - use smart normalization
@@ -530,6 +590,10 @@ def _resolve_eval_split():
530590
)
531591
)
532592

593+
label_cfg = getattr(data_cfg, "label_transform", None)
594+
if mode == "val" and "label" in keys and label_cfg is not None:
595+
_append_banis_pre_target_transforms(transforms, label_cfg)
596+
533597
# Only process labels if 'label' is in keys
534598
if "label" in keys:
535599
# Label transformations (affinity, distance transform, etc.)
@@ -558,6 +622,9 @@ def _resolve_eval_split():
558622
else:
559623
transforms.append(label_transform)
560624

625+
if mode == "val":
626+
_append_target_context_crop(transforms, cfg)
627+
561628
# NOTE: Do NOT squeeze labels here!
562629
# - DiceLoss needs (B, 1, H, W) with to_onehot_y=True
563630
# - CrossEntropyLoss needs (B, H, W)
@@ -733,16 +800,6 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo
733800

734801
# Intensity augmentations (only for images)
735802
if aug_cfg.intensity.enabled:
736-
if aug_cfg.intensity.gaussian_noise_prob > 0:
737-
transforms.append(
738-
RandGaussianNoised(
739-
keys=["image"],
740-
prob=aug_cfg.intensity.gaussian_noise_prob,
741-
std=aug_cfg.intensity.gaussian_noise_std,
742-
sample_std=True,
743-
)
744-
)
745-
746803
if getattr(aug_cfg.intensity, "banis_style", False):
747804
transforms.append(
748805
RandMulAddIntensityd(
@@ -752,7 +809,26 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo
752809
add_range=aug_cfg.intensity.add_range,
753810
)
754811
)
812+
if aug_cfg.intensity.gaussian_noise_prob > 0:
813+
transforms.append(
814+
RandGaussianNoised(
815+
keys=["image"],
816+
prob=aug_cfg.intensity.gaussian_noise_prob,
817+
std=aug_cfg.intensity.gaussian_noise_std,
818+
sample_std=True,
819+
)
820+
)
755821
else:
822+
if aug_cfg.intensity.gaussian_noise_prob > 0:
823+
transforms.append(
824+
RandGaussianNoised(
825+
keys=["image"],
826+
prob=aug_cfg.intensity.gaussian_noise_prob,
827+
std=aug_cfg.intensity.gaussian_noise_std,
828+
sample_std=True,
829+
)
830+
)
831+
756832
if aug_cfg.intensity.shift_intensity_prob > 0:
757833
transforms.append(
758834
RandShiftIntensityd(

connectomics/data/processing/transforms.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,95 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
470470
return d
471471

472472

473+
class RelabelConnectedComponentsd(MapTransform):
474+
"""Relabel disconnected components inside each crop while preserving unlabeled voxels."""
475+
476+
def __init__(
477+
self,
478+
keys: KeysCollection,
479+
connectivity: int = 6,
480+
allow_missing_keys: bool = False,
481+
) -> None:
482+
super().__init__(keys, allow_missing_keys)
483+
self.connectivity = int(connectivity)
484+
485+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
486+
try:
487+
import cc3d
488+
except ModuleNotFoundError as exc: # pragma: no cover
489+
raise ModuleNotFoundError(
490+
"relabel_connected_components requires cc3d. Install connected-components-3d."
491+
) from exc
492+
493+
d = dict(data)
494+
for key in self.key_iterator(d):
495+
if key not in d:
496+
continue
497+
498+
label = d[key]
499+
label_np = label.detach().cpu().numpy() if isinstance(label, torch.Tensor) else label
500+
label_np = np.asarray(label_np)
501+
restore_channel_dim = label_np.ndim == 4 and label_np.shape[0] == 1
502+
seg = label_np[0] if restore_channel_dim else label_np
503+
if seg.ndim != 3:
504+
raise ValueError(
505+
"RelabelConnectedComponentsd expects a 3D label crop "
506+
f"with optional singleton channel dim, got {tuple(label_np.shape)}"
507+
)
508+
509+
invalid = seg == -1
510+
relabeled = cc3d.connected_components(
511+
seg,
512+
connectivity=self.connectivity,
513+
out_dtype=np.uint32,
514+
).astype(np.int32)
515+
relabeled[invalid] = -1
516+
d[key] = relabeled[None, ...] if restore_channel_dim else relabeled
517+
return d
518+
519+
520+
class LeadingSpatialCropd:
521+
"""Crop arrays/tensors to ``roi_size`` from the low-index spatial corner."""
522+
523+
def __init__(
524+
self,
525+
roi_size: Sequence[int],
526+
keys: Optional[Sequence[str]] = None,
527+
allow_missing_keys: bool = True,
528+
) -> None:
529+
self.roi_size = tuple(int(v) for v in roi_size)
530+
self.keys = tuple(keys) if keys is not None else None
531+
self.allow_missing_keys = allow_missing_keys
532+
533+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
534+
d = dict(data)
535+
keys = self.keys if self.keys is not None else tuple(d.keys())
536+
spatial_ndim = len(self.roi_size)
537+
538+
for key in keys:
539+
if key not in d:
540+
if not self.allow_missing_keys:
541+
raise KeyError(key)
542+
continue
543+
544+
value = d[key]
545+
ndim = getattr(value, "ndim", None)
546+
if ndim is None or int(ndim) < spatial_ndim:
547+
continue
548+
549+
spatial_shape = tuple(int(v) for v in value.shape[-spatial_ndim:])
550+
if all(spatial_shape[i] == self.roi_size[i] for i in range(spatial_ndim)):
551+
continue
552+
553+
slices = [slice(None)] * int(ndim)
554+
first_spatial_axis = int(ndim) - spatial_ndim
555+
for axis, size in enumerate(self.roi_size, start=first_spatial_axis):
556+
slices[axis] = slice(0, size)
557+
d[key] = value[tuple(slices)]
558+
559+
return d
560+
561+
473562
class EnergyQuantized(MapTransform):
474563
"""Quantize continuous energy maps using MONAI MapTransform.
475564
@@ -881,6 +970,8 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
881970
"SegErosiond",
882971
"SegDilationd",
883972
"SegErosionInstanced",
973+
"RelabelConnectedComponentsd",
974+
"LeadingSpatialCropd",
884975
"EnergyQuantized",
885976
"DecodeQuantized",
886977
"SegSelectiond",

0 commit comments

Comments
 (0)