1010
1111from ..metrics .metrics_seg import AdaptedRandError
1212from ..metrics .segmentation_numpy import instance_matching , instance_matching_simple , voi
13+ from .context import EvaluationContext
1314
1415logger = logging .getLogger (__name__ )
1516
@@ -51,7 +52,7 @@ def is_instance_segmentation(pred_tensor: torch.Tensor) -> bool:
5152
5253
5354def compute_instance_metrics (
54- module ,
55+ context : EvaluationContext ,
5556 pred_tensor : torch .Tensor ,
5657 labels_tensor : torch .Tensor ,
5758 volume_prefix : str ,
@@ -61,10 +62,9 @@ def compute_instance_metrics(
6162 pred_instances = pred_tensor .long ()
6263 labels_instances = labels_tensor .long ()
6364
64- if hasattr (module , "test_adapted_rand" ) and isinstance (
65- module .test_adapted_rand , torchmetrics .Metric
66- ):
67- per_volume_metric = AdaptedRandError (return_all_stats = True ).to (module .device )
65+ adapted_rand_metric = context .metric ("adapted_rand" )
66+ if context .metric_requested ("adapted_rand" ):
67+ per_volume_metric = AdaptedRandError (return_all_stats = True ).to (context .device )
6868 per_volume_metric .update (pred_instances .cpu (), labels_instances .cpu ())
6969 adapted_rand_value = per_volume_metric .compute ()
7070 if isinstance (adapted_rand_value , dict ):
@@ -82,9 +82,11 @@ def compute_instance_metrics(
8282 logger .info ("%s %s: %.6f" , volume_prefix , k , val )
8383
8484 metrics_dict ["adapted_rand_error" ] = are_score
85- module .test_adapted_rand .update (pred_instances .cpu (), labels_instances .cpu ())
85+ if isinstance (adapted_rand_metric , torchmetrics .Metric ):
86+ adapted_rand_metric .update (pred_instances .cpu (), labels_instances .cpu ())
8687
87- if hasattr (module , "test_voi" ) and isinstance (module .test_voi , torchmetrics .Metric ):
88+ voi_metric = context .metric ("voi" )
89+ if context .metric_requested ("voi" ):
8890 split , merge = voi (pred_instances .cpu ().numpy (), labels_instances .cpu ().numpy ())
8991 logger .info ("%sVOI Split: %.6f" , volume_prefix , split )
9092 logger .info ("%sVOI Merge: %.6f" , volume_prefix , merge )
@@ -94,11 +96,11 @@ def compute_instance_metrics(
9496 metrics_dict ["voi_merge" ] = merge
9597 metrics_dict ["voi_total" ] = split + merge
9698
97- module .test_voi .update (pred_instances .cpu (), labels_instances .cpu ())
99+ if isinstance (voi_metric , torchmetrics .Metric ):
100+ voi_metric .update (pred_instances .cpu (), labels_instances .cpu ())
98101
99- if hasattr (module , "test_instance_accuracy" ) and isinstance (
100- module .test_instance_accuracy , torchmetrics .Metric
101- ):
102+ instance_accuracy_metric = context .metric ("instance_accuracy" )
103+ if context .metric_requested ("instance_accuracy" ):
102104 stats = instance_matching (
103105 labels_instances .cpu ().numpy (),
104106 pred_instances .cpu ().numpy (),
@@ -108,11 +110,11 @@ def compute_instance_metrics(
108110 logger .info ("%sInstance Accuracy: %.6f" , volume_prefix , stats ["accuracy" ])
109111 metrics_dict ["instance_accuracy" ] = stats ["accuracy" ]
110112
111- module .test_instance_accuracy .update (pred_instances .cpu (), labels_instances .cpu ())
113+ if isinstance (instance_accuracy_metric , torchmetrics .Metric ):
114+ instance_accuracy_metric .update (pred_instances .cpu (), labels_instances .cpu ())
112115
113- if hasattr (module , "test_instance_accuracy_detail" ) and isinstance (
114- module .test_instance_accuracy_detail , torchmetrics .Metric
115- ):
116+ instance_accuracy_detail_metric = context .metric ("instance_accuracy_detail" )
117+ if context .metric_requested ("instance_accuracy_detail" ):
116118 stats_simple = instance_matching_simple (
117119 labels_instances .cpu ().numpy (),
118120 pred_instances .cpu ().numpy (),
@@ -133,14 +135,12 @@ def compute_instance_metrics(
133135 metrics_dict ["instance_recall_detail" ] = stats_simple ["recall" ]
134136 metrics_dict ["instance_f1_detail" ] = stats_simple ["f1" ]
135137
136- module .test_instance_accuracy_detail .update (
137- pred_instances .cpu (),
138- labels_instances .cpu (),
139- )
138+ if isinstance (instance_accuracy_detail_metric , torchmetrics .Metric ):
139+ instance_accuracy_detail_metric .update (pred_instances .cpu (), labels_instances .cpu ())
140140
141141
142142def compute_binary_metrics (
143- module ,
143+ context : EvaluationContext ,
144144 pred_tensor : torch .Tensor ,
145145 labels_tensor : torch .Tensor ,
146146 volume_prefix : str ,
@@ -158,31 +158,37 @@ def compute_binary_metrics(
158158 else labels_tensor .long ()
159159 )
160160
161- if hasattr (module , "test_jaccard" ) and module .test_jaccard is not None :
161+ jaccard_metric = context .metric ("jaccard" )
162+ if context .metric_requested ("jaccard" ):
162163 jaccard_value = torchmetrics .functional .jaccard_index (
163164 pred_binary ,
164165 labels_binary ,
165166 task = "binary" ,
166167 )
167168 logger .info ("%sJaccard: %.6f" , volume_prefix , jaccard_value .item ())
168169 metrics_dict ["jaccard" ] = jaccard_value .item ()
169- module .test_jaccard .update (pred_binary , labels_binary )
170+ if jaccard_metric is not None :
171+ jaccard_metric .update (pred_binary , labels_binary )
170172
171- if hasattr (module , "test_dice" ) and module .test_dice is not None :
173+ dice_metric = context .metric ("dice" )
174+ if context .metric_requested ("dice" ):
172175 dice_value = torchmetrics .functional .dice (pred_binary , labels_binary )
173176 logger .info ("%sDice: %.6f" , volume_prefix , dice_value .item ())
174177 metrics_dict ["dice" ] = dice_value .item ()
175- module .test_dice .update (pred_binary , labels_binary )
178+ if dice_metric is not None :
179+ dice_metric .update (pred_binary , labels_binary )
176180
177- if hasattr (module , "test_accuracy" ) and module .test_accuracy is not None :
181+ accuracy_metric = context .metric ("accuracy" )
182+ if context .metric_requested ("accuracy" ):
178183 accuracy_value = torchmetrics .functional .accuracy (
179184 pred_binary ,
180185 labels_binary ,
181186 task = "binary" ,
182187 )
183188 logger .info ("%sAccuracy: %.6f" , volume_prefix , accuracy_value .item ())
184189 metrics_dict ["accuracy" ] = accuracy_value .item ()
185- module .test_accuracy .update (pred_binary , labels_binary )
190+ if accuracy_metric is not None :
191+ accuracy_metric .update (pred_binary , labels_binary )
186192
187193
188194__all__ = [
0 commit comments