Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# development version

- New option to impute missing data after the train/test split rather than before (#301, @megancoden and @shah-priyal).
- Added `impute_in_training` option to `run_ml()`, which defaults to FALSE.
- Added `impute_in_preprocessing` option to `preprocess()`, which defaults to TRUE.

# mikropml 1.6.0

- New functions:
Expand All @@ -13,7 +17,6 @@
- Renamed the column `names` to `feat` to represent each feature or group of correlated features.
- New column `lower` and `upper` to report the bounds of the empirical 95% confidence interval from the permutation test.
See `vignette('parallel')` for an example of plotting feature importance with confidence intervals.
- Minor documentation improvements (#323, #332, @kelly-sovacool).

# mikropml 1.5.0

Expand Down
37 changes: 37 additions & 0 deletions R/impute.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#' Replace NA values with the median value of the column for continious variables in the dataset
#'
#' @param transformed_cont Data frame that may include NA values in one or more columns
#'
#' @return Data frame that has no NA values in continious numeric columns
#'
#' @examples
#' transformed_cont <- impute(transformed_cont)
#' train_data <- impute(train_data)
#' test_data <- impute(test_data)
impute <- function(transformed_cont) {
sapply_fn <- select_apply("sapply")
cl <- sapply_fn(transformed_cont, function(x) {
class(x)
})
missing <-
is.na(transformed_cont[, cl %in% c("integer", "numeric")])
n_missing <- sum(missing)
if (n_missing > 0) {
transformed_cont <- sapply_fn(transformed_cont, function(x) {
if (class(x) %in% c("integer", "numeric")) {
m <- is.na(x)
x[m] <- stats::median(x, na.rm = TRUE)
}
message(typeof(x))
message(class(x))
return(x)
}) %>% dplyr::as_tibble()
message(
paste0(
n_missing,
" missing continuous value(s) were imputed using the median value of the feature."
)
)
}
return (transformed_cont)
}
68 changes: 18 additions & 50 deletions R/preprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ preprocess_data <- function(dataset, outcome_colname,
method = c("center", "scale"),
remove_var = "nzv", collapse_corr_feats = TRUE,
to_numeric = TRUE, group_neg_corr = TRUE,
prefilter_threshold = 1) {
prefilter_threshold = 1, impute_in_preprocessing = TRUE) {
progbar <- NULL
if (isTRUE(check_packages_installed("progressr"))) {
progbar <- progressr::progressor(steps = 20, message = "preprocessing")
Expand All @@ -70,10 +70,11 @@ preprocess_data <- function(dataset, outcome_colname,
check_outcome_column(dataset, outcome_colname, check_values = FALSE)
check_remove_var(remove_var)
pbtick(progbar)

dataset[[outcome_colname]] <- replace_spaces(dataset[[outcome_colname]])
dataset <- rm_missing_outcome(dataset, outcome_colname)
split_dat <- split_outcome_features(dataset, outcome_colname)

features <- split_dat$features
removed_feats <- character(0)
if (to_numeric) {
Expand All @@ -83,22 +84,22 @@ preprocess_data <- function(dataset, outcome_colname,
features <- feats$dat
}
pbtick(progbar)

nv_feats <- process_novar_feats(features, progbar = progbar)
pbtick(progbar)
split_feats <- process_cat_feats(nv_feats$var_feats, progbar = progbar)
pbtick(progbar)
cont_feats <- process_cont_feats(split_feats$cont_feats, method)
cont_feats <- process_cont_feats(split_feats$cont_feats, method, impute_in_preprocessing)
pbtick(progbar)

# combine all processed features
processed_feats <- dplyr::bind_cols(
cont_feats$transformed_cont,
split_feats$cat_feats,
nv_feats$novar_feats
)
pbtick(progbar)

# remove features with (near-)zero variance
feats <- get_caret_processed_df(processed_feats, remove_var)
processed_feats <- feats$processed
Expand Down Expand Up @@ -140,17 +141,15 @@ preprocess_data <- function(dataset, outcome_colname,
#' @inheritParams run_ml
#'
#' @return dataset with no missing outcomes
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' rm_missing_outcome(mikropml::otu_mini_bin, "dx")
#'
#' test_df <- mikropml::otu_mini_bin
#' test_df[1:100, "dx"] <- NA
#' rm_missing_outcome(test_df, "dx")
#' }
rm_missing_outcome <- function(dataset, outcome_colname) {
n_outcome_na <- sum(is.na(dataset %>% dplyr::pull(outcome_colname)))
total_outcomes <- nrow(dataset)
Expand All @@ -168,13 +167,11 @@ rm_missing_outcome <- function(dataset, outcome_colname) {
#' @param features dataframe of features for machine learning
#'
#' @return dataframe with numeric columns where possible
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' class(change_to_num(data.frame(val = c("1", "2", "3")))[[1]])
#' }
change_to_num <- function(features) {
lapply_fn <- select_apply(fun = "lapply")
check_features(features, check_missing = FALSE)
Expand Down Expand Up @@ -228,13 +225,11 @@ remove_singleton_columns <- function(dat, threshold = 1) {
#' @param progbar optional progress bar (default: `NULL`)
#'
#' @return list of two dataframes: features with variability (unprocessed) and without (processed)
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' process_novar_feats(mikropml::otu_small[, 2:ncol(otu_small)])
#' }
process_novar_feats <- function(features, progbar = NULL) {
novar_feats <- NULL
var_feats <- NULL
Expand Down Expand Up @@ -297,13 +292,11 @@ process_novar_feats <- function(features, progbar = NULL) {
#' @inheritParams process_novar_feats
#'
#' @return list of two dataframes: categorical (processed) and continuous features (unprocessed)
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' process_cat_feats(mikropml::otu_small[, 2:ncol(otu_small)])
#' }
process_cat_feats <- function(features, progbar = NULL) {
feature_design_cat_mat <- NULL
cont_feats <- NULL
Expand Down Expand Up @@ -367,14 +360,12 @@ process_cat_feats <- function(features, progbar = NULL) {
#' @inheritParams get_caret_processed_df
#'
#' @return dataframe of preprocessed features
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' process_cont_feats(mikropml::otu_small[, 2:ncol(otu_small)], c("center", "scale"))
#' }
process_cont_feats <- function(features, method) {
process_cont_feats <- function(features, method, impute_in_preprocessing) {
transformed_cont <- NULL
removed_cont <- NULL

Expand All @@ -389,31 +380,12 @@ process_cont_feats <- function(features, method) {
transformed_cont <- feats$processed
removed_cont <- feats$removed
}
sapply_fn <- select_apply("sapply")
cl <- sapply_fn(transformed_cont, function(x) {
class(x)
})
missing <-
is.na(transformed_cont[, cl %in% c("integer", "numeric")])
n_missing <- sum(missing)
if (n_missing > 0) {
# impute missing data using the median value
transformed_cont <- sapply_fn(transformed_cont, function(x) {
if (class(x) %in% c("integer", "numeric")) {
m <- is.na(x)
x[m] <- stats::median(x, na.rm = TRUE)
}
return(x)
}) %>% dplyr::as_tibble()
message(
paste0(
n_missing,
" missing continuous value(s) were imputed using the median value of the feature."
)
)
if (impute_in_preprocessing) {
transformed_cont <- impute(transformed_cont)
}
}
}
}
return(list(transformed_cont = transformed_cont, removed_cont = removed_cont))
}

Expand Down Expand Up @@ -450,11 +422,10 @@ get_caret_processed_df <- function(features, method) {
#' @inheritParams process_novar_feats
#' @param full_rank whether matrix should be full rank or not (see `[caret::dummyVars])
#' @return design matrix
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' df <- data.frame(
#' outcome = c("normal", "normal", "cancer"),
#' var1 = 1:3,
Expand All @@ -463,7 +434,6 @@ get_caret_processed_df <- function(features, method) {
#' var4 = c(0, 1, 0)
#' )
#' get_caret_dummyvars_df(df, TRUE)
#' }
get_caret_dummyvars_df <- function(features, full_rank = FALSE, progbar = NULL) {
check_features(features, check_missing = FALSE)
if (!is.null(process_novar_feats(features, progbar = progbar)$novar_feats)) {
Expand All @@ -481,13 +451,11 @@ get_caret_dummyvars_df <- function(features, full_rank = FALSE, progbar = NULL)
#' @inheritParams group_correlated_features
#'
#' @return features where perfectly correlated ones are collapsed
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' collapse_correlated_features(mikropml::otu_small[, 2:ncol(otu_small)])
#' }
collapse_correlated_features <- function(features, group_neg_corr = TRUE, progbar = NULL) {
feats_nocorr <- features
grp_feats <- NULL
Expand Down
43 changes: 24 additions & 19 deletions R/run_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ run_ml <-
group_partitions = NULL,
corr_thresh = 1,
seed = NA,
impute_after_split = FALSE,
...) {
check_all(
dataset,
Expand All @@ -162,7 +163,7 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}

# `future.apply` is required for `find_feature_importance()`.
# check it here to adhere to the fail fast principle.
if (find_feature_importance) {
Expand All @@ -173,20 +174,20 @@ run_ml <-
if (find_feature_importance) {
check_cat_feats(dataset %>% dplyr::select(-outcome_colname))
}

dataset <- dataset %>%
randomize_feature_order(outcome_colname) %>%
# convert tibble to dataframe to silence warning from caret::train():
# "Warning: Setting row names on a tibble is deprecated.."
as.data.frame()

outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)

if (length(training_frac) == 1) {
training_inds <- get_partition_indices(outcomes_vctr,
training_frac = training_frac,
groups = groups,
group_partitions = group_partitions
training_frac = training_frac,
groups = groups,
group_partitions = group_partitions
)
} else {
training_inds <- training_frac
Expand All @@ -201,30 +202,34 @@ run_ml <-
}
check_training_frac(training_frac)
check_training_indices(training_inds, dataset)

train_data <- dataset[training_inds, ]
test_data <- dataset[-training_inds, ]
if (impute_after_split == TRUE) {
train_data <- impute(train_data)
test_data <- impute(test_data)
}
# train_groups & test_groups will be NULL if groups is NULL
train_groups <- groups[training_inds]
test_groups <- groups[-training_inds]

if (is.null(hyperparameters)) {
hyperparameters <- get_hyperparams_list(dataset, method)
}
tune_grid <- get_tuning_grid(hyperparameters, method)


outcome_type <- get_outcome_type(outcomes_vctr)
class_probs <- outcome_type != "continuous"

if (is.null(perf_metric_function)) {
perf_metric_function <- get_perf_metric_fn(outcome_type)
}

if (is.null(perf_metric_name)) {
perf_metric_name <- get_perf_metric_name(outcome_type)
}

if (is.null(cross_val)) {
cross_val <- define_cv(
train_data,
Expand All @@ -238,8 +243,8 @@ run_ml <-
group_partitions = group_partitions
)
}


message("Training the model...")
trained_model_caret <- train_model(
train_data = train_data,
Expand All @@ -254,7 +259,7 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}

if (calculate_performance) {
performance_tbl <- get_performance_tbl(
trained_model_caret,
Expand All @@ -269,7 +274,7 @@ run_ml <-
} else {
performance_tbl <- "Skipped calculating performance"
}

if (find_feature_importance) {
message("Finding feature importance...")
feature_importance_tbl <- get_feature_importance(
Expand All @@ -287,7 +292,7 @@ run_ml <-
} else {
feature_importance_tbl <- "Skipped feature importance"
}

return(
list(
trained_model = trained_model_caret,
Expand Down
Loading