Skip to content

Commit 2babe3a

Browse files
authored
Merge pull request #380 from fslaborg/repo-assist/improve-issue-198-multilabel-thresholds-20260424-dabf9617a37643ae
[Repo Assist] feat: add explicit-thresholds overloads to multiLabelThresholdMap and calculateMultiLabelROC
2 parents 63ead34 + 3f6248d commit 2babe3a

2 files changed

Lines changed: 74 additions & 15 deletions

File tree

src/FSharp.Stats/Testing/ComparisonMetrics.fs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -368,29 +368,28 @@ type ComparisonMetrics = {
368368

369369
static member multiLabelThresholdMap(
370370
actual: #IConvertible [],
371-
predictions: (#IConvertible * float []) []
372-
) =
373-
374-
// we have to use a global threshold collection for all binary threshold maps, otherwise we do not necessarily have values for macro/micro averaging for each label.
375-
let allDistinctThresholds =
371+
predictions: (#IConvertible * float []) [],
372+
thresholds: float []
373+
) =
374+
// Use global max as prefix so micro/macro averages use a consistent threshold label.
375+
let globalMax =
376376
predictions
377377
|> Array.map snd
378378
|> Array.concat
379-
|> Array.distinct
380-
|> Array.sortDescending
379+
|> Array.max
381380

382-
let prefixedThresholds = [|allDistinctThresholds[0] + 1.; yield! allDistinctThresholds|]
381+
let prefixedThresholds = [|globalMax + 1.; yield! thresholds|]
383382

384383
let labelMetrics =
385-
predictions
384+
predictions
386385
|> Array.map (fun (label, preds) ->
387386
let labelTruth = actual |> Array.map (fun x -> x = label)
388-
label, BinaryConfusionMatrix.thresholdMap(labelTruth,preds,allDistinctThresholds)
387+
label, BinaryConfusionMatrix.thresholdMap(labelTruth, preds, thresholds)
389388
)
390389

391390
let transposedBCMs =
392391
labelMetrics
393-
|> Array.map (fun (x,y) -> y)
392+
|> Array.map snd
394393
|> JaggedArray.transpose
395394
|> JaggedArray.map snd
396395

@@ -405,12 +404,27 @@ type ComparisonMetrics = {
405404
|> Array.zip prefixedThresholds
406405

407406
[|
408-
yield! labelMetrics |> Array.map (fun (label, thrs) -> string label, thrs |> Array.map (fun (thr,bcm) -> thr, ComparisonMetrics.create bcm))
407+
yield! labelMetrics |> Array.map (fun (label, thrs) -> string label, thrs |> Array.map (fun (thr, bcm) -> thr, ComparisonMetrics.create bcm))
409408
"micro-average", microAverages
410409
"macro-average", macroAverages
411410
|]
412411
|> Map.ofArray
413412

413+
static member multiLabelThresholdMap(
414+
actual: #IConvertible [],
415+
predictions: (#IConvertible * float []) []
416+
) =
417+
418+
// we have to use a global threshold collection for all binary threshold maps, otherwise we do not necessarily have values for macro/micro averaging for each label.
419+
let allDistinctThresholds =
420+
predictions
421+
|> Array.map snd
422+
|> Array.concat
423+
|> Array.distinct
424+
|> Array.sortDescending
425+
426+
ComparisonMetrics.multiLabelThresholdMap(actual, predictions, allDistinctThresholds)
427+
414428
static member calculateROC(
415429
actual: seq<bool>,
416430
predictions: seq<float>,
@@ -439,6 +453,18 @@ type ComparisonMetrics = {
439453
metrics.FallOut, metrics.Sensitivity
440454
)
441455

456+
static member calculateMultiLabelROC(
457+
actual: #IConvertible [],
458+
predictions: (#IConvertible * float []) [],
459+
thresholds: float []
460+
) =
461+
ComparisonMetrics.multiLabelThresholdMap(
462+
actual,
463+
predictions,
464+
thresholds
465+
)
466+
|> Map.map (fun _k v -> v |> Array.map (fun (_,cm) -> cm.FallOut, cm.Sensitivity))
467+
442468
static member calculateMultiLabelROC(
443469
actual: #IConvertible [],
444470
predictions: (#IConvertible * float []) []
@@ -447,6 +473,4 @@ type ComparisonMetrics = {
447473
actual,
448474
predictions
449475
)
450-
|> Map.map (fun k v -> v |> Array.map (fun (_,cm) -> cm.FallOut, cm.Sensitivity)
451-
452-
)
476+
|> Map.map (fun _k v -> v |> Array.map (fun (_,cm) -> cm.FallOut, cm.Sensitivity))

tests/FSharp.Stats.Tests/Testing.fs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,41 @@ let comparisonMetricsTests =
11701170
testCase "C: threshold 0-1" (fun _ -> TestExtensions.comparisonMetricsEqualRounded 3 (snd (actual["C"][9])) (snd (expectedMetricsMap["C"][9])) "Incorrect metrics for threshold 0.1")
11711171
testCase "C: threshold 0-0" (fun _ -> TestExtensions.comparisonMetricsEqualRounded 3 (snd (actual["C"][10])) (snd (expectedMetricsMap["C"][10])) "Incorrect metrics for threshold 0.0")
11721172
]
1173+
testList "multi-label threshold map with explicit thresholds" [
1174+
// Use a coarse threshold list [0.9; 0.5; 0.1] — a subset of all distinct thresholds.
1175+
// Expected values are taken from the full-threshold test above (same data).
1176+
let explicitThresholds = [|0.9; 0.5; 0.1|]
1177+
let actualExplicit =
1178+
ComparisonMetrics.multiLabelThresholdMap(
1179+
actual = [|"A"; "A"; "A"; "A"; "A"; "B"; "B"; "B"; "C"; "C"; "C"; "C"; "C"; "C"|],
1180+
predictions = [|
1181+
"A", [|0.8; 0.7; 0.9; 0.4; 0.3; 0.1; 0.2; 0.5; 0.1; 0.1; 0.1; 0.3; 0.5; 0.4|]
1182+
"B", [|0.0; 0.1; 0.0; 0.5; 0.1; 0.8; 0.7; 0.4; 0.0; 0.1; 0.1; 0.0; 0.1; 0.3|]
1183+
"C", [|0.2; 0.2; 0.1; 0.1; 0.6; 0.1; 0.1; 0.1; 0.9; 0.8; 0.8; 0.7; 0.4; 0.3|]
1184+
|],
1185+
thresholds = explicitThresholds
1186+
)
1187+
// With 3 explicit thresholds the result should have 4 entries per label (prefix + 3)
1188+
testCase "explicit thresholds: result length" (fun _ ->
1189+
Expect.equal actualExplicit["A"].Length 4 "Expected 4 threshold entries for label A with 3 explicit thresholds"
1190+
)
1191+
// Values at threshold 0.9 should match the full-threshold result at that threshold
1192+
testCase "A: explicit threshold 0-9" (fun _ ->
1193+
TestExtensions.comparisonMetricsEqualRounded 3 (snd (actualExplicit["A"][1])) (BinaryConfusionMatrix.create(1,9,0,4) |> ComparisonMetrics.create) "Incorrect A metrics at threshold 0.9"
1194+
)
1195+
testCase "B: explicit threshold 0-5" (fun _ ->
1196+
TestExtensions.comparisonMetricsEqualRounded 3 (snd (actualExplicit["B"][2])) (BinaryConfusionMatrix.create(2,10,1,1) |> ComparisonMetrics.create) "Incorrect B metrics at threshold 0.5"
1197+
)
1198+
testCase "C: explicit threshold 0-1" (fun _ ->
1199+
TestExtensions.comparisonMetricsEqualRounded 3 (snd (actualExplicit["C"][3])) (BinaryConfusionMatrix.create(6,0,8,0) |> ComparisonMetrics.create) "Incorrect C metrics at threshold 0.1"
1200+
)
1201+
testCase "micro-average present" (fun _ ->
1202+
Expect.isTrue (actualExplicit.ContainsKey("micro-average")) "micro-average key should be present"
1203+
)
1204+
testCase "macro-average present" (fun _ ->
1205+
Expect.isTrue (actualExplicit.ContainsKey("macro-average")) "macro-average key should be present"
1206+
)
1207+
]
11731208
]
11741209

11751210

0 commit comments

Comments
 (0)