-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_downsample_all_set.py
More file actions
66 lines (51 loc) · 2.23 KB
/
run_downsample_all_set.py
File metadata and controls
66 lines (51 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import time
import fire
import random
import shutil
from utils.data_io import DataIO
from utils.init_functions import random_setup
def main(
seed: int = 42,
num_sample_subsets: int = 3,
) -> None:
"""
:param seed: Random seed of all modules.
:param num_sample_subsets: The number of downsample subsets.
:return: None.
"""
timer_start = time.perf_counter()
random_setup(seed=seed)
all_test_fp = os.path.join("data/intent_grasp/all", "test.jsonl")
gem_test_fp = os.path.join("data/intent_grasp/gem", "test.jsonl")
assert os.path.isfile(all_test_fp) and os.path.isfile(gem_test_fp)
all_test_data = DataIO.load_jsonl(all_test_fp, verbose=True)
gem_test_data = DataIO.load_jsonl(gem_test_fp, verbose=True)
assert isinstance(all_test_data, list) and len(all_test_data) > 0
assert isinstance(gem_test_data, list) and len(gem_test_data) > 0
num_all_test = len(all_test_data)
num_gem_test = len(gem_test_data)
print(f">>> [num_all_test = {num_all_test}] [num_gem_test = {num_gem_test}]")
# Sample num_gem_test data items from all_test_data
for sample_turn in range(num_sample_subsets):
cur_seed = seed + sample_turn
random.seed(cur_seed)
cur_sample_data = random.sample(all_test_data, num_gem_test)
assert isinstance(cur_sample_data, list) and len(cur_sample_data) == num_gem_test
cur_save_dir = os.path.join("data/intent_grasp", f"all2gem_{cur_seed}")
os.makedirs(cur_save_dir, exist_ok=True)
try:
shutil.copy(os.path.join("data/intent_grasp/all", "metadata.json"), cur_save_dir)
except Exception as e:
print(e)
return None
DataIO.save_jsonl(os.path.join(cur_save_dir, "test.jsonl"), cur_sample_data, mode="w", verbose=True)
DataIO.save_parquet(os.path.join(cur_save_dir, "test.parquet"), cur_sample_data, verbose=True)
print(f">>> Done saving. cur_save_dir: {cur_save_dir}")
timer_end = time.perf_counter()
total_sec = timer_end - timer_start
print(f"Total Running Time: {total_sec:.1f} sec ({total_sec / 60:.1f} min; {total_sec / 3600:.2f} h)")
if __name__ == "__main__":
fire.Fire(main)