-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathprepare_sft_data_code.py
More file actions
159 lines (126 loc) · 5.68 KB
/
prepare_sft_data_code.py
File metadata and controls
159 lines (126 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
import random
import argparse
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial
from utils import fix_gptoss_completion, get_after_think
from transformers import AutoTokenizer
from env_config import load_env, env_int, env_str
def qwq_prompt(problem):
"""Generate prompt template for QWQ model."""
template = (
f"<|im_start|>user\n{problem}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return template
def extract_code(text: str) -> str:
outputlines = text.split("\n")
indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
if len(indexlines) < 2:
return ""
return "\n".join(outputlines[indexlines[-2] + 1:indexlines[-1]])
def process_item_batch(items_batch, tokenizer_path, max_len, min_len):
"""Process a batch of items with tokenization."""
# Initialize tokenizer in each process
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
batch_results = []
batch_unique_prompts = set()
for item in items_batch:
try:
prompt = item["prompt"]
# Fix completion
try:
completion = fix_gptoss_completion(item["completions"][0])
except:
continue
code = extract_code(get_after_think(completion))
if code is None or code.strip() == "":
continue
# Format prompt and completion
formatted_prompt = qwq_prompt(prompt).strip()
formatted_completion = "\n" + completion
# Tokenize and check length
tok_len = len(tokenizer(formatted_prompt + formatted_completion)['input_ids']) + 20
if tok_len >= max_len or tok_len < min_len:
continue
# Check for uniqueness within this batch
if formatted_prompt not in batch_unique_prompts:
batch_unique_prompts.add(formatted_prompt)
batch_results.append({
"prompt": formatted_prompt,
"completion": formatted_completion,
})
except Exception as e:
# Silently skip problematic items
continue
return batch_results, batch_unique_prompts
if __name__ == "__main__":
load_env()
parser = argparse.ArgumentParser(description='Process data files for training')
# Namespaced env vars avoid collisions across scripts; unprefixed vars remain supported for compatibility.
data_path_default = env_str("PREPARE_SFT_DATA_CODE_DATA_PATH") or env_str("DATA_PATH")
output_path_default = env_str("PREPARE_SFT_DATA_CODE_OUTPUT_PATH") or env_str("OUTPUT_PATH")
parser.add_argument('--data_path', type=str, default=data_path_default, required=data_path_default is None, help='Input data file path')
parser.add_argument('--output_path', type=str, default=output_path_default, required=output_path_default is None, help='Output file path')
parser.add_argument('--tokenizer_path', type=str, default=env_str("PREPARE_SFT_DATA_CODE_TOKENIZER_PATH") or env_str("TOKENIZER_PATH", default="/personal/xueliang/hf_models/Qwen2.5-7B-Instruct"), help='Tokenizer path')
parser.add_argument('--min_len', type=int, default=env_int("PREPARE_SFT_DATA_CODE_MIN_LEN", default=None) or env_int("MIN_LEN", default=0), help='Minimum token length')
parser.add_argument('--max_len', type=int, default=env_int("PREPARE_SFT_DATA_CODE_MAX_LEN", default=None) or env_int("MAX_LEN", default=16384), help='Maximum token length')
args = parser.parse_args()
# Load all items from single input file
all_items = []
try:
with open(args.data_path, encoding="utf-8") as f:
for line in f:
try:
item = json.loads(line.strip())
all_items.append(item)
except json.JSONDecodeError:
continue
except FileNotFoundError:
print(f"Error: File {args.data_path} not found")
exit(1)
print(f"Loaded {len(all_items)} items total")
# Determine number of processes and batch size
num_processes = min(cpu_count(), 32) # Limit to 32 processes to avoid memory issues
batch_size = max(len(all_items) // num_processes, 1)
# Split items into batches
item_batches = [
all_items[i:i + batch_size]
for i in range(0, len(all_items), batch_size)
]
print(f"Processing with {num_processes} processes, {len(item_batches)} batches")
# Create partial function with fixed parameters
process_func = partial(
process_item_batch,
tokenizer_path=args.tokenizer_path,
max_len=args.max_len,
min_len=args.min_len
)
# Process batches in parallel
results = []
global_unique_prompts = set()
with Pool(processes=num_processes) as pool:
batch_results = list(tqdm(
pool.imap(process_func, item_batches),
total=len(item_batches),
desc="Processing batches"
))
# Combine results and ensure global uniqueness
for batch_result, batch_unique_prompts in batch_results:
for item in batch_result:
if item["prompt"] not in global_unique_prompts:
global_unique_prompts.add(item["prompt"])
results.append(item)
print(f"Final results: {len(results)} unique items")
# Shuffle results
random.shuffle(results)
# Print first result as example
if results:
print("Example result:")
print(results[0])
# Write to output file
with open(args.output_path, "w", encoding="utf-8") as f:
for item in results:
f.write(json.dumps(item) + "\n")
print(f"Results written to {args.output_path}")