@@ -13,7 +13,8 @@ kfold_ml <- function(X, Y, kfold = 5, FUN, ...){ #, threshold = 5000
1313 Y = as.matrix(Y )
1414
1515 # ind_lst <- createFolds(1:nrow(X), k = kfold, list = TRUE)
16- ind_lst <- Ipaper :: chunk(1 : nrow(X ), kfold )
16+ # ind_lst <- Ipaper::chunk(1:nrow(X), kfold)
17+ ind_lst <- chunk_stratified(Y , kfold )
1718
1819 res <- future_map(ind_lst , kfold_calib ,
1920 X = X , Y = Y ,
@@ -23,6 +24,32 @@ kfold_ml <- function(X, Y, kfold = 5, FUN, ...){ #, threshold = 5000
2324 kfold_tidy(res , ind_lst , Y )
2425}
2526
27+ # chunk <- function(x, nchunk = 6) {
28+ # split(x, cut(seq_along(x), nchunk, labels = FALSE)) %>% set_names(NULL)
29+ # }
30+
31+ # ' @export
32+ chunk_stratified <- function (y , kfold = 5 ) {
33+ # 1. 获取按目标变量 Y 值大小排序的对应索引
34+ idx_sorted <- order(y )
35+
36+ # 2. 计算能被切分成多少个大小为 kfold 的区块
37+ n_blocks <- ceiling(length(y ) / kfold )
38+
39+ # 3. 在每个区块内部进行 1:kfold 的随机乱序排列
40+ # 保证局部随机性,同时维持宏观的分布均匀
41+ set.seed(42 ) # 固定种子,保证交叉验证结果可精确复现
42+ groups <- unlist(lapply(1 : n_blocks , function (x ) sample(1 : kfold )))
43+
44+ # 4. 截去尾端多余的组号(对应 length(y) 不能整除 kfold 的情况)
45+ groups <- groups [1 : length(y )]
46+
47+ # 5. 将排序后的索引按照打乱后的组号分发,并去除 list 的 names
48+ ind_lst <- unname(split(idx_sorted , groups ))
49+ return (ind_lst )
50+ }
51+
52+
2653# ' @inheritParams ranger::ranger
2754# ' @rdname kfold_ml
2855# ' @export
0 commit comments