Skip to content

Commit 89bed43

Browse files
GWealecopybara-github
authored andcommitted
fix: Add finish reason mapping and remove custom file URI handling in LiteLLM
Introduces a function to map LiteLLM finish reason strings to the internal types.FinishReason enum and populates the finish_reason field in LlmResponse. Removes custom logic for handling file URIs, including special casing for different providers, and updates tests accordingly Close #4125 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 856421317
1 parent 8264211 commit 89bed43

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,18 @@
110110
)
111111

112112

113+
def _map_finish_reason(
114+
finish_reason: Any,
115+
) -> types.FinishReason | None:
116+
"""Maps a LiteLLM finish_reason value to a google-genai FinishReason enum."""
117+
if not finish_reason:
118+
return None
119+
if isinstance(finish_reason, types.FinishReason):
120+
return finish_reason
121+
finish_reason_str = str(finish_reason).lower()
122+
return _FINISH_REASON_MAPPING.get(finish_reason_str, types.FinishReason.OTHER)
123+
124+
113125
def _get_provider_from_model(model: str) -> str:
114126
"""Extracts the provider name from a LiteLLM model string.
115127
@@ -1840,6 +1852,9 @@ async def generate_content_async(
18401852
else None,
18411853
)
18421854
)
1855+
aggregated_llm_response_with_tool_call.finish_reason = (
1856+
_map_finish_reason(finish_reason)
1857+
)
18431858
text = ""
18441859
reasoning_parts = []
18451860
function_calls.clear()
@@ -1854,6 +1869,9 @@ async def generate_content_async(
18541869
if reasoning_parts
18551870
else None,
18561871
)
1872+
aggregated_llm_response.finish_reason = _map_finish_reason(
1873+
finish_reason
1874+
)
18571875
text = ""
18581876
reasoning_parts = []
18591877

tests/unittests/models/test_litellm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2880,6 +2880,7 @@ async def test_generate_content_async_stream(
28802880
"test_arg": "test_value"
28812881
}
28822882
assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id"
2883+
assert responses[3].finish_reason == types.FinishReason.STOP
28832884
assert responses[3].model_version == "test_model"
28842885
mock_completion.assert_called_once()
28852886

@@ -2900,6 +2901,55 @@ async def test_generate_content_async_stream(
29002901
)
29012902

29022903

2904+
@pytest.mark.asyncio
2905+
async def test_generate_content_async_stream_sets_finish_reason(
2906+
mock_completion, lite_llm_instance
2907+
):
2908+
mock_completion.return_value = iter([
2909+
ModelResponse(
2910+
model="test_model",
2911+
choices=[
2912+
StreamingChoices(
2913+
finish_reason=None,
2914+
delta=Delta(role="assistant", content="Hello "),
2915+
)
2916+
],
2917+
),
2918+
ModelResponse(
2919+
model="test_model",
2920+
choices=[
2921+
StreamingChoices(
2922+
finish_reason=None,
2923+
delta=Delta(role="assistant", content="world"),
2924+
)
2925+
],
2926+
),
2927+
ModelResponse(
2928+
model="test_model",
2929+
choices=[StreamingChoices(finish_reason="stop", delta=Delta())],
2930+
),
2931+
])
2932+
2933+
llm_request = LlmRequest(
2934+
contents=[
2935+
types.Content(
2936+
role="user", parts=[types.Part.from_text(text="Test prompt")]
2937+
)
2938+
],
2939+
)
2940+
2941+
responses = [
2942+
response
2943+
async for response in lite_llm_instance.generate_content_async(
2944+
llm_request, stream=True
2945+
)
2946+
]
2947+
2948+
assert responses[-1].partial is False
2949+
assert responses[-1].finish_reason == types.FinishReason.STOP
2950+
assert responses[-1].content.parts[0].text == "Hello world"
2951+
2952+
29032953
@pytest.mark.asyncio
29042954
async def test_generate_content_async_stream_with_usage_metadata(
29052955
mock_completion, lite_llm_instance
@@ -2944,6 +2994,7 @@ async def test_generate_content_async_stream_with_usage_metadata(
29442994
"test_arg": "test_value"
29452995
}
29462996
assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id"
2997+
assert responses[3].finish_reason == types.FinishReason.STOP
29472998

29482999
assert responses[3].usage_metadata.prompt_token_count == 10
29493000
assert responses[3].usage_metadata.candidates_token_count == 5

0 commit comments

Comments
 (0)