diff --git a/NEWS.md b/NEWS.md index 4845286..1940359 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # bundle (development version) +* Make to work with new versions of xgboost models (#75). + # bundle 0.1.2 * Added bundle method for objects from `dbarts::bart()` and, by extension, `parsnip::bart(engine = "dbarts")` (#64). diff --git a/R/bundle_xgboost.R b/R/bundle_xgboost.R index 1e02f25..cc8c43c 100644 --- a/R/bundle_xgboost.R +++ b/R/bundle_xgboost.R @@ -44,15 +44,27 @@ bundle.xgb.Booster <- function(x, ...) { bundle_constr( object = object, situate = situate_constr(function(object) { - res <- xgboost::xgb.load.raw(object, as_booster = TRUE) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + res <- xgboost::xgb.load.raw(object) - res$params <- list( - objective = !!x$params$objective, - num_class = !!x$params$num_class - ) + attr(res, "params") <- list( + objective = !!attr(x, "params")$objective, + num_class = !!attr(x, "params")$num_class + ) - res$nfeatures <- !!x$nfeatures - res$feature_names <- !!x$feature_names + attr(res, "nfeatures") <- !!attr(x, "nfeatures") + attr(res, "feature_names") <- !!attr(x, "feature_names") + } else { + res <- xgboost::xgb.load.raw(object, as_booster = TRUE) + + res$params <- list( + objective = !!x$params$objective, + num_class = !!x$params$num_class + ) + + res$nfeatures <- !!x$nfeatures + res$feature_names <- !!x$feature_names + } res }), diff --git a/tests/testthat/test_bundle_xgboost.R b/tests/testthat/test_bundle_xgboost.R index dfc71d0..c6b9256 100644 --- a/tests/testthat/test_bundle_xgboost.R +++ b/tests/testthat/test_bundle_xgboost.R @@ -11,10 +11,27 @@ test_that("bundling + unbundling xgboost fits", { set.seed(1) data(agaricus.train) - - xgb <- xgboost(data = agaricus.train$data, label = agaricus.train$label, - max_depth = 2, eta = 1, nthread = 2, nrounds = 2, - objective = "binary:logistic") + if (utils::packageVersion("xgboost") > "2.0.0.0") { + xgb <- xgboost( + x = agaricus.train$data, + y = agaricus.train$label, + max_depth = 2, + learning_rate = 1, + nthread = 2, + nrounds = 2, + objective = "reg:squarederror" + ) + } else { + xgb <- xgboost( + data = agaricus.train$data, + label = agaricus.train$label, + max_depth = 2, + eta = 1, + nthread = 2, + nrounds = 2, + objective = "binary:logistic" + ) + } xgb } @@ -84,7 +101,7 @@ test_that("bundling + unbundling xgboost fits", { mod_preds <- predict(mod_fit, agaricus.test$data) # check classes - expect_s3_class(mod_bundle, "bundled_xgb.Booster") + expect_true(any(class(mod_bundle) %in% c("bundled_xgb.Booster", "bundled_xgboost"))) expect_s3_class(unbundle(mod_bundle), "xgb.Booster") # ensure that the situater function didn't bring along the whole model