|
| 1 | +_base_: |
| 2 | +- ../connectomics/config/all_profiles.yaml |
| 3 | + |
| 4 | +experiment_name: mito_betaseg_banis_plus |
| 5 | +description: >- |
| 6 | + BetaSeg 3D mitochondria (16nm, 6-channel affinity r1+r10 + 1-channel SDT). |
| 7 | + banis_v0 (faithful BANIS-mito repro) upgraded with the strong "banis+" recipe: |
| 8 | + MedNeXt-L (vs S), EMA validation (0.999), label erosion=2, and per-channel |
| 9 | + class-balanced BCE on affinities (vs plain BCE). SDT head + binary/distance |
| 10 | + watershed decode unchanged from v0. |
| 11 | +
|
| 12 | +save_path: outputs/mito_betaseg_banis_plus |
| 13 | + |
| 14 | +default: |
| 15 | + system: |
| 16 | + num_gpus: -1 |
| 17 | + num_workers: -1 |
| 18 | + model: |
| 19 | + arch: |
| 20 | + type: mednext |
| 21 | + input_size: [128, 128, 128] |
| 22 | + output_size: [128, 128, 128] |
| 23 | + in_channels: 1 |
| 24 | + out_channels: 7 |
| 25 | + primary_head: aff |
| 26 | + heads: |
| 27 | + aff: |
| 28 | + out_channels: 6 |
| 29 | + num_blocks: 0 |
| 30 | + target_slice: "0:6" |
| 31 | + sdt: |
| 32 | + out_channels: 1 |
| 33 | + num_blocks: 0 |
| 34 | + target_slice: "6:7" |
| 35 | + loss: |
| 36 | + deep_supervision: false |
| 37 | + losses: |
| 38 | + # banis+ ("_opt"/v1) loss: per-channel class-balanced BCE over valid |
| 39 | + # affinity voxels so short- and long-range edges learn at comparable |
| 40 | + # speed (v0 used plain BCE, which lets the dominant class dominate). |
| 41 | + - function: PerChannelBCEWithLogitsLoss |
| 42 | + weight: 1.0 |
| 43 | + pred_head: aff |
| 44 | + target_slice: "0:6" |
| 45 | + kwargs: {auto_pos_weight: true, max_pos_weight: 10.0} |
| 46 | + # BANIS: mse_loss(tanh(pred), sdt_target), sdt_loss_weight = 1.0. |
| 47 | + - function: WeightedMSELoss |
| 48 | + weight: 1.0 |
| 49 | + pred_head: sdt |
| 50 | + target_slice: "6:7" |
| 51 | + kwargs: {tanh: true} |
| 52 | + # banis+ uses MedNeXt-L/k3 for higher-capacity affinity prediction. |
| 53 | + mednext: |
| 54 | + size: L |
| 55 | + kernel_size: 3 |
| 56 | + dim: 3d |
| 57 | + checkpoint_style: outside_block |
| 58 | + data: |
| 59 | + dataloader: |
| 60 | + patch_size: [128, 128, 128] |
| 61 | + target_context: [10, 10, 10] |
| 62 | + # MedNeXt-L is ~11x the params of S; keep per-GPU batch small and recover |
| 63 | + # effective batch via accumulation (see train.optimization). |
| 64 | + batch_size: 1 |
| 65 | + val_random_sampling: true |
| 66 | + cached_sampling_foreground_threshold: 0.05 |
| 67 | + cached_sampling_max_attempts: 50 |
| 68 | + cached_sampling_sample_nonzero_mask: true |
| 69 | + image_transform: |
| 70 | + normalize: "divide-255" |
| 71 | + label_transform: |
| 72 | + # banis+ adds label erosion=2 (v0 used 0). Source-index affinity storage. |
| 73 | + erosion: 2 |
| 74 | + relabel_connected_components: true |
| 75 | + relabel_connectivity: 6 |
| 76 | + resolution: [16, 16, 16] |
| 77 | + cache_dir: datasets/betaSeg |
| 78 | + targets: |
| 79 | + - name: affinity |
| 80 | + kwargs: |
| 81 | + offsets: ["1-0-0", "0-1-0", "0-0-1", "10-0-0", "0-10-0", "0-0-10"] |
| 82 | + affinity_mode: banis |
| 83 | + - name: skeleton_aware_edt |
| 84 | + kwargs: |
| 85 | + alpha: 0.8 |
| 86 | + bg_value: -1.0 |
| 87 | + augmentation: |
| 88 | + profile: aug_banis |
| 89 | + inference: |
| 90 | + execution: |
| 91 | + strategy: whole_volume |
| 92 | + model: |
| 93 | + head: "aff,sdt" |
| 94 | + channel_activations: |
| 95 | + - {channels: "0:6", activation: scale_sigmoid} |
| 96 | + - {channels: "6:7", activation: tanh} |
| 97 | + window: |
| 98 | + window_size: [128, 128, 128] |
| 99 | + sw_batch_size: 1 |
| 100 | + overlap: 0.5 |
| 101 | + blending: bump |
| 102 | + padding_mode: replicate |
| 103 | + keep_input_on_cpu: false |
| 104 | + test_time_augmentation: |
| 105 | + enabled: false |
| 106 | + save_results: true |
| 107 | + save_dtype: float16 |
| 108 | + save_backend: h5 |
| 109 | + |
| 110 | + decoding: |
| 111 | + steps: |
| 112 | + - name: decode_binary_contour_distance_watershed |
| 113 | + kwargs: |
| 114 | + binary_channels: [0, 1, 2] |
| 115 | + binary_channel_reduction: min |
| 116 | + contour_channels: |
| 117 | + distance_channels: [6] |
| 118 | + binary_threshold: |
| 119 | + - 0.9 |
| 120 | + - 0.85 |
| 121 | + contour_threshold: |
| 122 | + distance_threshold: |
| 123 | + - 0.5 |
| 124 | + - -0.5 |
| 125 | + min_instance_size: 100 |
| 126 | + min_seed_size: 20 |
| 127 | + prediction_scale: 1 |
| 128 | + |
| 129 | + evaluation: |
| 130 | + enabled: true |
| 131 | + metrics: [adapted_rand] |
| 132 | + |
| 133 | +train: |
| 134 | + system: |
| 135 | + seed: 0 |
| 136 | + optimization: |
| 137 | + optimizer: |
| 138 | + name: AdamW |
| 139 | + lr: 0.001 |
| 140 | + weight_decay: 0.01 |
| 141 | + betas: [0.9, 0.999] |
| 142 | + eps: 1.0e-8 |
| 143 | + scheduler: |
| 144 | + name: CosineAnnealingLR |
| 145 | + interval: step |
| 146 | + frequency: 1 |
| 147 | + min_lr: 0.0 |
| 148 | + params: |
| 149 | + t_max: 50000 |
| 150 | + max_epochs: 50 |
| 151 | + max_steps: 50000 |
| 152 | + n_steps_per_epoch: 5000 |
| 153 | + val_check_interval: 5000 |
| 154 | + val_check_interval_unit: step |
| 155 | + val_steps_per_epoch: 100 |
| 156 | + gradient_clip_val: 1.0 |
| 157 | + # Effective batch ~2 (batch_size 1 x accumulate 2) to fit MedNeXt-L. |
| 158 | + accumulate_grad_batches: 2 |
| 159 | + precision: "16-mixed" |
| 160 | + log_every_n_steps: 100 |
| 161 | + num_sanity_val_steps: 0 |
| 162 | + # banis+ ("_opt") EMA validation. |
| 163 | + ema: |
| 164 | + enabled: true |
| 165 | + decay: 0.999 |
| 166 | + warmup_steps: 500 |
| 167 | + validate_with_ema: true |
| 168 | + data: |
| 169 | + root_path: /projects/weilab/liupeng/data/raw/mito/betaSeg |
| 170 | + train: |
| 171 | + image: [high_c3_im.tiff, low_c1_im.tiff, low_c2_im.tiff] |
| 172 | + label: [high_c3_mito.tiff, low_c1_mito.tiff, low_c2_mito.tiff] |
| 173 | + label_aux_type: skeleton |
| 174 | + val: |
| 175 | + image: [high_c1_im.tiff] |
| 176 | + label: [high_c1_mito.tiff] |
| 177 | + label_aux_type: skeleton |
| 178 | + monitor: |
| 179 | + logging: |
| 180 | + scalar: |
| 181 | + loss: [train_loss_total_epoch, val_loss_total] |
| 182 | + loss_every_n_steps: 100 |
| 183 | + images: |
| 184 | + max_images: 8 |
| 185 | + num_slices: 2 |
| 186 | + log_every_n_epochs: 1 |
| 187 | + channel_mode: all |
| 188 | + checkpoint: |
| 189 | + monitor: val_loss_total |
| 190 | + mode: min |
| 191 | + save_top_k: 3 |
| 192 | + save_every_n_steps: 10000 |
| 193 | + save_on_train_epoch_end: false |
| 194 | + save_last: true |
| 195 | + checkpoint_filename: "{epoch:03d}-{val_loss_total:.4f}" |
| 196 | + |
| 197 | +test: |
| 198 | + data: |
| 199 | + root_path: /projects/weilab/liupeng/data/raw/mito/betaSeg |
| 200 | + test: |
| 201 | + path: "" |
| 202 | + # BANIS-mito test split: high_c2, high_c4, low_c3. |
| 203 | + image: [high_c2_im.tiff, high_c4_im.tiff, low_c3_im.tiff] |
| 204 | + label: [high_c2_mito.tiff, high_c4_mito.tiff, low_c3_mito.tiff] |
| 205 | + label_aux_type: skeleton |
| 206 | + resolution: [16, 16, 16] |
| 207 | + data_transform: |
| 208 | + pad_size: [64, 64, 64] |
0 commit comments