Skip to content

Commit 2b06d4c

Browse files
committed
chore: relax the validation logic
1 parent d1aef60 commit 2b06d4c

1 file changed

Lines changed: 34 additions & 13 deletions

File tree

src/google/adk/cli/fast_api.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ async def _get_a2a_runner_async() -> Runner:
746746

747747
import inspect
748748
import json
749+
from pydantic import ValidationError as _ValidationError
749750

750751
from google.adk.agents import Agent
751752
import google.auth
@@ -842,30 +843,40 @@ async def context_propagation(
842843
response_model_exclude_none=True,
843844
response_class=JSONResponse,
844845
)
845-
async def query(request: _QueryRequest):
846+
async def query(request: Request):
847+
try:
848+
body = await request.json()
849+
except json.JSONDecodeError as exc:
850+
raise HTTPException(
851+
status_code=400, detail=f"Invalid JSON: {exc}"
852+
)
853+
try:
854+
parsed = _QueryRequest.model_validate(body)
855+
except _ValidationError as exc:
856+
raise HTTPException(status_code=400, detail=exc.errors())
846857
if not adk_app._tmpl_attrs.get("runner"):
847858
adk_app._tmpl_attrs["runner"] = await adk_web_server.get_runner_async(
848859
app_name=gemini_enterprise_app_name
849860
)
850-
if request.class_method is None:
861+
if parsed.class_method is None:
851862
raise HTTPException(
852863
status_code=400, detail="class_method cannot be None"
853864
)
854-
if request.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS:
865+
if parsed.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS:
855866
raise HTTPException(
856867
status_code=400,
857-
detail=f"class_method {request.class_method} is not allowed",
868+
detail=f"class_method {parsed.class_method} is not allowed",
858869
)
859-
method = getattr(adk_app, request.class_method)
860-
output = await _invoke_callable_or_raise(method, request.input or {})
870+
method = getattr(adk_app, parsed.class_method)
871+
output = await _invoke_callable_or_raise(method, parsed.input or {})
861872

862873
try:
863874
json_serialized_content = jsonable_encoder({"output": output})
864875
except ValueError as encoding_error:
865876
logging.exception(
866877
"FastAPI could not JSON-encode the response from invocation method"
867878
" %s. Error: %s. Invocation method's original response: %r",
868-
request.class_method,
879+
parsed.class_method,
869880
encoding_error,
870881
output,
871882
)
@@ -877,22 +888,32 @@ async def query(request: _QueryRequest):
877888
response_model_exclude_none=True,
878889
response_class=StreamingResponse,
879890
)
880-
async def stream_query(request: _QueryRequest):
891+
async def stream_query(request: Request):
892+
try:
893+
body = await request.json()
894+
except json.JSONDecodeError as exc:
895+
raise HTTPException(
896+
status_code=400, detail=f"Invalid JSON: {exc}"
897+
)
898+
try:
899+
parsed = _QueryRequest.model_validate(body)
900+
except _ValidationError as exc:
901+
raise HTTPException(status_code=400, detail=exc.errors())
881902
if not adk_app._tmpl_attrs.get("runner"):
882903
adk_app._tmpl_attrs["runner"] = await adk_web_server.get_runner_async(
883904
app_name=gemini_enterprise_app_name
884905
)
885-
if request.class_method is None:
906+
if parsed.class_method is None:
886907
raise HTTPException(
887908
status_code=400, detail="class_method cannot be None"
888909
)
889-
if request.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS:
910+
if parsed.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS:
890911
raise HTTPException(
891912
status_code=400,
892-
detail=f"class_method {request.class_method} is not allowed",
913+
detail=f"class_method {parsed.class_method} is not allowed",
893914
)
894-
method = getattr(adk_app, request.class_method)
895-
output = await _invoke_callable_or_raise(method, request.input or {})
915+
method = getattr(adk_app, parsed.class_method)
916+
output = await _invoke_callable_or_raise(method, parsed.input or {})
896917

897918
if inspect.isgenerator(output):
898919

0 commit comments

Comments
 (0)