Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions metrics/f1/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
- 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall.
- 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification).
sample_weight (`list` of `float`): Sample weights Defaults to None.
zero_division (`int` or `"warn"`, optional): Passed directly to sklearn's `f1_score`. Controls behavior when a label has no predicted or true samples. Use `0`, `1`, or `"warn"` (default sklearn behavior).

- 0: Returns 0 when there is a zero division.
- 1: Returns 1 when there is a zero division.
- `'warn'`: Raises a warning and then returns 0 when there is a zero division.

Returns:
f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better.
Expand Down Expand Up @@ -84,6 +89,13 @@
>>> results = f1_metric.compute(predictions=[[0, 1, 1], [1, 1, 0]], references=[[0, 1, 1], [0, 1, 0]], average="macro")
>>> print(round(results['f1'], 2))
0.67

Example 6-The same multiclass example as in Example 4, but with `zero_division` set to `1` for labels with no predicted or true samples.
>>> predictions = [0, 0, 0, 0, 0]
>>> references = [0, 1, 0, 1, 2]
>>> results = f1_metric.compute(predictions=predictions, references=references, average=None, labels=[0, 1, 2, 3], zero_division=1)
>>> print([round(res, 2) for res in results['f1']])
[0.57, 0.0, 0.0, 1.0]
"""


Expand Down Expand Up @@ -123,8 +135,17 @@ def _info(self):
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
)

def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
def _compute(
self,
predictions,
references,
labels=None,
**kwargs,
):
score = f1_score(
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight
references,
predictions,
labels=labels,
**kwargs,
)
Comment on lines +138 to 150
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given there's no typing anyway, specifying defaults and argument keywords is redundant. Unless the evaluate package has other defaults, it might be even more future proof to just pass kwargs:

Suggested change
def _compute(
self,
predictions,
references,
labels=None,
pos_label=1,
average="binary",
sample_weight=None,
zero_division="warn",
):
score = f1_score(
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight
references,
predictions,
labels=labels,
pos_label=pos_label,
average=average,
sample_weight=sample_weight,
zero_division=zero_division,
)
def _compute(
self,
predictions,
references,
labels=None,
**kwargs,
):
score = f1_score(
references,
predictions,
labels=labels,
**kwargs,
)

return {"f1": score if getattr(score, "size", 1) > 1 else float(score)}