Skip to content

Commit b20a08d

Browse files
author
Donglai Wei
committed
Add NERL inference evaluation
1 parent deaf74f commit b20a08d

4 files changed

Lines changed: 472 additions & 10 deletions

File tree

connectomics/config/schema/inference.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ class EvaluationConfig:
194194
metrics: Optional[List[str]] = None # e.g., ['dice', 'jaccard', 'accuracy']
195195
prediction_threshold: float = 0.5 # Probability/logit threshold for binary metrics
196196
instance_iou_threshold: float = 0.5 # IoU threshold for instance matching metrics
197+
# Neurite ERL evaluation via lib/em_erl. nerl_graph accepts an ERLGraph
198+
# .npz or a BANIS/NISB-style NetworkX skeleton.pkl.
199+
nerl_graph: Any = None
200+
nerl_mask: Any = None
201+
nerl_resolution: Optional[List[float]] = None
202+
nerl_merge_threshold: int = 1
203+
nerl_chunk_num: int = 1
204+
nerl_skeleton_id_attribute: str = "id"
205+
nerl_skeleton_position_attribute: str = "index_position"
206+
nerl_skeleton_edge_length_attribute: str = "edge_length"
207+
nerl_skeleton_position_order: str = "xyz"
208+
nerl_prediction_position_order: Optional[str] = None
197209

198210

199211
@dataclass

connectomics/training/lightning/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,24 @@ def _save_metrics_to_file(self, metrics_dict: Dict[str, Any]):
850850
f.write(f" Accuracy: {metrics_dict['accuracy']:.6f}\n")
851851
f.write("\n")
852852

853+
if "nerl" in metrics_dict:
854+
f.write("Neurite ERL Metrics:\n")
855+
f.write("-" * 80 + "\n")
856+
f.write(f" NERL: {metrics_dict['nerl']:.6f}\n")
857+
if "nerl_erl" in metrics_dict:
858+
f.write(f" ERL: {metrics_dict['nerl_erl']:.6f}\n")
859+
if "nerl_max_erl" in metrics_dict:
860+
f.write(
861+
f" Max ERL: {metrics_dict['nerl_max_erl']:.6f}\n"
862+
)
863+
if "nerl_num_skeletons" in metrics_dict:
864+
f.write(
865+
f" Skeletons: {metrics_dict['nerl_num_skeletons']}\n"
866+
)
867+
if "nerl_graph" in metrics_dict:
868+
f.write(f" Graph: {metrics_dict['nerl_graph']}\n")
869+
f.write("\n")
870+
853871
f.write("=" * 80 + "\n")
854872

855873
logger.info(f"Metrics saved to: {metrics_file}")
@@ -925,6 +943,10 @@ def _log_decode_experiment(
925943
"instance_precision_detail",
926944
"instance_recall_detail",
927945
"instance_f1_detail",
946+
"nerl",
947+
"nerl_erl",
948+
"nerl_max_erl",
949+
"nerl_num_skeletons",
928950
]
929951

930952
header_cols = (

0 commit comments

Comments
 (0)