Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion clarifai/runners/models/vllm_openai_class.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,65 @@
import re
import threading
from typing import Iterator
import time
from typing import Iterator, Optional

import httpx
from clarifai_protocol import get_item_id, register_item_abort_callback

from clarifai.runners.models.openai_class import OpenAIModelClass
from clarifai.utils.logging import logger


class VLLMMetricsPoller:
"""Polls vLLM /metrics in background; caches kv_cache usage.

Fail-open: if the poller has never succeeded or is stale, admission is allowed.
"""

KV_CACHE_HIGH = 0.8
KV_CACHE_LOW = 0.5
STALE_AFTER_SECONDS = 5.0
Comment on lines +19 to +21

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these are class constants we can't modify them at instance level. Do you think we should put these under init so can be modified per instance?


def __init__(self, base_url: str, poll_interval: float = 5.0):
self.base_url = base_url
self.poll_interval = poll_interval
self._kv_cache = 0.0
self._lock = threading.Lock()
self._last_success = time.time()

threading.Thread(target=self._poll_loop, daemon=True).start()
logger.info(
f"[VLLMMetricsPoller] Started polling {base_url}/metrics every {poll_interval}s"
)

def _poll_loop(self):
while True:
try:
resp = httpx.get(f"{self.base_url}/metrics", timeout=1.0)
if resp.status_code == 200:
kv_cache = self._parse(
resp.text, r'vllm:kv_cache_usage_perc\{[^}]*\}\s+([\d.]+)'
)
with self._lock:
self._kv_cache = kv_cache
self._last_success = time.time()
logger.info(f"[VLLMMetricsPoller] kv_cache={kv_cache:.2%}")
except Exception as e:
logger.warning(f"[VLLMMetricsPoller] Poll failed: {e}")
time.sleep(self.poll_interval)

def _parse(self, text: str, pattern: str) -> float:
m = re.search(pattern, text)
return float(m.group(1)) if m else 0.0

def snapshot(self) -> float:
with self._lock:
return self._kv_cache

@property
def is_stale(self) -> bool:
with self._lock:
return time.time() - self._last_success > self.STALE_AFTER_SECONDS


class VLLMCancellationHandler:
Expand Down Expand Up @@ -91,6 +146,34 @@ def generate(self, prompt, ...) -> Iterator[str]:

server = None
cancellation_handler = None
_metrics_poller: Optional[VLLMMetricsPoller] = None

@property
def admission_increase_delay(self) -> float:
return 0.0

@property
def admission_decrease_delay(self) -> float:
return 0.0

def check_admission(self):
"""Three-state deadband on vLLM kv_cache usage.

Returns False above HIGH (AIMD shrinks), True below LOW (AIMD grows),
None in-between (AIMD holds). Fails open when the subclass has not
initialized ``self._metrics_poller`` or when the poller is stale.
"""
poller = self._metrics_poller
if poller is None or poller.is_stale:
return True

kv_cache = poller.snapshot()
if kv_cache > poller.KV_CACHE_HIGH:
logger.info(f"[AdmissionControl] REJECT kv_cache={kv_cache:.2%}")
return False
if kv_cache < poller.KV_CACHE_LOW:
return True
return None

def handle_liveness_probe(self) -> bool:
if self.server is None:
Expand Down
Loading