Skip to content

Commit 6ed8bf4

Browse files
committed
Add new accuracy test cases
1 parent 80e8cfc commit 6ed8bf4

7 files changed

Lines changed: 207 additions & 9 deletions

File tree

test/common/db_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ def _get_db():
6565
db_config = _get_db_config()
6666
_db_enabled = db_config.get("enabled", False)
6767

68-
backup_str = db_config.get("backup", "results/")
69-
_backup_path = Path(backup_str).resolve()
70-
_backup_path.mkdir(parents=True, exist_ok=True)
71-
logger.info(f"Backup directory set to: {_backup_path}")
72-
7368
if not _db_enabled:
7469
return None
7570

@@ -205,10 +200,16 @@ def write_to_db(table_name: str, data: Dict[str, Any]) -> bool:
205200

206201

207202
def database_connection(build_id: str) -> None:
203+
global _backup_path
208204
logger.info(f"Setting test build ID: {build_id}")
209205
_set_test_build_id(build_id)
210206

211207
db_config = _get_db_config()
208+
backup_str = db_config.get("backup", "results/")
209+
_backup_path = Path(backup_str).resolve()
210+
_backup_path.mkdir(parents=True, exist_ok=True)
211+
logger.info(f"Backup directory set to: {_backup_path}")
212+
212213
if not db_config.get("enabled", False):
213214
logger.info("Database connection skipped because enabled=false.")
214215
return

test/common/uc_eval/datasets/doc_qa/Galaxy_Railroad.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

test/common/uc_eval/datasets/doc_qa/prompt.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

test/common/uc_eval/task.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import json
2+
import os
13
import time
24
from abc import ABC, abstractmethod
35
from typing import Any, Dict, List, Union
46

7+
import pandas as pd
58
from common.uc_eval.utils.config_loader import ConfigLoader, TaskFactory
69
from common.uc_eval.utils.data_class import (
710
BenchmarkModeType,
@@ -241,6 +244,7 @@ def __init__(
241244
self.prompt_tokens = perf_config.prompt_tokens
242245
self.output_tokens = perf_config.output_tokens
243246
self.prefix_cache_num = perf_config.prefix_cache_num
247+
self.enable_warmup = perf_config.enable_warmup
244248
self.prompt_seed = 0 if self.enable_prefix_cache else -1
245249
self.stable_perf = self.benchmark_mode == BenchmarkModeType.STABLE_PREF
246250
self.stable_rate = stable_rate
@@ -272,7 +276,11 @@ def process(self):
272276
logger.info(
273277
f"Performance benchmark running with: enable prefix cache: ({self.enable_prefix_cache}), {syntheric_params=}"
274278
)
275-
if self.enable_prefix_cache and self.prefix_cache_num[idx] > 0:
279+
if (
280+
self.enable_prefix_cache
281+
and self.prefix_cache_num[idx] > 0
282+
and self.enable_warmup
283+
):
276284
logger.info(f"Begin build kvcache...")
277285
input_data = self.dataset.prepare_data(syntheric_params)
278286
self.client.handle_requests_with_pool(
@@ -359,10 +367,11 @@ def __init__(
359367
)
360368
self.dataset_file_path = perf_config.dataset_file_path
361369
self.max_tokens = model_config.payload.get("max_tokens")
370+
self.enable_warmup = perf_config.enable_warmup
362371

363372
def process(self):
364373
cases_list = self.dataset.prepare_data(self.dataset_file_path)
365-
if self.enable_prefix_cache:
374+
if self.enable_prefix_cache and self.enable_warmup:
366375
logger.info("Begin build kvcache...")
367376
self.client.handle_requests_with_pool(
368377
cases_list, self.parallel_num, BAD_COMPLETION_TOKENS_THR
@@ -389,10 +398,39 @@ def __init__(
389398
self.dataset_file_path = eval_config.dataset_file_path
390399
self.max_tokens = model_config.payload.get("max_tokens")
391400
self.eval_cls = eval_config.eval_class
401+
self.prompt_split_ratio = eval_config.prompt_split_ratio
402+
self.enable_warmup = eval_config.enable_warmup
403+
self.enable_clear_hbm = model_config.enable_clear_hbm
404+
self.round = getattr(eval_config, "round", 0)
405+
406+
def _split_prompt_by_tokens(
407+
self, prompt: str, tokenizer, split_ratio: float
408+
) -> str:
409+
"""Split prompt by token ratio and return the first part."""
410+
tokens = tokenizer.encode(prompt)
411+
split_idx = int(len(tokens) * split_ratio)
412+
first_tokens = tokens[:split_idx]
413+
return tokenizer.decode(first_tokens, skip_special_tokens=False)
392414

393415
def process(self):
394416
cases_list = self.dataset.prepare_data(self.dataset_file_path)
395-
if self.enable_prefix_cache:
417+
418+
if self.prompt_split_ratio is not None and 0 < self.prompt_split_ratio < 1:
419+
logger.info(
420+
f"Applying prompt split ratio: {self.prompt_split_ratio} (only sending first {self.prompt_split_ratio*100:.0f}% of prompt)"
421+
)
422+
tokenizer = self.client.tokenizer
423+
modified_cases = []
424+
for case in cases_list:
425+
case_name, context, question, answer = case
426+
full_prompt = context + question
427+
split_prompt = self._split_prompt_by_tokens(
428+
full_prompt, tokenizer, self.prompt_split_ratio
429+
)
430+
modified_cases.append([case_name, split_prompt, "", answer])
431+
cases_list = modified_cases
432+
433+
if self.enable_prefix_cache and self.enable_warmup:
396434
logger.info("Begin build kvcache...")
397435
self.client.handle_requests_with_pool(
398436
cases_list, self.parallel_num, BAD_COMPLETION_TOKENS_THR
@@ -402,8 +440,56 @@ def process(self):
402440
records: List[RequestRecord] = self.client.handle_requests_with_pool(
403441
cases_list, self.parallel_num, self.max_tokens
404442
)
443+
444+
if self.prompt_split_ratio is not None and 0 < self.prompt_split_ratio < 1:
445+
logger.info(
446+
f"Skipping accuracy evaluation when prompt_split_ratio={self.prompt_split_ratio} (service ran but no accuracy check)"
447+
)
448+
from common.uc_eval.utils.data_class import LatencyStatistics
449+
450+
empty_latency = LatencyStatistics()
451+
empty_latency.metric_dict = {}
452+
return empty_latency, len(records)
453+
405454
metric_result, match_record_list = self.benchmark.perf_show(
406455
records, self.parallel_num
407456
)
457+
458+
if self.enable_clear_hbm:
459+
self.client.clear_hbm()
460+
408461
self.save_eval_cases_excel(match_record_list, self.eval_cls)
462+
self.compare_first_round_results(match_record_list, self.round)
409463
return metric_result, len(records)
464+
465+
def compare_first_round_results(
466+
self, match_record_list: List[RequestRecord], round: int
467+
):
468+
if round == 0:
469+
return
470+
cache_file = "first_round_outputs.json"
471+
if round == 1:
472+
first_round_data = {r.case_name: r.output_data for r in match_record_list}
473+
with open(cache_file, "w", encoding="utf-8") as f:
474+
json.dump(first_round_data, f, ensure_ascii=False, indent=2)
475+
logger.info(f"First round outputs saved to {cache_file}")
476+
elif round == 2:
477+
if not os.path.exists(cache_file):
478+
return
479+
with open(cache_file, "r", encoding="utf-8") as f:
480+
first_round_data = json.load(f)
481+
for r in match_record_list:
482+
if r.case_name in first_round_data:
483+
first_output = first_round_data[r.case_name]
484+
is_match = first_output == r.output_data
485+
logger.info(f"First Round Output: {first_output}")
486+
logger.info(f"Second Round Output: {r.output_data}")
487+
if not is_match:
488+
logger.error(
489+
f"Case {r.case_name}: The output results are inconsistent."
490+
)
491+
else:
492+
logger.info(
493+
f"Case {r.case_name}: The output results are consistent"
494+
)
495+
os.remove(cache_file)

test/common/uc_eval/utils/config_loader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ def create_task(
194194
client_kwargs = {}
195195
if data_type is DatasetType.MULTI_DIALOGUE:
196196
client_kwargs["enable_prefix_cache"] = perf_config.enable_prefix_cache
197+
elif data_type is DatasetType.DOC_QA and eval_config:
198+
if (
199+
hasattr(eval_config, "prompt_split_ratio")
200+
and eval_config.prompt_split_ratio is not None
201+
):
202+
client_kwargs["prompt_split_ratio"] = eval_config.prompt_split_ratio
197203
return (
198204
cls._dataset[data_type](tokenizer_path),
199205
cls._client[data_type](model_config, stream, **client_kwargs),

test/common/uc_eval/utils/data_class.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ class ModelConfig:
3636
class EvalConfig:
3737
data_type: str = ""
3838
dataset_file_path: str = ""
39-
enable_prefix_cache: str = False
39+
enable_prefix_cache: bool = False
4040
parallel_num: int = 1
4141
benchmark_mode: str = "evaluate"
4242
metrics: Optional[List[str]] = field(default_factory=list)
4343
eval_class: Optional[str] = None
44+
prompt_split_ratio: Optional[float] = None
45+
enable_warmup: bool = True
46+
round: int = 0
4447

4548

4649
@dataclass
@@ -53,6 +56,7 @@ class PerfConfig:
5356
output_tokens: List[int] = field(default_factory=list)
5457
prefix_cache_num: List[float] = field(default_factory=list)
5558
benchmark_mode: str = ""
59+
enable_warmup: bool = True
5660

5761

5862
@dataclass

test/suites/E2E/test_accuracy.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import dataclasses
2+
import json
3+
4+
import pytest
5+
from common.capture_utils import export_vars
6+
from common.config_utils import config_utils as config_instance
7+
from common.uc_eval.task import DocQaEvalTask
8+
from common.uc_eval.utils.data_class import EvalConfig, ModelConfig
9+
10+
11+
@pytest.fixture(scope="session")
12+
def model_config() -> ModelConfig:
13+
cfg = config_instance.get_config("models") or {}
14+
field_name = [field.name for field in dataclasses.fields(ModelConfig)]
15+
kwargs = {k: v for k, v in cfg.items() if k in field_name and v is not None}
16+
if "payload" in kwargs and isinstance(kwargs["payload"], str):
17+
try:
18+
kwargs["payload"] = json.loads(kwargs["payload"])
19+
except json.JSONDecodeError as e:
20+
raise ValueError(f"Invalid payload JSON format: {e}")
21+
return ModelConfig(**kwargs)
22+
23+
24+
_DOC_QA_BASE_CONFIG = {
25+
"data_type": "doc_qa",
26+
"dataset_file_path": "../../common/uc_eval/datasets/doc_qa/Galaxy_Railroad.json",
27+
"enable_prefix_cache": True,
28+
"parallel_num": 1,
29+
"benchmark_mode": "evaluate",
30+
"metrics": ["accuracy", "bootstrap-accuracy", "f1-score"],
31+
"eval_class": "common.uc_eval.utils.metric:Includes",
32+
}
33+
34+
doc_qa_eval_cases = [
35+
pytest.param(
36+
EvalConfig(
37+
**{
38+
**_DOC_QA_BASE_CONFIG,
39+
"prompt_split_ratio": None,
40+
"enable_warmup": False,
41+
"round": 1,
42+
}
43+
),
44+
id="doc-qa-full-prompt-warmup-evaluate",
45+
),
46+
pytest.param(
47+
EvalConfig(
48+
**{**_DOC_QA_BASE_CONFIG, "prompt_split_ratio": 0.5, "enable_warmup": False}
49+
),
50+
id="doc-qa-full-prompt-no-warmup-evaluate",
51+
),
52+
pytest.param(
53+
EvalConfig(
54+
**{
55+
**_DOC_QA_BASE_CONFIG,
56+
"prompt_split_ratio": None,
57+
"enable_warmup": False,
58+
"round": 2,
59+
}
60+
),
61+
id="doc-qa-half-prompt-warmup-evaluate",
62+
),
63+
]
64+
65+
test_configs = [
66+
pytest.param(
67+
{"max_tokens": 1024, "ignore_eos": True, "temperature": 0.7},
68+
False, # enable_clear_hbm
69+
id="max_tokens_2048_clear_hbm_true",
70+
),
71+
]
72+
73+
74+
@pytest.mark.feature("accu_test")
75+
@pytest.mark.stage(2)
76+
@pytest.mark.parametrize("eval_config", doc_qa_eval_cases)
77+
@pytest.mark.parametrize("payload_updates,enable_clear_hbm", test_configs)
78+
@export_vars
79+
def test_doc_qa_perf(
80+
eval_config: EvalConfig,
81+
model_config: ModelConfig,
82+
payload_updates: dict,
83+
enable_clear_hbm: bool,
84+
request: pytest.FixtureRequest,
85+
):
86+
file_save_path = config_instance.get_config("reports").get("base_dir")
87+
if isinstance(model_config.payload, str):
88+
model_config.payload = json.loads(model_config.payload)
89+
90+
model_config.payload.update(payload_updates)
91+
92+
if eval_config.prompt_split_ratio is None:
93+
model_config.enable_clear_hbm = True
94+
else:
95+
model_config.enable_clear_hbm = enable_clear_hbm
96+
97+
task = DocQaEvalTask(model_config, eval_config, file_save_path)
98+
result = task.run()
99+
return {"_name": request.node.callspec.id, "_data": result}

0 commit comments

Comments
 (0)