Skip to content

Commit 625ddc1

Browse files
authored
refactor: llm api (#60)
New LLM generation workflow. * add an empty .env * refactor OpenAI util class * use new OpenAI client in main * assume .env unchanged * fix: response processing * use new Gemini client in main * enable reasoning effort from cli * document why two gemini wrapper * add Claude API * add claude models to supported list * handle UnionType for Literal ReasoningEffort * add vLLM support and use it as default option * fix: use vLLM chat interface instead of gen * env add vllm api key * add VLLM_HOST and VLLM_PORT * add vllm server mode * add vLLM in dependencies * doc: instruct to run vllm from uv * make deprecated ollama a standalone script * doc: revise ollama * use 3.12 * add Ollama models * fix: ollama model name * fix: ollama model name * fix: Gemini use its own EFFORT_TOKEN_MAP * remove unused imports * fix: google-genai version * fix: ci with uv run
1 parent 2564136 commit 625ddc1

25 files changed

Lines changed: 3498 additions & 752 deletions

.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
OPENAI_API_KEY=
2+
ANTHROPIC_API_KEY=
3+
GEMINI_API_KEY=

.github/workflows/mypy.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ jobs:
1717
- name: Set up Python
1818
run: uv sync
1919

20+
- name: install mypy
21+
run: uv pip install mypy
22+
2023
- name: Type Check Source Code
2124
run: uv run mypy src
22-

.github/workflows/pylint.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ jobs:
1616

1717
- name: Set up Python
1818
run: uv sync
19+
20+
- name: Install Pylint
21+
run: uv pip install pylint
1922

2023
- name: Lint Source Code
21-
run: uv run pylint src
24+
run: uv run pylint src

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.11
1+
3.12

README.md

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ uv sync # create a virtual environment, and install dependencies
2121
This script will build the benchmark (Prelude with NL) from the raw data.
2222

2323
```sh
24-
uv run scripts/preprocess_benchmark.py
24+
uv run scripts/preprocess_benchmark.py -o tfb.json
2525
```
2626

2727
### TF-Bench_pure
@@ -36,7 +36,7 @@ stack exec alpharewrite-exe 1 ../tfb.json > ../tfb.pure.json
3636
cd ..
3737
```
3838

39-
For details, please take a look at the README of [alpharewrite](https://github.com/SecurityLab-UCD/alpharewrite).
39+
For details, please check out the README of [alpharewrite](https://github.com/SecurityLab-UCD/alpharewrite).
4040

4141
## Download Pre-built Benchmark
4242

@@ -46,6 +46,14 @@ You can also download our pre-built benchmark from [Zenodo](https://doi.org/10.5
4646

4747
## Benchmarking!
4848

49+
Please have your API key ready in `.env`.
50+
Please note that the `.env` in the repository is tracked by git,
51+
we recommend telling your git to ignore its changes by
52+
53+
```sh
54+
git update-index --assume-unchanged .env
55+
```
56+
4957
### GPT Models
5058

5159
To run single model:
@@ -61,22 +69,58 @@ To run all GPT models:
6169
uv run run_all.py --option gpt
6270
```
6371

64-
### Open Source Models
72+
### Open Source Models with Ollama
6573

66-
We use [Ollama](https://ollama.com/) to manage and run the OSS models.
74+
We use [Ollama](https://ollama.com/) to manage and run the OSS models reported in the Appendix.
75+
We switched to vLLM for better performance and SDK design.
76+
Although the Ollama option is still available,
77+
it is no longer maintained.
78+
We recommend using vLLM instead.
6779

6880
```sh
6981
curl -fsSL https://ollama.com/install.sh | sh # install ollama, you need sudo for this
7082
ollama serve # start your own instance instead of a system service
71-
uv run --project . scripts/ollama_pull.sh # install required models
7283
```
7384

85+
NOTE: we required the ollama version at least 0.9.0 to enable thinking parsers.
86+
We use 0.11.7 for our experiments.
87+
7488
```sh
75-
uv run main.py -i Benchmark-F.json -m llama3:70b
89+
> ollama --version
90+
ollama version is 0.11.7
7691
```
7792

78-
To run all Ollama models:
93+
Run the benchmark.
7994

8095
```sh
81-
uv run run_all.py --option ollama
96+
uv run scripts/experiment_ollama.py -m llama3:8b
97+
```
98+
99+
### (WIP) Running Your Model with vLLM
100+
101+
#### OpenAI-Compatible Server
102+
103+
First, launch the vLLM OpenAI-Compatible Server (with default values, please check vLLM's doc for setting your own):
104+
105+
```sh
106+
uv run vllm serve openai/gpt-oss-120b --tensor-parallel-size 2 --async-scheduling
107+
```
108+
109+
Then, run the benchmark:
110+
111+
```sh
112+
uv run main.py -i Benchmark-F.json -m vllm_openai_chat_completion
113+
```
114+
115+
NOTE: if you set your API key, host, and port when launching the vLLM server,
116+
please add them to the `.env` file as well.
117+
Please modify `.env` for your vLLM api-key, host, and port.
118+
If they are left empty, the default values ("", "localhost", "8000") will be used.
119+
We do not recommend using the default values on machine connect to the public web,
120+
as they are not secure.
121+
122+
```
123+
VLLM_API_KEY=
124+
VLLM_HOST=
125+
VLLM_PORT=
82126
```

pyproject.toml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,36 @@ name = "tfbench"
33
version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
6-
requires-python = ">=3.11"
6+
requires-python = ">=3.12"
77
dependencies = [
8-
"anthropic>=0.49.0",
8+
"anthropic==0.49.0",
99
"dacite>=1.8.1",
10+
"deprecated>=1.2.18",
11+
"dotenv>=0.9.9",
1012
"fire==0.5.0",
1113
"funcy==2.0",
1214
"funcy-chain==0.2.0",
13-
"google-genai>=1.11.0",
14-
"groq==0.8.0",
15+
"google-genai==1.31.0",
1516
"hypothesis>=6.98.6",
1617
"markdown-to-json==2.1.2",
1718
"matplotlib>=3.8.3",
1819
"numpy>=1.26.4",
19-
"ollama>=0.2.1",
20-
"openai==1.75.0",
20+
"ollama==0.5.3",
21+
"openai==1.99.9",
2122
"pathos>=0.3.3",
22-
"pylint>=3.3.6",
2323
"pytest>=8.0.0",
2424
"python-dotenv==1.0.1",
2525
"requests==2.32.3",
26-
"returns[compatible-mypy]==0.22.0",
26+
"returns[compatible-mypy]>=0.26.0",
2727
"seaborn==0.13.2",
2828
"tabulate>=0.9.0",
29+
"tenacity>=9.1.2",
2930
"tiktoken==0.7.0",
3031
"tqdm>=4.66.2",
3132
"tree-sitter==0.22.3",
3233
"tree-sitter-haskell==0.21.0",
3334
"types-requests>=2.31.0",
35+
"vllm>=0.10.1.1",
3436
]
3537

3638
[build-system]
@@ -60,6 +62,7 @@ disable = [
6062
"too-many-statements",
6163
"unspecified-encoding",
6264
"missing-class-docstring",
65+
"too-few-public-methods", # LM only have 1 public method
6366
]
6467
fail-under = 9
6568
max-line-length = 120

scripts/experiment_ollama.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
(deprecated) Experiment script for OSS models using Ollama
3+
This script reproduce legacy results for OSS models using Ollama in our paper's Appendix.
4+
New models should use our vLLM option instead.
5+
"""
6+
7+
from typing import Union
8+
import os
9+
import json
10+
11+
from ollama import Client as OllamaClient, ResponseError
12+
import fire
13+
from dacite import from_dict
14+
from tqdm import tqdm
15+
from funcy_chain import Chain
16+
17+
from tfbench.lm import get_sys_prompt
18+
from tfbench.common import BenchmarkTask, get_prompt
19+
from tfbench.postprocessing import postprocess, RESPONSE_STRATEGIES
20+
from tfbench.evaluation import evaluate
21+
22+
OLLAMA_OSS = [
23+
"phi3:3.8b",
24+
"phi3:14b",
25+
"mistral",
26+
"mixtral:8x7b",
27+
"mixtral:8x22b",
28+
"llama3:8b",
29+
"llama3:70b",
30+
"llama3.1:8b",
31+
"llama3.1:70b",
32+
"llama3.1:405b",
33+
"llama3.2:1b",
34+
"llama3.2:3b",
35+
"llama3.3:70b",
36+
"gemma:2b",
37+
"gemma:7b",
38+
"gemma2:9b",
39+
"gemma2:27b",
40+
"qwen2:1.5b",
41+
"qwen2:7b",
42+
"qwen2:72b",
43+
"qwen2.5:1.5b",
44+
"qwen2.5:7b",
45+
"qwen2.5:72b",
46+
"deepseek-v2:16b",
47+
"deepseek-v2:236b",
48+
"deepseek-v2.5:236b",
49+
]
50+
51+
52+
OLLAMA_CODE = [
53+
"qwen2.5-coder:1.5b",
54+
"qwen2.5-coder:7b",
55+
"granite-code:3b",
56+
"granite-code:8b",
57+
"granite-code:20b",
58+
"granite-code:34b",
59+
"deepseek-coder-v2:16b",
60+
"deepseek-coder-v2:236b",
61+
]
62+
63+
OLLAMA_MODELS = OLLAMA_OSS + OLLAMA_CODE
64+
65+
66+
def get_ollama_model(
67+
client: OllamaClient,
68+
model: str = "llama3:8b",
69+
pure: bool = False,
70+
):
71+
"""
72+
Configure and return a function to generate type signatures using an Ollama model.
73+
74+
Parameters:
75+
client (OllamaClient): The Ollama client instance used for sending requests to the model.
76+
77+
model (str): Name of the model to use for generating type signatures.
78+
Must be one of the predefined models in OLLAMA_MODELS.
79+
Default is "llama3:8b".
80+
81+
pure (bool): If True, uses the original variable naming in type inference.
82+
If False, uses rewritten variable naming (e.g., `v1`, `v2`, ...). Default is False.
83+
84+
Returns:
85+
Callable[[str], Union[str, None]]:
86+
A function that takes a prompt string as input and returns the generated type
87+
signature as a string, or None if the generation fails.
88+
"""
89+
90+
def generate_type_signature(prompt: str) -> Union[str, None]:
91+
try:
92+
response = client.chat(
93+
messages=[
94+
{
95+
"role": "system",
96+
"content": get_sys_prompt(pure),
97+
},
98+
{"role": "user", "content": prompt},
99+
],
100+
model=model,
101+
)
102+
except ResponseError as e:
103+
print(e)
104+
return None
105+
106+
message = response.message
107+
if message.content:
108+
return str(message.content)
109+
110+
return None
111+
112+
return generate_type_signature
113+
114+
115+
def main(
116+
model: str = "llama3:8b",
117+
pure: bool = False,
118+
port: int = 11434,
119+
output_file: str | None = None,
120+
log_file: str = "evaluation_log.jsonl",
121+
):
122+
"""
123+
Run an experiment using various AI models to generate and evaluate type signatures.
124+
125+
Parameters:
126+
model (str): Name of the model to use for generating type signatures. Must be one of OLLAMA_MODELS
127+
128+
port (int): Port number for connecting to the Ollama server.
129+
Ignored for other models. Default is 11434.
130+
131+
pure (bool): If True, uses the original variable naming in type inference.
132+
If False, uses rewritten variable naming (e.g., `v1`, `v2`, ...). Default is False.
133+
"""
134+
assert model in OLLAMA_MODELS, f"{model} is not supported."
135+
136+
# hard-coding benchmark file path for experiment
137+
input_file = "tfb.pure.json" if pure else "tfb.json"
138+
input_file = os.path.abspath(input_file)
139+
assert os.path.exists(
140+
input_file
141+
), f"{input_file} does not exist! Please download or build it first."
142+
143+
if output_file is None:
144+
os.makedirs("result", exist_ok=True)
145+
output_file = f"result/{model}.txt"
146+
147+
client = OllamaClient(host=f"http://localhost:{port}")
148+
generate = get_ollama_model(client, model, pure)
149+
150+
with open(input_file, "r") as fp:
151+
tasks = [from_dict(data_class=BenchmarkTask, data=d) for d in json.load(fp)]
152+
153+
prompts = map(get_prompt, tasks)
154+
responses = map(generate, tqdm(prompts, desc=model))
155+
gen_results = (
156+
Chain(responses)
157+
.map(lambda x: x if x is not None else "") # convert None to empty string
158+
.map(lambda s: postprocess(s, RESPONSE_STRATEGIES))
159+
.map(str.strip)
160+
.value
161+
)
162+
163+
with open(output_file, "w", errors="ignore") as file:
164+
file.write("\n".join(gen_results))
165+
166+
eval_acc = evaluate(tasks, gen_results)
167+
print(eval_acc)
168+
169+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
170+
with open(log_file, "a") as fp:
171+
logging_result = {"model_name": model, **eval_acc, "pure": pure}
172+
fp.write(f"{json.dumps(logging_result)}\n")
173+
174+
175+
if __name__ == "__main__":
176+
fire.Fire(main)

scripts/preprocess_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
def main(input_raw_benchmark_path: str = "benchmark", output_path: str = "tfb.json"):
13+
"""Process pre-extracted tasks from Markdown to JSON"""
1314

1415
# read in all files ending with .md in the input_raw_benchmark_path
1516
tasks: list[BenchmarkTask] = []

0 commit comments

Comments
 (0)