@@ -248,38 +248,47 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
248248 }
249249
250250 // A split is only valid if it results in a positive gain.
251- if * best_split_score > 0.0 {
252- let mut left_idxs = Vec :: new ( ) ;
253- let mut right_idxs = Vec :: new ( ) ;
254- for idx in idxs. iter ( ) {
255- if data. get ( ( * idx, * best_feature_idx) ) . to_f64 ( ) . unwrap ( ) <= * best_threshold {
256- left_idxs. push ( * idx) ;
257- } else {
258- right_idxs. push ( * idx) ;
259- }
251+ if * best_split_score <= 0.0 {
252+ return ;
253+ }
254+
255+ let mut left_idxs = Vec :: new ( ) ;
256+ let mut right_idxs = Vec :: new ( ) ;
257+ for idx in idxs. iter ( ) {
258+ if data. get ( ( * idx, * best_feature_idx) ) . to_f64 ( ) . unwrap ( ) <= * best_threshold {
259+ left_idxs. push ( * idx) ;
260+ } else {
261+ right_idxs. push ( * idx) ;
260262 }
263+ }
261264
262- * left = Some ( Box :: new ( TreeRegressor :: fit (
263- data,
264- g,
265- h,
266- & left_idxs,
267- max_depth - 1 ,
268- min_child_weight,
269- lambda,
270- gamma,
271- ) ) ) ;
272- * right = Some ( Box :: new ( TreeRegressor :: fit (
273- data,
274- g,
275- h,
276- & right_idxs,
277- max_depth - 1 ,
278- min_child_weight,
279- lambda,
280- gamma,
281- ) ) ) ;
265+ if left_idxs. is_empty ( ) || right_idxs. is_empty ( ) {
266+ // A degenerate split where all samples land on one side. This can happen when feature
267+ // values are large enough that `(x_i + x_i_next) / 2.0` overflows to +inf,
268+ // all samples satisfy `<= +inf` and right_idxs is empty.
269+ return ;
282270 }
271+
272+ * left = Some ( Box :: new ( TreeRegressor :: fit (
273+ data,
274+ g,
275+ h,
276+ & left_idxs,
277+ max_depth - 1 ,
278+ min_child_weight,
279+ lambda,
280+ gamma,
281+ ) ) ) ;
282+ * right = Some ( Box :: new ( TreeRegressor :: fit (
283+ data,
284+ g,
285+ h,
286+ & right_idxs,
287+ max_depth - 1 ,
288+ min_child_weight,
289+ lambda,
290+ gamma,
291+ ) ) ) ;
283292 }
284293
285294 /// Iterates through a single feature to find the best possible split point.
@@ -733,6 +742,26 @@ mod tests {
733742 assert ! ( ( tree. right. unwrap( ) . value - ( -0.833333333 ) ) . abs( ) < 1e-9 ) ;
734743 }
735744
745+ /// Exercises the degenerate-split guard in insert_child_nodes.
746+ #[ test]
747+ fn test_no_panic_on_degenerate_split_from_overflow ( ) {
748+ let large = f64:: MAX / 1.5 ;
749+ let x_vec = vec ! [ vec![ large] , vec![ large * 1.1 ] ] ;
750+ let x = DenseMatrix :: from_2d_vec ( & x_vec) . unwrap ( ) ;
751+ let y = vec ! [ 0.0 , 1.0 ] ;
752+
753+ let params = XGRegressorParameters :: default ( )
754+ . with_n_estimators ( 10 )
755+ . with_max_depth ( 3 ) ;
756+
757+ let model = XGRegressor :: fit ( & x, & y, params) ;
758+ assert ! ( model. is_ok( ) , "Fit panicked or failed: {:?}" , model. err( ) ) ;
759+
760+ let predictions = model. unwrap ( ) . predict ( & x) ;
761+ assert ! ( predictions. is_ok( ) ) ;
762+ assert_eq ! ( predictions. unwrap( ) . len( ) , 2 ) ;
763+ }
764+
736765 /// A "smoke test" to ensure the main XGRegressor can fit and predict on multidimensional data.
737766 #[ test]
738767 fn test_xgregressor_fit_predict_multidimensional ( ) {
0 commit comments