Skip to content

Commit fd0a344

Browse files
committed
chunk_stratified表现极好, deprecated chunk
1 parent 606cb63 commit fd0a344

4 files changed

Lines changed: 35 additions & 1 deletion

File tree

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ S3method(print,kfold)
99
S3method(select,matrix)
1010
export(GOF)
1111
export(NSE)
12+
export(chunk_stratified)
1213
export(cv_coef)
1314
export(kfold_calib)
1415
export(kfold_lm)

R/kford_ml.R

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ kfold_ml <- function(X, Y, kfold = 5, FUN, ...){ #, threshold = 5000
1313
Y = as.matrix(Y)
1414

1515
# ind_lst <- createFolds(1:nrow(X), k = kfold, list = TRUE)
16-
ind_lst <- Ipaper::chunk(1:nrow(X), kfold)
16+
# ind_lst <- Ipaper::chunk(1:nrow(X), kfold)
17+
ind_lst <- chunk_stratified(Y, kfold)
1718

1819
res <- future_map(ind_lst, kfold_calib,
1920
X = X, Y = Y,
@@ -23,6 +24,32 @@ kfold_ml <- function(X, Y, kfold = 5, FUN, ...){ #, threshold = 5000
2324
kfold_tidy(res, ind_lst, Y)
2425
}
2526

27+
# chunk <- function(x, nchunk = 6) {
28+
# split(x, cut(seq_along(x), nchunk, labels = FALSE)) %>% set_names(NULL)
29+
# }
30+
31+
#' @export
32+
chunk_stratified <- function(y, kfold = 5) {
33+
# 1. 获取按目标变量 Y 值大小排序的对应索引
34+
idx_sorted <- order(y)
35+
36+
# 2. 计算能被切分成多少个大小为 kfold 的区块
37+
n_blocks <- ceiling(length(y) / kfold)
38+
39+
# 3. 在每个区块内部进行 1:kfold 的随机乱序排列
40+
# 保证局部随机性,同时维持宏观的分布均匀
41+
set.seed(42) # 固定种子,保证交叉验证结果可精确复现
42+
groups <- unlist(lapply(1:n_blocks, function(x) sample(1:kfold)))
43+
44+
# 4. 截去尾端多余的组号(对应 length(y) 不能整除 kfold 的情况)
45+
groups <- groups[1:length(y)]
46+
47+
# 5. 将排序后的索引按照打乱后的组号分发,并去除 list 的 names
48+
ind_lst <- unname(split(idx_sorted, groups))
49+
return(ind_lst)
50+
}
51+
52+
2653
#' @inheritParams ranger::ranger
2754
#' @rdname kfold_ml
2855
#' @export

README.Rmd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
output: github_document
33
---
44

5+
[![R-CMD-check](https://github.com/rpkgs/rtrend/workflows/R-CMD-check/badge.svg)](https://github.com/rpkgs/rtrend/actions)
6+
[![codecov](https://codecov.io/gh/rpkgs/rtrend/branch/master/graph/badge.svg)](https://codecov.io/gh/rpkgs/rtrend)
7+
8+
59
<!-- README.md is generated from README.Rmd. Please edit that file -->
610

711
```{r, include = FALSE}

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# kfold
55

66
<!-- badges: start -->
7+
[![R-CMD-check](https://github.com/rpkgs/rtrend/workflows/R-CMD-check/badge.svg)](https://github.com/rpkgs/rtrend/actions)
8+
[![codecov](https://codecov.io/gh/rpkgs/rtrend/branch/master/graph/badge.svg)](https://codecov.io/gh/rpkgs/rtrend)
79
<!-- badges: end -->
810

911
The goal of kfold is to …

0 commit comments

Comments
 (0)