Skip to content

Commit 79543a5

Browse files
committed
Add shuffle=False option for time series data (closes #44)
Add shuffle parameter (default True) to ratio(), fixed(), and kfold(). When False, files are split in sorted order without randomization. Add --no-shuffle CLI flag. Update README with documentation.
1 parent 2551252 commit 79543a5

5 files changed

Lines changed: 159 additions & 49 deletions

File tree

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ This should get you started to do some serious deep learning on your data. [Read
4747

4848
- Split files into a training set and a validation set (and optionally a test set).
4949
- Works on any file types.
50-
- The files get shuffled.
50+
- The files get shuffled (can be disabled for time series data).
5151
- A [seed](https://docs.python.org/3/library/random.html#random.seed) makes splits reproducible.
5252
- Allows randomized [oversampling](https://en.wikipedia.org/wiki/Oversampling_and_undersampling_in_data_analysis) for imbalanced datasets.
5353
- Optionally group files by prefix or by stem.
@@ -87,14 +87,14 @@ import splitfolders
8787
# To only split into training and validation set, set a tuple to `ratio`, i.e, `(.8, .2)`.
8888
splitfolders.ratio("input_folder", output="output",
8989
seed=1337, ratio=(.8, .1, .1), group_prefix=None, group=None,
90-
formats=None, move=False) # default values
90+
formats=None, move=False, shuffle=True) # default values
9191

9292
# Split val/test with a fixed number of items, e.g. `(100, 100)`, for each set.
9393
# To only split into training and validation set, use a single number to `fixed`, i.e., `10`.
9494
# Set 3 values, e.g. `(300, 100, 100)`, to limit the number of training values.
9595
splitfolders.fixed("input_folder", output="output",
9696
seed=1337, fixed=(100, 100), oversample=False, group_prefix=None, group=None,
97-
formats=None, move=False) # default values
97+
formats=None, move=False, shuffle=True) # default values
9898

9999
# Use `fixed="auto"` with oversampling to auto-compute the val size from the smallest class.
100100
# Allocates ~20% of the smallest class to validation, rest to training.
@@ -106,7 +106,11 @@ splitfolders.fixed("input_folder", output="output",
106106
# Uses symlinks by default to avoid k× disk usage.
107107
splitfolders.kfold("input_folder", output="output",
108108
seed=1337, k=5, group_prefix=None, group=None,
109-
formats=None, move="symlink") # default values
109+
formats=None, move="symlink", shuffle=True) # default values
110+
111+
# Split without shuffling (e.g. for time series data).
112+
splitfolders.ratio("input_folder", output="output",
113+
ratio=(.8, .1, .1), shuffle=False)
110114
```
111115

112116
### Grouping files
@@ -215,7 +219,7 @@ Set
215219

216220
```
217221
Usage:
218-
splitfolders [--output] [--ratio] [--fixed] [--kfold] [--seed] [--oversample] [--group_prefix] [--group] [--formats] [--move] folder_with_images
222+
splitfolders [--output] [--ratio] [--fixed] [--kfold] [--seed] [--oversample] [--group_prefix] [--group] [--formats] [--move] [--no-shuffle] folder_with_images
219223
Options:
220224
--output path to the output folder. defaults to `output`. Get created if non-existent.
221225
--ratio the ratio to split. e.g. for train/val/test `.8 .1 .1 --` or for train/val `.8 .2 --`.
@@ -231,6 +235,7 @@ Options:
231235
--formats split the files based on specified extension(s)
232236
--move move the files instead of copying
233237
--symlink symlink(create shortcut) the files instead of copying
238+
--no-shuffle do not shuffle files before splitting (useful for time series data)
234239
Example:
235240
splitfolders --ratio .8 .1 .1 -- folder_with_images
236241
splitfolders --kfold 5 folder_with_images

splitfolders/cli.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def run():
8888
default=None,
8989
help="specify the file format(s) which should be considered for spliting the data e.g. `.png .jpeg .jpg`",
9090
)
91+
parser.add_argument(
92+
"--no-shuffle",
93+
action="store_true",
94+
default=False,
95+
help="do not shuffle files before splitting (useful for time series data)",
96+
)
9197

9298
args = parser.parse_args()
9399

@@ -97,34 +103,29 @@ def run():
97103
if args.symlink:
98104
args.move = "symlink"
99105

106+
shuffle = not args.no_shuffle
107+
100108
if args.ratio:
101-
ratio(args.input, args.output, args.seed, args.ratio, args.group_prefix, args.group, args.move, args.formats)
109+
ratio(
110+
args.input, args.output, args.seed, args.ratio,
111+
args.group_prefix, args.group, args.move, args.formats, shuffle,
112+
)
102113
elif args.fixed:
103114
if args.fixed == ["auto"]:
104115
fixed_value = "auto"
105116
else:
106117
fixed_value = [int(x) for x in args.fixed]
107118
fixed(
108-
args.input,
109-
args.output,
110-
args.seed,
111-
fixed_value,
112-
args.oversample,
113-
args.group_prefix,
114-
args.group,
115-
args.move,
116-
args.formats,
119+
args.input, args.output, args.seed, fixed_value,
120+
args.oversample, args.group_prefix, args.group, args.move,
121+
args.formats, shuffle,
117122
)
118123
elif args.kfold:
119124
kfold(
120-
args.input,
121-
args.output,
122-
args.seed,
123-
args.kfold,
124-
args.group_prefix,
125-
args.group,
125+
args.input, args.output, args.seed, args.kfold,
126+
args.group_prefix, args.group,
126127
args.move if args.move else "symlink",
127-
args.formats,
128+
args.formats, shuffle,
128129
)
129130
else:
130131
print("Please specify either your `--ratio`, `--fixed`, or `--kfold` for the split. see -h for more help.")

splitfolders/grouping.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def group_by_stem(files):
5252
return [tuple(sorted(g)) for g in sorted(stem_groups.values(), key=lambda g: g[0])]
5353

5454

55-
def setup_sibling_files(input_dir, seed, formats=None):
55+
def setup_sibling_files(input_dir, seed, formats=None, shuffle=True):
5656
"""Lists type dirs, groups files by stem across all dirs.
5757
Validates every stem exists in every dir. Returns (type_dir_names, groups)."""
5858
from .utils import list_dirs, list_files
@@ -86,12 +86,13 @@ def setup_sibling_files(input_dir, seed, formats=None):
8686
"All stems must exist in every subdirectory for group='sibling'."
8787
)
8888

89-
# Build groups: each group is a dict mapping type_dir_name -> Path
89+
# Build groups: each group is a tuple mapping type_dir_name -> Path
9090
import random
9191

9292
random.seed(seed)
9393
sorted_stems = sorted(all_stems)
94-
random.shuffle(sorted_stems)
94+
if shuffle:
95+
random.shuffle(sorted_stems)
9596

9697
groups = []
9798
for stem in sorted_stems:

splitfolders/split.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def valid_extensions(formats):
107107

108108
def ratio(
109109
input, output="output", seed=1337, ratio=(0.8, 0.1, 0.1),
110-
group_prefix=None, group=None, move=False, formats=None,
110+
group_prefix=None, group=None, move=False, formats=None, shuffle=True,
111111
):
112112
if not round(sum(ratio), 5) == 1: # round for floating imprecision
113113
raise ValueError("The sums of `ratio` is over 1.")
@@ -121,11 +121,14 @@ def ratio(
121121
prog_bar = tqdm(desc="Copying files", unit=" files")
122122

123123
if group == "sibling":
124-
split_sibling_dirs_ratio(input, output, ratio, seed, prog_bar if use_tqdm else None, move, formats)
124+
split_sibling_dirs_ratio(
125+
input, output, ratio, seed, prog_bar if use_tqdm else None, move, formats, shuffle,
126+
)
125127
else:
126128
for class_dir in list_dirs(input):
127129
split_class_dir_ratio(
128-
class_dir, output, ratio, seed, prog_bar if use_tqdm else None, group_prefix, group, move, formats
130+
class_dir, output, ratio, seed, prog_bar if use_tqdm else None,
131+
group_prefix, group, move, formats, shuffle,
129132
)
130133

131134
if use_tqdm:
@@ -134,7 +137,7 @@ def ratio(
134137

135138
def fixed(
136139
input, output="output", seed=1337, fixed=(100, 100), oversample=False,
137-
group_prefix=None, group=None, move=False, formats=None,
140+
group_prefix=None, group=None, move=False, formats=None, shuffle=True,
138141
):
139142
check_input_format(input)
140143
valid_extensions(formats)
@@ -170,7 +173,9 @@ def fixed(
170173
prog_bar = tqdm(desc="Copying files", unit=" files")
171174

172175
if group == "sibling":
173-
split_sibling_dirs_fixed(input, output, fixed, seed, prog_bar if use_tqdm else None, move, formats)
176+
split_sibling_dirs_fixed(
177+
input, output, fixed, seed, prog_bar if use_tqdm else None, move, formats, shuffle,
178+
)
174179
if use_tqdm:
175180
prog_bar.close()
176181
return
@@ -180,7 +185,8 @@ def fixed(
180185
for class_dir in classes_dirs:
181186
num_items.append(
182187
split_class_dir_fixed(
183-
class_dir, output, fixed, seed, prog_bar if use_tqdm else None, group_prefix, group, move, formats
188+
class_dir, output, fixed, seed, prog_bar if use_tqdm else None,
189+
group_prefix, group, move, formats, shuffle,
184190
)
185191
)
186192

@@ -222,7 +228,10 @@ def fixed(
222228
shutil.copy2(str(f_orig), str(f_dest))
223229

224230

225-
def kfold(input, output="output", seed=1337, k=5, group_prefix=None, group=None, move="symlink", formats=None):
231+
def kfold(
232+
input, output="output", seed=1337, k=5, group_prefix=None, group=None,
233+
move="symlink", formats=None, shuffle=True,
234+
):
226235
if k < 2:
227236
raise ValueError("`k` must be 2 or greater.")
228237

@@ -233,24 +242,26 @@ def kfold(input, output="output", seed=1337, k=5, group_prefix=None, group=None,
233242
prog_bar = tqdm(desc="Copying files", unit=" files")
234243

235244
if group == "sibling":
236-
split_sibling_dirs_kfold(input, output, k, seed, prog_bar if use_tqdm else None, move, formats)
245+
split_sibling_dirs_kfold(
246+
input, output, k, seed, prog_bar if use_tqdm else None, move, formats, shuffle,
247+
)
237248
else:
238249
for class_dir in list_dirs(input):
239250
split_class_dir_kfold(
240251
class_dir, output, k, seed, prog_bar if use_tqdm else None,
241-
group_prefix, group, move, formats,
252+
group_prefix, group, move, formats, shuffle,
242253
)
243254

244255
if use_tqdm:
245256
prog_bar.close()
246257

247258

248-
def split_class_dir_kfold(class_dir, output, k, seed, prog_bar, group_prefix, group, move, formats):
259+
def split_class_dir_kfold(class_dir, output, k, seed, prog_bar, group_prefix, group, move, formats, shuffle=True):
249260
"""
250261
Splits a class folder into k folds for cross-validation.
251262
Each fold directory gets train/ and val/ subdirectories.
252263
"""
253-
files = setup_files(class_dir, seed, group_prefix, group, formats)
264+
files = setup_files(class_dir, seed, group_prefix, group, formats, shuffle)
254265

255266
# Partition files into k roughly equal chunks
256267
fold_size = len(files) // k
@@ -271,25 +282,26 @@ def split_class_dir_kfold(class_dir, output, k, seed, prog_bar, group_prefix, gr
271282
copy_files(li, class_dir, fold_output, prog_bar, move)
272283

273284

274-
def setup_files(class_dir, seed, group_prefix=None, group=None, formats=None):
285+
def setup_files(class_dir, seed, group_prefix=None, group=None, formats=None, shuffle=True):
275286
"""
276-
Returns shuffled list of filenames
287+
Returns sorted (and optionally shuffled) list of filenames
277288
"""
278289
random.seed(seed) # make sure its reproducible
279290

280291
files = list_files(class_dir, formats)
281292
files = resolve_grouping(files, group_prefix, group)
282293

283294
files.sort()
284-
random.shuffle(files)
295+
if shuffle:
296+
random.shuffle(files)
285297
return files
286298

287299

288-
def split_class_dir_ratio(class_dir, output, ratio, seed, prog_bar, group_prefix, group, move, formats):
300+
def split_class_dir_ratio(class_dir, output, ratio, seed, prog_bar, group_prefix, group, move, formats, shuffle=True):
289301
"""
290302
Splits a class folder
291303
"""
292-
files = setup_files(class_dir, seed, group_prefix, group, formats)
304+
files = setup_files(class_dir, seed, group_prefix, group, formats, shuffle)
293305

294306
# the data was shuffled already
295307
split_train_idx = int(ratio[0] * len(files))
@@ -299,11 +311,11 @@ def split_class_dir_ratio(class_dir, output, ratio, seed, prog_bar, group_prefix
299311
copy_files(li, class_dir, output, prog_bar, move)
300312

301313

302-
def split_class_dir_fixed(class_dir, output, fixed, seed, prog_bar, group_prefix, group, move, formats):
314+
def split_class_dir_fixed(class_dir, output, fixed, seed, prog_bar, group_prefix, group, move, formats, shuffle=True):
303315
"""
304316
Splits a class folder and returns the total number of files
305317
"""
306-
files = setup_files(class_dir, seed, group_prefix, group, formats)
318+
files = setup_files(class_dir, seed, group_prefix, group, formats, shuffle)
307319

308320
if not len(files) >= sum(fixed):
309321
raise ValueError(
@@ -389,8 +401,8 @@ def copy_sibling_files(files_type, type_dir_names, output, prog_bar, move):
389401
copy_fn(f, full_path)
390402

391403

392-
def split_sibling_dirs_ratio(input_dir, output, ratio, seed, prog_bar, move, formats):
393-
type_dir_names, groups = setup_sibling_files(input_dir, seed, formats)
404+
def split_sibling_dirs_ratio(input_dir, output, ratio, seed, prog_bar, move, formats, shuffle=True):
405+
type_dir_names, groups = setup_sibling_files(input_dir, seed, formats, shuffle)
394406

395407
split_train_idx = int(ratio[0] * len(groups))
396408
split_val_idx = split_train_idx + int(ratio[1] * len(groups))
@@ -399,8 +411,8 @@ def split_sibling_dirs_ratio(input_dir, output, ratio, seed, prog_bar, move, for
399411
copy_sibling_files(li, type_dir_names, output, prog_bar, move)
400412

401413

402-
def split_sibling_dirs_fixed(input_dir, output, fixed, seed, prog_bar, move, formats):
403-
type_dir_names, groups = setup_sibling_files(input_dir, seed, formats)
414+
def split_sibling_dirs_fixed(input_dir, output, fixed, seed, prog_bar, move, formats, shuffle=True):
415+
type_dir_names, groups = setup_sibling_files(input_dir, seed, formats, shuffle)
404416

405417
if not len(groups) >= sum(fixed):
406418
raise ValueError(
@@ -425,8 +437,8 @@ def split_sibling_dirs_fixed(input_dir, output, fixed, seed, prog_bar, move, for
425437
copy_sibling_files(li, type_dir_names, output, prog_bar, move)
426438

427439

428-
def split_sibling_dirs_kfold(input_dir, output, k, seed, prog_bar, move, formats):
429-
type_dir_names, groups = setup_sibling_files(input_dir, seed, formats)
440+
def split_sibling_dirs_kfold(input_dir, output, k, seed, prog_bar, move, formats, shuffle=True):
441+
type_dir_names, groups = setup_sibling_files(input_dir, seed, formats, shuffle)
430442

431443
fold_size = len(groups) // k
432444
remainder = len(groups) % k

0 commit comments

Comments
 (0)