1- import asyncio
21import os
32import sys
43import time
1211from astrbot .core .utils .astrbot_path import get_astrbot_path
1312from astrbot .core .utils .io import download_file
1413
15- from .zip_updator import ReleaseInfo , RepoZipUpdator
14+ from .zip_updator import FetchReleaseError , ReleaseInfo , RepoZipUpdator
1615
1716
1817class AstrBotUpdator (RepoZipUpdator ):
@@ -154,25 +153,23 @@ async def get_releases(self, latest: bool = True) -> list:
154153 @staticmethod
155154 def _is_expected_nightly_fetch_error (exc : BaseException ) -> bool :
156155 expected_types = (
157- asyncio . TimeoutError ,
156+ TimeoutError ,
158157 aiohttp .ClientError ,
159158 JSONDecodeError ,
159+ FetchReleaseError ,
160160 )
161-
162- def _matches (error : BaseException | None ) -> bool :
163- if error is None :
164- return False
165- if isinstance (error , expected_types ):
166- return True
167- if str (error ).startswith ("请求失败,状态码:" ):
161+ current : BaseException | None = exc
162+ seen : set [int ] = set ()
163+ while current is not None and id (current ) not in seen :
164+ seen .add (id (current ))
165+ if isinstance (current , expected_types ):
168166 return True
169- return False
170-
171- return (
172- _matches (exc )
173- or _matches (getattr (exc , "__cause__" , None ))
174- or _matches (getattr (exc , "__context__" , None ))
175- )
167+ current = getattr (current , "__cause__" , None ) or getattr (
168+ current ,
169+ "__context__" ,
170+ None ,
171+ )
172+ return False
176173
177174 async def get_nightly_release (self ) -> dict | None :
178175 nightly_release_url = f"{ self .GITHUB_RELEASE_API } /tags/{ self .NIGHTLY_TAG } "
@@ -204,35 +201,43 @@ async def get_releases_with_nightly(self, latest: bool = True) -> list:
204201 releases .insert (0 , nightly_release )
205202 return releases
206203
207- def _resolve_latest (self , releases : list ) -> tuple [str , str ]:
208- latest_release = next (
209- (
210- item
211- for item in releases
212- if item .get ("tag_name" , "" ).lower () != self .NIGHTLY_TAG
213- ),
214- None ,
215- )
216- if latest_release is None :
217- raise Exception ("未找到可用的发布版本。" )
204+ async def _resolve_update_target (
205+ self ,
206+ latest : bool ,
207+ version : str | None ,
208+ ) -> tuple [str , str ]:
209+ version_str = str (version ) if version is not None else ""
210+
211+ if latest :
212+ releases = await self .get_releases (latest = True )
213+ latest_release = next (
214+ (
215+ item
216+ for item in releases
217+ if item .get ("tag_name" , "" ).lower () != self .NIGHTLY_TAG
218+ ),
219+ None ,
220+ )
221+ if latest_release is None :
222+ raise Exception ("未找到可用的发布版本。" )
218223
219- latest_version = latest_release ["tag_name" ]
220- if self .compare_version (VERSION , latest_version ) >= 0 :
221- raise Exception ("当前已经是最新版本。" )
222- return latest_version , latest_release ["zipball_url" ]
224+ latest_version = latest_release ["tag_name" ]
225+ if self .compare_version (VERSION , latest_version ) >= 0 :
226+ raise Exception ("当前已经是最新版本。" )
227+ return latest_version , latest_release ["zipball_url" ]
223228
224- def _resolve_nightly ( self ) -> tuple [ str , str ] :
225- return self .NIGHTLY_TAG , (
226- f"https://github.com/AstrBotDevs/AstrBot/archive/refs/tags/{ self .NIGHTLY_TAG } .zip"
227- )
229+ if version_str . lower () == self . NIGHTLY_TAG :
230+ return self .NIGHTLY_TAG , (
231+ f"https://github.com/AstrBotDevs/AstrBot/archive/refs/tags/{ self .NIGHTLY_TAG } .zip"
232+ )
228233
229- def _resolve_tag (self , releases : list , version_str : str ) -> tuple [str , str ]:
230- for data in releases :
231- if data ["tag_name" ] == version_str :
232- return version_str , data ["zipball_url" ]
233- raise Exception (f"未找到版本号为 { version_str } 的更新文件。" )
234+ if version_str .startswith ("v" ):
235+ releases = await self .get_releases (latest = False )
236+ for data in releases :
237+ if data ["tag_name" ] == version_str :
238+ return version_str , data ["zipball_url" ]
239+ raise Exception (f"未找到版本号为 { version_str } 的更新文件。" )
234240
235- def _resolve_commit (self , version_str : str ) -> tuple [str , str ]:
236241 if len (version_str ) != 40 :
237242 raise Exception ("commit hash 长度不正确,应为 40" )
238243 return (
@@ -246,17 +251,7 @@ async def update(self, reboot=False, latest=True, version=None, proxy="") -> Non
246251 "Error: You are running AstrBot via CLI, please use `pip` or `uv tool upgrade` to update AstrBot."
247252 ) # 避免版本管理混乱
248253
249- version_str = str (version ) if version is not None else ""
250- if latest :
251- releases = await self .get_releases (latest = True )
252- target_version , file_url = self ._resolve_latest (releases )
253- elif version_str .lower () == self .NIGHTLY_TAG :
254- target_version , file_url = self ._resolve_nightly ()
255- elif version_str .startswith ("v" ):
256- releases = await self .get_releases (latest = False )
257- target_version , file_url = self ._resolve_tag (releases , version_str )
258- else :
259- target_version , file_url = self ._resolve_commit (version_str )
254+ target_version , file_url = await self ._resolve_update_target (latest , version )
260255
261256 logger .info (f"准备更新至 AstrBot Core: { target_version } " )
262257
0 commit comments