Skip to content

Commit 0569813

Browse files
authored
XGBoost: Do not panic if a tree split is degenerate (#364)
1 parent bc7c385 commit 0569813

1 file changed

Lines changed: 58 additions & 29 deletions

File tree

src/xgboost/xgb_regressor.rs

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -248,38 +248,47 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
248248
}
249249

250250
// A split is only valid if it results in a positive gain.
251-
if *best_split_score > 0.0 {
252-
let mut left_idxs = Vec::new();
253-
let mut right_idxs = Vec::new();
254-
for idx in idxs.iter() {
255-
if data.get((*idx, *best_feature_idx)).to_f64().unwrap() <= *best_threshold {
256-
left_idxs.push(*idx);
257-
} else {
258-
right_idxs.push(*idx);
259-
}
251+
if *best_split_score <= 0.0 {
252+
return;
253+
}
254+
255+
let mut left_idxs = Vec::new();
256+
let mut right_idxs = Vec::new();
257+
for idx in idxs.iter() {
258+
if data.get((*idx, *best_feature_idx)).to_f64().unwrap() <= *best_threshold {
259+
left_idxs.push(*idx);
260+
} else {
261+
right_idxs.push(*idx);
260262
}
263+
}
261264

262-
*left = Some(Box::new(TreeRegressor::fit(
263-
data,
264-
g,
265-
h,
266-
&left_idxs,
267-
max_depth - 1,
268-
min_child_weight,
269-
lambda,
270-
gamma,
271-
)));
272-
*right = Some(Box::new(TreeRegressor::fit(
273-
data,
274-
g,
275-
h,
276-
&right_idxs,
277-
max_depth - 1,
278-
min_child_weight,
279-
lambda,
280-
gamma,
281-
)));
265+
if left_idxs.is_empty() || right_idxs.is_empty() {
266+
// A degenerate split where all samples land on one side. This can happen when feature
267+
// values are large enough that `(x_i + x_i_next) / 2.0` overflows to +inf,
268+
// all samples satisfy `<= +inf` and right_idxs is empty.
269+
return;
282270
}
271+
272+
*left = Some(Box::new(TreeRegressor::fit(
273+
data,
274+
g,
275+
h,
276+
&left_idxs,
277+
max_depth - 1,
278+
min_child_weight,
279+
lambda,
280+
gamma,
281+
)));
282+
*right = Some(Box::new(TreeRegressor::fit(
283+
data,
284+
g,
285+
h,
286+
&right_idxs,
287+
max_depth - 1,
288+
min_child_weight,
289+
lambda,
290+
gamma,
291+
)));
283292
}
284293

285294
/// Iterates through a single feature to find the best possible split point.
@@ -733,6 +742,26 @@ mod tests {
733742
assert!((tree.right.unwrap().value - (-0.833333333)).abs() < 1e-9);
734743
}
735744

745+
/// Exercises the degenerate-split guard in insert_child_nodes.
746+
#[test]
747+
fn test_no_panic_on_degenerate_split_from_overflow() {
748+
let large = f64::MAX / 1.5;
749+
let x_vec = vec![vec![large], vec![large * 1.1]];
750+
let x = DenseMatrix::from_2d_vec(&x_vec).unwrap();
751+
let y = vec![0.0, 1.0];
752+
753+
let params = XGRegressorParameters::default()
754+
.with_n_estimators(10)
755+
.with_max_depth(3);
756+
757+
let model = XGRegressor::fit(&x, &y, params);
758+
assert!(model.is_ok(), "Fit panicked or failed: {:?}", model.err());
759+
760+
let predictions = model.unwrap().predict(&x);
761+
assert!(predictions.is_ok());
762+
assert_eq!(predictions.unwrap().len(), 2);
763+
}
764+
736765
/// A "smoke test" to ensure the main XGRegressor can fit and predict on multidimensional data.
737766
#[test]
738767
fn test_xgregressor_fit_predict_multidimensional() {

0 commit comments

Comments
 (0)