Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 8d5d573

Browse files
committed
use local variables to hold unique classes
1 parent 06392d2 commit 8d5d573

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

bigframes/ml/metrics/_metrics.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -331,18 +331,15 @@ def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser
331331
def _precision_score_binary_pos_only(
332332
y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str
333333
) -> float:
334-
if (
335-
y_true.unique(keep_order=False).count() != 2
336-
or y_pred.unique(keep_order=False).count() != 2
337-
):
334+
y_true_classes = y_true.unique(keep_order=False)
335+
y_pred_classes = y_pred.unique(keep_order=False)
336+
337+
if y_true_classes.count() != 2 or y_pred_classes.count() != 2:
338338
raise ValueError(
339339
"Target is multiclass but average='binary'. Please choose another average setting."
340340
)
341341

342-
total_labels = set(
343-
y_true.unique(keep_order=False).to_list()
344-
+ y_pred.unique(keep_order=False).to_list()
345-
)
342+
total_labels = set(y_true_classes.to_list() + y_pred_classes.to_list())
346343

347344
if len(total_labels) != 2:
348345
raise ValueError(

0 commit comments

Comments
 (0)