Skip to content

Commit fec8415

Browse files
committed
refactor: streamline update target resolution and errors
1 parent 26ac634 commit fec8415

4 files changed

Lines changed: 101 additions & 137 deletions

File tree

astrbot/core/updator.py

Lines changed: 67 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
22
import sys
33
import time
4-
from json import JSONDecodeError
54

6-
import aiohttp
75
import psutil
86

97
from astrbot.core import logger
@@ -18,26 +16,17 @@
1816

1917
from .zip_updator import (
2018
PRERELEASE_TAG_REGEX,
21-
FetchReleaseError,
2219
ReleaseInfo,
2320
RepoZipUpdator,
2421
)
2522

2623

27-
class InvalidUpdateTargetError(ValueError):
28-
"""Raised when update target arguments are invalid."""
24+
class AstrBotUpdateError(RuntimeError):
25+
"""Domain error for update-related failures."""
2926

30-
31-
class NoAvailableReleaseError(RuntimeError):
32-
"""Raised when no eligible release can be selected."""
33-
34-
35-
class AlreadyUpToDateError(RuntimeError):
36-
"""Raised when current version is already the latest stable version."""
37-
38-
39-
class UpdateFileNotFoundError(RuntimeError):
40-
"""Raised when no update file exists for a requested tag/version."""
27+
def __init__(self, reason: str, message: str) -> None:
28+
super().__init__(message)
29+
self.reason = reason
4130

4231

4332
class AstrBotUpdator(RepoZipUpdator):
@@ -172,32 +161,20 @@ async def check_update(
172161
consider_prerelease,
173162
)
174163

175-
async def get_releases(self, include_nightly: bool = False) -> list[dict]:
176-
releases = await self.fetch_release_info(self.ASTRBOT_RELEASE_API)
177-
if not include_nightly:
178-
return releases
164+
async def get_releases(self) -> list[dict]:
165+
return await self.fetch_release_info(self.ASTRBOT_RELEASE_API)
179166

167+
async def get_releases_with_nightly(self) -> list[dict]:
168+
releases = await self.get_releases()
180169
nightly_release_url = f"{self.GITHUB_RELEASE_API}/tags/{self.NIGHTLY_TAG}"
181-
expected_error_types = (
182-
TimeoutError,
183-
aiohttp.ClientError,
184-
JSONDecodeError,
185-
FetchReleaseError,
186-
)
187170
try:
188171
nightly_releases = await self.fetch_release_info(nightly_release_url)
189172
except Exception as e:
190-
if isinstance(e, expected_error_types):
191-
logger.warning(
192-
"获取 nightly 发布信息失败,跳过 nightly。"
193-
f"url={nightly_release_url}, error_type={type(e).__name__}, detail={e}",
194-
)
195-
return releases
196-
logger.exception(
197-
"获取 nightly 发布信息出现非预期异常。"
198-
f"url={nightly_release_url}, error_type={type(e).__name__}",
173+
logger.warning(
174+
"获取 nightly 发布信息失败,跳过 nightly。"
175+
f"url={nightly_release_url}, error_type={type(e).__name__}, detail={e}",
199176
)
200-
raise
177+
return releases
201178

202179
if not nightly_releases:
203180
return releases
@@ -215,15 +192,16 @@ def _parse_update_target(
215192
version_str = str(version).strip() if version is not None else ""
216193

217194
if latest and version_str:
218-
raise InvalidUpdateTargetError(
219-
"latest=True 时不能同时指定 version,请将 latest 设为 False。"
195+
raise AstrBotUpdateError(
196+
"invalid_target",
197+
"latest=True 时不能同时指定 version,请将 latest 设为 False。",
220198
)
221199

222200
if latest:
223201
return "latest", ""
224202

225203
if not version_str:
226-
raise InvalidUpdateTargetError("未指定有效的更新目标。")
204+
raise AstrBotUpdateError("invalid_target", "未指定有效的更新目标。")
227205

228206
if version_str.lower() == self.NIGHTLY_TAG:
229207
return "nightly", self.NIGHTLY_TAG
@@ -232,67 +210,63 @@ def _parse_update_target(
232210
return "tag", version_str
233211

234212
if len(version_str) != 40:
235-
raise InvalidUpdateTargetError("commit hash 长度不正确,应为 40")
213+
raise AstrBotUpdateError(
214+
"invalid_target", "commit hash 长度不正确,应为 40"
215+
)
236216
return "commit", version_str
237217

238-
async def _resolve_latest_target(self) -> tuple[str, str]:
239-
releases = await self.get_releases()
240-
latest_release = next(
241-
(
242-
item
243-
for item in releases
244-
if (tag := item.get("tag_name", ""))
245-
and not PRERELEASE_TAG_REGEX.search(tag)
246-
),
247-
None,
248-
)
249-
if latest_release is None:
250-
raise NoAvailableReleaseError("未找到可用的发布版本。")
251-
252-
latest_version = latest_release["tag_name"]
253-
if self.compare_version(VERSION, latest_version) >= 0:
254-
raise AlreadyUpToDateError("当前已经是最新版本。")
255-
return latest_version, latest_release["zipball_url"]
256-
257-
async def _resolve_nightly_target(self) -> tuple[str, str]:
258-
releases = await self.get_releases(include_nightly=True)
259-
nightly_release = next(
260-
(
261-
item
262-
for item in releases
263-
if item.get("tag_name", "").lower() == self.NIGHTLY_TAG
264-
),
265-
None,
266-
)
267-
if nightly_release is not None:
268-
return self.NIGHTLY_TAG, nightly_release["zipball_url"]
269-
return self.NIGHTLY_TAG, (
270-
f"{self.GITHUB_ARCHIVE_BASE}/refs/tags/{self.NIGHTLY_TAG}.zip"
271-
)
272-
273-
async def _resolve_tag_target(self, version_str: str) -> tuple[str, str]:
274-
releases = await self.get_releases()
275-
for data in releases:
276-
if data.get("tag_name") == version_str:
277-
return version_str, data["zipball_url"]
278-
raise UpdateFileNotFoundError(f"未找到版本号为 {version_str} 的更新文件。")
279-
280-
def _resolve_commit_target(self, version_str: str) -> tuple[str, str]:
281-
return version_str, f"{self.GITHUB_ARCHIVE_BASE}/{version_str}.zip"
282-
283218
async def _resolve_update_target(
284219
self,
285220
latest: bool,
286221
version: str | None,
287222
) -> tuple[str, str]:
288223
kind, value = self._parse_update_target(latest, version)
224+
289225
if kind == "latest":
290-
return await self._resolve_latest_target()
226+
releases = await self.get_releases()
227+
latest_release = next(
228+
(
229+
item
230+
for item in releases
231+
if (tag := item.get("tag_name", ""))
232+
and not PRERELEASE_TAG_REGEX.search(tag)
233+
),
234+
None,
235+
)
236+
if latest_release is None:
237+
raise AstrBotUpdateError("no_release", "未找到可用的发布版本。")
238+
latest_version = latest_release["tag_name"]
239+
if self.compare_version(VERSION, latest_version) >= 0:
240+
raise AstrBotUpdateError("up_to_date", "当前已经是最新版本。")
241+
return latest_version, latest_release["zipball_url"]
242+
291243
if kind == "nightly":
292-
return await self._resolve_nightly_target()
244+
releases = await self.get_releases_with_nightly()
245+
nightly_release = next(
246+
(
247+
item
248+
for item in releases
249+
if item.get("tag_name", "").lower() == self.NIGHTLY_TAG
250+
),
251+
None,
252+
)
253+
if nightly_release is not None:
254+
return self.NIGHTLY_TAG, nightly_release["zipball_url"]
255+
return self.NIGHTLY_TAG, (
256+
f"{self.GITHUB_ARCHIVE_BASE}/refs/tags/{self.NIGHTLY_TAG}.zip"
257+
)
258+
293259
if kind == "tag":
294-
return await self._resolve_tag_target(value)
295-
return self._resolve_commit_target(value)
260+
releases = await self.get_releases()
261+
for data in releases:
262+
if data.get("tag_name") == value:
263+
return value, data["zipball_url"]
264+
raise AstrBotUpdateError(
265+
"file_not_found",
266+
f"未找到版本号为 {value} 的更新文件。",
267+
)
268+
269+
return value, f"{self.GITHUB_ARCHIVE_BASE}/{value}.zip"
296270

297271
async def update(self, reboot=False, latest=True, version=None, proxy="") -> None:
298272
if os.environ.get("ASTRBOT_CLI") or os.environ.get("ASTRBOT_LAUNCHER"):
@@ -305,8 +279,10 @@ async def update(self, reboot=False, latest=True, version=None, proxy="") -> Non
305279
latest,
306280
version,
307281
)
308-
except InvalidUpdateTargetError as e:
309-
raise InvalidUpdateTargetError(f"更新参数错误: {e}") from e
282+
except AstrBotUpdateError as e:
283+
if e.reason == "invalid_target":
284+
raise AstrBotUpdateError("invalid_target", f"更新参数错误: {e}") from e
285+
raise
310286

311287
logger.info(f"准备更新至 AstrBot Core: {target_version}")
312288

astrbot/dashboard/routes/update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def check_update(self):
7979

8080
async def get_releases(self):
8181
try:
82-
ret = await self.astrbot_updator.get_releases(include_nightly=True)
82+
ret = await self.astrbot_updator.get_releases_with_nightly()
8383
return Response().ok(ret).__dict__
8484
except Exception as e:
8585
logger.error(f"/api/update/releases: {traceback.format_exc()}")

scripts/release/release_constants_loader.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
import ast
3+
import importlib.util
44
from pathlib import Path
5-
from typing import Any
65

76

87
def _constants_file() -> Path:
@@ -16,42 +15,30 @@ def _constants_file() -> Path:
1615

1716
def load_release_constants(*names: str) -> dict[str, str]:
1817
constants_path = _constants_file()
19-
source = constants_path.read_text(encoding="utf-8")
20-
tree = ast.parse(source, filename=str(constants_path))
21-
22-
wanted = set(names)
23-
values: dict[str, str] = {}
18+
spec = importlib.util.spec_from_file_location(
19+
"astrbot_core_release_constants_tmp",
20+
constants_path,
21+
)
22+
if spec is None or spec.loader is None:
23+
raise RuntimeError(f"Failed to load spec for {constants_path}")
2424

25-
for node in tree.body:
26-
target_name: str | None = None
27-
value_node: Any | None = None
25+
module = importlib.util.module_from_spec(spec)
26+
spec.loader.exec_module(module) # type: ignore[union-attr]
2827

29-
if isinstance(node, ast.Assign):
30-
for target in node.targets:
31-
if isinstance(target, ast.Name):
32-
target_name = target.id
33-
break
34-
value_node = node.value
35-
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
36-
target_name = node.target.id
37-
value_node = node.value
28+
values: dict[str, str] = {}
29+
missing: list[str] = []
3830

39-
if not target_name or target_name not in wanted:
31+
for name in names:
32+
value = getattr(module, name, None)
33+
if not isinstance(value, str):
34+
missing.append(name)
4035
continue
41-
if not isinstance(value_node, ast.Constant) or not isinstance(
42-
value_node.value,
43-
str,
44-
):
36+
value = value.strip()
37+
if not value:
38+
missing.append(name)
4539
continue
40+
values[name] = value
4641

47-
value = value_node.value.strip()
48-
if value:
49-
values[target_name] = value
50-
51-
if len(values) == len(wanted):
52-
break
53-
54-
missing = [name for name in names if name not in values]
5542
if missing:
5643
missing_str = ", ".join(missing)
5744
raise RuntimeError(

tests/unit/test_updator.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from astrbot.core.updator import AstrBotUpdator, InvalidUpdateTargetError
3+
from astrbot.core.updator import AstrBotUpdateError, AstrBotUpdator
44
from astrbot.core.zip_updator import FetchReleaseError, RepoZipUpdator
55

66

@@ -142,11 +142,10 @@ async def test_resolve_update_target_nightly_uses_archive_fallback(monkeypatch):
142142
updator = AstrBotUpdator()
143143
updator.GITHUB_ARCHIVE_BASE = "https://github.com/example-org/example-repo/archive"
144144

145-
async def mock_get_releases(include_nightly: bool = False):
146-
_ = include_nightly
145+
async def mock_get_releases_with_nightly():
147146
return []
148147

149-
monkeypatch.setattr(updator, "get_releases", mock_get_releases)
148+
monkeypatch.setattr(updator, "get_releases_with_nightly", mock_get_releases_with_nightly)
150149

151150
target_version, file_url = await updator._resolve_update_target(
152151
latest=False,
@@ -204,7 +203,7 @@ async def mock_fetch_release_info(url: str):
204203

205204
monkeypatch.setattr(updator, "fetch_release_info", mock_fetch_release_info)
206205

207-
releases = await updator.get_releases(include_nightly=True)
206+
releases = await updator.get_releases_with_nightly()
208207

209208
assert releases[0]["tag_name"] == "nightly"
210209
assert releases[1]["tag_name"] == "v9.9.9"
@@ -238,7 +237,7 @@ async def mock_fetch_release_info(url: str):
238237

239238
monkeypatch.setattr(updator, "fetch_release_info", mock_fetch_release_info)
240239

241-
releases = await updator.get_releases(include_nightly=True)
240+
releases = await updator.get_releases_with_nightly()
242241

243242
nightly_releases = [item for item in releases if item["tag_name"] == "nightly"]
244243
assert len(nightly_releases) == 1
@@ -288,8 +287,7 @@ async def test_resolve_update_target_skips_prerelease_tags_for_latest(monkeypatc
288287
},
289288
]
290289

291-
async def mock_get_releases(include_nightly: bool = False):
292-
assert include_nightly is False
290+
async def mock_get_releases():
293291
return releases
294292

295293
monkeypatch.setattr(updator, "get_releases", mock_get_releases)
@@ -308,7 +306,7 @@ async def test_resolve_update_target_rejects_version_when_latest_true():
308306
updator = AstrBotUpdator()
309307

310308
with pytest.raises(
311-
InvalidUpdateTargetError,
309+
AstrBotUpdateError,
312310
match="latest=True 时不能同时指定 version,请将 latest 设为 False。",
313311
):
314312
await updator._resolve_update_target(
@@ -337,13 +335,15 @@ async def mock_fetch_release_info(url: str):
337335

338336
monkeypatch.setattr(updator, "fetch_release_info", mock_fetch_release_info)
339337

340-
releases = await updator.get_releases(include_nightly=True)
338+
releases = await updator.get_releases_with_nightly()
341339
assert len(releases) == 1
342340
assert releases[0]["tag_name"] == "v9.9.9"
343341

344342

345343
@pytest.mark.asyncio
346-
async def test_get_releases_with_nightly_raises_for_unexpected_nightly_error(monkeypatch):
344+
async def test_get_releases_with_nightly_falls_back_on_unexpected_nightly_error(
345+
monkeypatch,
346+
):
347347
updator = AstrBotUpdator()
348348
stable_release = {
349349
"version": "v9.9.9",
@@ -362,8 +362,9 @@ async def mock_fetch_release_info(url: str):
362362

363363
monkeypatch.setattr(updator, "fetch_release_info", mock_fetch_release_info)
364364

365-
with pytest.raises(KeyError):
366-
await updator.get_releases(include_nightly=True)
365+
releases = await updator.get_releases_with_nightly()
366+
assert len(releases) == 1
367+
assert releases[0]["tag_name"] == "v9.9.9"
367368

368369

369370
@pytest.mark.asyncio

0 commit comments

Comments
 (0)