Skip to content

Commit d3ba681

Browse files
authored
Merge pull request #1235 from samdoran/infer-max
Set maximum length for fields
2 parents 6528bc2 + d15b85d commit d3ba681

4 files changed

Lines changed: 54 additions & 16 deletions

File tree

src/models/rlsapi/requests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class RlsapiV1Attachment(ConfigurationBase):
3131

3232
contents: str = Field(
3333
default="",
34+
max_length=65_536,
3435
description="File contents read on client",
3536
examples=["# Configuration file\nkey=value"],
3637
)
@@ -50,6 +51,7 @@ class RlsapiV1Terminal(ConfigurationBase):
5051

5152
output: str = Field(
5253
default="",
54+
max_length=65_536,
5355
description="Terminal output from client",
5456
examples=["bash: command not found", "Permission denied"],
5557
)
@@ -129,6 +131,7 @@ class RlsapiV1Context(ConfigurationBase):
129131

130132
stdin: str = Field(
131133
default="",
134+
max_length=65_536,
132135
description="Redirect input from stdin",
133136
examples=["piped input from previous command"],
134137
)
@@ -173,6 +176,7 @@ class RlsapiV1InferRequest(ConfigurationBase):
173176
question: str = Field(
174177
...,
175178
min_length=1,
179+
max_length=10_240,
176180
description="User question",
177181
examples=["How do I list files?", "How do I configure SELinux?"],
178182
)

tests/integration/endpoints/test_rlsapi_v1_integration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import pytest
1515
from fastapi import HTTPException, status
16+
from fastapi.testclient import TestClient
1617
from llama_stack_client import APIConnectionError
1718
from pytest_mock import MockerFixture
1819

@@ -494,3 +495,22 @@ async def test_rlsapi_v1_infer_skip_rag(
494495
auth=test_auth,
495496
)
496497
assert isinstance(response, RlsapiV1InferResponse)
498+
499+
500+
@pytest.mark.parametrize(
501+
"json",
502+
(
503+
({"question": "?" * 10_241}),
504+
({"question": "Q", "context": {"stdin": "a" * 65_537}}),
505+
({"question": "Q", "context": {"attachments": {"contents": "A" * 65_537}}}),
506+
({"question": "Q", "context": {"terminal": {"output": "T" * 65_537}}}),
507+
),
508+
ids=["question", "stdin", "attachment_contents", "terminal_output"],
509+
)
510+
def test_infer_size_limit(integration_http_client: TestClient, json) -> None:
511+
"""Test that a field exceeding limit is rejected."""
512+
response = integration_http_client.post("/v1/infer", json=json)
513+
detail = response.json()["detail"]
514+
515+
assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT
516+
assert "string_too_long" in {item["type"] for item in detail}

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def test_config_error_503_matches_llm_error_503_shape(
309309
# --- Test retrieve_simple_response ---
310310

311311

312-
@pytest.mark.asyncio
313312
async def test_retrieve_simple_response_success(
314313
mock_configuration: AppConfig, mock_llm_response: None
315314
) -> None:
@@ -320,7 +319,6 @@ async def test_retrieve_simple_response_success(
320319
assert response == "This is a test LLM response."
321320

322321

323-
@pytest.mark.asyncio
324322
async def test_retrieve_simple_response_empty_output(
325323
mock_configuration: AppConfig, mock_empty_llm_response: None
326324
) -> None:
@@ -331,7 +329,6 @@ async def test_retrieve_simple_response_empty_output(
331329
assert response == ""
332330

333331

334-
@pytest.mark.asyncio
335332
async def test_retrieve_simple_response_api_connection_error(
336333
mock_configuration: AppConfig, mock_api_connection_error: None
337334
) -> None:
@@ -384,7 +381,6 @@ def test_get_rh_identity_context_with_empty_values(mocker: MockerFixture) -> Non
384381
# --- Test infer_endpoint ---
385382

386383

387-
@pytest.mark.asyncio
388384
async def test_infer_minimal_request(
389385
mocker: MockerFixture,
390386
mock_configuration: AppConfig,
@@ -409,7 +405,6 @@ async def test_infer_minimal_request(
409405
assert check_suid(response.data.request_id)
410406

411407

412-
@pytest.mark.asyncio
413408
async def test_infer_full_context_request(
414409
mocker: MockerFixture,
415410
mock_configuration: AppConfig,
@@ -441,7 +436,6 @@ async def test_infer_full_context_request(
441436
assert response.data.request_id
442437

443438

444-
@pytest.mark.asyncio
445439
async def test_infer_generates_unique_request_ids(
446440
mocker: MockerFixture,
447441
mock_configuration: AppConfig,
@@ -469,7 +463,6 @@ async def test_infer_generates_unique_request_ids(
469463
assert response1.data.request_id != response2.data.request_id
470464

471465

472-
@pytest.mark.asyncio
473466
async def test_infer_api_connection_error_returns_503(
474467
mocker: MockerFixture,
475468
mock_configuration: AppConfig,
@@ -492,7 +485,6 @@ async def test_infer_api_connection_error_returns_503(
492485
assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
493486

494487

495-
@pytest.mark.asyncio
496488
async def test_infer_empty_llm_response_returns_fallback(
497489
mocker: MockerFixture,
498490
mock_configuration: AppConfig,
@@ -517,7 +509,6 @@ async def test_infer_empty_llm_response_returns_fallback(
517509
# --- Test Splunk integration ---
518510

519511

520-
@pytest.mark.asyncio
521512
async def test_infer_queues_splunk_event_on_success(
522513
mocker: MockerFixture,
523514
mock_configuration: AppConfig,
@@ -542,7 +533,6 @@ async def test_infer_queues_splunk_event_on_success(
542533
assert call_args[0][2] == "infer_with_llm"
543534

544535

545-
@pytest.mark.asyncio
546536
async def test_infer_queues_splunk_error_event_on_failure(
547537
mocker: MockerFixture,
548538
mock_configuration: AppConfig,
@@ -567,7 +557,6 @@ async def test_infer_queues_splunk_error_event_on_failure(
567557
assert call_args[0][2] == "infer_error"
568558

569559

570-
@pytest.mark.asyncio
571560
async def test_infer_splunk_event_includes_rh_identity_context(
572561
mocker: MockerFixture,
573562
mock_configuration: AppConfig,
@@ -638,7 +627,6 @@ def _setup_responses_mock_with_capture(
638627
return mock_create
639628

640629

641-
@pytest.mark.asyncio
642630
async def test_retrieve_simple_response_passes_tools(
643631
mocker: MockerFixture, mock_configuration: AppConfig
644632
) -> None:
@@ -660,7 +648,6 @@ async def test_retrieve_simple_response_passes_tools(
660648
assert call_kwargs["tools"] == tools
661649

662650

663-
@pytest.mark.asyncio
664651
async def test_retrieve_simple_response_defaults_to_empty_tools(
665652
mocker: MockerFixture, mock_configuration: AppConfig
666653
) -> None:
@@ -674,7 +661,6 @@ async def test_retrieve_simple_response_defaults_to_empty_tools(
674661
assert call_kwargs["tools"] == []
675662

676663

677-
@pytest.mark.asyncio
678664
async def test_infer_endpoint_calls_get_mcp_tools(
679665
mocker: MockerFixture,
680666
mock_configuration: AppConfig,
@@ -704,7 +690,6 @@ async def test_infer_endpoint_calls_get_mcp_tools(
704690
)
705691

706692

707-
@pytest.mark.asyncio
708693
async def test_infer_generic_runtime_error_reraises(
709694
mocker: MockerFixture,
710695
mock_configuration: AppConfig,
@@ -725,7 +710,6 @@ async def test_infer_generic_runtime_error_reraises(
725710
)
726711

727712

728-
@pytest.mark.asyncio
729713
async def test_infer_generic_runtime_error_records_failure(
730714
mocker: MockerFixture,
731715
mock_configuration: AppConfig,

tests/unit/models/rlsapi/test_requests.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,33 @@ def test_priority_order(self, make_request: Any) -> None:
594594
)
595595
result = request.get_input_source()
596596
assert result == "Q\n\nS\n\nA\n\nT"
597+
598+
599+
@pytest.mark.parametrize(
600+
("model", "field", "max_length"),
601+
[
602+
(RlsapiV1Attachment, "contents", 65_536),
603+
(RlsapiV1Terminal, "output", 65_536),
604+
(RlsapiV1Context, "stdin", 65_536),
605+
(RlsapiV1InferRequest, "question", 10_240),
606+
],
607+
ids=[
608+
"attachment-contents",
609+
"terminal-output",
610+
"context-stdin",
611+
"infer-request-question",
612+
],
613+
)
614+
def test_value_max_length(model, field, max_length) -> None:
615+
"""Test that fields with longer than allowed data are not allowed"""
616+
value = "a" * max_length
617+
bad_value = value + "a"
618+
619+
instance = model(**{field: value})
620+
with pytest.raises(
621+
ValidationError,
622+
match=f"should have at most {max_length} characters",
623+
):
624+
model(**{field: bad_value})
625+
626+
assert getattr(instance, field) == value

0 commit comments

Comments
 (0)