Skip to content

Commit 70df515

Browse files
authored
Merge pull request #182 from QiongWang1/master
Refactor decoding with minimal changes and reuse of existing pipeline
2 parents 915af4c + a2f805e commit 70df515

29 files changed

Lines changed: 3061 additions & 105 deletions

connectomics/config/hydra_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ class DataConfig:
440440

441441
# Sampling (for volumetric datasets)
442442
iter_num_per_epoch: Optional[int] = None # Alias for iter_num (if set, overrides iter_num)
443+
val_iter_num: Optional[int] = None # Validation iterations per epoch (auto-calculated if None)
443444
use_preloaded_cache: bool = (
444445
True # Preload volumes into memory for fast random cropping (default: True)
445446
)
@@ -480,6 +481,10 @@ class SchedulerConfig:
480481
warmup_start_lr: float = 0.0001
481482
min_lr: float = 0.00001
482483

484+
# Scheduler interval control
485+
interval: str = "epoch" # "epoch" or "step" - controls when scheduler steps
486+
frequency: int = 1 # How often to step the scheduler
487+
483488
# CosineAnnealing-specific
484489
t_max: Optional[int] = None
485490

@@ -941,6 +946,9 @@ class SavePredictionConfig:
941946
enabled: Enable saving intermediate predictions (default: True)
942947
intensity_scale: Scale factor for predictions (e.g., 255 for uint8 visualization)
943948
intensity_dtype: Data type for saved predictions (e.g., 'uint8', 'float32')
949+
output_formats: List of output formats to save predictions in (e.g., ['h5', 'tiff', 'nii.gz'])
950+
Supported formats: 'h5', 'tiff', 'nii', 'nii.gz', 'png'
951+
Default: ['h5', 'nii.gz'] for backward compatibility
944952
"""
945953

946954
enabled: bool = True # Enable saving intermediate predictions
@@ -951,6 +959,9 @@ class SavePredictionConfig:
951959
intensity_dtype: str = (
952960
"uint8" # Save as uint8 for visualization (ignored if intensity_scale < 0)
953961
)
962+
output_formats: List[str] = field(
963+
default_factory=lambda: ["h5", "nii.gz"] # Default: HDF5 + NIfTI for backward compatibility
964+
)
954965

955966

956967
@dataclass

connectomics/config/hydra_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,14 @@ def validate_config(cfg: Config) -> None:
183183
raise ValueError("optimization.optimizer.weight_decay must be non-negative")
184184

185185
# Training validation
186-
if cfg.optimization.max_epochs <= 0:
187-
raise ValueError("optimization.max_epochs must be positive")
186+
# [FIX 2] Allow max_epochs to be 0 or negative when using step-based training
187+
max_steps_cfg = getattr(cfg.optimization, "max_steps", None)
188+
if max_steps_cfg is None or max_steps_cfg <= 0:
189+
# Epoch-based training: max_epochs must be positive
190+
if cfg.optimization.max_epochs <= 0:
191+
raise ValueError("optimization.max_epochs must be positive when max_steps is not set")
192+
# If max_steps is set, max_epochs can be anything (will be overridden to -1 in trainer)
193+
188194
if cfg.optimization.gradient_clip_val < 0:
189195
raise ValueError("optimization.gradient_clip_val must be non-negative")
190196
if cfg.optimization.accumulate_grad_batches <= 0:

connectomics/data/augment/build.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def build_train_transforms(
196196
return Compose(transforms)
197197

198198

199-
def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str] = None) -> Compose:
199+
def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str] = None, skip_loading: bool = False) -> Compose:
200200
"""
201201
Internal implementation for building evaluation transforms (validation or test).
202202
@@ -207,6 +207,7 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
207207
cfg: Hydra Config object
208208
mode: 'val' or 'test' mode
209209
keys: Keys to transform (default: auto-detected based on mode)
210+
skip_loading: Skip LoadVolumed (for pre-cached datasets)
210211
211212
Returns:
212213
Composed MONAI transforms (no augmentation)
@@ -259,32 +260,34 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
259260
transforms = []
260261

261262
# Load images first - use appropriate loader based on dataset type
262-
dataset_type = getattr(cfg.data, "dataset_type", "volume")
263-
264-
if dataset_type == "filename":
265-
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
266-
transforms.append(LoadImaged(keys=keys, image_only=False))
267-
# Ensure channel-first format [C, H, W] or [C, D, H, W]
268-
transforms.append(EnsureChannelFirstd(keys=keys))
269-
else:
270-
# For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
271-
# Get transpose axes based on mode
272-
if mode == "val":
273-
transpose_axes = cfg.data.val_transpose if cfg.data.val_transpose else []
274-
else: # mode == "test"
275-
# Use test.data.test_transpose
276-
transpose_axes = []
277-
if (
278-
hasattr(cfg, "test")
279-
and hasattr(cfg.test, "data")
280-
and hasattr(cfg.test.data, "test_transpose")
281-
and cfg.test.data.test_transpose
282-
):
283-
transpose_axes = cfg.test.data.test_transpose
263+
# Skip loading if using pre-cached datasets
264+
if not skip_loading:
265+
dataset_type = getattr(cfg.data, "dataset_type", "volume")
284266

285-
transforms.append(
286-
LoadVolumed(keys=keys, transpose_axes=transpose_axes if transpose_axes else None)
287-
)
267+
if dataset_type == "filename":
268+
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
269+
transforms.append(LoadImaged(keys=keys, image_only=False))
270+
# Ensure channel-first format [C, H, W] or [C, D, H, W]
271+
transforms.append(EnsureChannelFirstd(keys=keys))
272+
else:
273+
# For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
274+
# Get transpose axes based on mode
275+
if mode == "val":
276+
transpose_axes = cfg.data.val_transpose if cfg.data.val_transpose else []
277+
else: # mode == "test"
278+
# Use test.data.test_transpose
279+
transpose_axes = []
280+
if (
281+
hasattr(cfg, "test")
282+
and hasattr(cfg.test, "data")
283+
and hasattr(cfg.test.data, "test_transpose")
284+
and cfg.test.data.test_transpose
285+
):
286+
transpose_axes = cfg.test.data.test_transpose
287+
288+
transforms.append(
289+
LoadVolumed(keys=keys, transpose_axes=transpose_axes if transpose_axes else None)
290+
)
288291

289292
# Apply volumetric split if enabled
290293
if cfg.data.split_enabled:
@@ -441,18 +444,19 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
441444
return Compose(transforms)
442445

443446

444-
def build_val_transforms(cfg: Config, keys: list[str] = None) -> Compose:
447+
def build_val_transforms(cfg: Config, keys: list[str] = None, skip_loading: bool = False) -> Compose:
445448
"""
446449
Build validation transforms from Hydra config.
447450
448451
Args:
449452
cfg: Hydra Config object
450453
keys: Keys to transform (default: auto-detected as ['image', 'label'])
454+
skip_loading: Skip LoadVolumed (for pre-cached datasets)
451455
452456
Returns:
453457
Composed MONAI transforms (no augmentation, center cropping)
454458
"""
455-
return _build_eval_transforms_impl(cfg, mode="val", keys=keys)
459+
return _build_eval_transforms_impl(cfg, mode="val", keys=keys, skip_loading=skip_loading)
456460

457461

458462
def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose:

connectomics/data/augment/monai_transforms.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class RandMisAlignmentd(RandomizableTransform, MapTransform):
2424
Simulates section misalignment artifacts common in EM volumes.
2525
"""
2626

27+
2728
def __init__(
2829
self,
2930
keys: KeysCollection,
@@ -1141,10 +1142,24 @@ def _normalize(
11411142
self, volume: Union[np.ndarray, torch.Tensor]
11421143
) -> Union[np.ndarray, torch.Tensor]:
11431144
"""Apply normalization to volume."""
1145+
from ...utils.debug_utils import print_tensor_stats
1146+
11441147
is_numpy = isinstance(volume, np.ndarray)
11451148
if not is_numpy:
11461149
volume = volume.numpy()
11471150

1151+
# DEBUG: Print raw input before normalization
1152+
print_tensor_stats(
1153+
volume,
1154+
stage_name="STAGE 1: RAW IMAGE (before normalization)",
1155+
tensor_name="image",
1156+
print_once=True,
1157+
extra_info={
1158+
"normalization_mode": self.mode,
1159+
"clip_percentiles": f"[{self.clip_percentile_low}, {self.clip_percentile_high}]"
1160+
}
1161+
)
1162+
11481163
# Step 1: Percentile clipping (if enabled by non-default values)
11491164
if self.clip_percentile_low > 0.0 or self.clip_percentile_high < 1.0:
11501165
low_val = np.percentile(volume, self.clip_percentile_low * 100)
@@ -1171,6 +1186,18 @@ def _normalize(
11711186
# Simple divide by K (e.g., divide-255 for uint8 images)
11721187
volume = volume / self.divide_value
11731188

1189+
# DEBUG: Print after normalization
1190+
print_tensor_stats(
1191+
volume,
1192+
stage_name="STAGE 2: AFTER IMAGE NORMALIZATION",
1193+
tensor_name="image",
1194+
print_once=True,
1195+
extra_info={
1196+
"normalization_applied": self.mode,
1197+
"expected_range": "[0, 1]" if self.mode == "0-1" else "varies"
1198+
}
1199+
)
1200+
11741201
return volume if is_numpy else torch.from_numpy(volume)
11751202

11761203

connectomics/data/dataset/dataset_base.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,24 @@ def __init__(
212212
self.dataset_length = self.iter_num
213213
else:
214214
self.dataset_length = len(data_dicts)
215+
216+
# [FIX] Add validation reseeding support
217+
self.base_seed = 0
218+
self.current_epoch = 0
215219

220+
def __getitem__(self, index: int) -> Dict[str, Any]:
221+
"""
222+
Get a data sample with caching.
223+
224+
When iter_num > len(data), we need to map the requested index
225+
to an actual data index by using modulo operation.
226+
"""
227+
# Map the requested index to actual data index
228+
actual_index = index % len(self.data)
229+
230+
# Call parent's __getitem__ with the mapped index
231+
return super().__getitem__(actual_index)
232+
216233
def __len__(self) -> int:
217234
"""
218235
Return dataset length.
@@ -226,13 +243,77 @@ def __len__(self) -> int:
226243
# Partial caching: return cached length for validation
227244
# For training with iter_num, we still want to iterate iter_num times
228245
if self.mode == 'train' and self.iter_num > 0:
229-
return self.dataset_length
246+
result = self.dataset_length
230247
else:
231248
# For validation/test, only iterate over cached items
232-
return len(self._cache)
249+
result = len(self._cache)
250+
else:
251+
# Full caching or no caching: use dataset_length
252+
result = self.dataset_length
233253

234-
# Full caching or no caching: use dataset_length
235-
return self.dataset_length
254+
return result
255+
256+
def set_epoch(self, epoch: int, base_seed: int = 0):
257+
"""
258+
Set current epoch for epoch-based validation reseeding.
259+
260+
This method enables validation to sample different patches each epoch
261+
while maintaining determinism. For training, this has no effect since
262+
training already uses random sampling.
263+
264+
Args:
265+
epoch: Current training epoch
266+
base_seed: Base random seed (typically from cfg.system.seed)
267+
268+
Usage:
269+
Called by ValidationReseedingCallback at the start of each validation epoch.
270+
"""
271+
if self.mode == "val":
272+
import random
273+
self.base_seed = base_seed
274+
self.current_epoch = epoch
275+
effective_seed = self.base_seed + epoch
276+
random.seed(effective_seed)
277+
278+
# IMPORTANT: Print to verify reseeding is happening
279+
print(f"[Validation] Set epoch={epoch}, base_seed={base_seed}, effective_seed={effective_seed}")
280+
print(f"[Validation] Dataset: {type(self).__name__}@{id(self)}, mode={self.mode}, iter_num={self.iter_num}")
281+
282+
def get_sampling_fingerprint(self, num_samples: int = 5) -> str:
283+
"""
284+
Generate a deterministic fingerprint of validation sampling.
285+
286+
This allows verification that validation patches change across epochs.
287+
For MonaiCachedConnectomicsDataset, we sample indices that would be used.
288+
289+
Args:
290+
num_samples: Number of random samples to include in fingerprint
291+
292+
Returns:
293+
String representing the sampling fingerprint
294+
"""
295+
if self.mode != "val":
296+
return "N/A (training mode)"
297+
298+
import random
299+
# Save current RNG state
300+
state = random.getstate()
301+
302+
try:
303+
# Generate deterministic samples
304+
samples = []
305+
for _ in range(num_samples):
306+
# Sample index (same logic as __getitem__)
307+
idx = random.randint(0, len(self.data) - 1)
308+
samples.append(idx)
309+
310+
# Create fingerprint string
311+
fingerprint = ", ".join([f"idx{i}" for i in samples])
312+
return fingerprint
313+
314+
finally:
315+
# Restore RNG state
316+
random.setstate(state)
236317

237318

238319
class MonaiPersistentConnectomicsDataset(PersistentDataset):

0 commit comments

Comments
 (0)