Skip to content

Commit 39ab72e

Browse files
committed
add thread safety to cache operations and implement LRU eviction
1 parent 79407c4 commit 39ab72e

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

ajet/tuner_lib/experimental/as_oai_model_server.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from loguru import logger
2626
from pydantic import BaseModel
27+
from functools import lru_cache
2728
from fastapi import FastAPI, Header, HTTPException, Request
2829
from fastapi.responses import StreamingResponse
2930
from contextlib import asynccontextmanager
@@ -63,6 +64,9 @@ class HealthCheckRequest(BaseModel):
6364
context = zmq.Context()
6465
atexit.register(context.term)
6566

67+
@lru_cache(maxsize=128)
68+
def ep_key(episode_uuid: str) -> str:
69+
return f"episodes-{episode_uuid}"
6670

6771
def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]:
6872

@@ -100,6 +104,14 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio
100104

101105
result_str = ""
102106
for _ in range(50): # max 5 minutes wait
107+
108+
if enable_swarm_mode:
109+
assert shared_mem_dict is not None
110+
ep_stat = shared_mem_dict[ep_key(episode_uuid)]
111+
episode_status = ep_stat.episode_status
112+
if episode_status != "claimed":
113+
raise HTTPException(status_code=404, detail="The episode is not claimed, cannot accept new requests.")
114+
103115
try:
104116
if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.")
105117

ajet/utils/tokenizer.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import json
3+
import threading
34
from typing import Dict, List
45

56

@@ -21,6 +22,7 @@ def cleanup_messages(messages: List[Dict]) -> List[Dict]:
2122

2223
# Cache storage
2324
_cache = {}
25+
_cache_lock = threading.Lock()
2426

2527

2628
def ajet_apply_chat_template(
@@ -41,11 +43,12 @@ def ajet_apply_chat_template(
4143
tokenize,
4244
)
4345

44-
# Check cache
45-
if cache_key in _cache:
46-
return _cache[cache_key]
46+
# Check cache with thread safety
47+
with _cache_lock:
48+
if cache_key in _cache:
49+
return _cache[cache_key]
4750

48-
# Compute result
51+
# Compute result (time consuming) - outside lock to avoid blocking other threads
4952
if tools:
5053
result = tokenizer.apply_chat_template(
5154
conversation,
@@ -60,10 +63,16 @@ def ajet_apply_chat_template(
6063
add_generation_prompt=add_generation_prompt,
6164
)
6265

63-
# Store in cache (implement LRU eviction if cache gets too large)
64-
if len(_cache) >= 1024:
65-
# Remove oldest item (first inserted)
66-
_cache.pop(next(iter(_cache)))
66+
# Store in cache with thread safety (implement LRU eviction if cache gets too large)
67+
with _cache_lock:
68+
if len(_cache) >= 1024:
69+
# Remove oldest item (first inserted)
70+
try:
71+
_cache.pop(next(iter(_cache)))
72+
except KeyError:
73+
# Cache was modified by another thread, which is fine
74+
pass
75+
76+
_cache[cache_key] = result
6777

68-
_cache[cache_key] = result
6978
return result

0 commit comments

Comments
 (0)