Skip to content

Commit ee1fb2b

Browse files
committed
fix: validate probabilities for temperature scaling
1 parent 055f55c commit ee1fb2b

1 file changed

Lines changed: 27 additions & 0 deletions

File tree

src/PredictionModelBuilder.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,14 +721,41 @@ private static void TryComputeClassificationCalibrationArtifacts(
721721
private static T FitTemperatureFromProbabilities(Tensor<T> probabilities, Vector<int> labels, int batch, int classes, INumericOperations<T> numOps)
722722
{
723723
var eps = numOps.FromDouble(1e-12);
724+
var sumTolerance = numOps.FromDouble(1e-3);
724725
var logits = new Matrix<T>(batch, classes);
725726
var flat = probabilities.ToVector();
726727

727728
for (int i = 0; i < batch; i++)
728729
{
730+
var sum = numOps.Zero;
729731
for (int c = 0; c < classes; c++)
730732
{
731733
var p = flat[i * classes + c];
734+
if (numOps.LessThan(p, numOps.Zero))
735+
{
736+
p = numOps.Zero;
737+
}
738+
sum = numOps.Add(sum, p);
739+
}
740+
741+
if (numOps.LessThan(sum, eps))
742+
{
743+
throw new ArgumentException("Temperature scaling requires per-sample probabilities with a positive sum.", nameof(probabilities));
744+
}
745+
746+
var shouldNormalize = numOps.GreaterThan(numOps.Abs(numOps.Subtract(sum, numOps.One)), sumTolerance);
747+
748+
for (int c = 0; c < classes; c++)
749+
{
750+
var p = flat[i * classes + c];
751+
if (numOps.LessThan(p, numOps.Zero))
752+
{
753+
p = numOps.Zero;
754+
}
755+
if (shouldNormalize)
756+
{
757+
p = numOps.Divide(p, sum);
758+
}
732759
if (numOps.LessThan(p, eps))
733760
{
734761
p = eps;

0 commit comments

Comments
 (0)