Skip to content

Commit 99105ca

Browse files
feat(litellm): Add async callbacks
1 parent 9ae99be commit 99105ca

File tree

2 files changed

+204
-2
lines changed

2 files changed

+204
-2
lines changed

sentry_sdk/integrations/litellm.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
170170
set_data_normalized(span, f"gen_ai.litellm.{key}", value)
171171

172172

173+
async def _async_input_callback(kwargs: "Dict[str, Any]") -> None:
174+
return _input_callback(kwargs)
175+
176+
173177
def _success_callback(
174178
kwargs: "Dict[str, Any]",
175179
completion_response: "Any",
@@ -233,10 +237,28 @@ def _success_callback(
233237
is_streaming = kwargs.get("stream")
234238
# Callback is fired multiple times when streaming a response.
235239
# Streaming flag checked at https://github.com/BerriAI/litellm/blob/33c3f13443eaf990ac8c6e3da78bddbc2b7d0e7a/litellm/litellm_core_utils/litellm_logging.py#L1603
236-
if is_streaming is not True or "complete_streaming_response" in kwargs:
240+
if (
241+
is_streaming is not True
242+
or "complete_streaming_response" in kwargs
243+
or "async_complete_streaming_response" in kwargs
244+
):
237245
span.__exit__(None, None, None)
238246

239247

248+
async def _async_success_callback(
249+
kwargs: "Dict[str, Any]",
250+
completion_response: "Any",
251+
start_time: "datetime",
252+
end_time: "datetime",
253+
) -> None:
254+
return _success_callback(
255+
kwargs,
256+
completion_response,
257+
start_time,
258+
end_time,
259+
)
260+
261+
240262
def _failure_callback(
241263
kwargs: "Dict[str, Any]",
242264
exception: Exception,
@@ -261,6 +283,20 @@ def _failure_callback(
261283
span.__exit__(type(exception), exception, None)
262284

263285

286+
async def _async_failure_callback(
287+
kwargs: "Dict[str, Any]",
288+
exception: Exception,
289+
start_time: "datetime",
290+
end_time: "datetime",
291+
) -> None:
292+
return _failure_callback(
293+
kwargs,
294+
exception,
295+
start_time,
296+
end_time,
297+
)
298+
299+
264300
class LiteLLMIntegration(Integration):
265301
"""
266302
LiteLLM integration for Sentry.
@@ -318,11 +354,17 @@ def setup_once() -> None:
318354
litellm.input_callback = input_callback or []
319355
if _input_callback not in litellm.input_callback:
320356
litellm.input_callback.append(_input_callback)
357+
if _async_input_callback not in litellm.input_callback:
358+
litellm.input_callback.append(_async_input_callback)
321359

322360
litellm.success_callback = success_callback or []
323361
if _success_callback not in litellm.success_callback:
324362
litellm.success_callback.append(_success_callback)
363+
if _async_success_callback not in litellm.success_callback:
364+
litellm.success_callback.append(_async_success_callback)
325365

326366
litellm.failure_callback = failure_callback or []
327367
if _failure_callback not in litellm.failure_callback:
328368
litellm.failure_callback.append(_failure_callback)
369+
if _async_failure_callback not in litellm.failure_callback:
370+
litellm.failure_callback.append(_async_failure_callback)

tests/integrations/litellm/test_litellm.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import pytest
44
import time
5+
import asyncio
56
from unittest import mock
67
from datetime import datetime
78

@@ -31,13 +32,14 @@ async def __call__(self, *args, **kwargs):
3132
)
3233
from sentry_sdk.utils import package_version
3334

34-
from openai import OpenAI
35+
from openai import OpenAI, AsyncOpenAI
3536

3637
from concurrent.futures import ThreadPoolExecutor
3738

3839
import litellm.utils as litellm_utils
3940
from litellm.litellm_core_utils import streaming_handler
4041
from litellm.litellm_core_utils import thread_pool_executor
42+
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
4143
from litellm.llms.custom_httpx.http_handler import HTTPHandler
4244

4345

@@ -240,6 +242,89 @@ def test_nonstreaming_chat_completion(
240242
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
241243

242244

245+
@pytest.mark.asyncio(loop_scope="session")
246+
@pytest.mark.parametrize(
247+
"send_default_pii, include_prompts",
248+
[
249+
(True, True),
250+
(True, False),
251+
(False, True),
252+
(False, False),
253+
],
254+
)
255+
async def test_async_nonstreaming_chat_completion(
256+
sentry_init,
257+
capture_events,
258+
send_default_pii,
259+
include_prompts,
260+
get_model_response,
261+
nonstreaming_chat_completions_model_response,
262+
):
263+
sentry_init(
264+
integrations=[LiteLLMIntegration(include_prompts=include_prompts)],
265+
traces_sample_rate=1.0,
266+
send_default_pii=send_default_pii,
267+
)
268+
events = capture_events()
269+
270+
messages = [{"role": "user", "content": "Hello!"}]
271+
272+
client = AsyncOpenAI(api_key="z")
273+
274+
model_response = get_model_response(
275+
nonstreaming_chat_completions_model_response,
276+
serialize_pydantic=True,
277+
request_headers={"X-Stainless-Raw-Response": "true"},
278+
)
279+
280+
with mock.patch.object(
281+
client.completions._client._client,
282+
"send",
283+
return_value=model_response,
284+
):
285+
with start_transaction(name="litellm test"):
286+
await litellm.acompletion(
287+
model="gpt-3.5-turbo",
288+
messages=messages,
289+
client=client,
290+
)
291+
292+
await GLOBAL_LOGGING_WORKER.flush()
293+
await asyncio.sleep(0.5)
294+
295+
assert len(events) == 1
296+
(event,) = events
297+
298+
assert event["type"] == "transaction"
299+
assert event["transaction"] == "litellm test"
300+
301+
chat_spans = list(
302+
x
303+
for x in event["spans"]
304+
if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
305+
)
306+
assert len(chat_spans) == 1
307+
span = chat_spans[0]
308+
309+
assert span["op"] == OP.GEN_AI_CHAT
310+
assert span["description"] == "chat gpt-3.5-turbo"
311+
assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-3.5-turbo"
312+
assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gpt-3.5-turbo"
313+
assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
314+
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
315+
316+
if send_default_pii and include_prompts:
317+
assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
318+
assert SPANDATA.GEN_AI_RESPONSE_TEXT in span["data"]
319+
else:
320+
assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
321+
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
322+
323+
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
324+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
325+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
326+
327+
243328
@pytest.mark.parametrize(
244329
"send_default_pii, include_prompts",
245330
[
@@ -311,6 +396,81 @@ def test_streaming_chat_completion(
311396
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
312397

313398

399+
@pytest.mark.asyncio(loop_scope="session")
400+
@pytest.mark.parametrize(
401+
"send_default_pii, include_prompts",
402+
[
403+
(True, True),
404+
(True, False),
405+
(False, True),
406+
(False, False),
407+
],
408+
)
409+
async def test_async_streaming_chat_completion(
410+
sentry_init,
411+
capture_events,
412+
send_default_pii,
413+
include_prompts,
414+
get_model_response,
415+
async_iterator,
416+
server_side_event_chunks,
417+
streaming_chat_completions_model_response,
418+
):
419+
sentry_init(
420+
integrations=[LiteLLMIntegration(include_prompts=include_prompts)],
421+
traces_sample_rate=1.0,
422+
send_default_pii=send_default_pii,
423+
)
424+
events = capture_events()
425+
426+
messages = [{"role": "user", "content": "Hello!"}]
427+
428+
client = AsyncOpenAI(api_key="z")
429+
430+
model_response = get_model_response(
431+
async_iterator(
432+
server_side_event_chunks(
433+
streaming_chat_completions_model_response,
434+
include_event_type=False,
435+
),
436+
),
437+
request_headers={"X-Stainless-Raw-Response": "true"},
438+
)
439+
440+
with mock.patch.object(
441+
client.completions._client._client,
442+
"send",
443+
return_value=model_response,
444+
):
445+
with start_transaction(name="litellm test"):
446+
response = await litellm.acompletion(
447+
model="gpt-3.5-turbo",
448+
messages=messages,
449+
client=client,
450+
stream=True,
451+
)
452+
async for _ in response:
453+
pass
454+
455+
await GLOBAL_LOGGING_WORKER.flush()
456+
await asyncio.sleep(0.5)
457+
458+
assert len(events) == 1
459+
(event,) = events
460+
461+
assert event["type"] == "transaction"
462+
chat_spans = list(
463+
x
464+
for x in event["spans"]
465+
if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
466+
)
467+
assert len(chat_spans) == 1
468+
span = chat_spans[0]
469+
470+
assert span["op"] == OP.GEN_AI_CHAT
471+
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
472+
473+
314474
def test_embeddings_create(sentry_init, capture_events, clear_litellm_cache):
315475
"""
316476
Test that litellm.embedding() calls are properly instrumented.

0 commit comments

Comments
 (0)