@@ -213,6 +213,7 @@ check_args.boost_tree <- function(object) {
213213 invisible (object )
214214}
215215
216+
216217# xgboost helpers --------------------------------------------------------------
217218
218219# ' Boosted trees via xgboost
@@ -256,11 +257,11 @@ check_args.boost_tree <- function(object) {
256257# ' @keywords internal
257258# ' @export
258259xgb_train <- function (
259- x , y ,
260- max_depth = 6 , nrounds = 15 , eta = 0.3 , colsample_bynode = NULL ,
261- colsample_bytree = NULL , min_child_weight = 1 , gamma = 0 , subsample = 1 ,
262- validation = 0 , early_stop = NULL , objective = NULL , counts = TRUE ,
263- event_level = c(" first" , " second" ), ... ) {
260+ x , y ,
261+ max_depth = 6 , nrounds = 15 , eta = 0.3 , colsample_bynode = NULL ,
262+ colsample_bytree = NULL , min_child_weight = 1 , gamma = 0 , subsample = 1 ,
263+ validation = 0 , early_stop = NULL , objective = NULL , counts = TRUE ,
264+ event_level = c(" first" , " second" ), weights = NULL , ... ) {
264265
265266 event_level <- rlang :: arg_match(event_level , c(" first" , " second" ))
266267 others <- list (... )
@@ -295,7 +296,11 @@ xgb_train <- function(
295296 n <- nrow(x )
296297 p <- ncol(x )
297298
298- x <- as_xgb_data(x , y , validation , event_level )
299+ x <-
300+ as_xgb_data(x , y ,
301+ validation = validation ,
302+ event_level = event_level ,
303+ weights = weights )
299304
300305
301306 if (! is.numeric(subsample ) || subsample < 0 || subsample > 1 ) {
@@ -401,7 +406,7 @@ xgb_pred <- function(object, newdata, ...) {
401406}
402407
403408
404- as_xgb_data <- function (x , y , validation = 0 , event_level = " first" , ... ) {
409+ as_xgb_data <- function (x , y , validation = 0 , weights = NULL , event_level = " first" , ... ) {
405410 lvls <- levels(y )
406411 n <- nrow(x )
407412
@@ -424,22 +429,36 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
424429
425430 if (! inherits(x , " xgb.DMatrix" )) {
426431 if (validation > 0 ) {
432+ # Split data
427433 m <- floor(n * (1 - validation )) + 1
428434 trn_index <- sample(1 : n , size = max(m , 2 ))
429- wlist <-
430- list (validation = xgboost :: xgb.DMatrix(x [- trn_index , ], label = y [- trn_index ], missing = NA ))
431- dat <- xgboost :: xgb.DMatrix(x [trn_index , ], label = y [trn_index ], missing = NA )
435+ val_data <- xgboost :: xgb.DMatrix(x [- trn_index ,], label = y [- trn_index ], missing = NA )
436+ watch_list <- list (validation = val_data )
437+
438+ info_list <- list (label = y [trn_index ])
439+ if (! is.null(weights )) {
440+ info_list $ weight <- weights [trn_index ]
441+ }
442+ dat <- xgboost :: xgb.DMatrix(x [trn_index ,], missing = NA , info = info_list )
443+
432444
433445 } else {
434- dat <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
435- wlist <- list (training = dat )
446+ info_list <- list (label = y )
447+ if (! is.null(weights )) {
448+ info_list $ weight <- weights
449+ }
450+ dat <- xgboost :: xgb.DMatrix(x , missing = NA , info = info_list )
451+ watch_list <- list (training = dat )
436452 }
437453 } else {
438454 dat <- xgboost :: setinfo(x , " label" , y )
439- wlist <- list (training = dat )
455+ if (! is.null(weights )) {
456+ dat <- xgboost :: setinfo(x , " weight" , weights )
457+ }
458+ watch_list <- list (training = dat )
440459 }
441460
442- list (data = dat , watchlist = wlist )
461+ list (data = dat , watchlist = watch_list )
443462}
444463
445464get_event_level <- function (model_spec ){
@@ -452,6 +471,7 @@ get_event_level <- function(model_spec){
452471 event_level
453472}
454473
474+
455475# ' @export
456476# ' @rdname multi_predict
457477# ' @param trees An integer vector for the number of trees in the ensemble.
0 commit comments