Skip to content

Commit f5c19bc

Browse files
authored
fixes to formatting and benchmark eval (#224)
1 parent a75f64f commit f5c19bc

3 files changed

Lines changed: 22 additions & 9 deletions

File tree

benchmarks/eval_accuracy.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,8 @@ def eval_accuracy(request_outputs_dict, match_type):
6666
for output in request_outputs_dict:
6767
preds.append(output["generated_text"])
6868
targets.append(output["original_output"])
69-
if match_type != "math":
70-
preds, targets = postprocess_text(preds, targets)
7169

7270
if match_type == "math":
73-
7471
correct_ans = 0
7572
wrong_ans = 0
7673
for p, t in zip(preds, targets):
@@ -87,11 +84,11 @@ def eval_accuracy(request_outputs_dict, match_type):
8784
result["literal"] = correct_ans / total_ans if total_ans > 0 else 0.0
8885
result["gen_len"] = total_ans
8986
result["gen_num"] = total_ans
90-
if match_type == "rouge":
87+
88+
else:
9189
metric = evaluate.load("rouge")
92-
nltk.download("punkt")
93-
preds = []
94-
targets = []
90+
nltk.download("punkt_tab")
91+
preds, targets = postprocess_text(preds, targets)
9592
result = metric.compute(
9693
predictions=preds,
9794
references=targets,

jetstream/core/orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,8 +1384,8 @@ async def Decode( # pylint: disable=invalid-overridden-method
13841384
if ttft == 0:
13851385
ttft = time.perf_counter() - request_start_time
13861386
if ttft > 2.0:
1387-
logging.info(
1388-
datetime.now(),
1387+
logger.info( # pylint: disable=logging-fstring-interpolation
1388+
f"{datetime.now()}: "
13891389
f"Slow TTFT: {ttft:.2f}s,"
13901390
f" stats={active_request.metadata.stats()},"
13911391
f" prefill_qsize={self._driver.prefill_backlog_size()}",

jetstream/engine/engine_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,14 @@ def insert(
367367
prefix: Prefix,
368368
decode_state: DecodeState,
369369
slot: int,
370+
request_id: Optional[uuid.UUID] = None,
370371
) -> DecodeState:
371372

372373
decode_state = self._downstream_engine.insert(
373374
prefix=prefix,
374375
decode_state=decode_state,
375376
slot=slot,
377+
request_id=request_id,
376378
)
377379
return decode_state
378380

@@ -438,3 +440,17 @@ def mesh(self) -> jax.sharding.Mesh:
438440
@property
439441
def colocated_cpus(self) -> Union[list[CpuDevices], None]:
440442
return self._downstream_engine.colocated_cpus
443+
444+
@property
445+
def use_chunked_prefill(self) -> bool:
446+
return self._downstream_engine.use_chunked_prefill
447+
448+
@property
449+
def chunk_size(self) -> bool:
450+
"""Maximum prefill length."""
451+
return self._downstream_engine.prefill_chunk_size
452+
453+
@property
454+
def prefill_chunk_size(self) -> int:
455+
"""Maximum prefill length."""
456+
return self._downstream_engine.prefill_chunk_size

0 commit comments

Comments
 (0)