|
16 | 16 |
|
17 | 17 | from omegaconf import DictConfig, ListConfig, OmegaConf |
18 | 18 |
|
19 | | -from ...data.processing.build import count_stacked_label_transform_channels |
20 | | -from ...models.architectures.registry import get_architecture_info |
21 | | -from ...utils.channel_slices import infer_min_required_channels |
22 | | -from ...utils.model_outputs import resolve_configured_output_head, resolve_output_heads |
23 | 19 | from ..schema import Config |
24 | 20 | from ..schema.root import MergeContext |
25 | 21 | from .profile_engine import _YAML_PROFILE_ENGINE |
@@ -486,289 +482,6 @@ def validate_config(cfg: Config) -> None: |
486 | 482 | f"when model.heads has multiple entries ({sorted(model_heads.keys())})" |
487 | 483 | ) |
488 | 484 |
|
489 | | - # --- Cross-section coherence validation (UX 2.4 #3) --- |
490 | | - _validate_cross_section_coherence(cfg) |
491 | | - |
492 | | - |
493 | | -def _architecture_supports_deep_supervision(arch_type: str) -> bool: |
494 | | - """Infer deep-supervision support from architecture registry metadata.""" |
495 | | - arch_info = get_architecture_info().get(arch_type) |
496 | | - if arch_info is None: |
497 | | - return True |
498 | | - |
499 | | - module_name = arch_info.get("module", "") |
500 | | - # Current MONAI wrappers are single-scale and do not expose deep supervision. |
501 | | - return not module_name.endswith("monai_models") |
502 | | - |
503 | | - |
504 | | -def _validate_cross_section_coherence(cfg: Config) -> None: |
505 | | - """Validate resolved cross-section coherence and raise clear errors.""" |
506 | | - # 1) model.input_size vs data.dataloader.patch_size mismatch |
507 | | - model_input = list(cfg.model.input_size) |
508 | | - patch_size = list(cfg.data.dataloader.patch_size) |
509 | | - if model_input != patch_size: |
510 | | - raise ValueError( |
511 | | - "Cross-section validation failed: model.input_size " |
512 | | - f"{model_input} must match data.dataloader.patch_size {patch_size}." |
513 | | - ) |
514 | | - |
515 | | - # 2) output-channel coherence vs loss/label/decoding/activation channel usage |
516 | | - out_channels = cfg.model.out_channels |
517 | | - model_heads = getattr(cfg.model, "heads", None) or {} |
518 | | - primary_head = getattr(cfg.model, "primary_head", None) |
519 | | - sole_head = next(iter(model_heads.keys())) if len(model_heads) == 1 else None |
520 | | - required_output_channels: List[tuple[str, int]] = [] |
521 | | - |
522 | | - label_cfg = getattr(cfg.data, "label_transform", None) |
523 | | - stacked_label_channels = ( |
524 | | - count_stacked_label_transform_channels(label_cfg) if label_cfg is not None else None |
525 | | - ) |
526 | | - |
527 | | - def _resolve_selector_head(entry: Any, *, selector_key: str) -> Optional[str]: |
528 | | - if selector_key == "pred2_slice": |
529 | | - selector_head = entry.get("pred2_head", entry.get("pred_head", None)) |
530 | | - else: |
531 | | - selector_head = entry.get("pred_head", None) |
532 | | - |
533 | | - if selector_head is None: |
534 | | - selector_head = primary_head or sole_head |
535 | | - return selector_head |
536 | | - |
537 | | - def _validate_head_channel_capacity( |
538 | | - *, |
539 | | - selector_key: str, |
540 | | - selector_head: str, |
541 | | - min_channels: int, |
542 | | - loss_idx: int, |
543 | | - ) -> bool: |
544 | | - if selector_head not in model_heads: |
545 | | - return False |
546 | | - |
547 | | - head_channels = int(getattr(model_heads[selector_head], "out_channels", 0)) |
548 | | - if min_channels > head_channels: |
549 | | - raise ValueError( |
550 | | - "Cross-section validation failed: " |
551 | | - f"model.loss.losses[{loss_idx}].{selector_key} requires at least " |
552 | | - f"{min_channels} channels in head '{selector_head}', but " |
553 | | - f"model.heads.{selector_head}.out_channels is {head_channels}." |
554 | | - ) |
555 | | - return True |
556 | | - |
557 | | - def _validate_label_channel_capacity(selector_value: Any, *, path: str) -> None: |
558 | | - min_channels = infer_min_required_channels(selector_value, context=path) |
559 | | - if min_channels is None: |
560 | | - return |
561 | | - |
562 | | - if stacked_label_channels is not None: |
563 | | - if min_channels > stacked_label_channels: |
564 | | - raise ValueError( |
565 | | - "Cross-section validation failed: " |
566 | | - f"{path} requires at least {min_channels} stacked label channels, but " |
567 | | - f"data.label_transform.targets produces {stacked_label_channels}." |
568 | | - ) |
569 | | - return |
570 | | - |
571 | | - if not model_heads: |
572 | | - required_output_channels.append((path, min_channels)) |
573 | | - |
574 | | - # 2a) Loss channel selectors |
575 | | - model_loss_cfg = getattr(cfg.model, "loss", None) |
576 | | - losses_cfg = getattr(model_loss_cfg, "losses", None) if model_loss_cfg else None |
577 | | - if losses_cfg is not None: |
578 | | - for i, entry in enumerate(losses_cfg): |
579 | | - if not isinstance(entry, dict): |
580 | | - continue |
581 | | - selector_keys = ("pred_slice", "target_slice", "mask_slice", "pred2_slice") |
582 | | - for selector_key in selector_keys: |
583 | | - min_channels = infer_min_required_channels( |
584 | | - entry.get(selector_key), |
585 | | - context=f"model.loss.losses[{i}].{selector_key}", |
586 | | - ) |
587 | | - if min_channels is not None: |
588 | | - path = f"model.loss.losses[{i}].{selector_key}" |
589 | | - if selector_key in {"pred_slice", "pred2_slice"} and model_heads: |
590 | | - selector_head = _resolve_selector_head(entry, selector_key=selector_key) |
591 | | - if selector_head is not None and _validate_head_channel_capacity( |
592 | | - selector_key=selector_key, |
593 | | - selector_head=selector_head, |
594 | | - min_channels=min_channels, |
595 | | - loss_idx=i, |
596 | | - ): |
597 | | - continue |
598 | | - if selector_key in {"target_slice", "mask_slice"}: |
599 | | - _validate_label_channel_capacity(entry.get(selector_key), path=path) |
600 | | - continue |
601 | | - required_output_channels.append((path, min_channels)) |
602 | | - |
603 | | - # 2b) Label transform targets (legacy lower-bound expectation for flat outputs) |
604 | | - if not model_heads and stacked_label_channels: |
605 | | - required_output_channels.append(("data.label_transform.targets", stacked_label_channels)) |
606 | | - |
607 | | - # 2c) Explicit head-to-label routing |
608 | | - if model_heads: |
609 | | - for head_name, head_cfg in model_heads.items(): |
610 | | - target_slice = getattr(head_cfg, "target_slice", None) |
611 | | - if target_slice is None: |
612 | | - continue |
613 | | - _validate_label_channel_capacity( |
614 | | - target_slice, |
615 | | - path=f"model.heads.{head_name}.target_slice", |
616 | | - ) |
617 | | - |
618 | | - # 2d) Decoding kwargs channel selectors (*_channels) |
619 | | - decoding_cfg = getattr(cfg, "decoding", None) |
620 | | - decode_has_channel_selection = False |
621 | | - decode_output_head = None |
622 | | - decode_available_channels = out_channels |
623 | | - decode_channel_scope = "model output" |
624 | | - if isinstance(decoding_cfg, list): |
625 | | - for i, decode_step in enumerate(decoding_cfg): |
626 | | - kwargs = getattr(decode_step, "kwargs", None) |
627 | | - if not isinstance(kwargs, dict): |
628 | | - continue |
629 | | - if any(key.endswith("_channels") for key in kwargs): |
630 | | - decode_has_channel_selection = True |
631 | | - break |
632 | | - |
633 | | - if model_heads and decode_has_channel_selection: |
634 | | - decode_heads = resolve_output_heads(cfg, purpose="decode channel selection") |
635 | | - if len(model_heads) > 1 and not decode_heads: |
636 | | - raise ValueError( |
637 | | - "Cross-section validation failed: decode channel selectors require " |
638 | | - "inference.head or model.primary_head when model.heads has multiple " |
639 | | - f"entries ({sorted(model_heads.keys())})." |
640 | | - ) |
641 | | - if len(decode_heads) > 1: |
642 | | - decode_available_channels = sum( |
643 | | - int(getattr(model_heads[h], "out_channels", 0)) for h in decode_heads |
644 | | - ) |
645 | | - decode_channel_scope = f"merged heads {decode_heads}" |
646 | | - decode_output_head = decode_heads[0] |
647 | | - elif decode_heads: |
648 | | - decode_output_head = decode_heads[0] |
649 | | - if decode_output_head in model_heads: |
650 | | - decode_available_channels = int( |
651 | | - getattr(model_heads[decode_output_head], "out_channels", out_channels) |
652 | | - ) |
653 | | - decode_channel_scope = f"head '{decode_output_head}'" |
654 | | - |
655 | | - for i, decode_step in enumerate(decoding_cfg): |
656 | | - kwargs = getattr(decode_step, "kwargs", None) |
657 | | - if not isinstance(kwargs, dict): |
658 | | - continue |
659 | | - for key, value in kwargs.items(): |
660 | | - if not key.endswith("_channels"): |
661 | | - continue |
662 | | - min_channels = infer_min_required_channels( |
663 | | - value, |
664 | | - context=f"decoding[{i}].kwargs.{key}", |
665 | | - ) |
666 | | - if min_channels is not None: |
667 | | - path = f"decoding[{i}].kwargs.{key}" |
668 | | - if model_heads and decode_has_channel_selection: |
669 | | - if min_channels > decode_available_channels: |
670 | | - raise ValueError( |
671 | | - "Cross-section validation failed: " |
672 | | - f"{path} requires at least {min_channels} channels in " |
673 | | - f"{decode_channel_scope}, but only " |
674 | | - f"{decode_available_channels} are available." |
675 | | - ) |
676 | | - continue |
677 | | - required_output_channels.append((path, min_channels)) |
678 | | - |
679 | | - # 2e) Inference channel selectors |
680 | | - tta_cfg = getattr(cfg.inference, "test_time_augmentation", None) |
681 | | - channel_activations = getattr(tta_cfg, "channel_activations", None) if tta_cfg else None |
682 | | - select_channel = getattr(cfg.inference, "select_channel", None) |
683 | | - inference_has_channel_selection = bool(channel_activations) or select_channel is not None |
684 | | - tta_heads = ( |
685 | | - resolve_output_heads(cfg, purpose="inference channel selection") if model_heads else [] |
686 | | - ) |
687 | | - tta_output_head = tta_heads[0] if tta_heads else None |
688 | | - if model_heads and len(model_heads) > 1 and inference_has_channel_selection and not tta_heads: |
689 | | - raise ValueError( |
690 | | - "Cross-section validation failed: inference channel selectors require inference.head " |
691 | | - "or model.primary_head when model.heads has multiple entries " |
692 | | - f"({sorted(model_heads.keys())})." |
693 | | - ) |
694 | | - if len(tta_heads) > 1: |
695 | | - tta_available_channels = sum( |
696 | | - int(getattr(model_heads[h], "out_channels", 0)) for h in tta_heads |
697 | | - ) |
698 | | - tta_channel_scope = f"merged heads {tta_heads}" |
699 | | - else: |
700 | | - tta_available_channels = ( |
701 | | - int(getattr(model_heads[tta_output_head], "out_channels", out_channels)) |
702 | | - if tta_output_head in model_heads |
703 | | - else out_channels |
704 | | - ) |
705 | | - tta_channel_scope = ( |
706 | | - f"head '{tta_output_head}'" if tta_output_head in model_heads else "model output" |
707 | | - ) |
708 | | - |
709 | | - def _validate_tta_channel_capacity(selector_value: Any, *, path: str) -> None: |
710 | | - min_selector_channels = infer_min_required_channels( |
711 | | - selector_value, |
712 | | - context=path, |
713 | | - ) |
714 | | - if min_selector_channels is None: |
715 | | - return |
716 | | - if min_selector_channels > tta_available_channels: |
717 | | - raise ValueError( |
718 | | - "Cross-section validation failed: " |
719 | | - f"{path} requires at least {min_selector_channels} channels in {tta_channel_scope}, " |
720 | | - f"but only {tta_available_channels} are available." |
721 | | - ) |
722 | | - |
723 | | - if isinstance(channel_activations, list): |
724 | | - for i, spec in enumerate(channel_activations): |
725 | | - if not isinstance(spec, dict): |
726 | | - raise ValueError( |
727 | | - "Cross-section validation failed: " |
728 | | - f"inference.test_time_augmentation.channel_activations[{i}] " |
729 | | - "must be a mapping with 'channels' and 'activation'." |
730 | | - ) |
731 | | - if "channels" not in spec or "activation" not in spec: |
732 | | - raise ValueError( |
733 | | - "Cross-section validation failed: " |
734 | | - f"inference.test_time_augmentation.channel_activations[{i}] " |
735 | | - "must define both 'channels' and 'activation'." |
736 | | - ) |
737 | | - _validate_tta_channel_capacity( |
738 | | - spec["channels"], |
739 | | - path=f"inference.test_time_augmentation.channel_activations[{i}].channels", |
740 | | - ) |
741 | | - _validate_tta_channel_capacity( |
742 | | - select_channel, |
743 | | - path="inference.select_channel", |
744 | | - ) |
745 | | - |
746 | | - if required_output_channels: |
747 | | - required_max = max(req for _, req in required_output_channels) |
748 | | - if required_max > out_channels: |
749 | | - details = ", ".join( |
750 | | - f"{path} needs >= {req}" |
751 | | - for path, req in sorted(required_output_channels, key=lambda x: x[1], reverse=True) |
752 | | - ) |
753 | | - raise ValueError( |
754 | | - "Cross-section validation failed: model.out_channels is " |
755 | | - f"{out_channels}, but resolved pipeline components require at least " |
756 | | - f"{required_max} channels ({details})." |
757 | | - ) |
758 | | - |
759 | | - # 3) deep_supervision=True with architectures that don't support it |
760 | | - deep_supervision = ( |
761 | | - getattr(model_loss_cfg, "deep_supervision", False) if model_loss_cfg else False |
762 | | - ) |
763 | | - if deep_supervision: |
764 | | - arch_type = getattr(cfg.model.arch, "type", "") |
765 | | - if not _architecture_supports_deep_supervision(arch_type): |
766 | | - raise ValueError( |
767 | | - "Cross-section validation failed: model.loss.deep_supervision=True but " |
768 | | - f"architecture '{arch_type}' does not support deep supervision. " |
769 | | - "Use MedNeXt/RSUNet or disable deep supervision." |
770 | | - ) |
771 | | - |
772 | 485 |
|
773 | 486 | # --------------------------------------------------------------------------- |
774 | 487 | # Config hashing and naming |
|
0 commit comments