Skip to content

Commit e7fc206

Browse files
author
Donglai Wei
committed
Extract evaluation runtime context
1 parent 2c954e6 commit e7fc206

10 files changed

Lines changed: 421 additions & 233 deletions

File tree

connectomics/evaluation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Evaluation stage orchestration."""
22

3+
from .context import EvaluationContext
34
from .curvilinear import evaluate_directory, evaluate_file_pair
45
from .nerl import import_em_erl
56
from .report import (
@@ -12,6 +13,7 @@
1213
from .stage import EvaluationStageResult, run_evaluation_stage
1314

1415
__all__ = [
16+
"EvaluationContext",
1517
"EvaluationStageResult",
1618
"compute_test_metrics",
1719
"configured_evaluation_metrics",

connectomics/evaluation/context.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Explicit evaluation runtime context."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Callable, Mapping
6+
from dataclasses import dataclass, field
7+
from pathlib import Path
8+
from typing import Any
9+
10+
from ..config.pipeline.dict_utils import cfg_get
11+
12+
13+
@dataclass
14+
class EvaluationContext:
15+
"""Runtime inputs required by evaluation helpers.
16+
17+
The context is intentionally independent of ``ConnectomicsModule``. Lightning
18+
code owns translation from module state into this explicit contract.
19+
"""
20+
21+
cfg: Any
22+
evaluation_cfg: Any = None
23+
inference_cfg: Any = None
24+
device: Any = "cpu"
25+
enabled: bool | None = None
26+
checkpoint_path: str | Path | None = None
27+
output_path: str | Path | None = None
28+
metrics: Mapping[str, Any] = field(default_factory=dict)
29+
log_fn: Callable[..., None] | None = None
30+
metrics_sink: Callable[[dict[str, Any]], None] | None = None
31+
distributed_single_volume_sharding: bool = False
32+
33+
def cfg_value(self, cfg_obj: Any, name: str, default: Any = None) -> Any:
34+
return cfg_get(cfg_obj, name, default)
35+
36+
@property
37+
def is_enabled(self) -> bool:
38+
if self.enabled is not None:
39+
return bool(self.enabled)
40+
return bool(self.cfg_value(self.evaluation_cfg, "enabled", False))
41+
42+
@property
43+
def requested_metrics(self) -> set[str]:
44+
metrics = self.cfg_value(self.evaluation_cfg, "metrics", None)
45+
if metrics is None:
46+
return set()
47+
if isinstance(metrics, str):
48+
return {metrics.lower()}
49+
return {str(metric).lower() for metric in metrics}
50+
51+
def metric_requested(self, metric_name: str) -> bool:
52+
return metric_name.lower() in self.requested_metrics
53+
54+
def metric(self, metric_name: str) -> Any:
55+
key = metric_name.lower()
56+
return self.metrics.get(key, self.metrics.get(f"test_{key}"))
57+
58+
def resolved_output_path(self) -> str | Path | None:
59+
if self.output_path is not None:
60+
return self.output_path
61+
save_prediction_cfg = self.cfg_value(self.inference_cfg, "save_prediction", None)
62+
return self.cfg_value(save_prediction_cfg, "output_path", None)
63+
64+
def log_metric(self, name: str, value: Any, **kwargs: Any) -> None:
65+
if self.log_fn is None:
66+
return
67+
self.log_fn(name, value, **kwargs)
68+
69+
def persist_metrics(self, metrics_dict: dict[str, Any]) -> bool:
70+
if self.metrics_sink is None:
71+
return False
72+
self.metrics_sink(metrics_dict)
73+
return True
74+
75+
76+
__all__ = ["EvaluationContext"]

connectomics/evaluation/metrics.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from ..metrics.metrics_seg import AdaptedRandError
1212
from ..metrics.segmentation_numpy import instance_matching, instance_matching_simple, voi
13+
from .context import EvaluationContext
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -51,7 +52,7 @@ def is_instance_segmentation(pred_tensor: torch.Tensor) -> bool:
5152

5253

5354
def compute_instance_metrics(
54-
module,
55+
context: EvaluationContext,
5556
pred_tensor: torch.Tensor,
5657
labels_tensor: torch.Tensor,
5758
volume_prefix: str,
@@ -61,10 +62,9 @@ def compute_instance_metrics(
6162
pred_instances = pred_tensor.long()
6263
labels_instances = labels_tensor.long()
6364

64-
if hasattr(module, "test_adapted_rand") and isinstance(
65-
module.test_adapted_rand, torchmetrics.Metric
66-
):
67-
per_volume_metric = AdaptedRandError(return_all_stats=True).to(module.device)
65+
adapted_rand_metric = context.metric("adapted_rand")
66+
if context.metric_requested("adapted_rand"):
67+
per_volume_metric = AdaptedRandError(return_all_stats=True).to(context.device)
6868
per_volume_metric.update(pred_instances.cpu(), labels_instances.cpu())
6969
adapted_rand_value = per_volume_metric.compute()
7070
if isinstance(adapted_rand_value, dict):
@@ -82,9 +82,11 @@ def compute_instance_metrics(
8282
logger.info("%s %s: %.6f", volume_prefix, k, val)
8383

8484
metrics_dict["adapted_rand_error"] = are_score
85-
module.test_adapted_rand.update(pred_instances.cpu(), labels_instances.cpu())
85+
if isinstance(adapted_rand_metric, torchmetrics.Metric):
86+
adapted_rand_metric.update(pred_instances.cpu(), labels_instances.cpu())
8687

87-
if hasattr(module, "test_voi") and isinstance(module.test_voi, torchmetrics.Metric):
88+
voi_metric = context.metric("voi")
89+
if context.metric_requested("voi"):
8890
split, merge = voi(pred_instances.cpu().numpy(), labels_instances.cpu().numpy())
8991
logger.info("%sVOI Split: %.6f", volume_prefix, split)
9092
logger.info("%sVOI Merge: %.6f", volume_prefix, merge)
@@ -94,11 +96,11 @@ def compute_instance_metrics(
9496
metrics_dict["voi_merge"] = merge
9597
metrics_dict["voi_total"] = split + merge
9698

97-
module.test_voi.update(pred_instances.cpu(), labels_instances.cpu())
99+
if isinstance(voi_metric, torchmetrics.Metric):
100+
voi_metric.update(pred_instances.cpu(), labels_instances.cpu())
98101

99-
if hasattr(module, "test_instance_accuracy") and isinstance(
100-
module.test_instance_accuracy, torchmetrics.Metric
101-
):
102+
instance_accuracy_metric = context.metric("instance_accuracy")
103+
if context.metric_requested("instance_accuracy"):
102104
stats = instance_matching(
103105
labels_instances.cpu().numpy(),
104106
pred_instances.cpu().numpy(),
@@ -108,11 +110,11 @@ def compute_instance_metrics(
108110
logger.info("%sInstance Accuracy: %.6f", volume_prefix, stats["accuracy"])
109111
metrics_dict["instance_accuracy"] = stats["accuracy"]
110112

111-
module.test_instance_accuracy.update(pred_instances.cpu(), labels_instances.cpu())
113+
if isinstance(instance_accuracy_metric, torchmetrics.Metric):
114+
instance_accuracy_metric.update(pred_instances.cpu(), labels_instances.cpu())
112115

113-
if hasattr(module, "test_instance_accuracy_detail") and isinstance(
114-
module.test_instance_accuracy_detail, torchmetrics.Metric
115-
):
116+
instance_accuracy_detail_metric = context.metric("instance_accuracy_detail")
117+
if context.metric_requested("instance_accuracy_detail"):
116118
stats_simple = instance_matching_simple(
117119
labels_instances.cpu().numpy(),
118120
pred_instances.cpu().numpy(),
@@ -133,14 +135,12 @@ def compute_instance_metrics(
133135
metrics_dict["instance_recall_detail"] = stats_simple["recall"]
134136
metrics_dict["instance_f1_detail"] = stats_simple["f1"]
135137

136-
module.test_instance_accuracy_detail.update(
137-
pred_instances.cpu(),
138-
labels_instances.cpu(),
139-
)
138+
if isinstance(instance_accuracy_detail_metric, torchmetrics.Metric):
139+
instance_accuracy_detail_metric.update(pred_instances.cpu(), labels_instances.cpu())
140140

141141

142142
def compute_binary_metrics(
143-
module,
143+
context: EvaluationContext,
144144
pred_tensor: torch.Tensor,
145145
labels_tensor: torch.Tensor,
146146
volume_prefix: str,
@@ -158,31 +158,37 @@ def compute_binary_metrics(
158158
else labels_tensor.long()
159159
)
160160

161-
if hasattr(module, "test_jaccard") and module.test_jaccard is not None:
161+
jaccard_metric = context.metric("jaccard")
162+
if context.metric_requested("jaccard"):
162163
jaccard_value = torchmetrics.functional.jaccard_index(
163164
pred_binary,
164165
labels_binary,
165166
task="binary",
166167
)
167168
logger.info("%sJaccard: %.6f", volume_prefix, jaccard_value.item())
168169
metrics_dict["jaccard"] = jaccard_value.item()
169-
module.test_jaccard.update(pred_binary, labels_binary)
170+
if jaccard_metric is not None:
171+
jaccard_metric.update(pred_binary, labels_binary)
170172

171-
if hasattr(module, "test_dice") and module.test_dice is not None:
173+
dice_metric = context.metric("dice")
174+
if context.metric_requested("dice"):
172175
dice_value = torchmetrics.functional.dice(pred_binary, labels_binary)
173176
logger.info("%sDice: %.6f", volume_prefix, dice_value.item())
174177
metrics_dict["dice"] = dice_value.item()
175-
module.test_dice.update(pred_binary, labels_binary)
178+
if dice_metric is not None:
179+
dice_metric.update(pred_binary, labels_binary)
176180

177-
if hasattr(module, "test_accuracy") and module.test_accuracy is not None:
181+
accuracy_metric = context.metric("accuracy")
182+
if context.metric_requested("accuracy"):
178183
accuracy_value = torchmetrics.functional.accuracy(
179184
pred_binary,
180185
labels_binary,
181186
task="binary",
182187
)
183188
logger.info("%sAccuracy: %.6f", volume_prefix, accuracy_value.item())
184189
metrics_dict["accuracy"] = accuracy_value.item()
185-
module.test_accuracy.update(pred_binary, labels_binary)
190+
if accuracy_metric is not None:
191+
accuracy_metric.update(pred_binary, labels_binary)
186192

187193

188194
__all__ = [

0 commit comments

Comments
 (0)