Skip to content

Commit 351462e

Browse files
authored
Support long context dataset accuracy measurement (#230)
1 parent 9d19631 commit 351462e

5 files changed

Lines changed: 229 additions & 17 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ unit-tests:
5151
coverage run -m unittest -v
5252

5353
check-test-coverage:
54-
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/math_utils.py" --fail-under=96
54+
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/eval_accuracy_longcontext.py,benchmarks/math_utils.py" --fail-under=96

benchmarks/benchmark_serving.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171

7272
from benchmarks.eval_accuracy import eval_accuracy
7373
from benchmarks.eval_accuracy_mmlu import eval_accuracy_mmlu
74+
from benchmarks.eval_accuracy_longcontext import eval_accuracy_longcontext
7475
from benchmarks.metrics import CounterMetric, EventMetric
7576
import grpc
7677
from jetstream.core.proto import jetstream_pb2
@@ -166,6 +167,7 @@ class InputRequest:
166167
output: str = ""
167168
output_len: int = 0
168169
sample_idx: int = -1
170+
metric: str = ""
169171

170172

171173
@dataclass
@@ -187,10 +189,12 @@ def to_dict(self):
187189
prompt = self.input_request.prompt
188190
original_output = self.input_request.output
189191
sample_idx = self.input_request.sample_idx
192+
metric = self.input_request.metric
190193
else:
191194
prompt = None
192195
original_output = None
193196
sample_idx = None
197+
metric = None
194198
return {
195199
"prompt": prompt,
196200
"original_output": original_output,
@@ -201,6 +205,7 @@ def to_dict(self):
201205
"ttst_sec": self.ttst_sec,
202206
"prompt_len": self.prompt_len,
203207
"sample_idx": sample_idx,
208+
"metric": metric,
204209
}
205210

206211

@@ -282,17 +287,19 @@ def load_openorca_dataset_pkl(
282287

283288
def load_longcontext_dataset_pkl(
284289
dataset_path: str,
285-
) -> list[tuple[Any, Any]]:
290+
) -> tuple[list[tuple[Any, Any]], list]:
286291
assert os.path.isfile(dataset_path)
287292

288293
# read pickle file
289294
data = pandas.read_pickle(dataset_path)
290295

291296
samples = []
297+
metrics = []
292298
for _, row in data.iterrows():
293-
samples.append((row["input"], row["ref_output"]))
299+
samples.append((row["input"], row["gt_output"]))
300+
metrics.append(row["metric"])
294301

295-
return samples
302+
return samples, metrics
296303

297304

298305
def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
@@ -421,7 +428,6 @@ def filter_dataset(
421428
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
422429
dataset_type: str,
423430
max_output_length: int = 0,
424-
run_mmlu_dataset: bool = False,
425431
min_input_length: int = 4,
426432
max_input_length: int = 0,
427433
max_target_length: int = 0,
@@ -443,7 +449,8 @@ def filter_dataset(
443449
sample_idx,
444450
) in tokenized_dataset:
445451
if prompt_len < min_input_length or (
446-
not (run_mmlu_dataset or dataset_type == "math500") and output_len < 4
452+
not (dataset_type == "mmlu" or dataset_type == "math500")
453+
and output_len < 4
447454
):
448455
# Prune too short sequences.
449456
# This is because TGI causes errors when the input or output length
@@ -479,11 +486,11 @@ def sample_requests(
479486
dataset_type: str,
480487
max_output_length: int = 0,
481488
oversample_multiplier: float = 1.2,
482-
run_mmlu_dataset: bool = False,
483489
min_input_length: int = 4,
484490
max_input_length: int = 0,
485491
max_target_length: int = 0,
486492
max_output_multiplier: int = 0,
493+
metrics: Optional[list[str]] = None,
487494
) -> list[InputRequest]:
488495

489496
# Original dataset size
@@ -521,13 +528,16 @@ def sample_requests(
521528
tokenized_dataset,
522529
dataset_type,
523530
max_output_length,
524-
run_mmlu_dataset,
525531
min_input_length,
526532
max_input_length,
527533
max_target_length,
528534
max_output_multiplier,
529535
)
530536

537+
if metrics is not None:
538+
for request in input_requests:
539+
request.metric = metrics[request.sample_idx]
540+
531541
# Sample the requests.
532542
if len(input_requests) > num_requests:
533543
input_requests = random.sample(input_requests, num_requests)
@@ -1068,11 +1078,6 @@ def parse_args() -> argparse.Namespace:
10681078
choices=["HELM", "Harness", ""],
10691079
help="mmlu method/format to generate shots",
10701080
)
1071-
parser.add_argument(
1072-
"--run-mmlu-dataset",
1073-
action="store_true",
1074-
help="specify if it's for mmlu dataset",
1075-
)
10761081
return parser.parse_args()
10771082

10781083

@@ -1094,6 +1099,7 @@ def main(args: argparse.Namespace):
10941099
tokenizer = get_tokenizer(
10951100
model_id, tokenizer_id, use_hf_tokenizer, hf_access_token
10961101
)
1102+
metrics = None
10971103
if tokenizer == "test" or args.dataset == "test":
10981104
input_requests = mock_requests(
10991105
args.total_mock_requests
@@ -1116,7 +1122,7 @@ def main(args: argparse.Namespace):
11161122
args.dataset_path,
11171123
)
11181124
elif args.dataset == "longcontext":
1119-
dataset = load_longcontext_dataset_pkl(
1125+
dataset, metrics = load_longcontext_dataset_pkl(
11201126
args.dataset_path,
11211127
)
11221128
else:
@@ -1134,11 +1140,11 @@ def main(args: argparse.Namespace):
11341140
num_requests=args.num_prompts,
11351141
dataset_type=args.dataset,
11361142
max_output_length=args.max_output_length,
1137-
run_mmlu_dataset=args.run_mmlu_dataset,
11381143
min_input_length=args.min_input_length,
11391144
max_input_length=args.max_input_length,
11401145
max_target_length=args.max_target_length,
11411146
max_output_multiplier=args.max_output_multiplier,
1147+
metrics=metrics,
11421148
)
11431149

11441150
warmup_requests = None
@@ -1184,8 +1190,10 @@ def main(args: argparse.Namespace):
11841190
# Process output
11851191
output = [output.to_dict() for output in request_outputs]
11861192
if args.run_eval:
1187-
if args.run_mmlu_dataset:
1193+
if args.dataset == "mmlu":
11881194
eval_json = eval_accuracy_mmlu(output)
1195+
elif args.dataset == "longcontext":
1196+
eval_json = eval_accuracy_longcontext(output)
11891197
else:
11901198
eval_json = eval_accuracy(output, args.dataset[:4])
11911199

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Evaluate accuracy of JetStream online serving only for long context dataset.
17+
"""
18+
19+
import argparse
20+
import nltk
21+
from tqdm import tqdm
22+
import pandas as pd
23+
import json
24+
import re
25+
from multiprocessing import Pool, cpu_count
26+
import numpy as np
27+
from rouge_score import rouge_scorer
28+
29+
30+
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
31+
32+
33+
def rouge(label, pred):
34+
"""Returns the ROUGE-L score based on the label and the prediction"""
35+
score = scorer.score(label, pred)
36+
return {
37+
"rougeL": 100 * score["rougeL"].fmeasure,
38+
}
39+
40+
41+
def niah_em(label, pred):
42+
"""
43+
Returns the NIAH (needle in a haystack) score based on the label and
44+
the prediction.
45+
Reference: https://github.com/gkamradt/LLMTest_NeedleInAHaystack
46+
"""
47+
label_uuids = re.findall(r"[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}", label)
48+
pred_uuids = re.findall(r"[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}", pred)
49+
50+
if len(pred_uuids) == 0:
51+
return {"exact_match": 0.0}
52+
53+
# https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py#L28
54+
score = (
55+
sum(
56+
[
57+
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
58+
/ len(ref)
59+
for pred, ref in zip(pred_uuids, label_uuids)
60+
]
61+
)
62+
/ len(pred_uuids)
63+
* 100
64+
)
65+
66+
return {"exact_match": round(score, 2)}
67+
68+
69+
def qa_em(label, pred):
70+
"""
71+
Returns the QA score based on the label and the prediction.
72+
Reference: https://github.com/mlcommons/inference/blob/master/\
73+
language/llama3.1-405b/evaluate-accuracy.py#L69
74+
"""
75+
answer_substring = pred
76+
77+
if "Answer: " in pred:
78+
last_answer_index = pred.rfind("Answer: ")
79+
if last_answer_index == -1:
80+
return {"exact_match": 0.0}
81+
82+
answer_substring = pred[last_answer_index + len("Answer: ") :]
83+
84+
if answer_substring in label:
85+
return {"exact_match": 100.0}
86+
87+
normalized_answer = re.sub(r"\s+", "", answer_substring).lower()
88+
label_entries = [
89+
re.sub(r"\s+", "", entry).lower() for entry in label.split("|")
90+
]
91+
92+
match_found = any(entry in normalized_answer for entry in label_entries)
93+
return {"exact_match": 100.0 if match_found else 0.0}
94+
95+
96+
def postprocess_text(preds, targets):
97+
preds = [pred.strip() for pred in preds]
98+
targets = [target.strip() for target in targets]
99+
100+
# rougeLSum expects newline after each sentence
101+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
102+
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
103+
104+
return preds, targets
105+
106+
107+
def process_item(item):
108+
pred, target, metric = item
109+
if metric == "rouge":
110+
metric_eval = rouge(target, pred)
111+
elif metric == "niah_em":
112+
metric_eval = niah_em(target, pred)
113+
elif metric == "qa_em":
114+
metric_eval = qa_em(target, pred)
115+
else:
116+
raise ValueError(f"Unknown metric: {metric}")
117+
return metric_eval
118+
119+
120+
def run_evaluation(preds, targets, target_metrics, n_process=None):
121+
n_process = cpu_count() if n_process is None else n_process
122+
with Pool(n_process) as pool:
123+
accuracies = list(
124+
tqdm(
125+
pool.imap(process_item, zip(preds, targets, target_metrics)),
126+
total=len(preds),
127+
)
128+
)
129+
df = pd.DataFrame({"accuracy": accuracies, "metric": target_metrics})
130+
return df.accuracy.apply(pd.Series).describe().loc["mean"].to_dict()
131+
132+
133+
def eval_accuracy_longcontext(request_outputs_dict):
134+
nltk.download("punkt")
135+
preds = []
136+
targets = []
137+
target_metrics = []
138+
for output in request_outputs_dict:
139+
preds.append(output["generated_text"])
140+
targets.append(output["original_output"])
141+
target_metrics.append(output["metric"])
142+
preds, targets = postprocess_text(preds, targets)
143+
result = run_evaluation(preds, targets, target_metrics)
144+
result = dict(result)
145+
prediction_lens = [len(pred) for pred in preds]
146+
result["gen_len"] = int(np.sum(prediction_lens))
147+
result["gen_num"] = len(preds)
148+
print("\nResults\n")
149+
print(result)
150+
return result
151+
152+
153+
def main(args):
154+
with open(args.output_path, "r", encoding="utf-8") as f:
155+
request_outputs_dict = json.load(f)
156+
157+
eval_accuracy_longcontext(request_outputs_dict)
158+
159+
160+
if __name__ == "__main__":
161+
parser = argparse.ArgumentParser()
162+
parser.add_argument(
163+
"--output_path",
164+
type=str,
165+
default="/tmp/request-outputs.json",
166+
help="File path which has original_output and inference generated_text.",
167+
)
168+
169+
parsed_args = parser.parse_args()
170+
171+
main(parsed_args)

benchmarks/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
nltk
1+
nltk==3.8.1
22
evaluate
33
rouge-score
44
transformers
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Tests for long context accuracy measurement."""
2+
3+
import unittest
4+
5+
from benchmarks.eval_accuracy_longcontext import eval_accuracy_longcontext
6+
7+
8+
class TestEvalAccuracy(unittest.TestCase):
9+
"""Tests for long context accuracy measurement."""
10+
11+
def setUp(self):
12+
self._request_outputs_dict = [
13+
{"generated_text": "abc", "original_output": "abc", "metric": "rouge"},
14+
{"generated_text": "abc", "original_output": "abc", "metric": "rouge"},
15+
{"generated_text": "abc", "original_output": "abc", "metric": "qa_em"},
16+
{"generated_text": "abc", "original_output": "abc", "metric": "qa_em"},
17+
{
18+
"generated_text": "abc",
19+
"original_output": "abc",
20+
"metric": "niah_em",
21+
},
22+
{
23+
"generated_text": "abc",
24+
"original_output": "abc",
25+
"metric": "niah_em",
26+
},
27+
]
28+
29+
def test_eval_accuracy_longcontext(self):
30+
self.assertEqual(
31+
eval_accuracy_longcontext(self._request_outputs_dict),
32+
{"rougeL": 100.0, "exact_match": 50.0, "gen_len": 18, "gen_num": 6},
33+
)

0 commit comments

Comments
 (0)