1+ import json
2+ import os
13import time
24from abc import ABC , abstractmethod
35from typing import Any , Dict , List , Union
46
7+ import pandas as pd
58from common .uc_eval .utils .config_loader import ConfigLoader , TaskFactory
69from common .uc_eval .utils .data_class import (
710 BenchmarkModeType ,
@@ -241,6 +244,7 @@ def __init__(
241244 self .prompt_tokens = perf_config .prompt_tokens
242245 self .output_tokens = perf_config .output_tokens
243246 self .prefix_cache_num = perf_config .prefix_cache_num
247+ self .enable_warmup = perf_config .enable_warmup
244248 self .prompt_seed = 0 if self .enable_prefix_cache else - 1
245249 self .stable_perf = self .benchmark_mode == BenchmarkModeType .STABLE_PREF
246250 self .stable_rate = stable_rate
@@ -272,7 +276,11 @@ def process(self):
272276 logger .info (
273277 f"Performance benchmark running with: enable prefix cache: ({ self .enable_prefix_cache } ), { syntheric_params = } "
274278 )
275- if self .enable_prefix_cache and self .prefix_cache_num [idx ] > 0 :
279+ if (
280+ self .enable_prefix_cache
281+ and self .prefix_cache_num [idx ] > 0
282+ and self .enable_warmup
283+ ):
276284 logger .info (f"Begin build kvcache..." )
277285 input_data = self .dataset .prepare_data (syntheric_params )
278286 self .client .handle_requests_with_pool (
@@ -359,10 +367,11 @@ def __init__(
359367 )
360368 self .dataset_file_path = perf_config .dataset_file_path
361369 self .max_tokens = model_config .payload .get ("max_tokens" )
370+ self .enable_warmup = perf_config .enable_warmup
362371
363372 def process (self ):
364373 cases_list = self .dataset .prepare_data (self .dataset_file_path )
365- if self .enable_prefix_cache :
374+ if self .enable_prefix_cache and self . enable_warmup :
366375 logger .info ("Begin build kvcache..." )
367376 self .client .handle_requests_with_pool (
368377 cases_list , self .parallel_num , BAD_COMPLETION_TOKENS_THR
@@ -389,10 +398,39 @@ def __init__(
389398 self .dataset_file_path = eval_config .dataset_file_path
390399 self .max_tokens = model_config .payload .get ("max_tokens" )
391400 self .eval_cls = eval_config .eval_class
401+ self .prompt_split_ratio = eval_config .prompt_split_ratio
402+ self .enable_warmup = eval_config .enable_warmup
403+ self .enable_clear_hbm = model_config .enable_clear_hbm
404+ self .round = getattr (eval_config , "round" , 0 )
405+
406+ def _split_prompt_by_tokens (
407+ self , prompt : str , tokenizer , split_ratio : float
408+ ) -> str :
409+ """Split prompt by token ratio and return the first part."""
410+ tokens = tokenizer .encode (prompt )
411+ split_idx = int (len (tokens ) * split_ratio )
412+ first_tokens = tokens [:split_idx ]
413+ return tokenizer .decode (first_tokens , skip_special_tokens = False )
392414
393415 def process (self ):
394416 cases_list = self .dataset .prepare_data (self .dataset_file_path )
395- if self .enable_prefix_cache :
417+
418+ if self .prompt_split_ratio is not None and 0 < self .prompt_split_ratio < 1 :
419+ logger .info (
420+ f"Applying prompt split ratio: { self .prompt_split_ratio } (only sending first { self .prompt_split_ratio * 100 :.0f} % of prompt)"
421+ )
422+ tokenizer = self .client .tokenizer
423+ modified_cases = []
424+ for case in cases_list :
425+ case_name , context , question , answer = case
426+ full_prompt = context + question
427+ split_prompt = self ._split_prompt_by_tokens (
428+ full_prompt , tokenizer , self .prompt_split_ratio
429+ )
430+ modified_cases .append ([case_name , split_prompt , "" , answer ])
431+ cases_list = modified_cases
432+
433+ if self .enable_prefix_cache and self .enable_warmup :
396434 logger .info ("Begin build kvcache..." )
397435 self .client .handle_requests_with_pool (
398436 cases_list , self .parallel_num , BAD_COMPLETION_TOKENS_THR
@@ -402,8 +440,56 @@ def process(self):
402440 records : List [RequestRecord ] = self .client .handle_requests_with_pool (
403441 cases_list , self .parallel_num , self .max_tokens
404442 )
443+
444+ if self .prompt_split_ratio is not None and 0 < self .prompt_split_ratio < 1 :
445+ logger .info (
446+ f"Skipping accuracy evaluation when prompt_split_ratio={ self .prompt_split_ratio } (service ran but no accuracy check)"
447+ )
448+ from common .uc_eval .utils .data_class import LatencyStatistics
449+
450+ empty_latency = LatencyStatistics ()
451+ empty_latency .metric_dict = {}
452+ return empty_latency , len (records )
453+
405454 metric_result , match_record_list = self .benchmark .perf_show (
406455 records , self .parallel_num
407456 )
457+
458+ if self .enable_clear_hbm :
459+ self .client .clear_hbm ()
460+
408461 self .save_eval_cases_excel (match_record_list , self .eval_cls )
462+ self .compare_first_round_results (match_record_list , self .round )
409463 return metric_result , len (records )
464+
465+ def compare_first_round_results (
466+ self , match_record_list : List [RequestRecord ], round : int
467+ ):
468+ if round == 0 :
469+ return
470+ cache_file = "first_round_outputs.json"
471+ if round == 1 :
472+ first_round_data = {r .case_name : r .output_data for r in match_record_list }
473+ with open (cache_file , "w" , encoding = "utf-8" ) as f :
474+ json .dump (first_round_data , f , ensure_ascii = False , indent = 2 )
475+ logger .info (f"First round outputs saved to { cache_file } " )
476+ elif round == 2 :
477+ if not os .path .exists (cache_file ):
478+ return
479+ with open (cache_file , "r" , encoding = "utf-8" ) as f :
480+ first_round_data = json .load (f )
481+ for r in match_record_list :
482+ if r .case_name in first_round_data :
483+ first_output = first_round_data [r .case_name ]
484+ is_match = first_output == r .output_data
485+ logger .info (f"First Round Output: { first_output } " )
486+ logger .info (f"Second Round Output: { r .output_data } " )
487+ if not is_match :
488+ logger .error (
489+ f"Case { r .case_name } : The output results are inconsistent."
490+ )
491+ else :
492+ logger .info (
493+ f"Case { r .case_name } : The output results are consistent"
494+ )
495+ os .remove (cache_file )
0 commit comments