Skip to content

Commit 2634d6c

Browse files
hot/cold eviction, idle timeout, request rate tracking, cache hit/miss
1 parent c6364c2 commit 2634d6c

4 files changed

Lines changed: 114 additions & 27 deletions

File tree

inference_model_manager/inference_model_manager/model_manager_process.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
_ERR_BACKEND = 4
8989
_ERR_LOAD_FAILED = 5
9090
_ERR_NOT_LOADED = 6
91+
_ERR_SERVER_FULL = 7
9192

9293

9394
# ---------------------------------------------------------------------------
@@ -252,6 +253,7 @@ def __init__(
252253
evict_threshold: float = 0.9,
253254
evict_check_interval_s: float = 5.0,
254255
monitor_interval_s: float = 5.0,
256+
idle_timeout_s: float = 300.0,
255257
manager: Optional[Any] = None,
256258
decoder: str = "imagecodecs",
257259
batch_max_size: int = 0,
@@ -279,6 +281,7 @@ def __init__(
279281
self._evict_threshold = evict_threshold
280282
self._evict_check_interval_s = evict_check_interval_s
281283
self._monitor_interval_s = monitor_interval_s
284+
self._idle_timeout_s = idle_timeout_s
282285
self._manager = manager
283286
self._own_manager = False # set True only if we created it
284287

@@ -295,9 +298,16 @@ def __init__(
295298
# model_id → BackendLike (registered by register_backend or _load_model)
296299
self._backends: dict[str, BackendLike] = {}
297300

298-
# flavor → monotonic timestamp of last T_SUBMIT (LRU eviction)
301+
# flavor → monotonic timestamp of last T_SUBMIT (LRU eviction + hot/cold)
299302
self._model_access: dict[str, float] = {}
300303

304+
# flavor → list of request timestamps (sliding window for request rate)
305+
self._model_request_times: dict[str, list[float]] = {}
306+
307+
# Cache hit/miss counters (model already loaded vs triggered load)
308+
self._cache_hits: int = 0
309+
self._cache_misses: int = 0
310+
301311
# Latest monitoring snapshot (updated by _monitoring_loop)
302312
self._stats_snapshot: dict[str, Any] = {}
303313

@@ -526,9 +536,12 @@ async def _handle_ensure_loaded(self, identity: bytes, data: list[bytes]) -> Non
526536

527537
fs = self._models.get(model_id)
528538
if fs is not None and fs.loaded:
539+
self._cache_hits += 1
529540
await self._send(identity, T_MODEL_READY, struct.pack(">Q", req_id))
530541
return
531542

543+
self._cache_misses += 1
544+
532545
if fs is None:
533546
fs = ModelState()
534547
self._models[model_id] = fs
@@ -725,12 +738,25 @@ async def _handle_stats(self, identity: bytes, data: list[bytes]) -> None:
725738
snapshot.update(self._manager.stats())
726739
except Exception:
727740
pass
741+
now = time.monotonic()
728742
model_stats: dict[str, Any] = {}
729743
for f, fs in self._models.items():
744+
last_access = self._model_access.get(f)
745+
idle_s = (now - last_access) if last_access else None
746+
req_times = self._model_request_times.get(f, [])
747+
# Trim stale entries for accurate rate
748+
cutoff = now - 60.0
749+
while req_times and req_times[0] < cutoff:
750+
req_times.pop(0)
751+
730752
entry: dict[str, Any] = {
731753
"loaded": fs.loaded,
732754
"sleeping": fs.sleeping,
733755
"loading": fs.loading,
756+
"last_access_ts": last_access,
757+
"idle_s": round(idle_s, 1) if idle_s is not None else None,
758+
"is_cold": idle_s is not None and idle_s > self._idle_timeout_s,
759+
"request_rate_60s": len(req_times),
734760
}
735761
backend = self._backends.get(f)
736762
if backend is not None:
@@ -747,6 +773,9 @@ async def _handle_stats(self, identity: bytes, data: list[bytes]) -> None:
747773
"mmp_free_slots": self._pool.free_count if self._pool else 0,
748774
"mmp_total_slots": self._n_slots,
749775
"mmp_pending": len(self._pending),
776+
"mmp_cache_hits": self._cache_hits,
777+
"mmp_cache_misses": self._cache_misses,
778+
"mmp_idle_timeout_s": self._idle_timeout_s,
750779
"mmp_models": model_stats,
751780
}
752781
)
@@ -764,7 +793,15 @@ async def _handle_stats(self, identity: bytes, data: list[bytes]) -> None:
764793
def _forward_to_backend(
765794
self, model_id: str, slot_id: int, req_id: int, params_bytes: bytes = b"{}"
766795
) -> None:
767-
self._model_access[model_id] = time.monotonic() # LRU update
796+
now = time.monotonic()
797+
self._model_access[model_id] = now
798+
# Sliding window: append timestamp, trim old entries
799+
times = self._model_request_times.setdefault(model_id, [])
800+
times.append(now)
801+
# Keep only last 60s of timestamps
802+
cutoff = now - 60.0
803+
while times and times[0] < cutoff:
804+
times.pop(0)
768805
backend = self._backends.get(model_id)
769806
if backend is None:
770807
logger.warning("MMP: no backend for '%s', req_id=%d", model_id, req_id)
@@ -964,7 +1001,11 @@ async def _eviction_loop(self) -> None:
9641001
logger.exception("MMP: error in eviction loop")
9651002

9661003
def _check_and_evict(self) -> None:
967-
"""Evict LRU model if GPU memory exceeds threshold.
1004+
"""Evict cold models if GPU memory exceeds threshold.
1005+
1006+
Cold-first: models idle > idle_timeout_s are evicted first (oldest first).
1007+
If all models are hot (recent traffic), no eviction — _ERR_SERVER_FULL
1008+
will be returned by _load_model when it can't free space.
9681009
9691010
Uses drain_and_unload (graceful: stop accepting, finish in-flight, then kill).
9701011
Next request for evicted model triggers _load_model which reloads from disk cache.
@@ -974,28 +1015,45 @@ def _check_and_evict(self) -> None:
9741015
gpu_frac = _gpu_used_fraction()
9751016
if gpu_frac < self._evict_threshold:
9761017
return
977-
lru = self._lru_evictable_model()
978-
if lru is None:
1018+
candidate = self._pick_eviction_candidate()
1019+
if candidate is None:
1020+
logger.warning(
1021+
"MMP: GPU %.0f%% > %.0f%% threshold but all models are hot — no eviction",
1022+
gpu_frac * 100,
1023+
self._evict_threshold * 100,
1024+
)
9791025
return
9801026
logger.warning(
981-
"MMP: GPU %.0f%% > %.0f%% threshold — evicting '%s'",
1027+
"MMP: GPU %.0f%% > %.0f%% threshold — evicting cold model '%s'",
9821028
gpu_frac * 100,
9831029
self._evict_threshold * 100,
984-
lru,
1030+
candidate,
9851031
)
1032+
self._evict_model(candidate)
1033+
1034+
def _evict_model(self, model_id: str) -> bool:
1035+
"""Evict a single model. Returns True on success."""
9861036
try:
987-
self._manager.unload(lru, drain=True)
988-
fs = self._models.get(lru)
1037+
self._manager.unload(model_id, drain=True)
1038+
fs = self._models.get(model_id)
9891039
if fs:
9901040
fs.loaded = False
9911041
fs.loading = False
9921042
fs.sleeping = False
993-
self._backends.pop(lru, None)
1043+
self._backends.pop(model_id, None)
1044+
self._model_access.pop(model_id, None)
1045+
self._model_request_times.pop(model_id, None)
1046+
return True
9941047
except Exception:
995-
logger.warning("MMP: eviction of '%s' failed", lru, exc_info=True)
1048+
logger.warning("MMP: eviction of '%s' failed", model_id, exc_info=True)
1049+
return False
9961050

997-
def _lru_evictable_model(self) -> Optional[str]:
998-
"""LRU flavor that is loaded, not sleeping, and not currently in-flight."""
1051+
def _pick_eviction_candidate(self) -> Optional[str]:
1052+
"""Pick best model to evict. Cold models first, then LRU among warm.
1053+
1054+
Returns None if no evictable model (all hot and in-flight, or nothing loaded).
1055+
"""
1056+
now = time.monotonic()
9991057
in_flight = {mid for _, _, mid in self._pending.values()}
10001058
candidates = [
10011059
f
@@ -1004,7 +1062,19 @@ def _lru_evictable_model(self) -> Optional[str]:
10041062
]
10051063
if not candidates:
10061064
return None
1007-
return min(candidates, key=lambda f: self._model_access.get(f, 0.0))
1065+
1066+
# Partition into cold (idle > timeout) and hot
1067+
cold = [
1068+
f for f in candidates
1069+
if (now - self._model_access.get(f, 0.0)) > self._idle_timeout_s
1070+
]
1071+
1072+
if cold:
1073+
# Among cold models, evict the one with oldest last access (most stale)
1074+
return min(cold, key=lambda f: self._model_access.get(f, 0.0))
1075+
1076+
# All candidates are hot — return None (caller decides: error or force-evict)
1077+
return None
10081078

10091079
# ------------------------------------------------------------------
10101080
# Monitoring loop

inference_model_manager/tests/integration_tests/backends/test_model_manager_process_cold_path.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -444,37 +444,50 @@ def _make_mmp(self) -> ModelManagerProcess:
444444
input_mb=_TEST_INPUT_MB,
445445
)
446446

447-
def test_lru_evictable_returns_none_when_no_loaded_models(self):
447+
def test_eviction_returns_none_when_no_loaded_models(self):
448448
mmp = self._make_mmp()
449-
assert mmp._lru_evictable_model() is None
449+
assert mmp._pick_eviction_candidate() is None
450450

451-
def test_lru_evictable_picks_oldest_access(self):
451+
def test_eviction_picks_coldest_model(self):
452452
mmp = self._make_mmp()
453453
from inference_model_manager.model_manager_process import ModelState
454454
mmp._models["a"] = ModelState(loaded=True)
455455
mmp._models["b"] = ModelState(loaded=True)
456-
mmp._model_access["a"] = 1.0 # older
457-
mmp._model_access["b"] = 2.0 # newer
458-
assert mmp._lru_evictable_model() == "a"
456+
# Both cold (access time far in the past vs idle_timeout_s=300)
457+
mmp._model_access["a"] = 1.0 # older = colder
458+
mmp._model_access["b"] = 2.0
459+
assert mmp._pick_eviction_candidate() == "a"
459460

460-
def test_lru_evictable_skips_sleeping(self):
461+
def test_eviction_returns_none_when_all_hot(self):
462+
mmp = self._make_mmp()
463+
from inference_model_manager.model_manager_process import ModelState
464+
import time
465+
now = time.monotonic()
466+
mmp._models["a"] = ModelState(loaded=True)
467+
mmp._models["b"] = ModelState(loaded=True)
468+
# Both hot (accessed just now, well within idle_timeout_s=300)
469+
mmp._model_access["a"] = now
470+
mmp._model_access["b"] = now
471+
assert mmp._pick_eviction_candidate() is None
472+
473+
def test_eviction_skips_sleeping(self):
461474
mmp = self._make_mmp()
462475
from inference_model_manager.model_manager_process import ModelState
463476
mmp._models["a"] = ModelState(loaded=False, sleeping=True)
464477
mmp._models["b"] = ModelState(loaded=True)
465478
mmp._model_access["a"] = 1.0
466-
mmp._model_access["b"] = 2.0
467-
assert mmp._lru_evictable_model() == "b"
479+
mmp._model_access["b"] = 1.0 # cold
480+
assert mmp._pick_eviction_candidate() == "b"
468481

469-
def test_lru_evictable_skips_in_flight(self):
482+
def test_eviction_skips_in_flight(self):
470483
mmp = self._make_mmp()
471484
from inference_model_manager.model_manager_process import ModelState
472485
mmp._models["a"] = ModelState(loaded=True)
473486
mmp._models["b"] = ModelState(loaded=True)
474-
mmp._model_access["a"] = 1.0
475-
mmp._model_access["b"] = 2.0
487+
mmp._model_access["a"] = 1.0 # cold
488+
mmp._model_access["b"] = 1.0 # cold
476489
mmp._pending[42] = (b"id", 0, "a") # "a" is in-flight
477-
assert mmp._lru_evictable_model() == "b"
490+
assert mmp._pick_eviction_candidate() == "b"
478491

479492
def test_check_and_evict_calls_manager_unload_drain(self):
480493
mgr = _MockManager()

inference_server/inference_server/launcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def launch_orchestrated(
118118
decoder: str = "imagecodecs",
119119
batch_max_size: int = 0,
120120
batch_max_wait_ms: float = 5.0,
121+
idle_timeout_s: float = 300.0,
121122
) -> LaunchHandle:
122123
"""Start a ModelManagerProcess and return a LaunchHandle.
123124
@@ -150,6 +151,7 @@ def launch_orchestrated(
150151
decoder=decoder,
151152
batch_max_size=batch_max_size,
152153
batch_max_wait_ms=batch_max_wait_ms,
154+
idle_timeout_s=idle_timeout_s,
153155
)
154156

155157
ready = threading.Event()

inference_server/inference_server/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def main() -> None:
184184
decoder = os.environ.get("INFERENCE_DECODER", "imagecodecs")
185185
batch_max_size = int(os.environ.get("INFERENCE_BATCH_MAX_SIZE", "0"))
186186
batch_max_wait = float(os.environ.get("INFERENCE_BATCH_MAX_WAIT_MS", "5.0"))
187+
idle_timeout = float(os.environ.get("INFERENCE_MODEL_IDLE_TIMEOUT_S", "300.0"))
187188

188189
# ── HTTP / TLS config ──────────────────────────────────────────────────
189190
host = os.environ.get("HOST", "0.0.0.0")
@@ -209,6 +210,7 @@ def main() -> None:
209210
decoder=decoder,
210211
batch_max_size=batch_max_size,
211212
batch_max_wait_ms=batch_max_wait,
213+
idle_timeout_s=idle_timeout,
212214
)
213215
logger.info("MMP ready: addr=%s shm=%s", handle.mmp_addr, handle.shm_name)
214216

0 commit comments

Comments
 (0)