Skip to content

Commit b004cc5

Browse files
[feat] Add json parser and enable parser config (#19)
* Add json parser and enable parser config Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
1 parent c821540 commit b004cc5

9 files changed

Lines changed: 100 additions & 55 deletions

File tree

src/inference_endpoint/commands/benchmark.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ async def run_benchmark_command(args: argparse.Namespace) -> None:
176176
# ===== YAML MODE - Load from config file =====
177177
config_path = args.config # Required by argparse
178178
try:
179-
effective_config = ConfigLoader.load_yaml(Path(config_path))
179+
effective_config: BenchmarkConfig = ConfigLoader.load_yaml(
180+
Path(config_path)
181+
)
180182

181183
# Only auxiliary params allowed (output)
182184
mode_str = getattr(args, "mode", None)
@@ -203,7 +205,9 @@ async def run_benchmark_command(args: argparse.Namespace) -> None:
203205
elif benchmark_mode_str in ("offline", "online"):
204206
# ===== CLI MODE - Build config from CLI params =====
205207
benchmark_mode = TestType(benchmark_mode_str) # TestType values are lowercase
206-
effective_config = _build_config_from_cli(args, benchmark_mode_str)
208+
effective_config: BenchmarkConfig = _build_config_from_cli(
209+
args, benchmark_mode_str
210+
)
207211
test_mode = (
208212
TestMode(args.mode) if getattr(args, "mode", None) else TestMode.PERF
209213
)
@@ -264,7 +268,7 @@ def _build_config_from_cli(
264268
name=args.dataset.stem,
265269
type=DatasetType.PERFORMANCE,
266270
path=str(args.dataset),
267-
format="pkl", # Will be inferred by DataLoaderFactory
271+
format=None, # Will be inferred by DataLoaderFactory
268272
)
269273
],
270274
settings=Settings(
@@ -289,6 +293,7 @@ def _build_config_from_cli(
289293
),
290294
),
291295
model_params=ModelParams(
296+
name=args.model,
292297
temperature=0.7,
293298
max_new_tokens=args.max_output_tokens if args.max_output_tokens else 1024,
294299
osl_distribution=OSLDistribution(
@@ -327,8 +332,7 @@ def _get_dataset_path(args: argparse.Namespace, config: BenchmarkConfig) -> Path
327332
2. Validate all dataset paths exist
328333
3. Support dataset interleaving strategies
329334
"""
330-
# Priority: CLI args > config
331-
if args.dataset:
335+
if hasattr(args, "dataset") and args.dataset:
332336
dataset_path = Path(args.dataset)
333337
else:
334338
# TODO: Multi-dataset - currently just picks single dataset
@@ -431,6 +435,8 @@ def _run_benchmark(
431435
model_name = getattr(args, "model", None)
432436
if not model_name and config.submission_ref:
433437
model_name = config.submission_ref.model
438+
if not model_name and config.model_params.name:
439+
model_name = config.model_params.name
434440

435441
if model_name:
436442
try:
@@ -476,17 +482,17 @@ def _run_benchmark(
476482
logger.info("Streaming: disabled (auto, offline mode)")
477483

478484
try:
479-
# Create loader using factory
480-
def parser(x):
481-
return {
482-
"prompt": x.text_input,
483-
"output": x.ref_output,
484-
"model": model_name,
485-
"stream": enable_streaming, # Enable streaming only for online mode
486-
}
485+
if any(d.parser for d in config.datasets):
486+
key_maps = [d.parser for d in config.datasets]
487+
else:
488+
key_maps = None
489+
logger.info(f"Parser key maps: {key_maps}")
487490

488491
dataloader = DataLoaderFactory.create_loader(
489-
dataset_path, format=dataset_format, parser=parser
492+
dataset_path,
493+
format=dataset_format,
494+
key_maps=key_maps,
495+
metadata={"model": model_name, "stream": enable_streaming},
490496
)
491497
dataloader.load()
492498
logger.info(f"Loaded {dataloader.num_samples()} samples")

src/inference_endpoint/config/schema.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class OSLDistribution(BaseModel):
135135
class ModelParams(BaseModel):
136136
"""Model generation parameters."""
137137

138+
name: str | None = None
138139
temperature: float = 0.7
139140
top_k: int | None = None
140141
top_p: float | None = None
@@ -179,9 +180,10 @@ class Dataset(BaseModel):
179180
name: str
180181
type: DatasetType
181182
path: str
182-
format: str = "pkl"
183+
format: str | None = None
183184
samples: int | None = None
184185
eval_method: EvalMethod | None = None
186+
parser: dict | None = None
185187

186188

187189
class RuntimeConfig(BaseModel):

src/inference_endpoint/config/templates/offline_template.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,24 @@ version: "1.0"
44
type: "offline"
55

66
model_params:
7+
name: "meta-llama/Llama-3.1-8B-Instruct"
78
temperature: 0.7
89
top_p: 0.9
910
max_new_tokens: 1024
1011

1112
datasets:
1213
- name: "perf-test"
1314
type: "performance"
14-
path: "datasets/openorca.pkl"
15+
path: "tests/datasets/dummy_1k.pkl"
1516
format: "pkl"
1617
samples: 1000
18+
parser:
19+
prompt: "text_input"
1720

1821
settings:
1922
runtime:
20-
min_duration_ms: 600000 # 10 minutes
21-
max_duration_ms: 1800000 # 30 minutes
23+
min_duration_ms: 60000 # 1 minutes
24+
max_duration_ms: 180000 # 3 minutes
2225
scheduler_random_seed: 42 # For Poisson/distribution sampling
2326
dataloader_random_seed: 42 # For dataset shuffling
2427

src/inference_endpoint/config/templates/online_template.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@ version: "1.0"
44
type: "online"
55

66
model_params:
7+
name: "meta-llama/Llama-3.1-8B-Instruct"
78
temperature: 0.7
89
top_p: 0.9
910
max_new_tokens: 1024
1011

1112
datasets:
1213
- name: "latency-test"
1314
type: "performance"
14-
path: "datasets/queries.pkl"
15-
format: "pkl"
15+
path: "cnn_dailymail_train.json"
1616
samples: 500
17+
parser:
18+
prompt: "article"
1719

1820
settings:
1921
runtime:
20-
min_duration_ms: 600000 # 10 minutes
21-
max_duration_ms: 1800000 # 30 minutes
22+
min_duration_ms: 60000 # 1 minutes
23+
max_duration_ms: 180000 # 3 minutes
2224
scheduler_random_seed: 42 # For Poisson/distribution sampling
2325
dataloader_random_seed: 42 # For dataset shuffling
2426

src/inference_endpoint/dataset_manager/dataloader.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import pickle
1718
from abc import ABC, abstractmethod
1819
from collections.abc import Callable
@@ -158,7 +159,7 @@ def __init__(
158159
self.parser = parser
159160
self.logger = getLogger(__name__)
160161
if parser is None:
161-
162+
# TODO : remove this default implementation
162163
def extract_text_input(row):
163164
return row.text_input
164165

@@ -228,7 +229,9 @@ def load_sample(self, index: int) -> Any:
228229
Loads a sample from the data.
229230
"""
230231
assert self.loaded, "Data is not loaded. Call load() to load the data."
231-
return self.parser(self.data.iloc[index])
232+
x = self.parser(self.data.iloc[index])
233+
self.logger.debug(f"Loaded sample from pickle file at {index} with keys: {x}")
234+
return x
232235

233236
def get_column_names(self):
234237
return self.data.columns
@@ -289,3 +292,34 @@ def __init__(self, file_path, parser: Callable[[Any], Any] = None):
289292
parser (Callable[[Any], Any], optional): Callable to parse individual data samples. If not provided, defaults to the parent class's parsing mechanism.
290293
"""
291294
super().__init__(file_path, parser=parser)
295+
296+
297+
class JsonlReader(DataLoader):
298+
def __init__(
299+
self,
300+
file_path,
301+
parser: Callable[[Any], Any] = None,
302+
metadata: dict | None = None,
303+
):
304+
if parser is None:
305+
# TODO: Implement a parser interface where yaml files specify the fields to pars
306+
def default_parser(x):
307+
# Use cnn/daily mail dataset as an example for now.
308+
return {"prompt": x["article"]} | metadata
309+
310+
parser = default_parser
311+
super().__init__()
312+
self.file_path = file_path
313+
self.data = []
314+
self.parser = parser
315+
316+
def load(self):
317+
with open(self.file_path) as file:
318+
for line in file:
319+
self.data.append(self.parser(json.loads(line)))
320+
321+
def load_sample(self, index: int) -> Any:
322+
return self.data[index]
323+
324+
def num_samples(self):
325+
return len(self.data)

src/inference_endpoint/dataset_manager/factory.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
"""
2020

2121
import logging
22-
from collections.abc import Callable
2322
from pathlib import Path
2423

2524
from .dataloader import (
2625
DataLoader,
2726
HFDataLoader,
27+
JsonlReader,
2828
PickleReader,
2929
)
3030

@@ -44,15 +44,16 @@ class DataLoaderFactory:
4444
def create_loader(
4545
dataset_path: Path | str,
4646
format: str = "pkl",
47-
parser: Callable | None = None,
47+
key_maps: list[dict[str, str]] | None = None,
48+
metadata: dict | None = None,
4849
**kwargs,
4950
) -> DataLoader:
5051
"""Create appropriate dataset loader based on format.
5152
5253
Args:
5354
dataset_path: Path to dataset file or directory
5455
format: Dataset format ("pkl", "jsonl", "hf")
55-
parser: Optional parser function for data transformation
56+
key_maps: Dictionary of key mappings for the parser
5657
**kwargs: Additional arguments for specific loaders
5758
5859
Returns:
@@ -61,27 +62,22 @@ def create_loader(
6162
Raises:
6263
ValueError: If format is unsupported
6364
"""
64-
format = format.lower()
65-
66-
if format == "pkl" or format == "pickle":
67-
# Pickle format - use DeepSeekR1ChatCompletionDataLoader
68-
if parser is None:
69-
# Default parser for chat completion format
70-
def default_parser(x):
71-
return {"prompt": x.text_input, "output": x.ref_output}
65+
if key_maps is None:
66+
# Assume that the `prompt` key already exists in the dataset
67+
key_maps = [{"prompt": "text_input"}]
7268

73-
parser = default_parser
69+
def parser(x):
70+
# TODO : handle the entire key_maps list
71+
return {k: x[v] for k, v in key_maps[0].items()} | (metadata or {})
7472

73+
format = format.lower()
74+
if format == "pkl" or format == "pickle":
7575
logger.info(f"Creating pickle dataset loader for {dataset_path}")
7676
return PickleReader(dataset_path, parser=parser)
7777

7878
elif format == "jsonl" or format == "json":
7979
# JSON Lines format
80-
# TODO: Implement JSONLDataLoader
81-
logger.error("JSONL format not yet implemented")
82-
raise NotImplementedError(
83-
"JSONL dataset format not yet supported. " "Supported formats: pkl, hf"
84-
)
80+
return JsonlReader(dataset_path, parser=parser, metadata=metadata)
8581

8682
elif format == "hf" or format == "huggingface":
8783
# HuggingFace dataset

src/inference_endpoint/openai/openai_adapter.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
ModelIdsShared,
3030
Object7,
3131
ReasoningEffort,
32-
Role,
3332
Role5,
3433
Role6,
3534
ServiceTier,
@@ -71,14 +70,14 @@ def to_openai_request(query: Query) -> CreateChatCompletionRequest:
7170

7271
request = CreateChatCompletionRequest(
7372
model=ModelIdsShared(query.data.get("model", "no-model-name")),
74-
# service_tier=ServiceTier.auto,
7573
reasoning_effort=ReasoningEffort.medium,
7674
messages=[
77-
{
78-
"role": Role.assistant.value,
79-
"content": "You are a helpful assistant.",
80-
},
8175
{"role": Role5.user.value, "content": query.data["prompt"]},
76+
# TODO remove this once we have a way to handle the assistant message
77+
# {
78+
# "role": Role.assistant.value,
79+
# "content": "You are a helpful assistant.",
80+
# },
8281
],
8382
stream=query.data.get("stream", False),
8483
max_completion_tokens=query.data.get("max_completion_tokens", 100),

tests/unit/openai/test_openai_types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def test_create_chat_completion_request_from_query(self):
6161
)
6262

6363
messages = query.model_dump(mode="json")["messages"]
64-
assert len(messages) == 2, f"Expected 2 messages, got {len(messages)}"
64+
assert len(messages) == 1, f"Expected 1 messages, got {len(messages)}"
65+
# TODO : cleanup this once we have a way to handle the assistant message
6566
for message in messages:
6667
assert message["role"] in [
6768
"assistant",
@@ -73,10 +74,10 @@ def test_create_chat_completion_request_from_query(self):
7374
assert (
7475
message["name"] is None
7576
), f"Expected name to be None, got {message['name']}"
76-
if message["role"] == "assistant":
77-
assert message["content"] == "You are a helpful assistant."
78-
else:
77+
if message["role"] == "user":
7978
assert message["content"] == "Test prompt"
79+
# TODO : cleanup this once we have a way to handle the assistant message
80+
# assert message["content"] == "You are a helpful assistant."
8081

8182
def test_create_chat_completion_response_from_query_result(self):
8283
message_content = "You are a helpful assistant."

tests/unit/test_core_types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ def test_query_creation(self) -> None:
4141
query = OpenAIAdapter.to_openai_request(
4242
Query(id="test-123", data=payload)
4343
).model_dump(mode="json")
44-
assert query["messages"][0]["content"] == "You are a helpful assistant."
45-
assert query["messages"][1]["content"] == "Test prompt"
44+
assert query["messages"][0]["content"] == "Test prompt"
45+
# TODO : remove this once we have a way to handle the assistant message
46+
# assert query["messages"][1]["content"] == "You are a helpful assistant."
4647
assert query["model"] == "test-model"
4748
assert query["max_completion_tokens"] == 100
4849
assert query["temperature"] == 0.7 # default value
@@ -60,8 +61,9 @@ def test_query_store_load(self) -> None:
6061
query_loaded = OpenAIAdapter.to_openai_request(
6162
Query(id="test-123", data=payload)
6263
)
63-
assert query_loaded.messages[0].root.content == "You are a helpful assistant."
64-
assert query_loaded.messages[1].root.content == payload["prompt"]
64+
assert query_loaded.messages[0].root.content == payload["prompt"]
65+
# TODO : remove this once we have a way to handle the assistant message
66+
# assert query_loaded.messages[1].root.content == "You are a helpful assistant."
6567
assert query_loaded.model.root == payload["model"]
6668
assert query_loaded.max_completion_tokens == payload["max_completion_tokens"]
6769
assert query_loaded.temperature == payload["temperature"]

0 commit comments

Comments
 (0)