Skip to content

Commit 31d53ed

Browse files
committed
refactor: standardize provider test method implementation
- Updated the `test` method in all provider classes to remove return values and raise exceptions for failure cases, enhancing clarity and consistency. - Adjusted related logic in the dashboard and command routes to align with the new `test` method behavior, simplifying error handling.
1 parent 2ba0460 commit 31d53ed

3 files changed

Lines changed: 33 additions & 75 deletions

File tree

astrbot/core/provider/provider.py

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ def meta(self) -> ProviderMeta:
4545
)
4646
return meta
4747

48-
async def test(self) -> bool:
48+
async def test(self):
4949
"""test the provider is a
5050
51-
Returns:
52-
bool: the provider is available
51+
raises:
52+
Exception: if the provider is not available
5353
"""
54-
return True
54+
...
5555

5656

5757
class Provider(AbstractProvider):
@@ -175,15 +175,11 @@ def _ensure_message_to_dicts(
175175

176176
return dicts
177177

178-
async def test(self, timeout: float = 45.0) -> bool:
179-
try:
180-
response = await asyncio.wait_for(
181-
self.text_chat(prompt="REPLY `PONG` ONLY"),
182-
timeout=timeout,
183-
)
184-
return response is not None
185-
except Exception:
186-
return False
178+
async def test(self, timeout: float = 45.0):
179+
await asyncio.wait_for(
180+
self.text_chat(prompt="REPLY `PONG` ONLY"),
181+
timeout=timeout,
182+
)
187183

188184

189185
class STTProvider(AbstractProvider):
@@ -197,19 +193,13 @@ async def get_text(self, audio_url: str) -> str:
197193
"""获取音频的文本"""
198194
raise NotImplementedError
199195

200-
async def test(self) -> bool:
201-
try:
202-
sample_audio_path = os.path.join(
203-
get_astrbot_path(),
204-
"samples",
205-
"stt_health_check.wav",
206-
)
207-
if not os.path.exists(sample_audio_path):
208-
return False
209-
text_result = await self.get_text(sample_audio_path)
210-
return isinstance(text_result, str) and bool(text_result)
211-
except Exception:
212-
return False
196+
async def test(self):
197+
sample_audio_path = os.path.join(
198+
get_astrbot_path(),
199+
"samples",
200+
"stt_health_check.wav",
201+
)
202+
await self.get_text(sample_audio_path)
213203

214204

215205
class TTSProvider(AbstractProvider):
@@ -223,12 +213,8 @@ async def get_audio(self, text: str) -> str:
223213
"""获取文本的音频,返回音频文件路径"""
224214
raise NotImplementedError
225215

226-
async def test(self) -> bool:
227-
try:
228-
audio_result = await self.get_audio("hi")
229-
return isinstance(audio_result, str) and bool(audio_result)
230-
except Exception:
231-
return False
216+
async def test(self):
217+
await self.get_audio("hi")
232218

233219

234220
class EmbeddingProvider(AbstractProvider):
@@ -252,14 +238,8 @@ def get_dim(self) -> int:
252238
"""获取向量的维度"""
253239
...
254240

255-
async def test(self) -> bool:
256-
try:
257-
embedding_result = await self.get_embedding("health_check")
258-
return isinstance(embedding_result, list) and (
259-
not embedding_result or isinstance(embedding_result[0], float)
260-
)
261-
except Exception:
262-
return False
241+
async def test(self):
242+
await self.get_embedding("astrbot")
263243

264244
async def get_embeddings_batch(
265245
self,
@@ -345,9 +325,7 @@ async def rerank(
345325
"""获取查询和文档的重排序分数"""
346326
...
347327

348-
async def test(self) -> bool:
349-
try:
350-
await self.rerank("Apple", documents=["apple", "banana"])
351-
return True
352-
except Exception:
353-
return False
328+
async def test(self):
329+
result = await self.rerank("Apple", documents=["apple", "banana"])
330+
if not result:
331+
raise Exception("Rerank provider test failed, no results returned")

astrbot/dashboard/routes/config.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -354,17 +354,11 @@ async def _test_single_provider(self, provider):
354354
)
355355

356356
try:
357-
result = await provider.test()
358-
if result:
359-
status_info["status"] = "available"
360-
logger.info(
361-
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
362-
)
363-
else:
364-
status_info["error"] = "Provider test returned False."
365-
logger.warning(
366-
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
367-
)
357+
await provider.test()
358+
status_info["status"] = "available"
359+
logger.info(
360+
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
361+
)
368362
except Exception as e:
369363
error_message = str(e)
370364
status_info["error"] = error_message

packages/astrbot/commands/provider.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,11 @@ async def _test_provider_capability(self, provider):
3434
provider_capability_type = meta.provider_type
3535

3636
try:
37-
result = await provider.test()
38-
if result:
39-
return True, None, None
37+
await provider.test()
38+
return True, None, None
39+
except Exception as e:
4040
err_code = "TEST_FAILED"
41-
err_reason = "Provider test returned False"
42-
self._log_reachability_failure(
43-
provider, provider_capability_type, err_code, err_reason
44-
)
45-
return False, err_code, err_reason
46-
except Exception as exc:
47-
err_code = (
48-
getattr(exc, "status_code", None)
49-
or getattr(exc, "code", None)
50-
or getattr(exc, "error_code", None)
51-
)
52-
err_reason = str(exc)
53-
if not err_code:
54-
err_code = exc.__class__.__name__
55-
41+
err_reason = str(e)
5642
self._log_reachability_failure(
5743
provider, provider_capability_type, err_code, err_reason
5844
)

0 commit comments

Comments
 (0)