Skip to content

Commit 7996112

Browse files
aperepellizzij
authored andcommitted
feat: map LiteLLM finish_reason strings to FinishReason enum
- Map finish_reason strings to proper FinishReason enum values in lite_llm.py - 'length' -> FinishReason.MAX_TOKENS - 'stop' -> FinishReason.STOP - 'tool_calls'/'function_call' -> FinishReason.STOP - 'content_filter' -> FinishReason.SAFETY - unknown values -> FinishReason.OTHER - Add clarifying comment in tracing.py for string fallback path - Update test_litellm.py to verify enum mapping: - Assert finish_reason is FinishReason enum instance - Verify correct enum values for each finish_reason string - Add test for unknown finish_reason mapping to OTHER Benefits: - Type consistency with Gemini native responses - Avoids runtime warnings from string finish_reason - Enables proper instanceof checks in callbacks - Better integration with ADK telemetry
1 parent ab6f577 commit 7996112

3 files changed

Lines changed: 64 additions & 2 deletions

File tree

src/google/adk/models/lite_llm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,21 @@ def _model_response_to_generate_content_response(
505505

506506
llm_response = _message_to_generate_content_response(message)
507507
if finish_reason:
508-
llm_response.finish_reason = finish_reason
508+
# Map LiteLLM finish_reason strings to FinishReason enum
509+
# This provides type consistency with Gemini native responses and avoids warnings
510+
finish_reason_str = str(finish_reason).lower()
511+
if finish_reason_str == "length":
512+
llm_response.finish_reason = types.FinishReason.MAX_TOKENS
513+
elif finish_reason_str == "stop":
514+
llm_response.finish_reason = types.FinishReason.STOP
515+
elif "tool" in finish_reason_str or "function" in finish_reason_str:
516+
# Handle tool_calls, function_call variants
517+
llm_response.finish_reason = types.FinishReason.STOP
518+
elif finish_reason_str == "content_filter":
519+
llm_response.finish_reason = types.FinishReason.SAFETY
520+
else:
521+
# For unknown reasons, use OTHER
522+
llm_response.finish_reason = types.FinishReason.OTHER
509523
if response.get("usage", None):
510524
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
511525
prompt_token_count=response["usage"].get("prompt_tokens", 0),

src/google/adk/telemetry/tracing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def trace_call_llm(
306306
if isinstance(llm_response.finish_reason, types.FinishReason):
307307
finish_reason_str = llm_response.finish_reason.name.lower()
308308
else:
309+
# Fallback for string values (should not occur with LiteLLM after enum mapping)
309310
finish_reason_str = str(llm_response.finish_reason).lower()
310311
span.set_attribute(
311312
'gen_ai.response.finish_reasons',

tests/unittests/models/test_litellm.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1967,11 +1967,58 @@ async def test_finish_reason_propagation(
19671967

19681968
async for response in lite_llm_instance.generate_content_async(llm_request):
19691969
assert response.content.role == "model"
1970-
assert response.finish_reason == finish_reason
1970+
# Verify finish_reason is mapped to FinishReason enum, not raw string
1971+
assert isinstance(response.finish_reason, types.FinishReason)
1972+
# Verify correct enum mapping
1973+
if finish_reason == "length":
1974+
assert response.finish_reason == types.FinishReason.MAX_TOKENS
1975+
elif finish_reason == "stop":
1976+
assert response.finish_reason == types.FinishReason.STOP
1977+
elif finish_reason == "tool_calls":
1978+
assert response.finish_reason == types.FinishReason.STOP
1979+
elif finish_reason == "content_filter":
1980+
assert response.finish_reason == types.FinishReason.SAFETY
19711981
if expected_content:
19721982
assert response.content.parts[0].text == expected_content
19731983
if has_tool_calls:
19741984
assert len(response.content.parts) > 0
19751985
assert response.content.parts[-1].function_call.name == "test_function"
19761986

19771987
mock_acompletion.assert_called_once()
1988+
1989+
1990+
1991+
@pytest.mark.asyncio
1992+
async def test_finish_reason_unknown_maps_to_other(
1993+
mock_acompletion, lite_llm_instance
1994+
):
1995+
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
1996+
mock_response = ModelResponse(
1997+
choices=[
1998+
Choices(
1999+
message=ChatCompletionAssistantMessage(
2000+
role="assistant",
2001+
content="Test response",
2002+
),
2003+
finish_reason="unknown_reason_type",
2004+
)
2005+
]
2006+
)
2007+
mock_acompletion.return_value = mock_response
2008+
2009+
llm_request = LlmRequest(
2010+
contents=[
2011+
types.Content(
2012+
role="user", parts=[types.Part.from_text(text="Test prompt")]
2013+
)
2014+
],
2015+
)
2016+
2017+
async for response in lite_llm_instance.generate_content_async(llm_request):
2018+
assert response.content.role == "model"
2019+
# Unknown finish_reason should map to OTHER
2020+
assert isinstance(response.finish_reason, types.FinishReason)
2021+
assert response.finish_reason == types.FinishReason.OTHER
2022+
2023+
mock_acompletion.assert_called_once()
2024+

0 commit comments

Comments
 (0)