Skip to content

Commit 0716531

Browse files
committed
temp
1 parent 2ebb14e commit 0716531

13 files changed

Lines changed: 1679 additions & 0 deletions

dpo/generate_dpo_pairs.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
Script 1: Generate two outputs from Qwen-2.5-7B for each input using vLLM
3+
"""
4+
import argparse
5+
import json
6+
7+
from tqdm import tqdm
8+
from vllm import LLM, SamplingParams
9+
10+
11+
def parse_args():
12+
parser = argparse.ArgumentParser(
13+
description="Generate two outputs per input using Qwen-2.5-7B with vLLM"
14+
)
15+
parser.add_argument(
16+
"--model",
17+
type=str,
18+
default="Qwen/Qwen2.5-7B-Instruct",
19+
help="HuggingFace model repo or path"
20+
)
21+
parser.add_argument(
22+
"--input_path",
23+
type=str,
24+
default="data/sotopia_grpo.json",
25+
help="Path to input JSON file"
26+
)
27+
parser.add_argument(
28+
"--output_path",
29+
type=str,
30+
default="data/dpo_pairs_generated.json",
31+
help="Path to output JSON file"
32+
)
33+
parser.add_argument(
34+
"--max_tokens",
35+
type=int,
36+
default=256,
37+
help="Maximum new tokens to generate"
38+
)
39+
parser.add_argument(
40+
"--temperature",
41+
type=float,
42+
default=0.7,
43+
help="Sampling temperature"
44+
)
45+
parser.add_argument(
46+
"--top_p",
47+
type=float,
48+
default=0.9,
49+
help="Top-p sampling parameter"
50+
)
51+
parser.add_argument(
52+
"--num_samples",
53+
type=int,
54+
default=None,
55+
help="Number of samples to process (None for all)"
56+
)
57+
parser.add_argument(
58+
"--tensor_parallel_size",
59+
type=int,
60+
default=1,
61+
help="Number of GPUs for tensor parallelism"
62+
)
63+
parser.add_argument(
64+
"--gpu_memory_utilization",
65+
type=float,
66+
default=0.9,
67+
help="GPU memory utilization (0.0 to 1.0)"
68+
)
69+
parser.add_argument(
70+
"--batch_size",
71+
type=int,
72+
default=1000,
73+
help="Number of samples per batch (for progress tracking and memory management)"
74+
)
75+
parser.add_argument(
76+
"--test",
77+
action="store_true",
78+
help="Test mode: only process one batch"
79+
)
80+
return parser.parse_args()
81+
82+
83+
def process_batch(llm, batch_data, batch_start_idx, sampling_params):
84+
"""Process a batch of inputs and return results."""
85+
# Prepare prompts for this batch (2 per input for output1 and output2)
86+
all_prompts = []
87+
prompt_to_idx = []
88+
89+
for local_idx, example in enumerate(batch_data):
90+
input_text = example['input']
91+
messages = [{"role": "user", "content": input_text}]
92+
# Add two prompts for each input
93+
all_prompts.append(messages)
94+
prompt_to_idx.append((local_idx, 1))
95+
all_prompts.append(messages)
96+
prompt_to_idx.append((local_idx, 2))
97+
98+
# Generate all outputs in batch
99+
outputs = llm.chat(
100+
messages=all_prompts,
101+
sampling_params=sampling_params,
102+
)
103+
104+
# Organize results
105+
results = [{
106+
"input": example['input'],
107+
"output1": None,
108+
"output2": None,
109+
"original_output": example.get('output', None),
110+
} for example in batch_data]
111+
112+
for output, (local_idx, output_num) in zip(outputs, prompt_to_idx):
113+
generated_text = output.outputs[0].text.strip()
114+
if output_num == 1:
115+
results[local_idx]["output1"] = generated_text
116+
else:
117+
results[local_idx]["output2"] = generated_text
118+
119+
return results
120+
121+
122+
def main():
123+
args = parse_args()
124+
125+
# Load input data
126+
print(f"Loading input data from: {args.input_path}")
127+
with open(args.input_path, 'r') as f:
128+
input_data = json.load(f)
129+
130+
if args.num_samples is not None:
131+
input_data = input_data[:args.num_samples]
132+
133+
total_samples = len(input_data)
134+
print(f"Processing {total_samples} samples in batches of {args.batch_size}...")
135+
136+
# Initialize vLLM
137+
print(f"Loading model: {args.model}")
138+
llm = LLM(
139+
model=args.model,
140+
tensor_parallel_size=args.tensor_parallel_size,
141+
gpu_memory_utilization=args.gpu_memory_utilization,
142+
trust_remote_code=True,
143+
)
144+
145+
# Sampling parameters
146+
sampling_params = SamplingParams(
147+
max_tokens=args.max_tokens,
148+
temperature=args.temperature,
149+
top_p=args.top_p,
150+
)
151+
152+
# Process in batches for better progress tracking and memory management
153+
all_results = []
154+
num_batches = (total_samples + args.batch_size - 1) // args.batch_size
155+
156+
if args.test:
157+
num_batches = 1
158+
print("[TEST MODE] Only processing 1 batch")
159+
160+
for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
161+
start_idx = batch_idx * args.batch_size
162+
end_idx = min(start_idx + args.batch_size, total_samples)
163+
batch_data = input_data[start_idx:end_idx]
164+
165+
batch_results = process_batch(llm, batch_data, start_idx, sampling_params)
166+
all_results.extend(batch_results)
167+
168+
# Save intermediate results after each batch
169+
print(f"\nSaving intermediate results ({end_idx}/{total_samples} samples)...")
170+
with open(args.output_path, 'w') as f:
171+
json.dump(all_results, f, indent=2, ensure_ascii=False)
172+
173+
print(f"\nDone! Generated {len(all_results)} pairs.")
174+
print(f"Results saved to: {args.output_path}")
175+
176+
177+
if __name__ == "__main__":
178+
main()

dpo/requirements.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
vllm>=0.6.0
2+
transformers>=4.45.0
3+
peft>=0.13.0
4+
torch>=2.1.0
5+
tqdm>=4.66.0
6+
accelerate>=1.0.0
7+
trl>=0.12.0
8+
wandb>=0.18.0
9+

0 commit comments

Comments
 (0)