Skip to content

Commit 205814e

Browse files
committed
Update EntryPointCatalog baselines with Objective and Alpha fields. Improved unit tests.
1 parent 453123a commit 205814e

4 files changed

Lines changed: 88 additions & 9 deletions

File tree

src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public enum RegressionObjective
142142
/// Quantile regression. Use <see cref="Alpha"/> to set the target quantile.
143143
/// </summary>
144144
Quantile
145-
}
145+
};
146146

147147
/// <summary>
148148
/// Determines what evaluation metric to use.
@@ -179,6 +179,7 @@ static Options()
179179
NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae");
180180
NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse");
181181
NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse");
182+
NameMapping.Add(nameof(Objective), "_regression_objective");
182183
}
183184

184185
internal override Dictionary<string, object> ToDictionary(IHost host)

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13266,6 +13266,36 @@
1326613266
"IsNullable": false,
1326713267
"Default": "RootMeanSquaredError"
1326813268
},
13269+
{
13270+
"Name": "Objective",
13271+
"Type": {
13272+
"Kind": "Enum",
13273+
"Values": [
13274+
"Regression",
13275+
"Quantile"
13276+
]
13277+
},
13278+
"Desc": "Regression objective type. Use 'Quantile' for quantile regression.",
13279+
"Aliases": [
13280+
"obj"
13281+
],
13282+
"Required": false,
13283+
"SortOrder": 150.0,
13284+
"IsNullable": false,
13285+
"Default": "Regression"
13286+
},
13287+
{
13288+
"Name": "Alpha",
13289+
"Type": "Float",
13290+
"Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
13291+
"Aliases": [
13292+
"qa"
13293+
],
13294+
"Required": false,
13295+
"SortOrder": 150.0,
13296+
"IsNullable": false,
13297+
"Default": 0.5
13298+
},
1326913299
{
1327013300
"Name": "MaximumBinCountPerFeature",
1327113301
"Type": "Int",

test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13266,6 +13266,36 @@
1326613266
"IsNullable": false,
1326713267
"Default": "RootMeanSquaredError"
1326813268
},
13269+
{
13270+
"Name": "Objective",
13271+
"Type": {
13272+
"Kind": "Enum",
13273+
"Values": [
13274+
"Regression",
13275+
"Quantile"
13276+
]
13277+
},
13278+
"Desc": "Regression objective type. Use 'Quantile' for quantile regression.",
13279+
"Aliases": [
13280+
"obj"
13281+
],
13282+
"Required": false,
13283+
"SortOrder": 150.0,
13284+
"IsNullable": false,
13285+
"Default": "Regression"
13286+
},
13287+
{
13288+
"Name": "Alpha",
13289+
"Type": "Float",
13290+
"Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
13291+
"Aliases": [
13292+
"qa"
13293+
],
13294+
"Required": false,
13295+
"SortOrder": 150.0,
13296+
"IsNullable": false,
13297+
"Default": 0.5
13298+
},
1326913299
{
1327013300
"Name": "MaximumBinCountPerFeature",
1327113301
"Type": "Int",

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,13 @@ public void LightGBMQuantileRegressorEstimator()
262262

263263
TestEstimatorCore(trainer, dataView);
264264
var model = trainer.Fit(dataView, dataView);
265+
266+
var gbmParameters = trainer.GetGbmParameters();
267+
Assert.True(gbmParameters.ContainsKey("objective"));
268+
Assert.Equal("quantile", gbmParameters["objective"]);
269+
Assert.True(gbmParameters.ContainsKey("alpha"));
270+
Assert.Equal(0.5, gbmParameters["alpha"]);
271+
265272
Done();
266273
}
267274

@@ -281,6 +288,8 @@ public void LightGBMQuantileRegressorPredictionOrdering()
281288
Alpha = 0.05,
282289
NumberOfIterations = 50,
283290
NumberOfLeaves = 10,
291+
Seed = 42,
292+
Deterministic = true,
284293
});
285294

286295
// Train model for the 95th percentile
@@ -290,6 +299,8 @@ public void LightGBMQuantileRegressorPredictionOrdering()
290299
Alpha = 0.95,
291300
NumberOfIterations = 50,
292301
NumberOfLeaves = 10,
302+
Seed = 42,
303+
Deterministic = true,
293304
});
294305

295306
var modelLow = trainerLow.Fit(dataView);
@@ -304,11 +315,18 @@ public void LightGBMQuantileRegressorPredictionOrdering()
304315
Assert.Equal(scoresLow.Length, scoresHigh.Length);
305316
Assert.True(scoresLow.Length > 0);
306317

307-
// On average, the 95th percentile predictions should exceed the 5th percentile predictions.
308-
var avgLow = scoresLow.Average();
309-
var avgHigh = scoresHigh.Average();
310-
Assert.True(avgHigh > avgLow,
311-
$"Expected average 95th percentile ({avgHigh}) > average 5th percentile ({avgLow})");
318+
// The 95th percentile predictions should generally be at least as large as the
319+
// 5th percentile predictions. Allow a small numerical tolerance and a limited
320+
// number of crossings since the models are trained independently.
321+
const float tolerance = 1e-4f;
322+
var orderedCount = Enumerable.Range(0, scoresLow.Length)
323+
.Count(i => scoresHigh[i] + tolerance >= scoresLow[i]);
324+
var orderedRatio = (float)orderedCount / scoresLow.Length;
325+
326+
Assert.True(orderedRatio >= 0.90f,
327+
$"Expected the 95th percentile prediction to be >= the 5th percentile prediction for most rows, " +
328+
$"but only {orderedCount} of {scoresLow.Length} rows satisfied the condition " +
329+
$"({orderedRatio:P2}, tolerance={tolerance}).");
312330

313331
Done();
314332
}
@@ -327,23 +345,23 @@ public void LightGBMQuantileRegressorInvalidAlpha()
327345
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
328346
Alpha = 0.0,
329347
});
330-
Assert.ThrowsAny<Exception>(() => trainerZero.Fit(dataView));
348+
Assert.Throws<ArgumentOutOfRangeException>(() => trainerZero.Fit(dataView));
331349

332350
// Alpha = 1 should fail
333351
var trainerOne = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
334352
{
335353
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
336354
Alpha = 1.0,
337355
});
338-
Assert.ThrowsAny<Exception>(() => trainerOne.Fit(dataView));
356+
Assert.Throws<ArgumentOutOfRangeException>(() => trainerOne.Fit(dataView));
339357

340358
// Alpha = -0.1 should fail
341359
var trainerNeg = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
342360
{
343361
Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
344362
Alpha = -0.1,
345363
});
346-
Assert.ThrowsAny<Exception>(() => trainerNeg.Fit(dataView));
364+
Assert.Throws<ArgumentOutOfRangeException>(() => trainerNeg.Fit(dataView));
347365

348366
Done();
349367
}

0 commit comments

Comments
 (0)