Skip to content

Commit 4f90c8b

Browse files
author
Donglai Wei
committed
fix bug when non-deep-supervision
1 parent ac279bb commit 4f90c8b

14 files changed

Lines changed: 933 additions & 53 deletions

.claude/repos_other/BANIS_SUMMARY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
BANIS is a baseline implementation for the **Neuron Instance Segmentation Benchmark (NISB)**, providing an easily adaptable framework for neuron instance segmentation in electron microscopy (EM) images. The project combines affinity prediction with modern deep learning architectures (MedNeXt) and simple connected components for post-processing.
66

7-
**Repository Location**: `/projects/weilab/weidf/lib/banis`
7+
**Repository Location**: `/projects/weilab/weidf/lib/seg/banis`
88

99
**Key Features**:
1010
- Affinity-based segmentation approach with short and long-range affinities

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,4 @@ lib/
154154

155155
# Development logs and documentation
156156
tmp/
157+
crackit/

connectomics/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
auto_plan_config,
2424
AutoConfigPlanner,
2525
AutoPlanResult,
26+
resolve_runtime_resource_sentinels,
2627
)
2728

2829
# GPU utilities
@@ -53,6 +54,7 @@
5354
"auto_plan_config",
5455
"AutoConfigPlanner",
5556
"AutoPlanResult",
57+
"resolve_runtime_resource_sentinels",
5658
# GPU utilities
5759
"get_gpu_info",
5860
"print_gpu_info",

connectomics/config/auto_config.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from omegaconf import OmegaConf, DictConfig
1818
import warnings
19+
import os
1920

2021
from .gpu_utils import (
2122
get_gpu_info,
@@ -25,6 +26,78 @@
2526
)
2627

2728

29+
def _available_cpus_for_current_run() -> int:
30+
"""
31+
Detect CPU slots available to the current process (SLURM/cgroup aware).
32+
33+
Priority:
34+
1) CPU affinity mask (best under cgroups/SLURM)
35+
2) SLURM_CPUS_PER_TASK
36+
3) os.cpu_count()
37+
"""
38+
try:
39+
affinity = os.sched_getaffinity(0)
40+
if affinity:
41+
return len(affinity)
42+
except Exception:
43+
pass
44+
45+
slurm_cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK")
46+
if slurm_cpus_per_task and slurm_cpus_per_task.isdigit():
47+
return max(int(slurm_cpus_per_task), 1)
48+
49+
return max(os.cpu_count() or 1, 1)
50+
51+
52+
def resolve_runtime_resource_sentinels(
53+
config: DictConfig,
54+
print_results: bool = True,
55+
) -> DictConfig:
56+
"""
57+
Resolve runtime resource sentinels in system.{training,inference}.
58+
59+
Sentinel convention:
60+
- num_gpus = -1 -> use all GPUs visible to this run
61+
- num_workers = -1 -> use all CPU slots available to this run
62+
63+
This is runtime-oriented (SLURM/cgroup aware) and complements auto-planning.
64+
"""
65+
if not hasattr(config, "system"):
66+
return config
67+
68+
gpu_info = get_gpu_info()
69+
available_gpus = gpu_info["num_gpus"] if gpu_info["cuda_available"] else 0
70+
available_cpus = _available_cpus_for_current_run()
71+
72+
for section_name in ("training", "inference"):
73+
section = getattr(config.system, section_name, None)
74+
if section is None:
75+
continue
76+
77+
if getattr(section, "num_gpus", None) == -1:
78+
section.num_gpus = available_gpus
79+
if print_results:
80+
print(
81+
f"🔧 Auto-detected system.{section_name}.num_gpus: "
82+
f"-1 → {section.num_gpus}"
83+
)
84+
85+
if getattr(section, "num_workers", None) == -1:
86+
section.num_workers = available_cpus
87+
if print_results:
88+
print(
89+
f"🔧 Auto-detected system.{section_name}.num_workers: "
90+
f"-1 → {section.num_workers}"
91+
)
92+
93+
if getattr(section, "num_gpus", 0) < -1:
94+
raise ValueError(f"system.{section_name}.num_gpus must be >= -1")
95+
if getattr(section, "num_workers", 0) < -1:
96+
raise ValueError(f"system.{section_name}.num_workers must be >= -1")
97+
98+
return config
99+
100+
28101
@dataclass
29102
class AutoPlanResult:
30103
"""Results from automatic planning."""

connectomics/training/deep_supervision.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def _loss_supports_weight(loss_fn: nn.Module) -> bool:
3434
return False
3535

3636

37+
def _is_class_index_loss(loss_fn: nn.Module) -> bool:
38+
"""Return True if loss expects class-index labels (1 channel target).
39+
40+
Cross-entropy style losses consume dense logits [B, C, ...] and class-index
41+
targets [B, 1, ...] or [B, ...], unlike BCE/MSE-style losses that require
42+
channel-aligned dense targets [B, C, ...].
43+
"""
44+
return loss_fn.__class__.__name__ in {"CrossEntropyLoss", "CrossEntropyLossWrapper"}
45+
46+
3747
class DeepSupervisionHandler:
3848
"""
3949
Handler for deep supervision and multi-task learning.
@@ -130,15 +140,33 @@ def compute_multitask_loss(
130140

131141
# Extract channels for this task from outputs
132142
task_output = outputs[:, start_ch:end_ch, ...]
133-
end_ch - start_ch
143+
task_output_channels = end_ch - start_ch
144+
145+
# Determine label channel convention per task:
146+
# - CE-style losses use class-index labels (1 channel)
147+
# - Dense losses (BCE/MSE/MAE/Dice/etc.) use channel-aligned labels
148+
task_loss_fns = [self.loss_functions[idx] for idx in loss_indices]
149+
uses_class_index_targets = any(_is_class_index_loss(fn) for fn in task_loss_fns)
150+
uses_dense_targets = any(not _is_class_index_loss(fn) for fn in task_loss_fns)
151+
152+
if uses_class_index_targets and uses_dense_targets:
153+
raise ValueError(
154+
f"Task '{task_name}' mixes class-index and dense target losses. "
155+
"Use either CE-style losses only, or dense losses only, per task."
156+
)
134157

135-
# Determine number of label channels needed
136-
# For softmax-based losses (2+ output channels), label has 1 channel
137-
# For sigmoid-based losses (1 output channel), label has 1 channel
138-
# So labels always use 1 channel per task
139-
num_label_channels = 1
158+
num_label_channels = 1 if uses_class_index_targets else task_output_channels
140159

141160
# Extract label channels
161+
if label_ch_offset + num_label_channels > labels.shape[1]:
162+
raise ValueError(
163+
f"Label channel mismatch for task '{task_name}': expected "
164+
f"{num_label_channels} channel(s) at offset {label_ch_offset}, "
165+
f"but label tensor has {labels.shape[1]} total channels. "
166+
f"Task output slice is [{start_ch}:{end_ch}] "
167+
f"({task_output_channels} channel(s))."
168+
)
169+
142170
task_label = labels[:, label_ch_offset:label_ch_offset + num_label_channels, ...]
143171
label_ch_offset += num_label_channels
144172

connectomics/training/lit/data_factory.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def create_datamodule(
299299
train_json_empty = True
300300
else:
301301
# Check if JSON file is empty or has no images
302-
with open(json_path, "r") as f:
302+
with open(json_path) as f:
303303
json_data = json.load(f)
304304
image_files = json_data.get(cfg.data.train_image_key, [])
305305
if not image_files:
@@ -463,11 +463,23 @@ def create_datamodule(
463463
if val_data_dicts:
464464
print(f" Val dataset size: {len(val_data_dicts)}")
465465

466-
# Auto-compute iter_num from volume size if not specified (only for training)
466+
# Auto-compute iter_num from volume size if not specified (only for training).
467+
# IMPORTANT: cfg.data.iter_num_per_epoch is interpreted as optimizer steps/epoch.
468+
# Dataset iter_num is sample-count based, so we convert steps -> samples.
467469
iter_num = None
468470
if mode == "train":
469-
iter_num = cfg.data.iter_num_per_epoch
470-
if iter_num == -1 and dataset_type != "filename":
471+
iter_num_cfg = cfg.data.iter_num_per_epoch
472+
if iter_num_cfg > 0:
473+
# Convert requested steps/epoch to per-epoch sample count expected by datasets.
474+
# Account for per-device batch size and number of training devices.
475+
num_devices = cfg.system.training.num_gpus if cfg.system.training.num_gpus > 0 else 1
476+
iter_num = int(iter_num_cfg * cfg.system.training.batch_size * num_devices)
477+
print(
478+
f" Requested iter_num_per_epoch={iter_num_cfg} steps -> "
479+
f"dataset samples={iter_num} "
480+
f"(batch_size={cfg.system.training.batch_size}, devices={num_devices})"
481+
)
482+
elif iter_num_cfg == -1 and dataset_type != "filename":
471483
# For filename datasets, iter_num is determined by the number of files
472484
print("📊 Auto-computing iter_num from volume size...")
473485
import h5py
@@ -506,8 +518,11 @@ def create_datamodule(
506518
print(f" Stride: {cfg.data.stride}")
507519
print(f" Samples per volume: {samples_per_vol}")
508520
print(f" ✅ Total possible samples (iter_num): {iter_num:,}")
509-
print(f" ✅ Batches per epoch: {iter_num // cfg.system.training.batch_size:,}")
510-
elif iter_num == -1 and dataset_type == "filename":
521+
# Approximate steps/epoch for informational logging.
522+
num_devices = cfg.system.training.num_gpus if cfg.system.training.num_gpus > 0 else 1
523+
denom = max(1, cfg.system.training.batch_size * num_devices)
524+
print(f" ✅ Approx steps per epoch: {iter_num // denom:,}")
525+
elif iter_num_cfg == -1 and dataset_type == "filename":
511526
# For filename datasets, iter_num will be determined by dataset length
512527
print(" Filename dataset: iter_num will be determined by number of files in JSON")
513528

@@ -558,9 +573,9 @@ def create_datamodule(
558573
pad_size = getattr(cfg.data.image_transform, "pad_size", None) or getattr(
559574
cfg.data, "pad_size", None
560575
)
561-
pad_mode = getattr(
562-
cfg.data.image_transform, "pad_mode", None
563-
) or getattr(cfg.data, "pad_mode", "reflect")
576+
pad_mode = getattr(cfg.data.image_transform, "pad_mode", None) or getattr(
577+
cfg.data, "pad_mode", "reflect"
578+
)
564579

565580
# Create optimized cached datasets
566581
train_dataset = CachedVolumeDataset(
@@ -589,9 +604,7 @@ def create_datamodule(
589604
persistent_workers=preloaded_num_workers > 0,
590605
)
591606

592-
print(
593-
f" Preload policy: train={train_preload_cfg}, val={val_preload_cfg}"
594-
)
607+
print(f" Preload policy: train={train_preload_cfg}, val={val_preload_cfg}")
595608

596609
# Create validation dataset and loader if validation data exists
597610
val_loader = None

connectomics/training/lit/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Config,
2222
load_config,
2323
resolve_data_paths,
24+
resolve_runtime_resource_sentinels,
2425
update_from_cli,
2526
validate_config,
2627
)
@@ -226,6 +227,9 @@ def setup_config(args) -> Config:
226227
cfg.data.cellmap["input_array_info"]["shape"] = [64, 64, 64]
227228
cfg.data.cellmap["target_array_info"]["shape"] = [64, 64, 64]
228229

230+
# Resolve -1 sentinels (auto-max resources for current runtime allocation).
231+
cfg = resolve_runtime_resource_sentinels(cfg, print_results=True)
232+
229233
# CPU-only fallback: avoid multiprocessing workers when no CUDA is available
230234
if not torch.cuda.is_available():
231235
if cfg.system.training.num_workers > 0:

justfile

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@
55
default:
66
@just --list
77

8+
# Resolve SLURM time limit for a partition (fallback to sensible defaults).
9+
_slurm-time-limit partition:
10+
#!/usr/bin/env bash
11+
set -euo pipefail
12+
time_limit=$(sinfo -p {{partition}} -h -o "%l" | head -1)
13+
if [ -z "$time_limit" ] || [ "$time_limit" = "infinite" ]; then
14+
case "{{partition}}" in
15+
short|interactive)
16+
time_limit="12:00:00"
17+
;;
18+
medium)
19+
time_limit="2-00:00:00"
20+
;;
21+
long)
22+
time_limit="5-00:00:00"
23+
;;
24+
*)
25+
time_limit="7-00:00:00"
26+
;;
27+
esac
28+
fi
29+
echo "$time_limit"
30+
831
# ============================================================================
932
# Setup & Data
1033
# ============================================================================
@@ -89,31 +112,17 @@ tensorboard-run experiment timestamp port='6006':
89112
# just slurm short 8 4 "python scripts/main.py --config tutorials/lucchi.yaml"
90113
# just slurm short 8 4 "just train lucchi++" "" "64G" # override memory
91114
# Time limits: short=12h, medium=2d, long=5d
115+
# CPU-only convenience wrapper for single-task jobs.
116+
# just slurm-cpu short 8 0 "python scripts/downsample_nisb.py --splits train"
92117
slurm partition num_cpu num_gpu cmd constraint='' mem='32G':
93118
#!/usr/bin/env bash
94119
constraint_flag=""
95120
if [ -n "{{constraint}}" ]; then
96121
constraint_flag="--constraint={{constraint}}"
97122
fi
98123

99-
# Set time limit to partition maximum
100-
time_limit=$(sinfo -p {{partition}} -h -o "%l" | head -1)
101-
if [ -z "$time_limit" ] || [ "$time_limit" = "infinite" ]; then
102-
case "{{partition}}" in
103-
short|interactive)
104-
time_limit="12:00:00"
105-
;;
106-
medium)
107-
time_limit="2-00:00:00"
108-
;;
109-
long)
110-
time_limit="5-00:00:00"
111-
;;
112-
*)
113-
time_limit="7-00:00:00"
114-
;;
115-
esac
116-
fi
124+
# Resolve partition time limit (with fallback defaults)
125+
time_limit=$(just _slurm-time-limit {{partition}})
117126

118127
# Run the command exactly as provided (no auto "just" wrapping).
119128
sbatch --job-name="pytc_{{cmd}}" \
@@ -129,9 +138,57 @@ slurm partition num_cpu num_gpu cmd constraint='' mem='32G':
129138
$constraint_flag \
130139
--wrap="mkdir -p \$HOME/.just && export JUST_TEMPDIR=\$HOME/.just TMPDIR=\$HOME/.just NCCL_SOCKET_FAMILY=AF_INET && source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc && cd $PWD && srun --ntasks=1 --gpus-per-task={{num_gpu}} --cpus-per-task={{num_cpu}} {{cmd}}"
131140

132-
# Alias for slurm (kept for backward compatibility)
133-
slurm-sh partition num_cpu num_gpu cmd constraint='' mem='32G':
134-
just slurm {{partition}} {{num_cpu}} {{num_gpu}} {{cmd}} {{constraint}} {{mem}}
141+
# Generic CPU-only multi-task launcher (single node, no GPU).
142+
# Example:
143+
# just slurm-cpu-parallel short 7 1 "python scripts/downsample_nisb.py --task \$SLURM_PROCID"
144+
slurm-cpu-parallel partition num_tasks='7' cpu_per_task='4' cmd='' constraint='' mem='64G':
145+
#!/usr/bin/env bash
146+
set -euo pipefail
147+
mkdir -p slurm_outputs
148+
cmd_value='{{cmd}}'
149+
if [ -z "$cmd_value" ]; then
150+
echo "Error: cmd must be provided. Usage:"
151+
echo " just slurm-cpu-parallel <partition> <num_tasks> <cpu_per_task> \"<command>\" [constraint] [mem]"
152+
exit 2
153+
fi
154+
155+
constraint_value='{{constraint}}'
156+
constraint_flag=""
157+
if [ -n "$constraint_value" ]; then
158+
constraint_flag="--constraint=$constraint_value"
159+
fi
160+
161+
# Resolve partition time limit (with fallback defaults)
162+
time_limit=$(just _slurm-time-limit {{partition}})
163+
164+
sbatch --job-name="pytc_cpu_{{num_tasks}}t" \
165+
--partition={{partition}} \
166+
--output=slurm_outputs/slurm-%j.out \
167+
--error=slurm_outputs/slurm-%j.err \
168+
--nodes=1 \
169+
--ntasks={{num_tasks}} \
170+
--gpus-per-task=0 \
171+
--cpus-per-task={{cpu_per_task}} \
172+
--mem={{mem}} \
173+
--time=$time_limit \
174+
$constraint_flag \
175+
--wrap="mkdir -p \$HOME/.just && export JUST_TEMPDIR=\$HOME/.just TMPDIR=\$HOME/.just && source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc && cd $PWD && srun --ntasks={{num_tasks}} --gpus-per-task=0 --cpus-per-task={{cpu_per_task}} bash -c '$cmd_value'"
176+
177+
# Generic CPU-only multi-task launcher for sharded scripts.
178+
# Automatically appends:
179+
# --num-shards $SLURM_NTASKS --shard-index $SLURM_PROCID
180+
# Example:
181+
# just slurm-cpu-sharded short 7 1 "python scripts/downsample_nisb.py"
182+
slurm-cpu-sharded partition num_tasks='7' cpu_per_task='4' cmd='' constraint='' mem='64G':
183+
#!/usr/bin/env bash
184+
set -euo pipefail
185+
cmd_value='{{cmd}}'
186+
if [ -z "$cmd_value" ]; then
187+
echo "Error: cmd must be provided. Usage:"
188+
echo " just slurm-cpu-sharded <partition> <num_tasks> <cpu_per_task> \"<command>\" [constraint] [mem]"
189+
exit 2
190+
fi
191+
just slurm-cpu-parallel {{partition}} {{num_tasks}} {{cpu_per_task}} "{{cmd}} --num-shards \$SLURM_NTASKS --shard-index \$SLURM_PROCID" "{{constraint}}" "{{mem}}"
135192

136193
# Launch parameter sweep from config (e.g., just sweep tutorials/sweep_example.yaml)
137194
sweep config:

0 commit comments

Comments
 (0)