-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsample_train.py
More file actions
54 lines (43 loc) · 1.89 KB
/
sample_train.py
File metadata and controls
54 lines (43 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import shutil
import random
import argparse
def sample_train_data(src_dir, dst_dir, n):
# Ensure destination directory exists
os.makedirs(dst_dir, exist_ok=True)
# Path to original train_list.txt
train_list_path = os.path.join(src_dir, "train_list.txt")
if not os.path.exists(train_list_path):
raise FileNotFoundError(f"{train_list_path} does not exist")
# Read lines and sample n
with open(train_list_path, "r") as f:
lines = f.readlines()
if n > len(lines):
print(f"Warning: requested {n} lines but only {len(lines)} available. Using all lines.")
sampled_lines = lines
else:
sampled_lines = random.sample(lines, n)
# Write sampled lines to destination train_list.txt
dst_train_list = os.path.join(dst_dir, "train_list.txt")
with open(dst_train_list, "w") as f:
f.writelines(sampled_lines)
# Copy all other files/subdirectories
for item in os.listdir(src_dir):
src_path = os.path.join(src_dir, item)
dst_path = os.path.join(dst_dir, item)
# Skip original train_list.txt (already copied sampled version)
if item == "train_list.txt":
continue
if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
else:
shutil.copy2(src_path, dst_path)
print(f"Sampled {len(sampled_lines)} lines to {dst_train_list}")
print(f"Copied other files to {dst_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sample train data and copy directory")
parser.add_argument("--src_dir", type=str, help="Source directory")
parser.add_argument("--dst_dir", type=str, help="Destination directory")
parser.add_argument("--n", type=int, help="Number of lines to sample from train_list.txt")
args = parser.parse_args()
sample_train_data(args.src_dir, args.dst_dir, args.n)