11"""Fairness evaluator."""
2-
32import inspect
43import itertools
54import logging
65import warnings
76from datetime import datetime
87from typing import Any , Callable , Dict , List , Optional , Union
98
9+ import array_api_compat .numpy
1010import numpy as np
11- import numpy .typing as npt
1211import pandas as pd
1312from datasets import Dataset , config
1413from datasets .features import Features
2120 get_columns_as_numpy_array ,
2221 set_decode ,
2322)
24- from cyclops .evaluate .metrics .factory import create_metric
25- from cyclops .evaluate .metrics .functional .precision_recall_curve import (
23+ from cyclops .evaluate .metrics .experimental .functional .precision_recall_curve import (
2624 _format_thresholds ,
25+ _validate_thresholds ,
2726)
28- from cyclops .evaluate .metrics .metric import Metric , MetricCollection , OperatorMetric
29- from cyclops .evaluate .metrics .utils import (
30- _check_thresholds ,
31- _get_value_if_singleton_array ,
32- )
27+ from cyclops .evaluate .metrics .experimental .metric import Metric , OperatorMetric
28+ from cyclops .evaluate .metrics .experimental .metric_dict import MetricDict
29+ from cyclops .evaluate .metrics .experimental .utils .types import Array
30+ from cyclops .evaluate .metrics .factory import create_metric
3331from cyclops .evaluate .utils import _format_column_names
3432from cyclops .utils .log import setup_logging
3533
3937
4038
4139def evaluate_fairness (
42- metrics : Union [str , Callable [..., Any ], Metric , MetricCollection ],
40+ metrics : Union [str , Callable [..., Any ], Metric , MetricDict ],
4341 dataset : Dataset ,
4442 groups : Union [str , List [str ]],
4543 target_columns : Union [str , List [str ]],
@@ -62,7 +60,7 @@ def evaluate_fairness(
6260
6361 Parameters
6462 ----------
65- metrics : Union[str, Callable[..., Any], Metric, MetricCollection ]
63+ metrics : Union[str, Callable[..., Any], Metric, MetricDict ]
6664 The metric or metrics to compute. If a string, it should be the name of a
6765 metric provided by CyclOps. If a callable, it should be a function that
6866 takes target, prediction, and optionally threshold/thresholds as arguments
@@ -147,18 +145,14 @@ def evaluate_fairness(
147145 raise TypeError (
148146 "Expected `dataset` to be of type `Dataset`, but got " f"{ type (dataset )} ." ,
149147 )
148+ _validate_thresholds (thresholds )
150149
151- _check_thresholds (thresholds )
152- fmt_thresholds : npt .NDArray [np .float_ ] = _format_thresholds ( # type: ignore
153- thresholds ,
154- )
155-
156- metrics_ : Union [Callable [..., Any ], MetricCollection ] = _format_metrics (
150+ metrics_ : Union [Callable [..., Any ], MetricDict ] = _format_metrics (
157151 metrics ,
158152 metric_name ,
159153 ** (metric_kwargs or {}),
160154 )
161-
155+ fmt_thresholds = _format_thresholds ( thresholds , xp = array_api_compat . numpy )
162156 fmt_groups : List [str ] = _format_column_names (groups )
163157 fmt_target_columns : List [str ] = _format_column_names (target_columns )
164158 fmt_prediction_columns : List [str ] = _format_column_names (prediction_columns )
@@ -361,15 +355,15 @@ def warn_too_many_unique_values(
361355
362356
363357def _format_metrics (
364- metrics : Union [str , Callable [..., Any ], Metric , MetricCollection ],
358+ metrics : Union [str , Callable [..., Any ], Metric , MetricDict ],
365359 metric_name : Optional [str ] = None ,
366360 ** metric_kwargs : Any ,
367- ) -> Union [Callable [..., Any ], Metric , MetricCollection ]:
361+ ) -> Union [Callable [..., Any ], Metric , MetricDict ]:
368362 """Format the metrics argument.
369363
370364 Parameters
371365 ----------
372- metrics : Union[str, Callable[..., Any], Metric, MetricCollection ]
366+ metrics : Union[str, Callable[..., Any], Metric, MetricDict ]
373367 The metrics to use for computing the metric results.
374368 metric_name : str, optional, default=None
375369 The name of the metric. This is only used if `metrics` is a callable.
@@ -379,23 +373,23 @@ def _format_metrics(
379373
380374 Returns
381375 -------
382- Union[Callable[..., Any], Metric, MetricCollection ]
376+ Union[Callable[..., Any], Metric, MetricDict ]
383377 The formatted metrics.
384378
385379 Raises
386380 ------
387381 TypeError
388- If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection `.
382+ If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict `.
389383
390384 """
391385 if isinstance (metrics , str ):
392- metrics = create_metric (metric_name = metrics , ** metric_kwargs )
386+ metrics = create_metric (metric_name = metrics , experimental = True , ** metric_kwargs )
393387 if isinstance (metrics , Metric ):
394388 if metric_name is not None and isinstance (metrics , OperatorMetric ):
395389 # single metric created from arithmetic operation, with given name
396- return MetricCollection ({metric_name : metrics })
397- return MetricCollection (metrics )
398- if isinstance (metrics , MetricCollection ):
390+ return MetricDict ({metric_name : metrics })
391+ return MetricDict (metrics )
392+ if isinstance (metrics , MetricDict ):
399393 return metrics
400394 if callable (metrics ):
401395 if metric_name is None :
@@ -407,7 +401,7 @@ def _format_metrics(
407401 return metrics
408402
409403 raise TypeError (
410- f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection `, or "
404+ f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict `, or "
411405 f"`Callable`, but got { type (metrics )} ." ,
412406 )
413407
@@ -701,7 +695,7 @@ def _get_slice_spec(
701695
702696
703697def _compute_metrics ( # noqa: C901, PLR0912
704- metrics : Union [Callable [..., Any ], MetricCollection ],
698+ metrics : Union [Callable [..., Any ], MetricDict ],
705699 dataset : Dataset ,
706700 target_columns : List [str ],
707701 prediction_column : str ,
@@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912
713707
714708 Parameters
715709 ----------
716- metrics : Union[Callable, MetricCollection ]
710+ metrics : Union[Callable, MetricDict ]
717711 The metrics to compute.
718712 dataset : Dataset
719713 The dataset to compute the metrics on.
@@ -738,12 +732,19 @@ def _compute_metrics( # noqa: C901, PLR0912
738732 "Encountered empty dataset while computing metrics. "
739733 "The metric values will be set to `None`."
740734 )
741- if isinstance (metrics , MetricCollection ):
735+ if isinstance (metrics , MetricDict ):
742736 if threshold is not None :
743737 # set the threshold for each metric in the collection
744738 for name , metric in metrics .items ():
745- if hasattr (metric , "threshold" ):
739+ if isinstance ( metric , Metric ) and hasattr (metric , "threshold" ):
746740 metric .threshold = threshold
741+ elif isinstance (metric , OperatorMetric ):
742+ if hasattr (metric .metric_a , "threshold" ) and hasattr (
743+ metric .metric_b ,
744+ "threshold" ,
745+ ):
746+ metric .metric_a .threshold = threshold
747+ metric .metric_b .threshold = threshold # type: ignore[union-attr]
747748 else :
748749 LOGGER .warning (
749750 "Metric %s does not have a threshold attribute. "
@@ -754,7 +755,7 @@ def _compute_metrics( # noqa: C901, PLR0912
754755 if len (dataset ) == 0 :
755756 warnings .warn (empty_dataset_msg , RuntimeWarning , stacklevel = 1 )
756757 results : Dict [str , Any ] = {
757- metric_name : float ("NaN" ) for metric_name in metrics
758+ metric_name : float ("NaN" ) for metric_name in metrics # type: ignore[attr-defined]
758759 }
759760 elif (
760761 batch_size is None or batch_size <= 0
@@ -779,11 +780,11 @@ def _compute_metrics( # noqa: C901, PLR0912
779780 columns = prediction_column ,
780781 )
781782
782- metrics .update_state (targets , predictions )
783+ metrics .update (targets , predictions )
783784
784785 results = metrics .compute ()
785786
786- metrics .reset_state ()
787+ metrics .reset ()
787788
788789 return results
789790 if callable (metrics ):
@@ -817,26 +818,26 @@ def _compute_metrics( # noqa: C901, PLR0912
817818 return {metric_name .title (): output }
818819
819820 raise TypeError (
820- "The `metrics` argument must be a string, a Metric, a MetricCollection , "
821+ "The `metrics` argument must be a string, a Metric, a MetricDict , "
821822 f"or a callable. Got { type (metrics )} ." ,
822823 )
823824
824825
825826def _get_metric_results_for_prediction_and_slice (
826- metrics : Union [Callable [..., Any ], MetricCollection ],
827+ metrics : Union [Callable [..., Any ], MetricDict ],
827828 dataset : Dataset ,
828829 target_columns : List [str ],
829830 prediction_column : str ,
830831 slice_name : str ,
831832 batch_size : Optional [int ] = config .DEFAULT_MAX_BATCH_SIZE ,
832833 metric_name : Optional [str ] = None ,
833- thresholds : Optional [npt . NDArray [ np . float_ ] ] = None ,
834+ thresholds : Optional [Array ] = None ,
834835) -> Dict [str , Dict [str , Any ]]:
835836 """Compute metrics for a slice of a dataset.
836837
837838 Parameters
838839 ----------
839- metrics : Union[Callable, MetricCollection ]
840+ metrics : Union[Callable, MetricDict ]
840841 The metrics to compute.
841842 dataset : Dataset
842843 The dataset to compute the metrics on.
@@ -850,7 +851,7 @@ def _get_metric_results_for_prediction_and_slice(
850851 The batch size to use for the computation.
851852 metric_name : Optional[str]
852853 The name of the metric to compute.
853- thresholds : Optional[List[float] ]
854+ thresholds : Optional[Array ]
854855 The thresholds to use for the metrics.
855856
856857 Returns
@@ -873,7 +874,7 @@ def _get_metric_results_for_prediction_and_slice(
873874 return {slice_name : metric_output }
874875
875876 results : Dict [str , Dict [str , Any ]] = {}
876- for threshold in thresholds :
877+ for threshold in thresholds : # type: ignore[attr-defined]
877878 metric_output = _compute_metrics (
878879 metrics = metrics ,
879880 dataset = dataset ,
@@ -969,11 +970,7 @@ def _compute_parity_metrics(
969970 )
970971
971972 parity_results [key ].setdefault (slice_name , {}).update (
972- {
973- parity_metric_name : _get_value_if_singleton_array (
974- parity_metric_value ,
975- ),
976- },
973+ {parity_metric_name : parity_metric_value },
977974 )
978975
979976 return parity_results
0 commit comments