|
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import pytest |
| 8 | +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage |
| 9 | +from openai.types.chat.chat_completion import Choice |
8 | 10 |
|
9 | 11 | from mellea.backends import ModelOption |
10 | 12 | from mellea.backends.openai import OpenAIBackend |
| 13 | +from mellea.core.base import ModelOutputThunk |
11 | 14 |
|
12 | 15 |
|
13 | 16 | def _make_backend(model_options: dict | None = None) -> OpenAIBackend: |
@@ -168,5 +171,127 @@ def test_make_backend_specific_unknown_mellea_keys_removed(backend): |
168 | 171 | assert ModelOption.SYSTEM_PROMPT not in result |
169 | 172 |
|
170 | 173 |
|
| 174 | +# --- processing(): reasoning / thinking trace extraction --- |
| 175 | + |
| 176 | + |
| 177 | +def _vllm_chat_completion(reasoning: str, content: str | None) -> ChatCompletion: |
| 178 | + """Build a ChatCompletion that matches vLLM's thinking-model response shape. |
| 179 | +
|
| 180 | + vLLM surfaces the reasoning trace under the ``reasoning`` key in the raw |
| 181 | + message dict rather than as ``reasoning_content`` on the SDK object. The |
| 182 | + openai SDK's ``ChatCompletionMessage`` tolerates the extra field, so it |
| 183 | + only appears via ``model_dump()``. |
| 184 | + """ |
| 185 | + message = ChatCompletionMessage.model_validate( |
| 186 | + {"role": "assistant", "content": content, "reasoning": reasoning} |
| 187 | + ) |
| 188 | + return ChatCompletion( |
| 189 | + id="vllm-test", |
| 190 | + created=0, |
| 191 | + model="qwen3", |
| 192 | + object="chat.completion", |
| 193 | + choices=[Choice(index=0, finish_reason="stop", message=message)], |
| 194 | + ) |
| 195 | + |
| 196 | + |
| 197 | +async def test_processing_captures_vllm_reasoning_field(backend): |
| 198 | + """Non-streaming: mot._thinking captures the raw ``reasoning`` key from vLLM.""" |
| 199 | + mot: ModelOutputThunk = ModelOutputThunk(value=None) |
| 200 | + chunk = _vllm_chat_completion(reasoning="2 + 2 equals 4.", content="4") |
| 201 | + # Sanity check: the SDK object does not expose reasoning_content |
| 202 | + assert not hasattr(chunk.choices[0].message, "reasoning_content") |
| 203 | + |
| 204 | + await backend.processing(mot, chunk) |
| 205 | + |
| 206 | + assert mot._thinking == "2 + 2 equals 4." |
| 207 | + assert mot._underlying_value == "4" |
| 208 | + |
| 209 | + |
| 210 | +async def test_processing_vllm_reasoning_with_null_content(backend): |
| 211 | + """Non-streaming: reasoning is captured even when ``content`` is null.""" |
| 212 | + mot: ModelOutputThunk = ModelOutputThunk(value=None) |
| 213 | + chunk = _vllm_chat_completion(reasoning="some thinking", content=None) |
| 214 | + |
| 215 | + await backend.processing(mot, chunk) |
| 216 | + |
| 217 | + assert mot._thinking == "some thinking" |
| 218 | + assert mot._underlying_value == "" |
| 219 | + |
| 220 | + |
| 221 | +async def test_processing_streaming_captures_vllm_reasoning_field(backend): |
| 222 | + """Streaming: per-chunk ``reasoning`` deltas accumulate into mot._thinking.""" |
| 223 | + mot: ModelOutputThunk = ModelOutputThunk(value=None) |
| 224 | + chunk_a = ChatCompletionChunk.model_validate( |
| 225 | + { |
| 226 | + "id": "vllm-stream", |
| 227 | + "created": 0, |
| 228 | + "model": "qwen3", |
| 229 | + "object": "chat.completion.chunk", |
| 230 | + "choices": [ |
| 231 | + { |
| 232 | + "index": 0, |
| 233 | + "delta": { |
| 234 | + "role": "assistant", |
| 235 | + "content": None, |
| 236 | + "reasoning": "first ", |
| 237 | + }, |
| 238 | + "finish_reason": None, |
| 239 | + } |
| 240 | + ], |
| 241 | + } |
| 242 | + ) |
| 243 | + chunk_b = ChatCompletionChunk.model_validate( |
| 244 | + { |
| 245 | + "id": "vllm-stream", |
| 246 | + "created": 0, |
| 247 | + "model": "qwen3", |
| 248 | + "object": "chat.completion.chunk", |
| 249 | + "choices": [ |
| 250 | + { |
| 251 | + "index": 0, |
| 252 | + "delta": {"content": "ans", "reasoning": "second"}, |
| 253 | + "finish_reason": None, |
| 254 | + } |
| 255 | + ], |
| 256 | + } |
| 257 | + ) |
| 258 | + |
| 259 | + await backend.processing(mot, chunk_a) |
| 260 | + await backend.processing(mot, chunk_b) |
| 261 | + |
| 262 | + assert mot._thinking == "first second" |
| 263 | + assert mot._underlying_value == "ans" |
| 264 | + |
| 265 | + |
| 266 | +async def test_processing_reasoning_content_still_used(backend): |
| 267 | + """Regression guard: the pre-existing ``reasoning_content`` path is preserved. |
| 268 | +
|
| 269 | + Some providers surface the trace as ``reasoning_content`` on the message |
| 270 | + object itself. The fix must not regress that path in favour of the raw-dict |
| 271 | + fallback. |
| 272 | + """ |
| 273 | + message = ChatCompletionMessage.model_validate( |
| 274 | + { |
| 275 | + "role": "assistant", |
| 276 | + "content": "answer", |
| 277 | + "reasoning_content": "attribute-style trace", |
| 278 | + } |
| 279 | + ) |
| 280 | + chunk = ChatCompletion( |
| 281 | + id="rc-test", |
| 282 | + created=0, |
| 283 | + model="fake", |
| 284 | + object="chat.completion", |
| 285 | + choices=[Choice(index=0, finish_reason="stop", message=message)], |
| 286 | + ) |
| 287 | + assert hasattr(chunk.choices[0].message, "reasoning_content") |
| 288 | + |
| 289 | + mot: ModelOutputThunk = ModelOutputThunk(value=None) |
| 290 | + await backend.processing(mot, chunk) |
| 291 | + |
| 292 | + assert mot._thinking == "attribute-style trace" |
| 293 | + assert mot._underlying_value == "answer" |
| 294 | + |
| 295 | + |
171 | 296 | if __name__ == "__main__": |
172 | 297 | pytest.main([__file__, "-v"]) |
0 commit comments