Skip to content

Commit 71aa564

Browse files
lizzijcopybara-github
authored andcommitted
fix: Propagate LiteLLM finish_reason to LlmResponse for use in callbacks
Closes #3114. Co-authored-by: Eliza Huang <heliza@google.com> COPYBARA_INTEGRATE_REVIEW=#3319 from lizzij:fix/litellm-finish-reason-3109 da6ed0a PiperOrigin-RevId: 825783229
1 parent 6429457 commit 71aa564

3 files changed

Lines changed: 146 additions & 3 deletions

File tree

src/google/adk/models/lite_llm.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@
6464
_NEW_LINE = "\n"
6565
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
6666

67+
# Mapping of LiteLLM finish_reason strings to FinishReason enum values
68+
# Note: tool_calls/function_call map to STOP because:
69+
# 1. FinishReason.TOOL_CALL enum does not exist (as of google-genai 0.8.0)
70+
# 2. Tool calls represent normal completion (model stopped to invoke tools)
71+
# 3. Gemini native responses use STOP for tool calls (see lite_llm.py:910)
72+
_FINISH_REASON_MAPPING = {
73+
"length": types.FinishReason.MAX_TOKENS,
74+
"stop": types.FinishReason.STOP,
75+
"tool_calls": (
76+
types.FinishReason.STOP
77+
), # Normal completion with tool invocation
78+
"function_call": types.FinishReason.STOP, # Legacy function call variant
79+
"content_filter": types.FinishReason.SAFETY,
80+
}
81+
6782

6883
class ChatCompletionFileUrlObject(TypedDict, total=False):
6984
file_data: str
@@ -541,13 +556,26 @@ def _model_response_to_generate_content_response(
541556
"""
542557

543558
message = None
544-
if response.get("choices", None):
545-
message = response["choices"][0].get("message", None)
559+
finish_reason = None
560+
if (choices := response.get("choices")) and choices:
561+
first_choice = choices[0]
562+
message = first_choice.get("message", None)
563+
finish_reason = first_choice.get("finish_reason", None)
546564

547565
if not message:
548566
raise ValueError("No message in response")
549567

550568
llm_response = _message_to_generate_content_response(message)
569+
if finish_reason:
570+
# If LiteLLM already provides a FinishReason enum (e.g., for Gemini), use
571+
# it directly. Otherwise, map the finish_reason string to the enum.
572+
if isinstance(finish_reason, types.FinishReason):
573+
llm_response.finish_reason = finish_reason
574+
else:
575+
finish_reason_str = str(finish_reason).lower()
576+
llm_response.finish_reason = _FINISH_REASON_MAPPING.get(
577+
finish_reason_str, types.FinishReason.OTHER
578+
)
551579
if response.get("usage", None):
552580
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
553581
prompt_token_count=response["usage"].get("prompt_tokens", 0),

src/google/adk/telemetry/tracing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,13 @@ def trace_call_llm(
303303
llm_response.usage_metadata.candidates_token_count,
304304
)
305305
if llm_response.finish_reason:
306+
try:
307+
finish_reason_str = llm_response.finish_reason.value.lower()
308+
except AttributeError:
309+
finish_reason_str = str(llm_response.finish_reason).lower()
306310
span.set_attribute(
307311
'gen_ai.response.finish_reasons',
308-
[llm_response.finish_reason.value.lower()],
312+
[finish_reason_str],
309313
)
310314

311315

tests/unittests/models/test_litellm.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import warnings
2020

2121
from google.adk.models.lite_llm import _content_to_message_param
22+
from google.adk.models.lite_llm import _FINISH_REASON_MAPPING
2223
from google.adk.models.lite_llm import _function_declaration_to_tool_param
2324
from google.adk.models.lite_llm import _get_content
2425
from google.adk.models.lite_llm import _message_to_generate_content_response
@@ -1938,3 +1939,113 @@ def test_non_gemini_litellm_no_warning():
19381939
# Test with non-Gemini model
19391940
LiteLlm(model="openai/gpt-4o")
19401941
assert len(w) == 0
1942+
1943+
1944+
@pytest.mark.parametrize(
1945+
"finish_reason,response_content,expected_content,has_tool_calls",
1946+
[
1947+
("length", "Test response", "Test response", False),
1948+
("stop", "Complete response", "Complete response", False),
1949+
(
1950+
"tool_calls",
1951+
"",
1952+
"",
1953+
True,
1954+
),
1955+
("content_filter", "", "", False),
1956+
],
1957+
ids=["length", "stop", "tool_calls", "content_filter"],
1958+
)
1959+
@pytest.mark.asyncio
1960+
async def test_finish_reason_propagation(
1961+
mock_acompletion,
1962+
lite_llm_instance,
1963+
finish_reason,
1964+
response_content,
1965+
expected_content,
1966+
has_tool_calls,
1967+
):
1968+
"""Test that finish_reason is properly propagated from LiteLLM response."""
1969+
tool_calls = None
1970+
if has_tool_calls:
1971+
tool_calls = [
1972+
ChatCompletionMessageToolCall(
1973+
type="function",
1974+
id="test_id",
1975+
function=Function(
1976+
name="test_function",
1977+
arguments='{"arg": "value"}',
1978+
),
1979+
)
1980+
]
1981+
1982+
mock_response = ModelResponse(
1983+
choices=[
1984+
Choices(
1985+
message=ChatCompletionAssistantMessage(
1986+
role="assistant",
1987+
content=response_content,
1988+
tool_calls=tool_calls,
1989+
),
1990+
finish_reason=finish_reason,
1991+
)
1992+
]
1993+
)
1994+
mock_acompletion.return_value = mock_response
1995+
1996+
llm_request = LlmRequest(
1997+
contents=[
1998+
types.Content(
1999+
role="user", parts=[types.Part.from_text(text="Test prompt")]
2000+
)
2001+
],
2002+
)
2003+
2004+
async for response in lite_llm_instance.generate_content_async(llm_request):
2005+
assert response.content.role == "model"
2006+
# Verify finish_reason is mapped to FinishReason enum
2007+
assert isinstance(response.finish_reason, types.FinishReason)
2008+
# Verify correct enum mapping using the actual mapping from lite_llm
2009+
assert response.finish_reason == _FINISH_REASON_MAPPING[finish_reason]
2010+
if expected_content:
2011+
assert response.content.parts[0].text == expected_content
2012+
if has_tool_calls:
2013+
assert len(response.content.parts) > 0
2014+
assert response.content.parts[-1].function_call.name == "test_function"
2015+
2016+
mock_acompletion.assert_called_once()
2017+
2018+
2019+
@pytest.mark.asyncio
2020+
async def test_finish_reason_unknown_maps_to_other(
2021+
mock_acompletion, lite_llm_instance
2022+
):
2023+
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
2024+
mock_response = ModelResponse(
2025+
choices=[
2026+
Choices(
2027+
message=ChatCompletionAssistantMessage(
2028+
role="assistant",
2029+
content="Test response",
2030+
),
2031+
finish_reason="unknown_reason_type",
2032+
)
2033+
]
2034+
)
2035+
mock_acompletion.return_value = mock_response
2036+
2037+
llm_request = LlmRequest(
2038+
contents=[
2039+
types.Content(
2040+
role="user", parts=[types.Part.from_text(text="Test prompt")]
2041+
)
2042+
],
2043+
)
2044+
2045+
async for response in lite_llm_instance.generate_content_async(llm_request):
2046+
assert response.content.role == "model"
2047+
# Unknown finish_reason should map to OTHER
2048+
assert isinstance(response.finish_reason, types.FinishReason)
2049+
assert response.finish_reason == types.FinishReason.OTHER
2050+
2051+
mock_acompletion.assert_called_once()

0 commit comments

Comments
 (0)