Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: modelbased
Title: Estimation of Model-Based Predictions, Contrasts and Means
Version: 0.14.0
Version: 0.14.0.1
Authors@R:
c(person(given = "Dominique",
family = "Makowski",
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# modelbased (devel)

## Changes

* Support for models of class `nestedLogit`.

# modelbased 0.14.0

## Changes
Expand Down
159 changes: 101 additions & 58 deletions R/estimate_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,17 @@
#' estimate_relation(model)
#' }
#' @export
estimate_expectation <- function(model,
data = NULL,
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...) {
estimate_expectation <- function(
model,
data = NULL,
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...
) {
Comment thread
strengejacke marked this conversation as resolved.
.estimate_predicted(
model,
data = data,
Expand All @@ -280,15 +282,17 @@ estimate_expectation <- function(model,

#' @rdname estimate_expectation
#' @export
estimate_link <- function(model,
data = "grid",
by = NULL,
predict = "link",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...) {
estimate_link <- function(
model,
data = "grid",
by = NULL,
predict = "link",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...
) {
# reset to NULL if only "by" was specified
if (missing(data) && !missing(by)) {
data <- NULL
Expand All @@ -309,15 +313,17 @@ estimate_link <- function(model,

#' @rdname estimate_expectation
#' @export
estimate_prediction <- function(model,
data = NULL,
by = NULL,
predict = "prediction",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...) {
estimate_prediction <- function(
model,
data = NULL,
by = NULL,
predict = "prediction",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...
) {
.estimate_predicted(
model,
data = data,
Expand All @@ -333,15 +339,17 @@ estimate_prediction <- function(model,

#' @rdname estimate_expectation
#' @export
estimate_relation <- function(model,
data = "grid",
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...) {
estimate_relation <- function(
model,
data = "grid",
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...
) {
# reset to NULL if only "by" was specified
if (missing(data) && !missing(by)) {
data <- NULL
Expand All @@ -364,15 +372,17 @@ estimate_relation <- function(model,
# Internal ----------------------------------------------------------------

#' @keywords internal
.estimate_predicted <- function(model,
data = "grid",
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...) {
.estimate_predicted <- function(
model,
data = "grid",
by = NULL,
predict = "expectation",
ci = 0.95,
transform = NULL,
iterations = NULL,
keep_iterations = FALSE,
...
) {
# return early for htest
if (inherits(model, "htest")) {
return(insight::get_predicted(model, ...))
Expand All @@ -384,7 +394,14 @@ estimate_relation <- function(model,
}

# keep_iterations cannot be larger than interations
if (!is.null(keep_iterations) && !is.null(iterations) && is.numeric(keep_iterations) && is.numeric(iterations) && keep_iterations > iterations) { # nolint
if (
!is.null(keep_iterations) &&
!is.null(iterations) &&
is.numeric(keep_iterations) &&
is.numeric(iterations) &&
keep_iterations > iterations
) {
# nolint
insight::format_error("`keep_iterations` cannot be larger than `iterations`.")
}

Expand Down Expand Up @@ -450,7 +467,12 @@ estimate_relation <- function(model,
data <- model_data
} else if (!is.data.frame(data)) {
if (is_grid) {
data <- insight::get_datagrid(model, reference = model_data, include_response = is_nullmodel, ...)
data <- insight::get_datagrid(
model,
reference = model_data,
include_response = is_nullmodel,
...
)
} else {
insight::format_error(
"The `data` argument must either NULL, \"grid\" or another data frame."
Expand All @@ -462,7 +484,12 @@ estimate_relation <- function(model,
grid_specs <- attributes(data)

# Get response for later residuals -------------
if (!is.null(model_response) && length(model_response) == 1 && model_response %in% names(data)) { # nolint
if (
!is.null(model_response) &&
length(model_response) == 1 &&
model_response %in% names(data)
) {
# nolint
response <- data[[model_response]]
} else {
response <- NULL
Expand Down Expand Up @@ -492,7 +519,10 @@ estimate_relation <- function(model,
)

# for predicting grouplevel random effects, add "allow.new.levels"
if (!is.null(grouplevel_effects) && any(grouplevel_effects %in% grid_specs$at_spec$varname)) {
if (
!is.null(grouplevel_effects) &&
any(grouplevel_effects %in% grid_specs$at_spec$varname)
) {
prediction_args$allow.new.levels <- TRUE
dots$allow.new.levels <- NULL
}
Expand All @@ -511,13 +541,22 @@ estimate_relation <- function(model,
}

# remove response variable from data frame, as this variable is predicted
if (!is.null(model_response) && length(model_response) == 1 && model_response %in% colnames(out)) { # nolint
if (
!is.null(model_response) &&
length(model_response) == 1 &&
model_response %in% colnames(out)
) {
# nolint
out[[model_response]] <- NULL
}

# keep row-column, but make sure it's integer
if ("Row" %in% colnames(out)) {
out[["Row"]] <- insight::format_value(out[["Row"]], protect_integers = TRUE)
if (inherits(model, "nestedLogit")) {
out[["Row"]] <- NULL
} else {
out[["Row"]] <- insight::format_value(out[["Row"]], protect_integers = TRUE)
}
}

# Add residuals
Expand Down Expand Up @@ -557,17 +596,21 @@ estimate_relation <- function(model,
by = grid_specs$at,
type = "predictions",
model = model,
info = c(
grid_specs,
list(predict = predict),
transform = !is.null(transform)
)
info = c(grid_specs, list(predict = predict), transform = !is.null(transform))
Comment thread
strengejacke marked this conversation as resolved.
)

attributes(out) <- c(attributes(out), grid_specs[!names(grid_specs) %in% names(attributes(out))])
attributes(out) <- c(
attributes(out),
grid_specs[!names(grid_specs) %in% names(attributes(out))]
)

# Class
class(out) <- c(paste0("estimate_", predict), "estimate_predicted", "see_estimate_predicted", class(out))
class(out) <- c(
paste0("estimate_", predict),
"estimate_predicted",
"see_estimate_predicted",
class(out)
)

out
}
Loading