Skip to content

Commit 9ec48ec

Browse files
committed
gof distinguish train and valid
1 parent 9673500 commit 9ec48ec

6 files changed

Lines changed: 20 additions & 20 deletions

File tree

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export(previous_tn)
2020
import(magrittr)
2121
import(ranger)
2222
import(xgboost)
23+
importFrom(Ipaper,melt_list)
2324
importFrom(data.table,as.data.table)
2425
importFrom(data.table,data.table)
2526
importFrom(dplyr,across)

R/GOF.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ GOF <- function(yobs, ysim, w, include.cv = FALSE, include.r = TRUE) {
151151

152152
if (include.r) out <- c(R2 = R2, out, R = R, pvalue = pvalue)
153153
if (include.cv) out <- c(out, CV_obs = CV_obs, CV_sim = CV_sim)
154-
return(out)
154+
as.data.table(as.list(out))
155155
}
156156

157157
#' weighted CV

R/kfold-package.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#' @importFrom data.table data.table as.data.table
44
#' @importFrom purrr map is_empty
55
#' @importFrom dplyr mutate across
6+
#' @importFrom Ipaper melt_list
67
#' @import magrittr
78
#'
89
"_PACKAGE"
@@ -16,9 +17,7 @@ NULL
1617
.onLoad <- function(libname, pkgname) {
1718
if (getRversion() >= "2.15.1") {
1819
utils::globalVariables(
19-
c(
20-
"."
21-
)
20+
c(".")
2221
)
2322
}
2423
}

R/kfold_calib.R

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,31 @@ kfold_calib <- function(X, Y, FUN = xgboost, index=NULL, ..., ratio_valid=0.3) {
1414
y_test <- Y[index, , drop = F]
1515

1616
m <- FUN(x_train, y_train, ...)
17-
ypred <- predict(m, x_test)
18-
list(gof = GOF(y_test, ypred), ypred = ypred, model = m)
17+
ypred_train <- predict(m, x_train)
18+
ypred_test <- predict(m, x_test)
19+
20+
gof = list(
21+
train = GOF(y_train, ypred_train),
22+
test = GOF(y_test, ypred_test)
23+
) %>% melt_list("type")
24+
list(gof = gof, ypred = ypred_test, model = m)
1925
}
2026

2127
#' @export
2228
kfold_tidy <- function(res, ind_lst, Y) {
2329
kfold_names <- names(ind_lst)
2430
if (is.null(kfold_names)) kfold_names <- paste0(seq_along(ind_lst))
2531

26-
## 3. GOF information get
32+
## GOF information get
2733
val <- map(res, ~ .x$ypred) %>% unlist() # pred value
2834
ypred <- Y * NA
2935
ypred[unlist(ind_lst)] <- val
30-
info_all <- GOF(Y, ypred)
36+
info_all <- cbind(type = "valid", GOF(Y, ypred))
3137

3238
model <- map(res, "model")
33-
gof <- map(res, "gof") %>%
39+
gof <- map(res, "gof") %>% set_names(kfold_names) %>%
3440
c(., all = list(info_all)) %>%
35-
do.call(rbind, .) %>%
36-
as.data.table()
37-
gof$kfold <- c(kfold_names, "all")
38-
41+
melt_list("kfold") %>% data.table()
3942
listk(gof, ypred, index = ind_lst, model) %>% set_class("kfold") # how to return back to original value?
4043
}
4144

R/kford_ml.R

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ kfold_rf <- function(X, Y, kfold = 5,
3838
#' @import xgboost
3939
#' @rdname kfold_ml
4040
#' @export
41-
kfold_xgboost <- function(X, Y, kfold = 5, verbose = FALSE, nrounds = 500, ...) {
42-
kfold_ml(X, Y, kfold,
43-
FUN = xgboost, nrounds = nrounds, verbose = verbose, ...)
41+
kfold_xgboost <- function(X, Y, kfold = 5, nrounds = 500, ...) {
42+
kfold_ml(X, Y, kfold, FUN = xgboost, nrounds = nrounds, ...)
4443
}
4544

4645
#' @rdname kfold_ml
@@ -58,7 +57,7 @@ predict.ranger <- function(object, data = NULL, ...) {
5857

5958
#' @import ranger
6059
ranger <- function(x, y, ntree = 500, ...) {
61-
ranger::ranger(x = x, y = y, num.trees = ntree, ...)
60+
ranger::ranger(x = x, y = drop(y), num.trees = ntree, ...)
6261
}
6362

6463
# ' @export

man/kfold_ml.Rd

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)