Skip to content

Commit e21820e

Browse files
committed
added gmi measure
1 parent ddbd878 commit e21820e

10 files changed

Lines changed: 36 additions & 346 deletions

File tree

fairbench/v2/blocks/measures/classification.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,32 @@ def f1(predictions, labels, sensitive=None):
257257
quantities.an(n),
258258
],
259259
)
260+
261+
262+
263+
@c.measure("the geometric mean of tpr and tnr - accounts for class imbalance")
264+
def gmi(predictions, labels, sensitive=None):
265+
predictions = np.array(predictions)
266+
labels = np.array(labels)
267+
sensitive = np.ones_like(predictions) if sensitive is None else np.array(sensitive)
268+
269+
tp = (predictions * sensitive * labels).sum()
270+
tn = ((1 - predictions) * sensitive * (1 - labels)).sum()
271+
p = (labels * sensitive).sum()
272+
n = ((1 - labels) * sensitive).sum()
273+
fn = p - tp
274+
fp = n - tn
275+
276+
precision = 0 if (tp + fp) == 0 else tp / (tp + fp)
277+
recall = 0 if (tp + fn) == 0 else tp / (tp + fn)
278+
value = precision*recall
279+
280+
return c.Value(
281+
c.TargetedNumber(value, 1),
282+
depends=[
283+
quantities.tp(tp),
284+
quantities.tn(tn),
285+
quantities.ap(p),
286+
quantities.an(n),
287+
],
288+
)

fairbench/v2/core/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def wrapper(**kwargs) -> Value:
3838
or isinstance(value.value, int)
3939
), f"{descriptor} computed {type(value.value)} instead of float, int, Number, or TargetedNumber"
4040
if unit and 1 < float(value) < 1 + eps: # take care of rounding errors
41-
value = 1.0
41+
value.value = 1.0
4242
assert not unit or (
4343
0 <= float(value) <= 1
4444
), f"{descriptor} computed {float(value)} that is not in [0,1]"

fairbench/v2/core/values.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ def exists(self) -> bool:
337337
return True
338338
return False
339339

340-
def rebase(self, dep: Descriptor):
340+
def rebase(self, dep: Descriptor|str):
341+
if isinstance(dep,str):
342+
dep = Descriptor(dep, self.descriptor.role)
341343
return Value(self.value, dep, list(self.depends.values()))
342344

343345
def tostring(self, tab="", depth=0, details: bool = False):

fairbench/v2/reports/adhoc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
blocks.measures.tnr,
1313
blocks.measures.ppv,
1414
blocks.measures.f1,
15+
blocks.measures.gmi,
1516
blocks.measures.tar,
1617
blocks.measures.trr,
1718
blocks.measures.lift,

playground/ausc/differentials.py

Lines changed: 0 additions & 157 deletions
This file was deleted.

playground/ausc/fb_interface.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

playground/ausc/models.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)