-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_hyperparameter_sweep.py
More file actions
168 lines (150 loc) · 7.19 KB
/
Copy pathrun_hyperparameter_sweep.py
File metadata and controls
168 lines (150 loc) · 7.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python3
import argparse
import itertools
from run_kmer_specmer import run_kmer_specme_experiment
def run_hyperparameter_sweep(
protein_name,
kmer_dist_path,
exp_path,
model_type,
device,
num_candidates_list,
k_values_list,
n_draft_tokens_list,
temperature_list,
num_experiments=1,
dist_type="normalized",
verbose=False,
seed=None,
batch_size=4,
draft_model_size="small",
target_model_size="medium"
):
"""
Run a hyperparameter sweep over different combinations of parameters.
Args:
protein_name: Name of the protein to generate
kmer_dist_path: Directory containing k-mer distribution files
exp_path: Location to store results
model_type: "progen2" or "tranception"
device: Device to run on
num_candidates_list: List of number of candidates to try
k_values_list: List of k_values configurations to try (each element is a list of k values)
n_draft_tokens_list: List of n_draft_tokens values to try
temperature_list: List of temperature values to try
num_experiments: Number of sequences to generate per configuration
dist_type: Distribution type to use
verbose: Whether to print detailed output
seed: Random seed
batch_size: Number of candidates to process in parallel
draft_model_size: Size of draft ProGen2 model (small, medium, large, etc.)
target_model_size: Size of target ProGen2 model (small, medium, large, etc.)
"""
# Total number of experiments to run
total_configs = (len(num_candidates_list) * len(k_values_list) *
len(n_draft_tokens_list) * len(temperature_list))
print(f"Starting hyperparameter sweep with {total_configs} configurations")
# Track current experiment
current_config = 0
# Loop through all combinations
for num_candidates in num_candidates_list:
for k_values in k_values_list:
for n_draft_tokens in n_draft_tokens_list:
for temperature in temperature_list:
current_config += 1
# Format hyperparameters for display
k_values_str = ', '.join(map(str, k_values)) if k_values else "default"
print(f"\n======= Configuration {current_config}/{total_configs} =======")
print(f"num_candidates: {num_candidates}")
print(f"k_values: [{k_values_str}]")
print(f"n_draft_tokens: {n_draft_tokens}")
print(f"temperature: {temperature}")
# Run experiment with this configuration
run_kmer_specme_experiment(
protein_name=protein_name,
num_experiments=num_experiments,
n_draft_tokens=n_draft_tokens,
temperature=temperature,
exp_path=exp_path,
kmer_dist_path=kmer_dist_path,
dist_type=dist_type,
num_candidates=num_candidates,
k_values=k_values,
device=device,
verbose=verbose,
seed=seed,
model_type=model_type,
batch_size=batch_size,
draft_model_size=draft_model_size,
target_model_size=target_model_size
)
def main():
parser = argparse.ArgumentParser(description="Run hyperparameter sweep for k-mer guided speculative decoding")
# Required arguments
parser.add_argument("--protein", type=str, required=True,
help="Protein name (GFP, RBP1, etc.)")
parser.add_argument("--kmer_dist_path", type=str, required=True,
help="Directory containing k-mer distribution files")
parser.add_argument("--exp_path", type=str, default="./results/sweep",
help="Directory to save results")
# Model parameters
parser.add_argument("--model_type", type=str, default="progen2", choices=["progen2", "tranception"],
help="Which model to use")
parser.add_argument("--device", type=str, default="cuda:0",
help="Device to run on")
# Hyperparameter sweep configurations
parser.add_argument("--num_candidates_list", type=int, nargs="+", default=[3, 5, 10],
help="List of number of candidates to try")
parser.add_argument("--k_values_configs", type=str, nargs="+", default=["1,3", "1,3,5"],
help="Comma-separated k values for each configuration (e.g., '1,3' '1,3,5')")
parser.add_argument("--n_draft_tokens_list", type=int, nargs="+", default=[3, 5, 10],
help="List of n_draft_tokens values to try")
parser.add_argument("--temperature_list", type=float, nargs="+", default=[0.7, 1.0, 1.4],
help="List of temperature values to try")
# Other parameters
parser.add_argument("--num_experiments", type=int, default=1,
help="Number of sequences to generate per configuration")
parser.add_argument("--dist_type", type=str, default="normalized",
help="Distribution type to use")
parser.add_argument("--batch_size", type=int, default=5,
help="Number of candidates to process in parallel")
parser.add_argument("--seed", type=int, default=None,
help="Random seed")
parser.add_argument("--verbose", action="store_true",
help="Print detailed output")
parser.add_argument("--draft_model_size", type=str, default="small",
choices=["small", "medium", "large", "base", "oas", "xlarge"],
help="Size of the draft ProGen2 model")
parser.add_argument("--target_model_size", type=str, default="medium",
choices=["small", "medium", "large", "base", "oas", "xlarge"],
help="Size of the target ProGen2 model")
args = parser.parse_args()
# Parse k_values configurations
k_values_list = []
for config in args.k_values_configs:
if config:
k_values = [int(k) for k in config.split(',')]
k_values_list.append(k_values)
else:
k_values_list.append(None)
# Run the hyperparameter sweep
run_hyperparameter_sweep(
protein_name=args.protein,
kmer_dist_path=args.kmer_dist_path,
exp_path=args.exp_path,
model_type=args.model_type,
device=args.device,
num_candidates_list=args.num_candidates_list,
k_values_list=k_values_list,
n_draft_tokens_list=args.n_draft_tokens_list,
temperature_list=args.temperature_list,
num_experiments=args.num_experiments,
dist_type=args.dist_type,
verbose=args.verbose,
seed=args.seed,
batch_size=args.batch_size,
draft_model_size=args.draft_model_size,
target_model_size=args.target_model_size
)
if __name__ == "__main__":
main()