From 2f53166c73255f994c12f798849a459643b6d18b Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 8 Dec 2025 14:22:19 -0800 Subject: [PATCH] make xgboost example version switched --- R/bundle_xgboost.R | 14 ++++++++++---- man/bundle_xgboost.Rd | 12 +++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/R/bundle_xgboost.R b/R/bundle_xgboost.R index cc8c43c..7365b08 100644 --- a/R/bundle_xgboost.R +++ b/R/bundle_xgboost.R @@ -22,9 +22,15 @@ #' data(agaricus.train) #' data(agaricus.test) #' -#' xgb <- xgboost(data = agaricus.train$data, label = agaricus.train$label, -#' max_depth = 2, eta = 1, nthread = 2, nrounds = 2, -#' objective = "binary:logistic") +#' if (utils::packageVersion("xgboost") >= "2.0.0.0") { +#' xgb <- xgboost(x = agaricus.train$data, y = agaricus.train$label, +#' max_depth = 2, learning_rate = 1, nthread = 2, nrounds = 2, +#' objective = "reg:squarederror") +#' } else { +#' xgb <- xgboost(data = agaricus.train$data, label = agaricus.train$label, +#' max_depth = 2, eta = 1, nthread = 2, nrounds = 2, +#' objective = "binary:logistic") +#' } #' #' xgb_bundle <- bundle(xgb) #' @@ -56,7 +62,7 @@ bundle.xgb.Booster <- function(x, ...) { attr(res, "feature_names") <- !!attr(x, "feature_names") } else { res <- xgboost::xgb.load.raw(object, as_booster = TRUE) - + res$params <- list( objective = !!x$params$objective, num_class = !!x$params$num_class diff --git a/man/bundle_xgboost.Rd b/man/bundle_xgboost.Rd index 3ebb759..1850ed7 100644 --- a/man/bundle_xgboost.Rd +++ b/man/bundle_xgboost.Rd @@ -85,9 +85,15 @@ set.seed(1) data(agaricus.train) data(agaricus.test) -xgb <- xgboost(data = agaricus.train$data, label = agaricus.train$label, - max_depth = 2, eta = 1, nthread = 2, nrounds = 2, - objective = "binary:logistic") +if (utils::packageVersion("xgboost") >= "2.0.0.0") { + xgb <- xgboost(x = agaricus.train$data, y = agaricus.train$label, + max_depth = 2, learning_rate = 1, nthread = 2, nrounds = 2, + objective = "reg:squarederror") +} else { + xgb <- xgboost(data = agaricus.train$data, label = agaricus.train$label, + max_depth = 2, eta = 1, nthread = 2, nrounds = 2, + objective = "binary:logistic") +} xgb_bundle <- bundle(xgb)