Skip to content

Commit 8fb3cf1

Browse files
fcogidiamrit110
andauthored
Integrate experimental metrics with other modules (#549)
* integrate experimental metrics with other modules * add average precision metric to experimental metrics package * fix tutorials * Add type hints and keyword arguments to metrics classes * Update nbsphinx version to 0.9.3 * Update nbconvert version to 7.14.2 * Fix type annotations and formatting issues * Update kernel display name in mortality_prediction.ipynb * Add guard clause to prevent module execution on import * Update `torch_distributed.py` with type hints * Add multiclass and multilabel average precision metrics * Change jupyter kernel * Fix type annotations for metric values in ClassificationPlotter --------- Co-authored-by: Amrit K <amritk@vectorinstitute.ai>
1 parent 5c4ebb2 commit 8fb3cf1

33 files changed

Lines changed: 1900 additions & 394 deletions

cyclops/evaluate/evaluator.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Evaluate one or more models on a dataset."""
2-
32
import logging
43
import warnings
54
from dataclasses import asdict
@@ -16,7 +15,9 @@
1615
)
1716
from cyclops.evaluate.fairness.config import FairnessConfig
1817
from cyclops.evaluate.fairness.evaluator import evaluate_fairness
19-
from cyclops.evaluate.metrics.metric import Metric, MetricCollection
18+
from cyclops.evaluate.metrics.experimental.metric import Metric
19+
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
20+
from cyclops.evaluate.metrics.experimental.utils.types import Array
2021
from cyclops.evaluate.utils import _format_column_names, choose_split
2122
from cyclops.utils.log import setup_logging
2223

@@ -27,7 +28,7 @@
2728

2829
def evaluate(
2930
dataset: Union[str, Dataset, DatasetDict],
30-
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
31+
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
3132
target_columns: Union[str, List[str]],
3233
prediction_columns: Union[str, List[str]],
3334
ignore_columns: Optional[Union[str, List[str]]] = None,
@@ -47,7 +48,7 @@ def evaluate(
4748
The dataset to evaluate on. If a string, the dataset will be loaded
4849
using `datasets.load_dataset`. If `DatasetDict`, the `split` argument
4950
must be specified.
50-
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection]
51+
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict]
5152
The metrics to compute.
5253
target_columns : Union[str, List[str]]
5354
The name of the column(s) containing the target values. A string value
@@ -202,28 +203,28 @@ def _load_data(
202203

203204

204205
def _prepare_metrics(
205-
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
206-
) -> MetricCollection:
206+
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
207+
) -> MetricDict:
207208
"""Prepare metrics for evaluation."""
208-
# TODO: wrap in BootstrappedMetric if computing confidence intervals
209+
# TODO [fcogidi]: wrap in BootstrappedMetric if computing confidence intervals
209210
if isinstance(metrics, (Metric, Sequence, Dict)) and not isinstance(
210211
metrics,
211-
MetricCollection,
212+
MetricDict,
212213
):
213-
return MetricCollection(metrics)
214-
if isinstance(metrics, MetricCollection):
214+
return MetricDict(metrics) # type: ignore[arg-type]
215+
if isinstance(metrics, MetricDict):
215216
return metrics
216217

217218
raise TypeError(
218219
f"Invalid type for `metrics`: {type(metrics)}. "
219220
"Expected one of: Metric, Sequence[Metric], Dict[str, Metric], "
220-
"MetricCollection.",
221+
"MetricDict.",
221222
)
222223

223224

224225
def _compute_metrics(
225226
dataset: Dataset,
226-
metrics: MetricCollection,
227+
metrics: MetricDict,
227228
slice_spec: SliceSpec,
228229
target_columns: Union[str, List[str]],
229230
prediction_columns: Union[str, List[str]],
@@ -266,8 +267,8 @@ def _compute_metrics(
266267
RuntimeWarning,
267268
stacklevel=1,
268269
)
269-
metric_output = {
270-
metric_name: float("NaN") for metric_name in metrics
270+
metric_output: Dict[str, Array] = {
271+
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined,misc]
271272
}
272273
elif (
273274
batch_size is None or batch_size < 0
@@ -293,10 +294,10 @@ def _compute_metrics(
293294
)
294295

295296
# update the metric state
296-
metrics.update_state(targets, predictions)
297+
metrics.update(targets, predictions)
297298

298299
metric_output = metrics.compute()
299-
metrics.reset_state()
300+
metrics.reset()
300301

301302
model_name: str = "model_for_%s" % prediction_column
302303
results.setdefault(model_name, {})

cyclops/evaluate/fairness/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
from datasets import Dataset, config
77

8-
from cyclops.evaluate.metrics.metric import Metric, MetricCollection
8+
from cyclops.evaluate.metrics.experimental.metric import Metric
9+
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
910

1011

1112
@dataclass
1213
class FairnessConfig:
1314
"""Configuration for fairness metrics."""
1415

15-
metrics: Union[str, Callable[..., Any], Metric, MetricCollection]
16+
metrics: Union[str, Callable[..., Any], Metric, MetricDict]
1617
dataset: Dataset
1718
groups: Union[str, List[str]]
1819
target_columns: Union[str, List[str]]

cyclops/evaluate/fairness/evaluator.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Fairness evaluator."""
2-
32
import inspect
43
import itertools
54
import logging
65
import warnings
76
from datetime import datetime
87
from typing import Any, Callable, Dict, List, Optional, Union
98

9+
import array_api_compat.numpy
1010
import numpy as np
11-
import numpy.typing as npt
1211
import pandas as pd
1312
from datasets import Dataset, config
1413
from datasets.features import Features
@@ -21,15 +20,14 @@
2120
get_columns_as_numpy_array,
2221
set_decode,
2322
)
24-
from cyclops.evaluate.metrics.factory import create_metric
25-
from cyclops.evaluate.metrics.functional.precision_recall_curve import (
23+
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
2624
_format_thresholds,
25+
_validate_thresholds,
2726
)
28-
from cyclops.evaluate.metrics.metric import Metric, MetricCollection, OperatorMetric
29-
from cyclops.evaluate.metrics.utils import (
30-
_check_thresholds,
31-
_get_value_if_singleton_array,
32-
)
27+
from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric
28+
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
29+
from cyclops.evaluate.metrics.experimental.utils.types import Array
30+
from cyclops.evaluate.metrics.factory import create_metric
3331
from cyclops.evaluate.utils import _format_column_names
3432
from cyclops.utils.log import setup_logging
3533

@@ -39,7 +37,7 @@
3937

4038

4139
def evaluate_fairness(
42-
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
40+
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
4341
dataset: Dataset,
4442
groups: Union[str, List[str]],
4543
target_columns: Union[str, List[str]],
@@ -62,7 +60,7 @@ def evaluate_fairness(
6260
6361
Parameters
6462
----------
65-
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
63+
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
6664
The metric or metrics to compute. If a string, it should be the name of a
6765
metric provided by CyclOps. If a callable, it should be a function that
6866
takes target, prediction, and optionally threshold/thresholds as arguments
@@ -147,18 +145,14 @@ def evaluate_fairness(
147145
raise TypeError(
148146
"Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.",
149147
)
148+
_validate_thresholds(thresholds)
150149

151-
_check_thresholds(thresholds)
152-
fmt_thresholds: npt.NDArray[np.float_] = _format_thresholds( # type: ignore
153-
thresholds,
154-
)
155-
156-
metrics_: Union[Callable[..., Any], MetricCollection] = _format_metrics(
150+
metrics_: Union[Callable[..., Any], MetricDict] = _format_metrics(
157151
metrics,
158152
metric_name,
159153
**(metric_kwargs or {}),
160154
)
161-
155+
fmt_thresholds = _format_thresholds(thresholds, xp=array_api_compat.numpy)
162156
fmt_groups: List[str] = _format_column_names(groups)
163157
fmt_target_columns: List[str] = _format_column_names(target_columns)
164158
fmt_prediction_columns: List[str] = _format_column_names(prediction_columns)
@@ -361,15 +355,15 @@ def warn_too_many_unique_values(
361355

362356

363357
def _format_metrics(
364-
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
358+
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
365359
metric_name: Optional[str] = None,
366360
**metric_kwargs: Any,
367-
) -> Union[Callable[..., Any], Metric, MetricCollection]:
361+
) -> Union[Callable[..., Any], Metric, MetricDict]:
368362
"""Format the metrics argument.
369363
370364
Parameters
371365
----------
372-
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
366+
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
373367
The metrics to use for computing the metric results.
374368
metric_name : str, optional, default=None
375369
The name of the metric. This is only used if `metrics` is a callable.
@@ -379,23 +373,23 @@ def _format_metrics(
379373
380374
Returns
381375
-------
382-
Union[Callable[..., Any], Metric, MetricCollection]
376+
Union[Callable[..., Any], Metric, MetricDict]
383377
The formatted metrics.
384378
385379
Raises
386380
------
387381
TypeError
388-
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection`.
382+
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict`.
389383
390384
"""
391385
if isinstance(metrics, str):
392-
metrics = create_metric(metric_name=metrics, **metric_kwargs)
386+
metrics = create_metric(metric_name=metrics, experimental=True, **metric_kwargs)
393387
if isinstance(metrics, Metric):
394388
if metric_name is not None and isinstance(metrics, OperatorMetric):
395389
# single metric created from arithmetic operation, with given name
396-
return MetricCollection({metric_name: metrics})
397-
return MetricCollection(metrics)
398-
if isinstance(metrics, MetricCollection):
390+
return MetricDict({metric_name: metrics})
391+
return MetricDict(metrics)
392+
if isinstance(metrics, MetricDict):
399393
return metrics
400394
if callable(metrics):
401395
if metric_name is None:
@@ -407,7 +401,7 @@ def _format_metrics(
407401
return metrics
408402

409403
raise TypeError(
410-
f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection`, or "
404+
f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict`, or "
411405
f"`Callable`, but got {type(metrics)}.",
412406
)
413407

@@ -701,7 +695,7 @@ def _get_slice_spec(
701695

702696

703697
def _compute_metrics( # noqa: C901, PLR0912
704-
metrics: Union[Callable[..., Any], MetricCollection],
698+
metrics: Union[Callable[..., Any], MetricDict],
705699
dataset: Dataset,
706700
target_columns: List[str],
707701
prediction_column: str,
@@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912
713707
714708
Parameters
715709
----------
716-
metrics : Union[Callable, MetricCollection]
710+
metrics : Union[Callable, MetricDict]
717711
The metrics to compute.
718712
dataset : Dataset
719713
The dataset to compute the metrics on.
@@ -738,12 +732,19 @@ def _compute_metrics( # noqa: C901, PLR0912
738732
"Encountered empty dataset while computing metrics. "
739733
"The metric values will be set to `None`."
740734
)
741-
if isinstance(metrics, MetricCollection):
735+
if isinstance(metrics, MetricDict):
742736
if threshold is not None:
743737
# set the threshold for each metric in the collection
744738
for name, metric in metrics.items():
745-
if hasattr(metric, "threshold"):
739+
if isinstance(metric, Metric) and hasattr(metric, "threshold"):
746740
metric.threshold = threshold
741+
elif isinstance(metric, OperatorMetric):
742+
if hasattr(metric.metric_a, "threshold") and hasattr(
743+
metric.metric_b,
744+
"threshold",
745+
):
746+
metric.metric_a.threshold = threshold
747+
metric.metric_b.threshold = threshold # type: ignore[union-attr]
747748
else:
748749
LOGGER.warning(
749750
"Metric %s does not have a threshold attribute. "
@@ -754,7 +755,7 @@ def _compute_metrics( # noqa: C901, PLR0912
754755
if len(dataset) == 0:
755756
warnings.warn(empty_dataset_msg, RuntimeWarning, stacklevel=1)
756757
results: Dict[str, Any] = {
757-
metric_name: float("NaN") for metric_name in metrics
758+
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined]
758759
}
759760
elif (
760761
batch_size is None or batch_size <= 0
@@ -779,11 +780,11 @@ def _compute_metrics( # noqa: C901, PLR0912
779780
columns=prediction_column,
780781
)
781782

782-
metrics.update_state(targets, predictions)
783+
metrics.update(targets, predictions)
783784

784785
results = metrics.compute()
785786

786-
metrics.reset_state()
787+
metrics.reset()
787788

788789
return results
789790
if callable(metrics):
@@ -817,26 +818,26 @@ def _compute_metrics( # noqa: C901, PLR0912
817818
return {metric_name.title(): output}
818819

819820
raise TypeError(
820-
"The `metrics` argument must be a string, a Metric, a MetricCollection, "
821+
"The `metrics` argument must be a string, a Metric, a MetricDict, "
821822
f"or a callable. Got {type(metrics)}.",
822823
)
823824

824825

825826
def _get_metric_results_for_prediction_and_slice(
826-
metrics: Union[Callable[..., Any], MetricCollection],
827+
metrics: Union[Callable[..., Any], MetricDict],
827828
dataset: Dataset,
828829
target_columns: List[str],
829830
prediction_column: str,
830831
slice_name: str,
831832
batch_size: Optional[int] = config.DEFAULT_MAX_BATCH_SIZE,
832833
metric_name: Optional[str] = None,
833-
thresholds: Optional[npt.NDArray[np.float_]] = None,
834+
thresholds: Optional[Array] = None,
834835
) -> Dict[str, Dict[str, Any]]:
835836
"""Compute metrics for a slice of a dataset.
836837
837838
Parameters
838839
----------
839-
metrics : Union[Callable, MetricCollection]
840+
metrics : Union[Callable, MetricDict]
840841
The metrics to compute.
841842
dataset : Dataset
842843
The dataset to compute the metrics on.
@@ -850,7 +851,7 @@ def _get_metric_results_for_prediction_and_slice(
850851
The batch size to use for the computation.
851852
metric_name : Optional[str]
852853
The name of the metric to compute.
853-
thresholds : Optional[List[float]]
854+
thresholds : Optional[Array]
854855
The thresholds to use for the metrics.
855856
856857
Returns
@@ -873,7 +874,7 @@ def _get_metric_results_for_prediction_and_slice(
873874
return {slice_name: metric_output}
874875

875876
results: Dict[str, Dict[str, Any]] = {}
876-
for threshold in thresholds:
877+
for threshold in thresholds: # type: ignore[attr-defined]
877878
metric_output = _compute_metrics(
878879
metrics=metrics,
879880
dataset=dataset,
@@ -969,11 +970,7 @@ def _compute_parity_metrics(
969970
)
970971

971972
parity_results[key].setdefault(slice_name, {}).update(
972-
{
973-
parity_metric_name: _get_value_if_singleton_array(
974-
parity_metric_value,
975-
),
976-
},
973+
{parity_metric_name: parity_metric_value},
977974
)
978975

979976
return parity_results

cyclops/evaluate/metrics/experimental/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
MulticlassAUROC,
1010
MultilabelAUROC,
1111
)
12+
from cyclops.evaluate.metrics.experimental.average_precision import (
13+
BinaryAveragePrecision,
14+
MulticlassAveragePrecision,
15+
MultilabelAveragePrecision,
16+
)
1217
from cyclops.evaluate.metrics.experimental.confusion_matrix import (
1318
BinaryConfusionMatrix,
1419
MulticlassConfusionMatrix,

0 commit comments

Comments
 (0)