Skip to content

Commit d6bcd00

Browse files
make robust to xgboost versions (#75)
* make robust to xgboost versions * update news
1 parent ca7759e commit d6bcd00

3 files changed

Lines changed: 43 additions & 12 deletions

File tree

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# bundle (development version)
22

3+
* Make to work with new versions of xgboost models (#75).
4+
35
# bundle 0.1.2
46

57
* Added bundle method for objects from `dbarts::bart()` and, by extension, `parsnip::bart(engine = "dbarts")` (#64).

R/bundle_xgboost.R

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,27 @@ bundle.xgb.Booster <- function(x, ...) {
4444
bundle_constr(
4545
object = object,
4646
situate = situate_constr(function(object) {
47-
res <- xgboost::xgb.load.raw(object, as_booster = TRUE)
47+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
48+
res <- xgboost::xgb.load.raw(object)
4849

49-
res$params <- list(
50-
objective = !!x$params$objective,
51-
num_class = !!x$params$num_class
52-
)
50+
attr(res, "params") <- list(
51+
objective = !!attr(x, "params")$objective,
52+
num_class = !!attr(x, "params")$num_class
53+
)
5354

54-
res$nfeatures <- !!x$nfeatures
55-
res$feature_names <- !!x$feature_names
55+
attr(res, "nfeatures") <- !!attr(x, "nfeatures")
56+
attr(res, "feature_names") <- !!attr(x, "feature_names")
57+
} else {
58+
res <- xgboost::xgb.load.raw(object, as_booster = TRUE)
59+
60+
res$params <- list(
61+
objective = !!x$params$objective,
62+
num_class = !!x$params$num_class
63+
)
64+
65+
res$nfeatures <- !!x$nfeatures
66+
res$feature_names <- !!x$feature_names
67+
}
5668

5769
res
5870
}),

tests/testthat/test_bundle_xgboost.R

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,27 @@ test_that("bundling + unbundling xgboost fits", {
1111
set.seed(1)
1212

1313
data(agaricus.train)
14-
15-
xgb <- xgboost(data = agaricus.train$data, label = agaricus.train$label,
16-
max_depth = 2, eta = 1, nthread = 2, nrounds = 2,
17-
objective = "binary:logistic")
14+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
15+
xgb <- xgboost(
16+
x = agaricus.train$data,
17+
y = agaricus.train$label,
18+
max_depth = 2,
19+
learning_rate = 1,
20+
nthread = 2,
21+
nrounds = 2,
22+
objective = "reg:squarederror"
23+
)
24+
} else {
25+
xgb <- xgboost(
26+
data = agaricus.train$data,
27+
label = agaricus.train$label,
28+
max_depth = 2,
29+
eta = 1,
30+
nthread = 2,
31+
nrounds = 2,
32+
objective = "binary:logistic"
33+
)
34+
}
1835

1936
xgb
2037
}
@@ -84,7 +101,7 @@ test_that("bundling + unbundling xgboost fits", {
84101
mod_preds <- predict(mod_fit, agaricus.test$data)
85102

86103
# check classes
87-
expect_s3_class(mod_bundle, "bundled_xgb.Booster")
104+
expect_true(any(class(mod_bundle) %in% c("bundled_xgb.Booster", "bundled_xgboost")))
88105
expect_s3_class(unbundle(mod_bundle), "xgb.Booster")
89106

90107
# ensure that the situater function didn't bring along the whole model

0 commit comments

Comments
 (0)