Skip to content

Commit df4739d

Browse files
authored
Add AUROC (#546)
* Add AUROC metric to experimental module * Refactor binary and multiclass ROC functions * Refactor tests to use a common thresholds list * Fix mypy error
1 parent 6a91acf commit df4739d

11 files changed

Lines changed: 1770 additions & 58 deletions

File tree

cyclops/evaluate/metrics/experimental/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
MulticlassAccuracy,
55
MultilabelAccuracy,
66
)
7+
from cyclops.evaluate.metrics.experimental.auroc import (
8+
BinaryAUROC,
9+
MulticlassAUROC,
10+
MultilabelAUROC,
11+
)
712
from cyclops.evaluate.metrics.experimental.confusion_matrix import (
813
BinaryConfusionMatrix,
914
MulticlassConfusionMatrix,
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""Classes for computing the area under the ROC curve."""
2+
from typing import List, Literal, Optional, Tuple, Union
3+
4+
from cyclops.evaluate.metrics.experimental.functional.auroc import (
5+
_binary_auroc_compute,
6+
_binary_auroc_validate_args,
7+
_multiclass_auroc_compute,
8+
_multiclass_auroc_validate_args,
9+
_multilabel_auroc_compute,
10+
_multilabel_auroc_validate_args,
11+
)
12+
from cyclops.evaluate.metrics.experimental.precision_recall_curve import (
13+
BinaryPrecisionRecallCurve,
14+
MulticlassPrecisionRecallCurve,
15+
MultilabelPrecisionRecallCurve,
16+
)
17+
from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat
18+
from cyclops.evaluate.metrics.experimental.utils.types import Array
19+
20+
21+
class BinaryAUROC(BinaryPrecisionRecallCurve):
22+
"""Area under the Receiver Operating Characteristic (ROC) curve.
23+
24+
Parameters
25+
----------
26+
max_fpr : float, optional, default=None
27+
If not `None`, computes the maximum area under the curve up to the given
28+
false positive rate value. Must be a float in the range (0, 1].
29+
thresholds : Union[int, List[float], Array], optional, default=None
30+
The thresholds to use for computing the ROC curve. Can be one of the following:
31+
- `None`: use all unique values in `preds` as thresholds.
32+
- `int`: use `int` (larger than 1) uniformly spaced thresholds in the range
33+
[0, 1].
34+
- `List[float]`: use the values in the list as bins for the thresholds.
35+
- `Array`: use the values in the Array as bins for the thresholds. The
36+
array must be 1D.
37+
ignore_index : int, optional, default=None
38+
The value in `target` that should be ignored when computing the AUROC.
39+
If `None`, all values in `target` are used.
40+
41+
Examples
42+
--------
43+
>>> import numpy.array_api as anp
44+
>>> from cyclops.evaluate.metrics.experimental import BinaryAUROC
45+
>>> target = anp.asarray([0, 1, 1, 0, 1, 0, 0, 1])
46+
>>> preds = anp.asarray([0.1, 0.4, 0.35, 0.8, 0.2, 0.6, 0.7, 0.3])
47+
>>> auroc = BinaryAUROC(thresholds=None)
48+
>>> auroc(target, preds)
49+
Array(0.25, dtype=float32)
50+
>>> auroc = BinaryAUROC(thresholds=5)
51+
>>> auroc(target, preds)
52+
Array(0.21875, dtype=float32)
53+
"""
54+
55+
name: str = "AUC ROC Curve"
56+
57+
def __init__(
58+
self,
59+
max_fpr: Optional[float] = None,
60+
thresholds: Optional[Union[int, List[float], Array]] = None,
61+
ignore_index: Optional[int] = None,
62+
) -> None:
63+
"""Initialize the BinaryAUROC metric."""
64+
super().__init__(thresholds=thresholds, ignore_index=ignore_index)
65+
_binary_auroc_validate_args(
66+
max_fpr=max_fpr,
67+
thresholds=thresholds,
68+
ignore_index=ignore_index,
69+
)
70+
self.max_fpr = max_fpr
71+
72+
def _compute_metric(self) -> Array: # type: ignore[override]
73+
"""Compute the AUROC.""" ""
74+
state = (
75+
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
76+
if self.thresholds is None
77+
else self.confmat # type: ignore[attr-defined]
78+
)
79+
return _binary_auroc_compute(state, thresholds=self.thresholds, max_fpr=self.max_fpr) # type: ignore
80+
81+
82+
class MulticlassAUROC(MulticlassPrecisionRecallCurve):
83+
"""Area under the Receiver Operating Characteristic (ROC) curve.
84+
85+
Parameters
86+
----------
87+
num_classes : int
88+
The number of classes in the classification problem.
89+
thresholds : Union[int, List[float], Array], optional, default=None
90+
The thresholds to use for computing the ROC curve. Can be one of the following:
91+
- `None`: use all unique values in `preds` as thresholds.
92+
- `int`: use `int` (larger than 1) uniformly spaced thresholds in the range
93+
[0, 1].
94+
- `List[float]`: use the values in the list as bins for the thresholds.
95+
- `Array`: use the values in the Array as bins for the thresholds. The
96+
array must be 1D.
97+
average : {"macro", "weighted", "none"}, optional, default="macro"
98+
The type of averaging to use for computing the AUROC. Can be one of
99+
the following:
100+
- `"macro"`: interpolates the curves from each class at a combined set of
101+
thresholds and then average over the classwise interpolated curves.
102+
- `"weighted"`: average over the classwise curves weighted by the support
103+
(the number of true instances for each class).
104+
- `"none"`: do not average over the classwise curves.
105+
ignore_index : int or Tuple[int], optional, default=None
106+
The value(s) in `target` that should be ignored when computing the AUROC.
107+
If `None`, all values in `target` are used.
108+
109+
Examples
110+
--------
111+
>>> import numpy.array_api as anp
112+
>>> from cyclops.evaluate.metrics.experimental import MulticlassAUROC
113+
>>> target = anp.asarray([0, 1, 2, 0, 1, 2])
114+
>>> preds = anp.asarray(
115+
... [[0.11, 0.22, 0.67],
116+
... [0.84, 0.73, 0.12],
117+
... [0.33, 0.92, 0.44],
118+
... [0.11, 0.22, 0.67],
119+
... [0.84, 0.73, 0.12],
120+
... [0.33, 0.92, 0.44]])
121+
>>> auroc = MulticlassAUROC(num_classes=3, average="macro", thresholds=None)
122+
>>> auroc(target, preds)
123+
Array(0.33333334, dtype=float32)
124+
>>> auroc = MulticlassAUROC(num_classes=3, average=None, thresholds=None)
125+
>>> auroc(target, preds)
126+
Array([0. , 0.5, 0.5], dtype=float32)
127+
>>> auroc = MulticlassAUROC(num_classes=3, average="macro", thresholds=5)
128+
>>> auroc(target, preds)
129+
Array(0.33333334, dtype=float32)
130+
>>> auroc = MulticlassAUROC(num_classes=3, average=None, thresholds=5)
131+
>>> auroc(target, preds)
132+
Array([0. , 0.5, 0.5], dtype=float32)
133+
"""
134+
135+
name: str = "AUC ROC Curve"
136+
137+
def __init__(
138+
self,
139+
num_classes: int,
140+
thresholds: Optional[Union[int, List[float], Array]] = None,
141+
average: Optional[Literal["macro", "weighted", "none"]] = "macro",
142+
ignore_index: Optional[Union[int, Tuple[int]]] = None,
143+
) -> None:
144+
"""Initialize the MulticlassAUROC metric."""
145+
super().__init__(
146+
num_classes,
147+
thresholds=thresholds,
148+
ignore_index=ignore_index,
149+
)
150+
_multiclass_auroc_validate_args(
151+
num_classes=num_classes,
152+
thresholds=thresholds,
153+
average=average,
154+
ignore_index=ignore_index,
155+
)
156+
self.average = average # type: ignore[assignment]
157+
158+
def _compute_metric(self) -> Array: # type: ignore[override]
159+
"""Compute the AUROC."""
160+
state = (
161+
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
162+
if self.thresholds is None
163+
else self.confmat # type: ignore[attr-defined]
164+
)
165+
return _multiclass_auroc_compute(
166+
state,
167+
self.num_classes,
168+
thresholds=self.thresholds, # type: ignore[arg-type]
169+
average=self.average, # type: ignore[arg-type]
170+
)
171+
172+
173+
class MultilabelAUROC(MultilabelPrecisionRecallCurve):
174+
"""Area under the Receiver Operating Characteristic (ROC) curve.
175+
176+
num_labels : int
177+
The number of labels in the multilabel classification problem.
178+
thresholds : Union[int, List[float], Array], optional, default=None
179+
The thresholds to use for computing the ROC curve. Can be one of the following:
180+
- `None`: use all unique values in `preds` as thresholds.
181+
- `int`: use `int` (larger than 1) uniformly spaced thresholds in the range
182+
[0, 1].
183+
- `List[float]`: use the values in the list as bins for the thresholds.
184+
- `Array`: use the values in the Array as bins for the thresholds. The
185+
array must be 1D.
186+
average : {"micro", "macro", "weighted", "none"}, optional, default="macro"
187+
The type of averaging to use for computing the AUROC. Can be one of
188+
the following:
189+
- `"micro"`: compute the AUROC globally by considering each element of the
190+
label indicator matrix as a label.
191+
- `"macro"`: compute the AUROC for each label and average them.
192+
- `"weighted"`: compute the AUROC for each label and average them weighted
193+
by the support (the number of true instances for each label).
194+
- `"none"`: do not average over the labelwise AUROC.
195+
ignore_index : int, optional, default=None
196+
The value in `target` that should be ignored when computing the AUROC.
197+
If `None`, all values in `target` are used.
198+
199+
Examples
200+
--------
201+
>>> import numpy.array_api as anp
202+
>>> from cyclops.evaluate.metrics.experimental import MultilabelAUROC
203+
>>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]])
204+
>>> preds = anp.asarray(
205+
... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]],
206+
... )
207+
>>> auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=None)
208+
>>> auroc(target, preds)
209+
Array(0.5, dtype=float32)
210+
>>> auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=None)
211+
>>> auroc(target, preds)
212+
Array([1. , 0. , 0.5], dtype=float32)
213+
>>> auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=5)
214+
>>> auroc(target, preds)
215+
Array(0.5, dtype=float32)
216+
>>> auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=5)
217+
>>> auroc(target, preds)
218+
Array([1. , 0. , 0.5], dtype=float32)
219+
220+
"""
221+
222+
name: str = "AUC ROC Curve"
223+
224+
def __init__(
225+
self,
226+
num_labels: int,
227+
thresholds: Optional[Union[int, List[float], Array]] = None,
228+
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
229+
ignore_index: Optional[int] = None,
230+
) -> None:
231+
"""Initialize the MultilabelAUROC metric."""
232+
super().__init__(
233+
num_labels,
234+
thresholds=thresholds,
235+
ignore_index=ignore_index,
236+
)
237+
_multilabel_auroc_validate_args(
238+
num_labels=num_labels,
239+
thresholds=thresholds,
240+
average=average,
241+
ignore_index=ignore_index,
242+
)
243+
self.average = average
244+
245+
def _compute_metric(self) -> Array: # type: ignore[override]
246+
"""Compute the AUROC."""
247+
state = (
248+
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
249+
if self.thresholds is None
250+
else self.confmat # type: ignore[attr-defined]
251+
)
252+
return _multilabel_auroc_compute(
253+
state,
254+
self.num_labels,
255+
thresholds=self.thresholds, # type: ignore[arg-type]
256+
average=self.average,
257+
ignore_index=self.ignore_index,
258+
)

cyclops/evaluate/metrics/experimental/functional/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
multiclass_accuracy,
55
multilabel_accuracy,
66
)
7+
from cyclops.evaluate.metrics.experimental.functional.auroc import (
8+
binary_auroc,
9+
multiclass_auroc,
10+
multilabel_auroc,
11+
)
712
from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import (
813
binary_confusion_matrix,
914
multiclass_confusion_matrix,

0 commit comments

Comments
 (0)