Skip to content

Commit 152d138

Browse files
committed
changes for case weights and tidymodels/censored#163
1 parent 0e9f4ba commit 152d138

4 files changed

Lines changed: 38 additions & 15 deletions

File tree

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ Config/rcmdcheck/ignore-inconsequential-notes: true
8989
Encoding: UTF-8
9090
LazyData: true
9191
Roxygen: list(markdown = TRUE)
92-
RoxygenNote: 7.1.2
92+
RoxygenNote: 7.1.2.9000

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* `xgb_train()` now allows for case weights
4+
35
# parsnip 0.2.1
46

57
* Fixed a major bug in spark models induced in the previous version (#671).

R/boost_tree.R

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ check_args.boost_tree <- function(object) {
213213
invisible(object)
214214
}
215215

216+
216217
# xgboost helpers --------------------------------------------------------------
217218

218219
#' Boosted trees via xgboost
@@ -256,11 +257,11 @@ check_args.boost_tree <- function(object) {
256257
#' @keywords internal
257258
#' @export
258259
xgb_train <- function(
259-
x, y,
260-
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
261-
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
262-
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
263-
event_level = c("first", "second"), ...) {
260+
x, y,
261+
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
262+
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
263+
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
264+
event_level = c("first", "second"), weights = NULL, ...) {
264265

265266
event_level <- rlang::arg_match(event_level, c("first", "second"))
266267
others <- list(...)
@@ -295,7 +296,11 @@ xgb_train <- function(
295296
n <- nrow(x)
296297
p <- ncol(x)
297298

298-
x <- as_xgb_data(x, y, validation, event_level)
299+
x <-
300+
as_xgb_data(x, y,
301+
validation = validation,
302+
event_level = event_level,
303+
weights = weights)
299304

300305

301306
if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
@@ -401,7 +406,7 @@ xgb_pred <- function(object, newdata, ...) {
401406
}
402407

403408

404-
as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
409+
as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) {
405410
lvls <- levels(y)
406411
n <- nrow(x)
407412

@@ -424,22 +429,36 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
424429

425430
if (!inherits(x, "xgb.DMatrix")) {
426431
if (validation > 0) {
432+
# Split data
427433
m <- floor(n * (1 - validation)) + 1
428434
trn_index <- sample(1:n, size = max(m, 2))
429-
wlist <-
430-
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
431-
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
435+
val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA)
436+
watch_list <- list(validation = val_data)
437+
438+
info_list <- list(label = y[trn_index])
439+
if (!is.null(weights)) {
440+
info_list$weight <- weights[trn_index]
441+
}
442+
dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list)
443+
432444

433445
} else {
434-
dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
435-
wlist <- list(training = dat)
446+
info_list <- list(label = y)
447+
if (!is.null(weights)) {
448+
info_list$weight <- weights
449+
}
450+
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
451+
watch_list <- list(training = dat)
436452
}
437453
} else {
438454
dat <- xgboost::setinfo(x, "label", y)
439-
wlist <- list(training = dat)
455+
if (!is.null(weights)) {
456+
dat <- xgboost::setinfo(x, "weight", weights)
457+
}
458+
watch_list <- list(training = dat)
440459
}
441460

442-
list(data = dat, watchlist = wlist)
461+
list(data = dat, watchlist = watch_list)
443462
}
444463

445464
get_event_level <- function(model_spec){
@@ -452,6 +471,7 @@ get_event_level <- function(model_spec){
452471
event_level
453472
}
454473

474+
455475
#' @export
456476
#' @rdname multi_predict
457477
#' @param trees An integer vector for the number of trees in the ensemble.

man/xgb_train.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)