Skip to content

Commit 2f6a34f

Browse files
fix litellm audio management
1 parent 179380f commit 2f6a34f

2 files changed

Lines changed: 121 additions & 11 deletions

File tree

src/google/adk/models/lite_llm.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,17 @@
103103
_MEDIA_URL_CONTENT_TYPE_BY_MAJOR_MIME_TYPE = {
104104
"image": "image_url",
105105
"video": "video_url",
106-
"audio": "audio_url",
106+
}
107+
108+
# LiteLLM input_audio only accepts "mp3" and "wav" as format values.
109+
# Maps audio MIME subtypes (including common aliases) to the canonical format.
110+
_AUDIO_MIME_SUBTYPE_TO_FORMAT: dict[str, str] = {
111+
"mpeg": "mp3",
112+
"mp3": "mp3",
113+
"wav": "wav",
114+
"x-wav": "wav",
115+
"wave": "wav",
116+
"vnd.wave": "wav",
107117
}
108118

109119
# Mapping of LiteLLM finish_reason strings to FinishReason enum values
@@ -1048,6 +1058,21 @@ async def _get_content(
10481058
"type": url_content_type,
10491059
url_content_type: {"url": data_uri},
10501060
})
1061+
elif mime_type.startswith("audio/"):
1062+
audio_subtype = mime_type.split("/", 1)[1]
1063+
audio_format = _AUDIO_MIME_SUBTYPE_TO_FORMAT.get(audio_subtype)
1064+
if audio_format is None:
1065+
raise ValueError(
1066+
f"Unsupported audio MIME type '{part.inline_data.mime_type}'."
1067+
" LiteLLM input_audio only supports mp3 and wav."
1068+
)
1069+
content_objects.append({
1070+
"type": "input_audio",
1071+
"input_audio": {
1072+
"data": base64_string,
1073+
"format": audio_format,
1074+
},
1075+
})
10511076
elif mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
10521077
# OpenAI/Azure require file_id from uploaded file, not inline data
10531078
if provider in _FILE_ID_REQUIRED_PROVIDERS:

tests/unittests/models/test_litellm.py

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,12 +2873,6 @@ async def test_get_content_file_uri_file_id_required_falls_back_to_text(
28732873
"video_url",
28742874
id="video",
28752875
),
2876-
pytest.param(
2877-
"https://example.com/audio.mp3",
2878-
"audio/mpeg",
2879-
"audio_url",
2880-
id="audio",
2881-
),
28822876
],
28832877
)
28842878
async def test_get_content_file_uri_media_url_file_id_required_uses_url_type(
@@ -2899,6 +2893,32 @@ async def test_get_content_file_uri_media_url_file_id_required_uses_url_type(
28992893
}]
29002894

29012895

2896+
@pytest.mark.asyncio
2897+
@pytest.mark.parametrize(
2898+
"provider,model",
2899+
[
2900+
("openai", "openai/gpt-4o-audio-preview"),
2901+
("azure", "azure/gpt-4o-audio-preview"),
2902+
],
2903+
)
2904+
async def test_get_content_file_uri_audio_http_url_file_id_required_falls_back_to_text(
2905+
provider, model
2906+
):
2907+
# audio_url is not a valid LiteLLM content type; HTTP audio URLs for
2908+
# file-id-required providers fall back to a text reference.
2909+
parts = [
2910+
types.Part(
2911+
file_data=types.FileData(
2912+
file_uri="https://example.com/audio.mp3",
2913+
mime_type="audio/mpeg",
2914+
display_name="audio.mp3",
2915+
)
2916+
)
2917+
]
2918+
content = await _get_content(parts, provider=provider, model=model)
2919+
assert content == [{"type": "text", "text": '[File reference: "audio.mp3"]'}]
2920+
2921+
29022922
@pytest.mark.asyncio
29032923
@pytest.mark.parametrize(
29042924
"provider,model",
@@ -3144,16 +3164,81 @@ async def test_get_content_file_uri_mime_type_inference(
31443164

31453165
@pytest.mark.asyncio
31463166
async def test_get_content_audio():
3167+
# Audio inline_data must produce an input_audio block (not audio_url).
3168+
# The data field is raw base64 (no data URI prefix) and format is the
3169+
# MIME subtype extracted from the MIME type.
31473170
parts = [
31483171
types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg")
31493172
]
31503173
content = await _get_content(parts)
3151-
assert content[0]["type"] == "audio_url"
3174+
assert content[0]["type"] == "input_audio"
3175+
assert content[0]["input_audio"]["data"] == "dGVzdF9hdWRpb19kYXRh"
3176+
assert content[0]["input_audio"]["format"] == "mp3"
3177+
assert "url" not in content[0]["input_audio"]
3178+
3179+
3180+
@pytest.mark.asyncio
3181+
@pytest.mark.parametrize(
3182+
"mime_type,expected_format",
3183+
[
3184+
pytest.param("audio/mpeg", "mp3", id="mpeg_to_mp3"),
3185+
pytest.param("audio/mp3", "mp3", id="mp3_alias"),
3186+
pytest.param("audio/wav", "wav", id="wav"),
3187+
pytest.param("audio/x-wav", "wav", id="x-wav_alias"),
3188+
pytest.param("audio/wave", "wav", id="wave_alias"),
3189+
],
3190+
)
3191+
async def test_get_content_audio_formats(mime_type, expected_format):
3192+
# Only mp3 and wav are valid input_audio formats; verify MIME aliases map
3193+
# to the correct canonical format string.
3194+
parts = [types.Part.from_bytes(data=b"audio_bytes", mime_type=mime_type)]
3195+
content = await _get_content(parts)
3196+
assert content[0]["type"] == "input_audio"
3197+
assert content[0]["input_audio"]["format"] == expected_format
31523198
assert (
3153-
content[0]["audio_url"]["url"]
3154-
== "data:audio/mpeg;base64,dGVzdF9hdWRpb19kYXRh"
3199+
content[0]["input_audio"]["data"]
3200+
== base64.b64encode(b"audio_bytes").decode()
31553201
)
3156-
assert "format" not in content[0]["audio_url"]
3202+
3203+
3204+
@pytest.mark.asyncio
3205+
@pytest.mark.parametrize(
3206+
"mime_type",
3207+
["audio/mp4", "audio/ogg", "audio/webm", "audio/aac"],
3208+
)
3209+
async def test_get_content_audio_unsupported_format_raises(mime_type):
3210+
# Formats other than mp3/wav are not supported by LiteLLM input_audio and
3211+
# should raise a ValueError rather than silently producing a bad payload.
3212+
parts = [types.Part.from_bytes(data=b"audio_bytes", mime_type=mime_type)]
3213+
with pytest.raises(ValueError, match="Unsupported audio MIME type"):
3214+
await _get_content(parts)
3215+
3216+
3217+
@pytest.mark.asyncio
3218+
async def test_get_content_audio_raw_base64_not_data_uri():
3219+
# Ensure the data field is raw base64 with no "data:audio/...;base64," prefix.
3220+
raw_bytes = b"\x00\x01\x02\x03"
3221+
parts = [types.Part.from_bytes(data=raw_bytes, mime_type="audio/wav")]
3222+
content = await _get_content(parts)
3223+
audio_data = content[0]["input_audio"]["data"]
3224+
assert not audio_data.startswith("data:")
3225+
assert audio_data == base64.b64encode(raw_bytes).decode()
3226+
3227+
3228+
@pytest.mark.asyncio
3229+
async def test_get_content_audio_mixed_with_text():
3230+
# When audio is combined with text, both parts appear as separate content
3231+
# objects: text block followed by input_audio block.
3232+
parts = [
3233+
types.Part.from_text(text="What is said in this audio?"),
3234+
types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg"),
3235+
]
3236+
content = await _get_content(parts)
3237+
assert len(content) == 2
3238+
assert content[0] == {"type": "text", "text": "What is said in this audio?"}
3239+
assert content[1]["type"] == "input_audio"
3240+
assert content[1]["input_audio"]["data"] == "dGVzdF9hdWRpb19kYXRh"
3241+
assert content[1]["input_audio"]["format"] == "mp3"
31573242

31583243

31593244
def test_to_litellm_role():

0 commit comments

Comments
 (0)