@@ -861,18 +861,21 @@ async def _reply_result(
861861 # Cold path — load / wake
862862 # ------------------------------------------------------------------
863863
864+ def _is_cuda_oom (self , exc : BaseException ) -> bool :
865+ """Check if exception is a CUDA out-of-memory error."""
866+ msg = str (exc ).lower ()
867+ return "cuda" in msg and ("out of memory" in msg or "oom" in msg )
868+
864869 async def _load_model (
865870 self , model_id : str , api_key : str = "" , device : str = ""
866871 ) -> None :
867- """Load model via ModelManager, falling back to stub on failure/no manager.
872+ """Load model via ModelManager with reactive admission loop.
873+
874+ On CUDA OOM: evicts coldest model, retries. Repeats until success,
875+ no cold models left (_ERR_SERVER_FULL), or non-OOM failure (_ERR_LOAD_FAILED).
868876
869877 model_id is the full routing key (may include ":instance" suffix).
870878 model_id_or_path strips the suffix to fetch the correct weights.
871- device selects which GPU to use (forwarded to manager.load).
872-
873- Stub mode: marks model_id loaded immediately (no real model).
874- T_ENSURE_LOADED waiters get T_MODEL_READY, but T_SUBMIT will return
875- T_ERROR because no backend is registered.
876879 """
877880 loop = asyncio .get_running_loop ()
878881 fs = self ._models .get (model_id )
@@ -882,14 +885,18 @@ async def _load_model(
882885 await self ._wake_model (model_id )
883886 return
884887
885- if self ._manager is not None :
886- # Strip ":instance" suffix to get the actual model weights identifier
887- model_id_or_path = model_id .rsplit (":" , 1 )[0 ]
888+ if self ._manager is None :
889+ # Stub mode: no manager
890+ self ._stub_load (model_id , fs )
891+ return
892+
893+ model_id_or_path = model_id .rsplit (":" , 1 )[0 ]
894+ max_retries = 5 # safety cap — never loop forever
895+
896+ for attempt in range (max_retries ):
888897 logger .info (
889- "MMP: loading '%s' (weights=%s device=%s) via ModelManager" ,
890- model_id ,
891- model_id_or_path ,
892- device or "default" ,
898+ "MMP: loading '%s' (weights=%s device=%s attempt=%d)" ,
899+ model_id , model_id_or_path , device or "default" , attempt + 1 ,
893900 )
894901 try :
895902 await loop .run_in_executor (
@@ -912,20 +919,70 @@ async def _load_model(
912919 if backend is not None and hasattr (backend , "signal_slot" ):
913920 self .register_backend (model_id , backend )
914921 return
915- except Exception :
916- logger .exception (
917- "MMP: manager.load('%s') raised — stub-loading without backend" ,
918- model_id ,
922+ # Loaded but no backend — fall through to stub
923+ break
924+
925+ except Exception as exc :
926+ if not self ._is_cuda_oom (exc ):
927+ # Non-OOM failure — don't retry
928+ logger .exception ("MMP: load '%s' failed (non-OOM)" , model_id )
929+ self ._fail_load (model_id , fs , _ERR_LOAD_FAILED )
930+ return
931+
932+ # CUDA OOM — try to evict a cold model and retry
933+ candidate = self ._pick_eviction_candidate ()
934+ if candidate is None :
935+ logger .error (
936+ "MMP: OOM loading '%s' — all models hot, cannot evict" ,
937+ model_id ,
938+ )
939+ self ._fail_load (model_id , fs , _ERR_SERVER_FULL )
940+ return
941+
942+ logger .warning (
943+ "MMP: OOM loading '%s' — evicting cold model '%s' and retrying" ,
944+ model_id , candidate ,
919945 )
946+ self ._evict_model (candidate )
947+
948+ # Clean up the failed partial load before retrying
949+ try :
950+ self ._manager .unload (model_id )
951+ except Exception :
952+ pass
920953
921- # Stub: mark loaded, flush waiters; no backend registered
954+ # Exhausted retries — should not normally reach here
955+ logger .error ("MMP: load '%s' exhausted %d retries" , model_id , max_retries )
956+ self ._fail_load (model_id , fs , _ERR_LOAD_FAILED )
957+
958+ def _stub_load (self , model_id : str , fs : Optional [ModelState ]) -> None :
959+ """Mark model as stub-loaded (no real backend). Flushes waiters."""
922960 logger .info ("MMP: '%s' stub-loaded (no real model)" , model_id )
923961 if fs is None :
924962 fs = self ._models .setdefault (model_id , ModelState ())
925963 fs .loading = False
926964 fs .loaded = True
927965 self ._flush_load_waiters (model_id )
928966
967+ def _fail_load (self , model_id : str , fs : Optional [ModelState ], err_code : int ) -> None :
968+ """Clean up ModelState and notify waiters with T_ERROR on load failure."""
969+ if fs is None :
970+ fs = self ._models .get (model_id )
971+ if fs is not None :
972+ waiters , fs .load_waiters = fs .load_waiters , []
973+ fs .loading = False
974+ fs .loaded = False
975+ for identity , req_id , _ in waiters :
976+ asyncio .create_task (
977+ self ._send (
978+ identity , T_ERROR , struct .pack (">QB" , req_id , err_code )
979+ )
980+ )
981+ # Clean up any partial state
982+ self ._backends .pop (model_id , None )
983+ self ._model_access .pop (model_id , None )
984+ self ._model_request_times .pop (model_id , None )
985+
929986 def _flush_load_waiters (self , model_id : str ) -> None :
930987 """Notify all T_ENSURE_LOADED waiters for this model_id.
931988
0 commit comments