Skip to content

Commit d95a4b7

Browse files
authored
Add long context dataset benchmark support (#227)
1 parent 5240b3b commit d95a4b7

5 files changed

Lines changed: 32 additions & 16 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,21 @@ def load_openorca_dataset_pkl(
279279
return [(prompt, output) for prompt, output in zip(prompts, outputs)]
280280

281281

282+
def load_longcontext_dataset_pkl(
283+
dataset_path: str,
284+
) -> list[tuple[Any, Any]]:
285+
assert os.path.isfile(dataset_path)
286+
287+
# read pickle file
288+
data = pandas.read_pickle(dataset_path)
289+
290+
samples = []
291+
for _, row in data.iterrows():
292+
samples.append((row["input"], row["ref_output"]))
293+
294+
return samples
295+
296+
282297
def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
283298
assert dataset_path != ""
284299
dataset = []
@@ -837,7 +852,14 @@ def parse_args() -> argparse.Namespace:
837852
"--dataset",
838853
type=str,
839854
default="test",
840-
choices=["test", "sharegpt", "openorca", "mmlu", "math500"],
855+
choices=[
856+
"test",
857+
"sharegpt",
858+
"openorca",
859+
"mmlu",
860+
"math500",
861+
"longcontext",
862+
],
841863
help="The dataset name.",
842864
)
843865
parser.add_argument("--dataset-path", type=str, help="Path to the dataset.")
@@ -1057,6 +1079,10 @@ def main(args: argparse.Namespace):
10571079
dataset = load_math500_dataset(
10581080
args.dataset_path,
10591081
)
1082+
elif args.dataset == "longcontext":
1083+
dataset = load_longcontext_dataset_pkl(
1084+
args.dataset_path,
1085+
)
10601086
else:
10611087
raise ValueError(
10621088
f"Fatal Error: Uncognized input parameters: {args.dataset}"

jetstream/core/orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def _prefill_thread(self, idx: int):
628628
is_bos,
629629
prefill_engine.max_prefill_length,
630630
prefill_engine.use_chunked_prefill,
631-
prefill_engine.chunk_size,
631+
prefill_engine.prefill_chunk_size,
632632
)
633633
)
634634
prefill_result = None
@@ -649,7 +649,7 @@ def _prefill_thread(self, idx: int):
649649
t_l_array = jnp.expand_dims(
650650
jnp.arange(
651651
0,
652-
chunk_num * prefill_engine.chunk_size
652+
chunk_num * prefill_engine.prefill_chunk_size
653653
+ true_lengths_of_chunks[chunk_num],
654654
),
655655
1,

jetstream/engine/engine_api.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,6 @@ def colocated_cpus(self) -> Union[list[CpuDevices], None]:
445445
def use_chunked_prefill(self) -> bool:
446446
return self._downstream_engine.use_chunked_prefill
447447

448-
@property
449-
def chunk_size(self) -> bool:
450-
"""Maximum prefill length."""
451-
return self._downstream_engine.prefill_chunk_size
452-
453448
@property
454449
def prefill_chunk_size(self) -> int:
455450
"""Maximum prefill length."""

jetstream/engine/mock_engine.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,6 @@ def use_chunked_prefill(self) -> bool:
500500
"""Maximum prefill length."""
501501
return self._use_chunked_prefill
502502

503-
@property
504-
def chunk_size(self) -> bool:
505-
"""Maximum prefill length."""
506-
return 2
507-
508503
@property
509504
def prefill_chunk_size(self) -> int:
510505
"""Maximum prefill length."""

jetstream/tests/core/test_orchestrator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ async def test_orchestrator_chunked_prefill(self, interleaved_mode: bool):
108108
max_tokens=3,
109109
)
110110
iterator = client.Decode(request)
111-
# chr of [266, 332, 415].
112-
expected_text = ["B", "R", "g", ""]
113-
expected_token_ids = [66, 82, 103, None]
111+
# chr of [135, 168, 210].
112+
expected_text = ["\x87", "¨", "Ò", ""]
113+
expected_token_ids = [135, 168, 210, None]
114114
counter = 0
115115
async for resp in iterator:
116116
output_text = resp.stream_content.samples[0].text

0 commit comments

Comments
 (0)