Skip to content

Commit c67b8b0

Browse files
authored
Allow for assessing performance in multi-class classification setting (#176)
1 parent 477ad5d commit c67b8b0

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

selene_sdk/utils/performance_metrics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,12 @@ def compute_score(prediction, target, metric_fn,
200200
`(None, [])`.
201201
"""
202202
feature_scores = np.ones(target.shape[1]) * np.nan
203-
for index, feature_preds in enumerate(prediction.T):
203+
# Deal with the case of multi-class classification, where each example only has one target value but multiple prediction values
204+
if target.shape[1] == 1 and prediction.shape[1] > 1:
205+
prediction = [prediction]
206+
else:
207+
prediction = prediction.T
208+
for index, feature_preds in enumerate(prediction):
204209
feature_targets = target[:, index]
205210
if len(np.unique(feature_targets)) > 0 and \
206211
np.count_nonzero(feature_targets) > report_gt_feature_n_positives:

0 commit comments

Comments
 (0)