Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ Imports:
data.table (>= 1.9.6),
jsonlite (>= 1.0)
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.3
Encoding: UTF-8
SystemRequirements: GNU make, C++17
Config/roxygen2/version: 8.0.0
45 changes: 44 additions & 1 deletion R-package/R/xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,9 @@ xgboost <- function(
#'
#' Note that this check might add some sizable latency to the predictions, so it's
#' recommended to disable it for performance-sensitive applications.
#' @param ... Not used.
#' @param ... Legacy boolean flags (e.g., \code{predcontrib}) are supported for
#' backward compatibility but are deprecated and map to the \code{type} argument.
#' Any other arguments passed to \code{...} will cause an error.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not how most libraries work. Consider for example:

data("mtcars")
model <- lm(mpg ~ ., data=mtcars)
pred <- predict(model, mtcars, predcontrib = TRUE)

Better to throw a warning than to error out.

#' @return Either a numeric vector (for 1D outputs), numeric matrix (for 2D outputs), numeric array
#' (for 3D and higher), or `factor` (for class predictions). See documentation for parameter `type`
#' for details about what the output type and shape will be.
Expand Down Expand Up @@ -1352,6 +1354,47 @@ predict.xgboost <- function(
validate_features = TRUE,
...
) {
dots <- list(...)
if (length(dots) > 0) {
mapping <- list(
Comment thread
Sanidhyavijay24 marked this conversation as resolved.
outputmargin = "raw",
predleaf = "leaf",
predcontrib = "contrib",
approxcontrib = "contrib",
predinteraction = "interaction"
)
found_legacy <- intersect(names(dots), names(mapping))

if (length(found_legacy) > 0) {
for (legacy_arg in found_legacy) {
if (isTRUE(dots[[legacy_arg]])) {
type <- mapping[[legacy_arg]]
warning(
sprintf(
"Argument '%s' is deprecated. Please use type = '%s' instead.",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the case that it's deprecated. predict.xgboost is a new function introduced in version 3 which never had this argument. Better to raise a warning without making this choice.

legacy_arg, type
),
call. = FALSE
)
break
}
}
dots[found_legacy] <- NULL
}

if (length(dots) > 0) {
dot_names <- names(dots)
if (is.null(dot_names)) {
dot_names <- rep("<unnamed>", length(dots))
}
dot_names[dot_names == ""] <- "<unnamed>"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this <unnamed> come from?

stop(
"predict.xgboost: arguments in '...' are not supported (",
paste(dot_names, collapse = ", "), ")."
)
}
}

if (inherits(newdata, "xgb.DMatrix")) {
stop(
"Predictions on 'xgb.DMatrix' objects are not supported with 'xgboost' class.",
Expand Down
4 changes: 3 additions & 1 deletion R-package/man/predict.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.deepness.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions R-package/tests/testthat/test_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -1131,3 +1131,53 @@ test_that("Linear booster importance uses class names", {
expect_true(is.factor(imp$Class))
expect_equal(levels(imp$Class), levels(y))
})

test_that("predict.xgboost maps legacy boolean flags to type", {
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1L])
model <- xgboost(x, y, nthreads = 1L, nrounds = 1L)

expect_warning(
pred <- predict(model, x, predcontrib = TRUE),
paste0(
"Argument 'predcontrib' is deprecated. ",
"Please use type = 'contrib' instead."
),
fixed = TRUE
)
expect_equal(dim(pred), c(nrow(x), ncol(x) + 1L))

# Test multiple flags (first is FALSE, second is TRUE)
expect_warning(
pred2 <- predict(model, x, outputmargin = FALSE, predinteraction = TRUE),
paste0(
"Argument 'predinteraction' is deprecated. ",
"Please use type = 'interaction' instead."
),
fixed = TRUE
)
expect_equal(dim(pred2), c(nrow(x), ncol(x) + 1L, ncol(x) + 1L))

# Test conflict between legacy flag and explicit type
expect_warning(
pred3 <- predict(model, x, type = "response", approxcontrib = TRUE),
paste0(
"Argument 'approxcontrib' is deprecated. ",
"Please use type = 'contrib' instead."
),
fixed = TRUE
)
expect_equal(dim(pred3), c(nrow(x), ncol(x) + 1L))

# Test unsupported arguments
expect_error(
predict(model, x, foobar = TRUE),
"predict.xgboost: arguments in '...' are not supported (foobar).",
fixed = TRUE
)
expect_error(
predict(model, x, "response", NULL, NULL, TRUE, "some_unnamed_arg"),
"predict.xgboost: arguments in '...' are not supported (<unnamed>).",
fixed = TRUE
)
})
Comment thread
Sanidhyavijay24 marked this conversation as resolved.
Loading