Skip to content

Commit ff9e20f

Browse files
committed
BUG: worker dead with xoscar v0.8.0
1 parent 4e0de86 commit ff9e20f

1 file changed

Lines changed: 69 additions & 46 deletions

File tree

xinference/core/worker.py

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
XINFERENCE_DISABLE_METRICS,
5252
XINFERENCE_ENABLE_VIRTUAL_ENV,
5353
XINFERENCE_HEALTH_CHECK_INTERVAL,
54+
XINFERENCE_HEALTH_CHECK_TIMEOUT,
5455
XINFERENCE_VIRTUAL_ENV_DIR,
5556
XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED,
5657
)
@@ -191,6 +192,14 @@ def __init__(
191192

192193
self._lock = asyncio.Lock()
193194

195+
async def _reset_supervisor_refs(self):
196+
async with self._lock:
197+
self._supervisor_ref = None
198+
self._status_guard_ref = None
199+
self._event_collector_ref = None
200+
self._cache_tracker_ref = None
201+
self._progress_tracker_ref = None
202+
194203
async def recover_sub_pool(self, address):
195204
logger.warning("Process %s is down.", address)
196205
# Xoscar does not remove the address from sub_processes.
@@ -437,52 +446,53 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
437446
"""
438447
from .supervisor import SupervisorActor
439448

440-
if self._supervisor_ref is not None:
441-
return self._supervisor_ref
442-
supervisor_ref = await xo.actor_ref( # type: ignore
443-
address=self._supervisor_address, uid=SupervisorActor.default_uid()
444-
)
445-
# Prevent concurrent operations leads to double initialization, check again.
446-
if self._supervisor_ref is not None:
449+
async with self._lock:
450+
if self._supervisor_ref is not None:
451+
return self._supervisor_ref
452+
supervisor_ref = await xo.actor_ref( # type: ignore
453+
address=self._supervisor_address, uid=SupervisorActor.default_uid()
454+
)
455+
# Prevent concurrent operations leads to double initialization, check again.
456+
if self._supervisor_ref is not None:
457+
return self._supervisor_ref
458+
self._supervisor_ref = supervisor_ref
459+
if add_worker and len(self._model_uid_to_model) == 0:
460+
# Newly started (or restarted), has no model, notify supervisor
461+
await self._supervisor_ref.add_worker(self.address)
462+
logger.info("Connected to supervisor as a fresh worker")
463+
464+
self._status_guard_ref = await xo.actor_ref(
465+
address=self._supervisor_address, uid=StatusGuardActor.default_uid()
466+
)
467+
self._event_collector_ref = await xo.actor_ref(
468+
address=self._supervisor_address, uid=EventCollectorActor.default_uid()
469+
)
470+
self._cache_tracker_ref = await xo.actor_ref(
471+
address=self._supervisor_address, uid=CacheTrackerActor.default_uid()
472+
)
473+
self._progress_tracker_ref = None
474+
# cache_tracker is on supervisor
475+
from ..model.audio import get_audio_model_descriptions
476+
from ..model.embedding import get_embedding_model_descriptions
477+
from ..model.flexible import get_flexible_model_descriptions
478+
from ..model.image import get_image_model_descriptions
479+
from ..model.llm import get_llm_version_infos
480+
from ..model.rerank import get_rerank_model_descriptions
481+
from ..model.video import get_video_model_descriptions
482+
483+
# record model version
484+
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
485+
model_version_infos.update(get_llm_version_infos())
486+
model_version_infos.update(get_embedding_model_descriptions())
487+
model_version_infos.update(get_rerank_model_descriptions())
488+
model_version_infos.update(get_image_model_descriptions())
489+
model_version_infos.update(get_audio_model_descriptions())
490+
model_version_infos.update(get_video_model_descriptions())
491+
model_version_infos.update(get_flexible_model_descriptions())
492+
await self._cache_tracker_ref.record_model_version(
493+
model_version_infos, self.address
494+
)
447495
return self._supervisor_ref
448-
self._supervisor_ref = supervisor_ref
449-
if add_worker and len(self._model_uid_to_model) == 0:
450-
# Newly started (or restarted), has no model, notify supervisor
451-
await self._supervisor_ref.add_worker(self.address)
452-
logger.info("Connected to supervisor as a fresh worker")
453-
454-
self._status_guard_ref = await xo.actor_ref(
455-
address=self._supervisor_address, uid=StatusGuardActor.default_uid()
456-
)
457-
self._event_collector_ref = await xo.actor_ref(
458-
address=self._supervisor_address, uid=EventCollectorActor.default_uid()
459-
)
460-
self._cache_tracker_ref = await xo.actor_ref(
461-
address=self._supervisor_address, uid=CacheTrackerActor.default_uid()
462-
)
463-
self._progress_tracker_ref = None
464-
# cache_tracker is on supervisor
465-
from ..model.audio import get_audio_model_descriptions
466-
from ..model.embedding import get_embedding_model_descriptions
467-
from ..model.flexible import get_flexible_model_descriptions
468-
from ..model.image import get_image_model_descriptions
469-
from ..model.llm import get_llm_version_infos
470-
from ..model.rerank import get_rerank_model_descriptions
471-
from ..model.video import get_video_model_descriptions
472-
473-
# record model version
474-
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
475-
model_version_infos.update(get_llm_version_infos())
476-
model_version_infos.update(get_embedding_model_descriptions())
477-
model_version_infos.update(get_rerank_model_descriptions())
478-
model_version_infos.update(get_image_model_descriptions())
479-
model_version_infos.update(get_audio_model_descriptions())
480-
model_version_infos.update(get_video_model_descriptions())
481-
model_version_infos.update(get_flexible_model_descriptions())
482-
await self._cache_tracker_ref.record_model_version(
483-
model_version_infos, self.address
484-
)
485-
return self._supervisor_ref
486496

487497
@staticmethod
488498
def get_devices_count():
@@ -1837,7 +1847,20 @@ async def report_status(self):
18371847
except Exception:
18381848
logger.exception("Report status got error.")
18391849
supervisor_ref = await self.get_supervisor_ref()
1840-
await supervisor_ref.report_worker_status(self.address, status)
1850+
try:
1851+
await asyncio.wait_for(
1852+
supervisor_ref.report_worker_status(self.address, status),
1853+
timeout=XINFERENCE_HEALTH_CHECK_TIMEOUT,
1854+
)
1855+
except asyncio.TimeoutError:
1856+
logger.warning(
1857+
"report_worker_status timed out, will reset supervisor refs for retry"
1858+
)
1859+
await self._reset_supervisor_refs()
1860+
raise
1861+
except Exception:
1862+
await self._reset_supervisor_refs()
1863+
raise
18411864

18421865
async def _periodical_report_status(self):
18431866
while True:

0 commit comments

Comments
 (0)