|
4 | 4 | import time |
5 | 5 | from typing import List |
6 | 6 |
|
| 7 | +import litellm |
7 | 8 | from litellm import acompletion |
8 | | -from typing import Dict |
| 9 | +from litellm.types.utils import ModelResponse, Choices |
| 10 | +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper |
9 | 11 |
|
10 | 12 | from eval_protocol.dataset_logger import default_logger |
11 | 13 | from eval_protocol.models import EvaluationRow, Message |
@@ -62,12 +64,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: |
62 | 64 | if row.tools is not None: |
63 | 65 | request_params["tools"] = row.tools |
64 | 66 |
|
65 | | - # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet |
66 | | - import importlib |
| 67 | + if request_params.get("stream") is True: |
| 68 | + chunks = [] |
| 69 | + stream = await acompletion(**request_params) |
67 | 70 |
|
68 | | - _litellm = importlib.import_module("litellm") |
69 | | - acompletion = getattr(_litellm, "acompletion") |
70 | | - response = await acompletion(**request_params) |
| 71 | + assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper" |
| 72 | + |
| 73 | + async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues] |
| 74 | + chunks.append(chunk) |
| 75 | + response = litellm.stream_chunk_builder(chunks, messages_payload) |
| 76 | + else: |
| 77 | + response = await acompletion(**request_params) |
| 78 | + |
| 79 | + assert response is not None, "Response is None" |
| 80 | + assert isinstance(response, ModelResponse), "Response should be ModelResponse" |
| 81 | + assert isinstance(response.choices[0], Choices), "Response choice should be a Choices" |
71 | 82 |
|
72 | 83 | assistant_content = response.choices[0].message.content or "" |
73 | 84 | tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None |
@@ -110,11 +121,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: |
110 | 121 | tool_calls=converted_tool_calls, |
111 | 122 | ) |
112 | 123 | ] |
113 | | - |
114 | | - row.execution_metadata.usage = CompletionUsage( |
115 | | - prompt_tokens=response.usage.prompt_tokens, |
116 | | - completion_tokens=response.usage.completion_tokens, |
117 | | - total_tokens=response.usage.total_tokens, |
| 124 | + row.execution_metadata.usage = ( |
| 125 | + CompletionUsage( # Note: LiteLLM sets usage dynamically via setattr(), not as a typed field |
| 126 | + prompt_tokens=response.usage.prompt_tokens, # pyright: ignore[reportAttributeAccessIssue] |
| 127 | + completion_tokens=response.usage.completion_tokens, # pyright: ignore[reportAttributeAccessIssue] |
| 128 | + total_tokens=response.usage.total_tokens, # pyright: ignore[reportAttributeAccessIssue] |
| 129 | + ) |
118 | 130 | ) |
119 | 131 |
|
120 | 132 | row.messages = messages |
|
0 commit comments