Skip to content

Commit b0de01d

Browse files
reactive admission loop — OOM evict-retry, proper load failure handling
1 parent 5d5ed1e commit b0de01d

2 files changed

Lines changed: 78 additions & 21 deletions

File tree

inference_model_manager/inference_model_manager/model_manager_process.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -861,18 +861,21 @@ async def _reply_result(
861861
# Cold path — load / wake
862862
# ------------------------------------------------------------------
863863

864+
def _is_cuda_oom(self, exc: BaseException) -> bool:
865+
"""Check if exception is a CUDA out-of-memory error."""
866+
msg = str(exc).lower()
867+
return "cuda" in msg and ("out of memory" in msg or "oom" in msg)
868+
864869
async def _load_model(
865870
self, model_id: str, api_key: str = "", device: str = ""
866871
) -> None:
867-
"""Load model via ModelManager, falling back to stub on failure/no manager.
872+
"""Load model via ModelManager with reactive admission loop.
873+
874+
On CUDA OOM: evicts coldest model, retries. Repeats until success,
875+
no cold models left (_ERR_SERVER_FULL), or non-OOM failure (_ERR_LOAD_FAILED).
868876
869877
model_id is the full routing key (may include ":instance" suffix).
870878
model_id_or_path strips the suffix to fetch the correct weights.
871-
device selects which GPU to use (forwarded to manager.load).
872-
873-
Stub mode: marks model_id loaded immediately (no real model).
874-
T_ENSURE_LOADED waiters get T_MODEL_READY, but T_SUBMIT will return
875-
T_ERROR because no backend is registered.
876879
"""
877880
loop = asyncio.get_running_loop()
878881
fs = self._models.get(model_id)
@@ -882,14 +885,18 @@ async def _load_model(
882885
await self._wake_model(model_id)
883886
return
884887

885-
if self._manager is not None:
886-
# Strip ":instance" suffix to get the actual model weights identifier
887-
model_id_or_path = model_id.rsplit(":", 1)[0]
888+
if self._manager is None:
889+
# Stub mode: no manager
890+
self._stub_load(model_id, fs)
891+
return
892+
893+
model_id_or_path = model_id.rsplit(":", 1)[0]
894+
max_retries = 5 # safety cap — never loop forever
895+
896+
for attempt in range(max_retries):
888897
logger.info(
889-
"MMP: loading '%s' (weights=%s device=%s) via ModelManager",
890-
model_id,
891-
model_id_or_path,
892-
device or "default",
898+
"MMP: loading '%s' (weights=%s device=%s attempt=%d)",
899+
model_id, model_id_or_path, device or "default", attempt + 1,
893900
)
894901
try:
895902
await loop.run_in_executor(
@@ -912,20 +919,70 @@ async def _load_model(
912919
if backend is not None and hasattr(backend, "signal_slot"):
913920
self.register_backend(model_id, backend)
914921
return
915-
except Exception:
916-
logger.exception(
917-
"MMP: manager.load('%s') raised — stub-loading without backend",
918-
model_id,
922+
# Loaded but no backend — fall through to stub
923+
break
924+
925+
except Exception as exc:
926+
if not self._is_cuda_oom(exc):
927+
# Non-OOM failure — don't retry
928+
logger.exception("MMP: load '%s' failed (non-OOM)", model_id)
929+
self._fail_load(model_id, fs, _ERR_LOAD_FAILED)
930+
return
931+
932+
# CUDA OOM — try to evict a cold model and retry
933+
candidate = self._pick_eviction_candidate()
934+
if candidate is None:
935+
logger.error(
936+
"MMP: OOM loading '%s' — all models hot, cannot evict",
937+
model_id,
938+
)
939+
self._fail_load(model_id, fs, _ERR_SERVER_FULL)
940+
return
941+
942+
logger.warning(
943+
"MMP: OOM loading '%s' — evicting cold model '%s' and retrying",
944+
model_id, candidate,
919945
)
946+
self._evict_model(candidate)
947+
948+
# Clean up the failed partial load before retrying
949+
try:
950+
self._manager.unload(model_id)
951+
except Exception:
952+
pass
920953

921-
# Stub: mark loaded, flush waiters; no backend registered
954+
# Exhausted retries — should not normally reach here
955+
logger.error("MMP: load '%s' exhausted %d retries", model_id, max_retries)
956+
self._fail_load(model_id, fs, _ERR_LOAD_FAILED)
957+
958+
def _stub_load(self, model_id: str, fs: Optional[ModelState]) -> None:
959+
"""Mark model as stub-loaded (no real backend). Flushes waiters."""
922960
logger.info("MMP: '%s' stub-loaded (no real model)", model_id)
923961
if fs is None:
924962
fs = self._models.setdefault(model_id, ModelState())
925963
fs.loading = False
926964
fs.loaded = True
927965
self._flush_load_waiters(model_id)
928966

967+
def _fail_load(self, model_id: str, fs: Optional[ModelState], err_code: int) -> None:
968+
"""Clean up ModelState and notify waiters with T_ERROR on load failure."""
969+
if fs is None:
970+
fs = self._models.get(model_id)
971+
if fs is not None:
972+
waiters, fs.load_waiters = fs.load_waiters, []
973+
fs.loading = False
974+
fs.loaded = False
975+
for identity, req_id, _ in waiters:
976+
asyncio.create_task(
977+
self._send(
978+
identity, T_ERROR, struct.pack(">QB", req_id, err_code)
979+
)
980+
)
981+
# Clean up any partial state
982+
self._backends.pop(model_id, None)
983+
self._model_access.pop(model_id, None)
984+
self._model_request_times.pop(model_id, None)
985+
929986
def _flush_load_waiters(self, model_id: str) -> None:
930987
"""Notify all T_ENSURE_LOADED waiters for this model_id.
931988

inference_model_manager/tests/integration_tests/backends/test_model_manager_process_cold_path.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,13 @@ def test_ensure_loaded_triggers_manager_load(self):
403403
assert msg == T_MODEL_READY
404404
assert any(c.startswith("load:yolov8n") for c in mgr.calls)
405405

406-
def test_ensure_loaded_unknown_falls_back_to_stub(self):
407-
"""Unknown flavor (not in mock manager) → stub load → T_MODEL_READY."""
406+
def test_ensure_loaded_unknown_returns_error(self):
407+
"""Unknown model (not in mock manager) → load fails → T_ERROR."""
408408
mgr = _MockManager() # no models registered
409409
h = _MMPHarness(manager=mgr)
410410
msg = h.ensure_loaded("no-such-model")
411411
h.teardown()
412-
assert msg == T_MODEL_READY
412+
assert msg == T_ERROR
413413

414414
def test_end_to_end_with_mock_manager(self):
415415
"""Full lifecycle: load → alloc → submit → result → free."""

0 commit comments

Comments
 (0)