File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -66,11 +66,8 @@ def eval_accuracy(request_outputs_dict, match_type):
6666 for output in request_outputs_dict :
6767 preds .append (output ["generated_text" ])
6868 targets .append (output ["original_output" ])
69- if match_type != "math" :
70- preds , targets = postprocess_text (preds , targets )
7169
7270 if match_type == "math" :
73-
7471 correct_ans = 0
7572 wrong_ans = 0
7673 for p , t in zip (preds , targets ):
@@ -87,11 +84,11 @@ def eval_accuracy(request_outputs_dict, match_type):
8784 result ["literal" ] = correct_ans / total_ans if total_ans > 0 else 0.0
8885 result ["gen_len" ] = total_ans
8986 result ["gen_num" ] = total_ans
90- if match_type == "rouge" :
87+
88+ else :
9189 metric = evaluate .load ("rouge" )
92- nltk .download ("punkt" )
93- preds = []
94- targets = []
90+ nltk .download ("punkt_tab" )
91+ preds , targets = postprocess_text (preds , targets )
9592 result = metric .compute (
9693 predictions = preds ,
9794 references = targets ,
Original file line number Diff line number Diff line change @@ -1384,8 +1384,8 @@ async def Decode( # pylint: disable=invalid-overridden-method
13841384 if ttft == 0 :
13851385 ttft = time .perf_counter () - request_start_time
13861386 if ttft > 2.0 :
1387- logging .info (
1388- datetime .now (),
1387+ logger .info ( # pylint: disable=logging-fstring-interpolation
1388+ f" { datetime .now ()} : "
13891389 f"Slow TTFT: { ttft :.2f} s,"
13901390 f" stats={ active_request .metadata .stats ()} ,"
13911391 f" prefill_qsize={ self ._driver .prefill_backlog_size ()} " ,
Original file line number Diff line number Diff line change @@ -367,12 +367,14 @@ def insert(
367367 prefix : Prefix ,
368368 decode_state : DecodeState ,
369369 slot : int ,
370+ request_id : Optional [uuid .UUID ] = None ,
370371 ) -> DecodeState :
371372
372373 decode_state = self ._downstream_engine .insert (
373374 prefix = prefix ,
374375 decode_state = decode_state ,
375376 slot = slot ,
377+ request_id = request_id ,
376378 )
377379 return decode_state
378380
@@ -438,3 +440,17 @@ def mesh(self) -> jax.sharding.Mesh:
438440 @property
439441 def colocated_cpus (self ) -> Union [list [CpuDevices ], None ]:
440442 return self ._downstream_engine .colocated_cpus
443+
444+ @property
445+ def use_chunked_prefill (self ) -> bool :
446+ return self ._downstream_engine .use_chunked_prefill
447+
448+ @property
449+ def chunk_size (self ) -> bool :
450+ """Maximum prefill length."""
451+ return self ._downstream_engine .prefill_chunk_size
452+
453+ @property
454+ def prefill_chunk_size (self ) -> int :
455+ """Maximum prefill length."""
456+ return self ._downstream_engine .prefill_chunk_size
You can’t perform that action at this time.
0 commit comments