Skip to content

Commit fb74338

Browse files
committed
refactor: simplify update target resolution flow
1 parent 64b556b commit fb74338

File tree

2 files changed

+36
-53
lines changed

2 files changed

+36
-53
lines changed

astrbot/core/updator.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import re
32
import sys
43
import time
54
from json import JSONDecodeError
@@ -34,6 +33,7 @@ def __init__(self, repo_mirror: str = "") -> None:
3433
self.GITHUB_RELEASE_API = (
3534
"https://api.github.com/repos/AstrBotDevs/AstrBot/releases"
3635
)
36+
self.GITHUB_ARCHIVE_BASE = "https://github.com/AstrBotDevs/AstrBot/archive"
3737
self.NIGHTLY_TAG = NIGHTLY_TAG
3838

3939
def terminate_child_processes(self) -> None:
@@ -201,51 +201,14 @@ async def get_releases_with_nightly(self) -> list:
201201
releases.insert(0, nightly_release)
202202
return releases
203203

204-
async def _resolve_latest_target(self) -> tuple[str, str]:
205-
releases = await self.get_releases()
206-
latest_release = next(
207-
(
208-
item
209-
for item in releases
210-
if (tag := item.get("tag_name", ""))
211-
and tag.lower() != self.NIGHTLY_TAG
212-
and not PRERELEASE_TAG_REGEX.search(tag)
213-
),
214-
None,
215-
)
216-
if latest_release is None:
217-
raise Exception("未找到可用的发布版本。")
218-
219-
latest_version = latest_release["tag_name"]
220-
if self.compare_version(VERSION, latest_version) >= 0:
221-
raise Exception("当前已经是最新版本。")
222-
return latest_version, latest_release["zipball_url"]
223-
224-
def _resolve_github_archive_base(self) -> str:
225-
match = re.search(
226-
r"/repos/([^/]+)/([^/]+)/releases/?$",
227-
self.GITHUB_RELEASE_API,
228-
)
229-
if match is None:
230-
raise Exception("GITHUB_RELEASE_API 格式不正确,无法解析仓库信息。")
231-
owner, repo = match.groups()
232-
return f"https://github.com/{owner}/{repo}/archive"
233-
234204
def _resolve_nightly_target(self) -> tuple[str, str]:
235-
archive_base = self._resolve_github_archive_base()
205+
archive_base = self.GITHUB_ARCHIVE_BASE
236206
return self.NIGHTLY_TAG, (f"{archive_base}/refs/tags/{self.NIGHTLY_TAG}.zip")
237207

238-
async def _resolve_tag_target(self, version_str: str) -> tuple[str, str]:
239-
releases = await self.get_releases()
240-
for data in releases:
241-
if data["tag_name"] == version_str:
242-
return version_str, data["zipball_url"]
243-
raise Exception(f"未找到版本号为 {version_str} 的更新文件。")
244-
245208
def _resolve_commit_target(self, version_str: str) -> tuple[str, str]:
246209
if len(version_str) != 40:
247210
raise Exception("commit hash 长度不正确,应为 40")
248-
archive_base = self._resolve_github_archive_base()
211+
archive_base = self.GITHUB_ARCHIVE_BASE
249212
return (
250213
version_str,
251214
f"{archive_base}/{version_str}.zip",
@@ -256,16 +219,37 @@ async def _resolve_update_target(
256219
latest: bool,
257220
version: str | None,
258221
) -> tuple[str, str]:
259-
version_str = str(version) if version is not None else ""
222+
version_str = str(version).strip() if version is not None else ""
260223

261224
if latest:
262-
return await self._resolve_latest_target()
225+
releases = await self.get_releases()
226+
latest_release = next(
227+
(
228+
item
229+
for item in releases
230+
if (tag := item.get("tag_name", ""))
231+
and tag.lower() != self.NIGHTLY_TAG
232+
and not PRERELEASE_TAG_REGEX.search(tag)
233+
),
234+
None,
235+
)
236+
if latest_release is None:
237+
raise Exception("未找到可用的发布版本。")
238+
239+
latest_version = latest_release["tag_name"]
240+
if self.compare_version(VERSION, latest_version) >= 0:
241+
raise Exception("当前已经是最新版本。")
242+
return latest_version, latest_release["zipball_url"]
263243

264244
if version_str.lower() == self.NIGHTLY_TAG:
265245
return self._resolve_nightly_target()
266246

267247
if version_str.startswith("v"):
268-
return await self._resolve_tag_target(version_str)
248+
releases = await self.get_releases()
249+
for data in releases:
250+
if data.get("tag_name") == version_str:
251+
return version_str, data["zipball_url"]
252+
raise Exception(f"未找到版本号为 {version_str} 的更新文件。")
269253

270254
return self._resolve_commit_target(version_str)
271255

tests/unit/test_updator.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,9 @@ def mock_unzip_file(zip_path: str, target_dir: str):
125125
assert captured["target_dir"] == str(tmp_path)
126126

127127

128-
def test_resolve_nightly_target_uses_repo_from_release_api():
128+
def test_resolve_nightly_target_uses_archive_base():
129129
updator = AstrBotUpdator()
130-
updator.GITHUB_RELEASE_API = (
131-
"https://api.github.com/repos/example-org/example-repo/releases"
132-
)
130+
updator.GITHUB_ARCHIVE_BASE = "https://github.com/example-org/example-repo/archive"
133131

134132
target_version, file_url = updator._resolve_nightly_target()
135133
assert target_version == "nightly"
@@ -139,11 +137,9 @@ def test_resolve_nightly_target_uses_repo_from_release_api():
139137
)
140138

141139

142-
def test_resolve_commit_target_uses_repo_from_release_api():
140+
def test_resolve_commit_target_uses_archive_base():
143141
updator = AstrBotUpdator()
144-
updator.GITHUB_RELEASE_API = (
145-
"https://api.github.com/repos/example-org/example-repo/releases"
146-
)
142+
updator.GITHUB_ARCHIVE_BASE = "https://github.com/example-org/example-repo/archive"
147143
version_str = "1234567890123456789012345678901234567890"
148144

149145
target_version, file_url = updator._resolve_commit_target(version_str)
@@ -247,7 +243,7 @@ async def mock_fetch_release_info(url: str):
247243

248244

249245
@pytest.mark.asyncio
250-
async def test_resolve_latest_target_skips_prerelease_tags(monkeypatch):
246+
async def test_resolve_update_target_skips_prerelease_tags_for_latest(monkeypatch):
251247
updator = AstrBotUpdator()
252248
releases = [
253249
{
@@ -272,7 +268,10 @@ async def mock_get_releases():
272268
monkeypatch.setattr(updator, "get_releases", mock_get_releases)
273269
monkeypatch.setattr(updator, "compare_version", lambda _current, _target: -1)
274270

275-
target_version, file_url = await updator._resolve_latest_target()
271+
target_version, file_url = await updator._resolve_update_target(
272+
latest=True,
273+
version=None,
274+
)
276275
assert target_version == "v9.9.8"
277276
assert file_url == "https://example.com/stable.zip"
278277

0 commit comments

Comments
 (0)