diff --git a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
index d16ced2de7..62c20bcdd1 100644
--- a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
+++ b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
@@ -129,6 +129,21 @@ public enum EvaluateMetricType
MeanSquaredError
};
+ ///
+ /// The type of regression objective to use.
+ ///
+ public enum RegressionObjective
+ {
+ ///
+ /// Standard L2 (least squares) regression.
+ ///
+ Regression,
+ ///
+ /// Quantile regression. Use to set the target quantile.
+ ///
+ Quantile
+ };
+
///
/// Determines what evaluation metric to use.
///
@@ -137,6 +152,25 @@ public enum EvaluateMetricType
ShortName = "em")]
public EvaluateMetricType EvaluationMetric = EvaluateMetricType.RootMeanSquaredError;
+ ///
+ /// The regression objective type. Use with
+ /// for quantile regression.
+ ///
+ [Argument(ArgumentType.AtMostOnce,
+ HelpText = "Regression objective type. Use 'Quantile' for quantile regression.",
+ ShortName = "obj")]
+ public RegressionObjective Objective = RegressionObjective.Regression;
+
+ ///
+ /// The quantile to predict when is .
+ /// Must be in the open interval (0, 1). For example, 0.05 for the 5th percentile or
+ /// 0.95 for the 95th percentile.
+ ///
+ [Argument(ArgumentType.AtMostOnce,
+ HelpText = "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
+ ShortName = "qa")]
+ public double Alpha = 0.5;
+
static Options()
{
NameMapping.Add(nameof(EvaluateMetricType), "metric");
@@ -145,6 +179,7 @@ static Options()
NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae");
NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse");
NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse");
+ NameMapping.Add(nameof(Objective), "_regression_objective");
}
internal override Dictionary ToDictionary(IHost host)
@@ -240,7 +275,19 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups)
{
- GbmOptions["objective"] = "regression";
+ var regressionOptions = (Options)LightGbmTrainerOptions;
+
+ if (regressionOptions.Objective == Options.RegressionObjective.Quantile)
+ {
+ Contracts.CheckUserArg(regressionOptions.Alpha > 0 && regressionOptions.Alpha < 1,
+ nameof(Options.Alpha), "Alpha for quantile regression must be in the open interval (0, 1).");
+ GbmOptions["objective"] = "quantile";
+ GbmOptions["alpha"] = regressionOptions.Alpha;
+ }
+ else
+ {
+ GbmOptions["objective"] = "regression";
+ }
}
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 55e9439cf4..b30265cea2 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -13266,6 +13266,36 @@
"IsNullable": false,
"Default": "RootMeanSquaredError"
},
+ {
+ "Name": "Objective",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "Regression",
+ "Quantile"
+ ]
+ },
+ "Desc": "Regression objective type. Use 'Quantile' for quantile regression.",
+ "Aliases": [
+ "obj"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": "Regression"
+ },
+ {
+ "Name": "Alpha",
+ "Type": "Float",
+ "Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
+ "Aliases": [
+ "qa"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": 0.5
+ },
{
"Name": "MaximumBinCountPerFeature",
"Type": "Int",
diff --git a/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json
index a2cc2b6836..88e80589b7 100644
--- a/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json
@@ -13266,6 +13266,36 @@
"IsNullable": false,
"Default": "RootMeanSquaredError"
},
+ {
+ "Name": "Objective",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "Regression",
+ "Quantile"
+ ]
+ },
+ "Desc": "Regression objective type. Use 'Quantile' for quantile regression.",
+ "Aliases": [
+ "obj"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": "Regression"
+ },
+ {
+ "Name": "Alpha",
+ "Type": "Float",
+ "Desc": "The alpha (quantile) value for quantile regression. Must be in (0, 1).",
+ "Aliases": [
+ "qa"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": 0.5
+ },
{
"Name": "MaximumBinCountPerFeature",
"Type": "Int",
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
index b36dfa574a..978c18afdf 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
@@ -244,6 +244,127 @@ public void LightGBMRegressorEstimator()
Done();
}
+ ///
+ /// LightGbmRegressionTrainer with quantile objective TrainerEstimator test
+ ///
+ [LightGBMFact]
+ public void LightGBMQuantileRegressorEstimator()
+ {
+ var dataView = GetRegressionPipeline();
+
+ var trainer = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
+ {
+ Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
+ Alpha = 0.5,
+ NumberOfIterations = 10,
+ NumberOfLeaves = 5,
+ });
+
+ TestEstimatorCore(trainer, dataView);
+ var model = trainer.Fit(dataView, dataView);
+
+ var gbmParameters = trainer.GetGbmParameters();
+ Assert.True(gbmParameters.ContainsKey("objective"));
+ Assert.Equal("quantile", gbmParameters["objective"]);
+ Assert.True(gbmParameters.ContainsKey("alpha"));
+ Assert.Equal(0.5, gbmParameters["alpha"]);
+
+ Done();
+ }
+
+ ///
+ /// Verify that quantile regression predictions with different alpha values
+ /// produce appropriately ordered results (lower quantile less than upper quantile).
+ ///
+ [LightGBMFact]
+ public void LightGBMQuantileRegressorPredictionOrdering()
+ {
+ var dataView = GetRegressionPipeline();
+
+ // Train model for the 5th percentile
+ var trainerLow = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
+ {
+ Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
+ Alpha = 0.05,
+ NumberOfIterations = 50,
+ NumberOfLeaves = 10,
+ Seed = 42,
+ Deterministic = true,
+ });
+
+ // Train model for the 95th percentile
+ var trainerHigh = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
+ {
+ Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
+ Alpha = 0.95,
+ NumberOfIterations = 50,
+ NumberOfLeaves = 10,
+ Seed = 42,
+ Deterministic = true,
+ });
+
+ var modelLow = trainerLow.Fit(dataView);
+ var modelHigh = trainerHigh.Fit(dataView);
+
+ var predictionsLow = modelLow.Transform(dataView);
+ var predictionsHigh = modelHigh.Transform(dataView);
+
+ var scoresLow = predictionsLow.GetColumn(predictionsLow.Schema["Score"]).ToArray();
+ var scoresHigh = predictionsHigh.GetColumn(predictionsHigh.Schema["Score"]).ToArray();
+
+ Assert.Equal(scoresLow.Length, scoresHigh.Length);
+ Assert.True(scoresLow.Length > 0);
+
+ // The 95th percentile predictions should generally be at least as large as the
+ // 5th percentile predictions. Allow a small numerical tolerance and a limited
+ // number of crossings since the models are trained independently.
+ const float tolerance = 1e-4f;
+ var orderedCount = Enumerable.Range(0, scoresLow.Length)
+ .Count(i => scoresHigh[i] + tolerance >= scoresLow[i]);
+ var orderedRatio = (float)orderedCount / scoresLow.Length;
+
+ Assert.True(orderedRatio >= 0.90f,
+ $"Expected the 95th percentile prediction to be >= the 5th percentile prediction for most rows, " +
+ $"but only {orderedCount} of {scoresLow.Length} rows satisfied the condition " +
+ $"({orderedRatio:P2}, tolerance={tolerance}).");
+
+ Done();
+ }
+
+ ///
+ /// Verify that invalid Alpha values are rejected for quantile regression.
+ ///
+ [LightGBMFact]
+ public void LightGBMQuantileRegressorInvalidAlpha()
+ {
+ var dataView = GetRegressionPipeline();
+
+ // Alpha = 0 should fail
+ var trainerZero = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
+ {
+ Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
+ Alpha = 0.0,
+ });
+ Assert.Throws(() => trainerZero.Fit(dataView));
+
+ // Alpha = 1 should fail
+ var trainerOne = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
+ {
+ Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
+ Alpha = 1.0,
+ });
+ Assert.Throws(() => trainerOne.Fit(dataView));
+
+ // Alpha = -0.1 should fail
+ var trainerNeg = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options
+ {
+ Objective = LightGbmRegressionTrainer.Options.RegressionObjective.Quantile,
+ Alpha = -0.1,
+ });
+ Assert.Throws(() => trainerNeg.Fit(dataView));
+
+ Done();
+ }
///
/// RegressionGamTrainer TrainerEstimator test