Skip to content

Commit e51a34e

Browse files
committed
refactor: linearize update target resolution
1 parent fb74338 commit e51a34e

2 files changed

Lines changed: 70 additions & 30 deletions

File tree

astrbot/core/updator.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,7 @@ def _is_expected_nightly_fetch_error(exc: BaseException) -> bool:
165165
JSONDecodeError,
166166
FetchReleaseError,
167167
)
168-
if isinstance(exc, expected_types):
169-
return True
170-
171-
cause = getattr(exc, "__cause__", None) or getattr(exc, "__context__", None)
172-
return isinstance(cause, expected_types)
168+
return isinstance(exc, expected_types)
173169

174170
async def get_nightly_release(self) -> dict | None:
175171
nightly_release_url = f"{self.GITHUB_RELEASE_API}/tags/{self.NIGHTLY_TAG}"
@@ -192,27 +188,20 @@ async def get_nightly_release(self) -> dict | None:
192188
return None
193189
return nightly_releases[0]
194190

195-
async def get_releases_with_nightly(self) -> list:
191+
async def _fetch_all_releases(self, include_nightly: bool) -> list:
196192
releases = await self.get_releases()
193+
if not include_nightly:
194+
return releases
195+
197196
nightly_release = await self.get_nightly_release()
198197
if nightly_release and all(
199198
item.get("tag_name") != self.NIGHTLY_TAG for item in releases
200199
):
201200
releases.insert(0, nightly_release)
202201
return releases
203202

204-
def _resolve_nightly_target(self) -> tuple[str, str]:
205-
archive_base = self.GITHUB_ARCHIVE_BASE
206-
return self.NIGHTLY_TAG, (f"{archive_base}/refs/tags/{self.NIGHTLY_TAG}.zip")
207-
208-
def _resolve_commit_target(self, version_str: str) -> tuple[str, str]:
209-
if len(version_str) != 40:
210-
raise Exception("commit hash 长度不正确,应为 40")
211-
archive_base = self.GITHUB_ARCHIVE_BASE
212-
return (
213-
version_str,
214-
f"{archive_base}/{version_str}.zip",
215-
)
203+
async def get_releases_with_nightly(self) -> list:
204+
return await self._fetch_all_releases(include_nightly=True)
216205

217206
async def _resolve_update_target(
218207
self,
@@ -221,14 +210,25 @@ async def _resolve_update_target(
221210
) -> tuple[str, str]:
222211
version_str = str(version).strip() if version is not None else ""
223212

213+
if (
214+
not latest
215+
and version_str
216+
and version_str.lower() != self.NIGHTLY_TAG
217+
and not version_str.startswith("v")
218+
):
219+
if len(version_str) != 40:
220+
raise Exception("commit hash 长度不正确,应为 40")
221+
return version_str, f"{self.GITHUB_ARCHIVE_BASE}/{version_str}.zip"
222+
223+
include_nightly = version_str.lower() == self.NIGHTLY_TAG
224+
releases = await self._fetch_all_releases(include_nightly=include_nightly)
225+
224226
if latest:
225-
releases = await self.get_releases()
226227
latest_release = next(
227228
(
228229
item
229230
for item in releases
230231
if (tag := item.get("tag_name", ""))
231-
and tag.lower() != self.NIGHTLY_TAG
232232
and not PRERELEASE_TAG_REGEX.search(tag)
233233
),
234234
None,
@@ -242,16 +242,30 @@ async def _resolve_update_target(
242242
return latest_version, latest_release["zipball_url"]
243243

244244
if version_str.lower() == self.NIGHTLY_TAG:
245-
return self._resolve_nightly_target()
245+
nightly_release = next(
246+
(
247+
item
248+
for item in releases
249+
if item.get("tag_name", "").lower() == self.NIGHTLY_TAG
250+
),
251+
None,
252+
)
253+
if nightly_release is not None:
254+
return self.NIGHTLY_TAG, nightly_release["zipball_url"]
255+
return self.NIGHTLY_TAG, (
256+
f"{self.GITHUB_ARCHIVE_BASE}/refs/tags/{self.NIGHTLY_TAG}.zip"
257+
)
246258

247259
if version_str.startswith("v"):
248-
releases = await self.get_releases()
249260
for data in releases:
250261
if data.get("tag_name") == version_str:
251262
return version_str, data["zipball_url"]
252263
raise Exception(f"未找到版本号为 {version_str} 的更新文件。")
253264

254-
return self._resolve_commit_target(version_str)
265+
if version_str:
266+
raise Exception("commit hash 长度不正确,应为 40")
267+
268+
raise Exception("未指定有效的更新目标。")
255269

256270
async def update(self, reboot=False, latest=True, version=None, proxy="") -> None:
257271
if os.environ.get("ASTRBOT_CLI") or os.environ.get("ASTRBOT_LAUNCHER"):

tests/unit/test_updator.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,20 @@ async def mock_download_file(url: str, path: str, *args, **kwargs):
103103
captured["url"] = url
104104
captured["path"] = path
105105

106-
async def mock_fetch_release_info(*args, **kwargs):
107-
raise AssertionError("nightly update should not fetch stable release list")
106+
async def mock_fetch_release_info(url: str):
107+
if url == updator.ASTRBOT_RELEASE_API:
108+
return []
109+
if url == f"{updator.GITHUB_RELEASE_API}/tags/{updator.NIGHTLY_TAG}":
110+
return [
111+
{
112+
"version": "nightly",
113+
"published_at": "2026-03-02T00:00:00Z",
114+
"body": "nightly",
115+
"tag_name": "nightly",
116+
"zipball_url": "https://example.com/nightly.zip",
117+
}
118+
]
119+
raise AssertionError(f"unexpected URL: {url}")
108120

109121
def mock_unzip_file(zip_path: str, target_dir: str):
110122
captured["zip_path"] = zip_path
@@ -119,30 +131,44 @@ def mock_unzip_file(zip_path: str, target_dir: str):
119131

120132
await updator.update(latest=False, version="nightly")
121133

122-
assert captured["url"].endswith("/archive/refs/tags/nightly.zip")
134+
assert captured["url"] == "https://example.com/nightly.zip"
123135
assert captured["path"] == "temp.zip"
124136
assert captured["zip_path"] == "temp.zip"
125137
assert captured["target_dir"] == str(tmp_path)
126138

127139

128-
def test_resolve_nightly_target_uses_archive_base():
140+
@pytest.mark.asyncio
141+
async def test_resolve_update_target_nightly_uses_archive_fallback(monkeypatch):
129142
updator = AstrBotUpdator()
130143
updator.GITHUB_ARCHIVE_BASE = "https://github.com/example-org/example-repo/archive"
131144

132-
target_version, file_url = updator._resolve_nightly_target()
145+
async def mock_fetch_all_releases(*, include_nightly: bool):
146+
_ = include_nightly
147+
return []
148+
149+
monkeypatch.setattr(updator, "_fetch_all_releases", mock_fetch_all_releases)
150+
151+
target_version, file_url = await updator._resolve_update_target(
152+
latest=False,
153+
version="nightly",
154+
)
133155
assert target_version == "nightly"
134156
assert (
135157
file_url
136158
== "https://github.com/example-org/example-repo/archive/refs/tags/nightly.zip"
137159
)
138160

139161

140-
def test_resolve_commit_target_uses_archive_base():
162+
@pytest.mark.asyncio
163+
async def test_resolve_update_target_commit_uses_archive_base():
141164
updator = AstrBotUpdator()
142165
updator.GITHUB_ARCHIVE_BASE = "https://github.com/example-org/example-repo/archive"
143166
version_str = "1234567890123456789012345678901234567890"
144167

145-
target_version, file_url = updator._resolve_commit_target(version_str)
168+
target_version, file_url = await updator._resolve_update_target(
169+
latest=False,
170+
version=version_str,
171+
)
146172
assert target_version == version_str
147173
assert (
148174
file_url

0 commit comments

Comments
 (0)