|
| 1 | +from random import sample |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | + |
| 5 | + |
| 6 | +Y_COLUMN = 'label' |
| 7 | +DROP_COLUMS = ['raw', 'path'] |
| 8 | + |
| 9 | + |
| 10 | +def _shuffle(data): |
| 11 | + return data.reindex(np.random.permutation(data.index)) |
| 12 | + |
| 13 | + |
| 14 | +def split_random(data, test_size=0.3): |
| 15 | + test_size = int(data.shape[0] * test_size) |
| 16 | + train_size = len(data) - test_size |
| 17 | + |
| 18 | + data = _shuffle(data)[:test_size + train_size] |
| 19 | + data = data.reindex().drop(DROP_COLUMS, axis=1) |
| 20 | + |
| 21 | + test = data[:test_size] |
| 22 | + train = data[test_size:] |
| 23 | + |
| 24 | + test_Y = test[Y_COLUMN] |
| 25 | + test_X = test.drop(Y_COLUMN, axis=1) |
| 26 | + |
| 27 | + train_Y = train[Y_COLUMN] |
| 28 | + train_X = train.drop(Y_COLUMN, axis=1) |
| 29 | + |
| 30 | + return train_X, train_Y, test_X, test_Y |
| 31 | + |
| 32 | + |
| 33 | +def _train_test_indices(num_samples, idx_file_path, test_size=0.3): |
| 34 | + if not os.path.isfile(idx_file_path): |
| 35 | + test_indices = np.array(sample(range(num_samples), k=int(num_samples * test_size))) |
| 36 | + test_indices.dump(idx_file_path) |
| 37 | + |
| 38 | + test_indices = np.load(idx_file_path) |
| 39 | + train_indices = np.array([i for i in range(num_samples) if i not in test_indices]) |
| 40 | + |
| 41 | + return train_indices, test_indices |
| 42 | + |
| 43 | + |
| 44 | +def split_fixed(data, idx_file_path): |
| 45 | + data = data.reindex().drop(DROP_COLUMS, axis=1) |
| 46 | + |
| 47 | + train_indices, test_indices = _train_test_indices(data.shape[0], idx_file_path) |
| 48 | + |
| 49 | + test = data.iloc[test_indices] |
| 50 | + train = data.iloc[train_indices] |
| 51 | + |
| 52 | + test_Y = test[Y_COLUMN] |
| 53 | + test_X = test.drop(Y_COLUMN, axis=1) |
| 54 | + |
| 55 | + train_Y = train[Y_COLUMN] |
| 56 | + train_X = train.drop(Y_COLUMN, axis=1) |
| 57 | + |
| 58 | + return train_X, train_Y, test_X, test_Y |
0 commit comments