-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_split_k_fold.py
More file actions
73 lines (54 loc) · 2.41 KB
/
create_split_k_fold.py
File metadata and controls
73 lines (54 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import argparse
def main(config):
np.random.seed(config.seed)
file_meta = "EMVD/metadata_files.csv"
#read df and filter
df_meta = pd.read_csv(file_meta)
df_meta = df_meta[df_meta['type']=='Technique']
df_meta = df_meta[df_meta['authors_rank']!='0']
df_meta = df_meta[df_meta['name']!='GrindInhale']
df_meta = df_meta.sort_values(by=['singer_id'])
n_splits = config.n_splits
######################
########## K FOLDS
######################
df_meta = df_meta.reset_index(drop=True)
kf = KFold(n_splits=n_splits, shuffle=True, random_state=0)
split_gen = kf.split(df_meta.index.values.tolist())
train_splits = []
valid_splits = []
eval_splits = []
for idx, (fulltrain_index, eval_index) in enumerate(split_gen):
train_samples = int(0.7 * len(fulltrain_index))
train_index, valid_index = train_test_split(fulltrain_index, train_size=train_samples, random_state=0)
train_splits.append(train_index)
valid_splits.append(valid_index)
eval_splits.append(eval_index)
for split_idx in range(n_splits):
split_name = f'split{split_idx}'
df_meta[split_name] = 'train'
df_meta.loc[valid_splits[split_idx], split_name] = 'valid'
df_meta.loc[eval_splits[split_idx], split_name] = 'eval'
#######################
###################
# SAVE SPLIT
#################
#####################
df = pd.read_csv(file_meta)
# Merge the DataFrames on the 'file_name' column using a left join
merged_df = df.merge(df_meta.drop(columns=['singer_id', 'type', 'name', 'range', 'vowel', 'authors_rank', 'duration(s)']), on='file_name', how='left')
merged_df = merged_df.drop(columns=['singer_id', 'type', 'name', 'range', 'vowel', 'authors_rank', 'duration(s)'])
# Replace NaN values with 'None'
merged_df = merged_df.fillna('-')
# Save the merged DataFrame as an Excel file
merged_df.to_csv('new_split_kfolds.csv', index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-seed', '--seed', help='chosen seed for numpy', default=0)
parser.add_argument('-n_splits', '--n_splits', help='number of splits for the cross-validation', default=4)
config = parser.parse_args()
main(config)