|
6 | 6 | 3. ``_parse_openai_completion`` — second-pass cleaning on assembled text |
7 | 7 | """ |
8 | 8 |
|
| 9 | +from collections.abc import AsyncGenerator |
9 | 10 | from unittest.mock import AsyncMock, MagicMock, patch |
10 | 11 |
|
11 | 12 | import pytest |
| 13 | +import pytest_asyncio |
12 | 14 |
|
13 | 15 | from astrbot.core.agent.tool import ( |
14 | 16 | ToolSet, # noqa: F401 – ensures the module is importable |
@@ -184,134 +186,112 @@ class TestParseOpenAICompletionCleaning: |
184 | 186 | correctly applies the extra GLM cleaning pass on top. |
185 | 187 | """ |
186 | 188 |
|
| 189 | + @pytest_asyncio.fixture |
| 190 | + async def provider(self) -> AsyncGenerator[ProviderZhipu, None]: |
| 191 | + p = _make_provider() |
| 192 | + yield p |
| 193 | + await p.terminate() |
| 194 | + |
187 | 195 | @pytest.mark.asyncio |
188 | | - async def test_null_token_content_becomes_empty(self): |
| 196 | + async def test_null_token_content_becomes_empty(self, provider: ProviderZhipu): |
189 | 197 | """content='\\n<None>' (real API response) should produce an empty reply.""" |
190 | | - provider = _make_provider() |
191 | | - try: |
192 | | - fake_completion = MagicMock() |
193 | | - parent_response = _make_llm_response("\n<None>") |
194 | | - |
195 | | - with patch.object( |
196 | | - ProviderOpenAIOfficial, |
197 | | - "_parse_openai_completion", |
198 | | - new=AsyncMock(return_value=parent_response), |
199 | | - ): |
200 | | - result = await provider._parse_openai_completion(fake_completion, None) |
201 | | - |
202 | | - assert result.completion_text == "" |
203 | | - finally: |
204 | | - await provider.terminate() |
| 198 | + fake_completion = MagicMock() |
| 199 | + parent_response = _make_llm_response("\n<None>") |
| 200 | + |
| 201 | + with patch.object( |
| 202 | + ProviderOpenAIOfficial, |
| 203 | + "_parse_openai_completion", |
| 204 | + new=AsyncMock(return_value=parent_response), |
| 205 | + ): |
| 206 | + result = await provider._parse_openai_completion(fake_completion, None) |
| 207 | + |
| 208 | + assert result.completion_text == "" |
205 | 209 |
|
206 | 210 | @pytest.mark.asyncio |
207 | | - async def test_endoftext_token_stripped_from_end(self): |
208 | | - provider = _make_provider() |
209 | | - try: |
210 | | - parent_response = _make_llm_response("当然可以!<|endoftext|>") |
211 | | - |
212 | | - with patch.object( |
213 | | - ProviderOpenAIOfficial, |
214 | | - "_parse_openai_completion", |
215 | | - new=AsyncMock(return_value=parent_response), |
216 | | - ): |
217 | | - result = await provider._parse_openai_completion(MagicMock(), None) |
218 | | - |
219 | | - assert result.completion_text == "当然可以!" |
220 | | - finally: |
221 | | - await provider.terminate() |
| 211 | + async def test_endoftext_token_stripped_from_end(self, provider: ProviderZhipu): |
| 212 | + parent_response = _make_llm_response("当然可以!<|endoftext|>") |
| 213 | + |
| 214 | + with patch.object( |
| 215 | + ProviderOpenAIOfficial, |
| 216 | + "_parse_openai_completion", |
| 217 | + new=AsyncMock(return_value=parent_response), |
| 218 | + ): |
| 219 | + result = await provider._parse_openai_completion(MagicMock(), None) |
| 220 | + |
| 221 | + assert result.completion_text == "当然可以!" |
222 | 222 |
|
223 | 223 | @pytest.mark.asyncio |
224 | | - async def test_assistant_role_token_prefix_stripped(self): |
225 | | - provider = _make_provider() |
226 | | - try: |
227 | | - parent_response = _make_llm_response("<|assistant|>我是一个AI助手。") |
228 | | - |
229 | | - with patch.object( |
230 | | - ProviderOpenAIOfficial, |
231 | | - "_parse_openai_completion", |
232 | | - new=AsyncMock(return_value=parent_response), |
233 | | - ): |
234 | | - result = await provider._parse_openai_completion(MagicMock(), None) |
235 | | - |
236 | | - assert result.completion_text == "我是一个AI助手。" |
237 | | - finally: |
238 | | - await provider.terminate() |
| 224 | + async def test_assistant_role_token_prefix_stripped(self, provider: ProviderZhipu): |
| 225 | + parent_response = _make_llm_response("<|assistant|>我是一个AI助手。") |
| 226 | + |
| 227 | + with patch.object( |
| 228 | + ProviderOpenAIOfficial, |
| 229 | + "_parse_openai_completion", |
| 230 | + new=AsyncMock(return_value=parent_response), |
| 231 | + ): |
| 232 | + result = await provider._parse_openai_completion(MagicMock(), None) |
| 233 | + |
| 234 | + assert result.completion_text == "我是一个AI助手。" |
239 | 235 |
|
240 | 236 | @pytest.mark.asyncio |
241 | | - async def test_normal_content_unchanged(self): |
| 237 | + async def test_normal_content_unchanged(self, provider: ProviderZhipu): |
242 | 238 | """Normal GLM replies must not be modified.""" |
243 | | - provider = _make_provider() |
244 | | - try: |
245 | | - normal = "好的,我来帮你解答这个问题。" |
246 | | - parent_response = _make_llm_response(normal) |
247 | | - |
248 | | - with patch.object( |
249 | | - ProviderOpenAIOfficial, |
250 | | - "_parse_openai_completion", |
251 | | - new=AsyncMock(return_value=parent_response), |
252 | | - ): |
253 | | - result = await provider._parse_openai_completion(MagicMock(), None) |
254 | | - |
255 | | - assert result.completion_text == normal |
256 | | - finally: |
257 | | - await provider.terminate() |
| 239 | + normal = "好的,我来帮你解答这个问题。" |
| 240 | + parent_response = _make_llm_response(normal) |
| 241 | + |
| 242 | + with patch.object( |
| 243 | + ProviderOpenAIOfficial, |
| 244 | + "_parse_openai_completion", |
| 245 | + new=AsyncMock(return_value=parent_response), |
| 246 | + ): |
| 247 | + result = await provider._parse_openai_completion(MagicMock(), None) |
| 248 | + |
| 249 | + assert result.completion_text == normal |
258 | 250 |
|
259 | 251 | @pytest.mark.asyncio |
260 | | - async def test_empty_completion_text_not_modified(self): |
| 252 | + async def test_empty_completion_text_not_modified(self, provider: ProviderZhipu): |
261 | 253 | """When the base class returns empty completion_text, don't error out.""" |
262 | | - provider = _make_provider() |
263 | | - try: |
264 | | - parent_response = LLMResponse("assistant") |
265 | | - parent_response.result_chain = None |
266 | | - parent_response._completion_text = "" |
267 | | - |
268 | | - with patch.object( |
269 | | - ProviderOpenAIOfficial, |
270 | | - "_parse_openai_completion", |
271 | | - new=AsyncMock(return_value=parent_response), |
272 | | - ): |
273 | | - result = await provider._parse_openai_completion(MagicMock(), None) |
274 | | - |
275 | | - assert result.completion_text == "" |
276 | | - finally: |
277 | | - await provider.terminate() |
| 254 | + parent_response = LLMResponse("assistant") |
| 255 | + parent_response.result_chain = None |
| 256 | + parent_response._completion_text = "" |
| 257 | + |
| 258 | + with patch.object( |
| 259 | + ProviderOpenAIOfficial, |
| 260 | + "_parse_openai_completion", |
| 261 | + new=AsyncMock(return_value=parent_response), |
| 262 | + ): |
| 263 | + result = await provider._parse_openai_completion(MagicMock(), None) |
| 264 | + |
| 265 | + assert result.completion_text == "" |
278 | 266 |
|
279 | 267 | @pytest.mark.asyncio |
280 | | - async def test_reasoning_content_preserved(self): |
| 268 | + async def test_reasoning_content_preserved(self, provider: ProviderZhipu): |
281 | 269 | """Cleaning must not touch reasoning_content.""" |
282 | | - provider = _make_provider() |
283 | | - try: |
284 | | - parent_response = _make_llm_response("\n<None>") |
285 | | - parent_response.reasoning_content = "思考过程:用户打了招呼,不需要回复。" |
286 | | - |
287 | | - with patch.object( |
288 | | - ProviderOpenAIOfficial, |
289 | | - "_parse_openai_completion", |
290 | | - new=AsyncMock(return_value=parent_response), |
291 | | - ): |
292 | | - result = await provider._parse_openai_completion(MagicMock(), None) |
293 | | - |
294 | | - assert result.completion_text == "" |
295 | | - assert "思考过程" in result.reasoning_content |
296 | | - finally: |
297 | | - await provider.terminate() |
| 270 | + parent_response = _make_llm_response("\n<None>") |
| 271 | + parent_response.reasoning_content = "思考过程:用户打了招呼,不需要回复。" |
| 272 | + |
| 273 | + with patch.object( |
| 274 | + ProviderOpenAIOfficial, |
| 275 | + "_parse_openai_completion", |
| 276 | + new=AsyncMock(return_value=parent_response), |
| 277 | + ): |
| 278 | + result = await provider._parse_openai_completion(MagicMock(), None) |
| 279 | + |
| 280 | + assert result.completion_text == "" |
| 281 | + assert "思考过程" in result.reasoning_content |
298 | 282 |
|
299 | 283 | @pytest.mark.asyncio |
300 | | - async def test_other_response_fields_preserved(self): |
| 284 | + async def test_other_response_fields_preserved(self, provider: ProviderZhipu): |
301 | 285 | """id, usage and other metadata must survive the cleaning pass.""" |
302 | | - provider = _make_provider() |
303 | | - try: |
304 | | - parent_response = _make_llm_response("普通回复") |
305 | | - parent_response.id = "cmp-test-id-123" |
306 | | - |
307 | | - with patch.object( |
308 | | - ProviderOpenAIOfficial, |
309 | | - "_parse_openai_completion", |
310 | | - new=AsyncMock(return_value=parent_response), |
311 | | - ): |
312 | | - result = await provider._parse_openai_completion(MagicMock(), None) |
313 | | - |
314 | | - assert result.id == "cmp-test-id-123" |
315 | | - assert result.completion_text == "普通回复" |
316 | | - finally: |
317 | | - await provider.terminate() |
| 286 | + parent_response = _make_llm_response("普通回复") |
| 287 | + parent_response.id = "cmp-test-id-123" |
| 288 | + |
| 289 | + with patch.object( |
| 290 | + ProviderOpenAIOfficial, |
| 291 | + "_parse_openai_completion", |
| 292 | + new=AsyncMock(return_value=parent_response), |
| 293 | + ): |
| 294 | + result = await provider._parse_openai_completion(MagicMock(), None) |
| 295 | + |
| 296 | + assert result.id == "cmp-test-id-123" |
| 297 | + assert result.completion_text == "普通回复" |
0 commit comments