forked from eldarkurtic/GuardBench
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvllm_server_eval.py
More file actions
139 lines (121 loc) · 5.68 KB
/
Copy pathvllm_server_eval.py
File metadata and controls
139 lines (121 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from argparse import ArgumentParser
import torch
from openai import OpenAI
from guardbench import benchmark
from transformers import AutoTokenizer
def moderate(
conversations: list[list[dict[str, str]]],
tokenizer: AutoTokenizer,
client: OpenAI,
model: str,
safe_keyword: str,
unsafe_keyword: str,
top_logprobs: int | None = None,
logit_bias_strength: float = 0.0,
) -> list[float]:
# Llama Guard does not support conversation starting with the assistant
# Therefore, we drop the first utterance if it is from the assistant
for i, x in enumerate(conversations):
if x[0]["role"] == "assistant":
conversations[i] = x[1:]
assert len(conversations) == 1, "Batch-size must be 1 for now; batching not yet implemented"
# default behavior: request all tokens
if top_logprobs is None:
top_logprobs = tokenizer.vocab_size
logit_bias = None
if logit_bias_strength and top_logprobs < tokenizer.vocab_size:
safe_ids = tokenizer.encode(safe_keyword, add_special_tokens=False)
unsafe_ids = tokenizer.encode(unsafe_keyword, add_special_tokens=False)
assert len(safe_ids) == 1, f"{safe_keyword=} must be a single token but got ids={safe_ids}"
assert len(unsafe_ids) == 1, f"{unsafe_keyword=} must be a single token but got ids={unsafe_ids}"
# OpenAI API expects token ids as *strings* in the logit_bias map.
# Adding the SAME bias to both tokens preserves their relative logit difference,
# but makes it very likely they appear in the returned top_logprobs list.
logit_bias = {str(safe_ids[0]): float(logit_bias_strength), str(unsafe_ids[0]): float(logit_bias_strength)}
resp = client.chat.completions.create(
model=model,
messages=conversations[0],
max_tokens=1,
temperature=0.0,
logprobs=True,
top_logprobs=top_logprobs,
logit_bias=logit_bias,
# NOTE: prompts are good for this model, I've manually verified with:
# extra_body={
# "echo": True,
# "return_token_ids": True,
# "prompt_logprobs": 5,
# }
)
assert len(resp.choices) == 1, "Model should produce only one generation per prompt"
resp = resp.choices[0]
assert len(resp.logprobs.content) == 1, "Model should produce only one token due to max_tokens=1 above"
generated_token = resp.logprobs.content[0] # TODO: currently we don't check if the generated token is safe/unsafe_keyword
if top_logprobs == tokenizer.vocab_size:
assert len(generated_token.top_logprobs) == tokenizer.vocab_size, "Model should produce logits for all tokens in the vocabulary"
# .logprob below corresponds to raw logits because the model should be served with
# vllm serve <mdl_name> --logprobs-mode raw_logits
safe_logit = None
unsafe_logit = None
for item in generated_token.top_logprobs:
if item.token == safe_keyword:
safe_logit = item.logprob
if item.token == unsafe_keyword:
unsafe_logit = item.logprob
if safe_logit is not None and unsafe_logit is not None:
break
assert safe_logit is not None and unsafe_logit is not None, f"Model should produce logits for {safe_keyword} and {unsafe_keyword}"
unsafe_prob = torch.softmax(torch.tensor([safe_logit, unsafe_logit]), dim=-1)[1].item()
return [unsafe_prob]
def main():
parser = ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--datasets", nargs="+", required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--vllm_port", type=int, default=8000)
parser.add_argument("--vllm_api_key", type=str, default="EMPTY")
parser.add_argument("--safe_keyword", type=str, default="safe")
parser.add_argument("--unsafe_keyword", type=str, default="unsafe")
parser.add_argument("--top_logprobs", type=int, default=-1, help="If -1, requests full vocab (slow). Try 2 with --logit_bias_strength>0 for speed.")
parser.add_argument("--logit_bias_strength", type=float, default=0.0, help="If >0 and top_logprobs is small, biases safe/unsafe equally so they appear in top-k.")
args = parser.parse_args()
# [
# "--model", "meta-llama/LlamaGuard-7b",
# "--datasets", "beaver_tails_330k",
# "--output_dir", "outputs_eldar_sweep",
# "--top_logprobs", "2",
# "--logit_bias_strength", "100",
# ]
# )
model = args.model
datasets = args.datasets
if len(datasets) == 1 and datasets[0] == "all":
datasets = "all"
batch_size = args.batch_size
output_dir = args.output_dir
vllm_port = args.vllm_port
vllm_api_key = args.vllm_api_key
safe_keyword = args.safe_keyword
unsafe_keyword = args.unsafe_keyword
top_logprobs = None if args.top_logprobs == -1 else args.top_logprobs
logit_bias_strength = args.logit_bias_strength
tokenizer = AutoTokenizer.from_pretrained(model)
client = OpenAI(base_url=f"http://localhost:{vllm_port}/v1", api_key=vllm_api_key)
benchmark(
moderate=moderate,
model_name=model,
batch_size=batch_size,
datasets=datasets,
out_dir=output_dir,
# Moderate kwargs - the following arguments are given as input to `moderate`
client=client,
tokenizer=tokenizer,
model=model,
safe_keyword=safe_keyword,
unsafe_keyword=unsafe_keyword,
top_logprobs=tokenizer.vocab_size if top_logprobs is None else top_logprobs,
logit_bias_strength=logit_bias_strength,
)
if __name__ == "__main__":
main()