Skip to content

Commit eb98e3a

Browse files
committed
fix(prompt-cache): avoid redundant refresh races
1 parent 871fc31 commit eb98e3a

3 files changed

Lines changed: 105 additions & 18 deletions

File tree

langfuse/_client/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3482,8 +3482,9 @@ def refresh_task() -> None:
34823482
fetch_timeout_seconds=fetch_timeout_seconds,
34833483
)
34843484

3485-
self._resources.prompt_cache.add_refresh_prompt_task(
3485+
self._resources.prompt_cache.add_refresh_prompt_task_if_current(
34863486
cache_key,
3487+
cached_prompt,
34873488
refresh_task,
34883489
)
34893490
langfuse_logger.debug(

langfuse/_utils/prompt_cache.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from datetime import datetime
66
from queue import Queue
7-
from threading import Thread
7+
from threading import RLock, Thread
88
from typing import Callable, Dict, List, Optional, Set
99

1010
from langfuse._client.environment_variables import (
@@ -77,12 +77,14 @@ class PromptCacheTaskManager(object):
7777
_threads: int
7878
_queue: Queue
7979
_processing_keys: Set[str]
80+
_lock: RLock
8081

8182
def __init__(self, threads: int = 1):
8283
self._queue = Queue()
8384
self._consumers = []
8485
self._threads = threads
8586
self._processing_keys = set()
87+
self._lock = RLock()
8688

8789
for i in range(self._threads):
8890
consumer = PromptCacheRefreshConsumer(self._queue, i)
@@ -92,16 +94,20 @@ def __init__(self, threads: int = 1):
9294
atexit.register(self.shutdown)
9395

9496
def add_task(self, key: str, task: Callable[[], None]) -> None:
95-
if key not in self._processing_keys:
96-
logger.debug(f"Adding prompt cache refresh task for key: {key}")
97-
self._processing_keys.add(key)
98-
wrapped_task = self._wrap_task(key, task)
99-
self._queue.put((wrapped_task))
100-
else:
101-
logger.debug(f"Prompt cache refresh task already submitted for key: {key}")
97+
with self._lock:
98+
if key not in self._processing_keys:
99+
logger.debug(f"Adding prompt cache refresh task for key: {key}")
100+
self._processing_keys.add(key)
101+
wrapped_task = self._wrap_task(key, task)
102+
self._queue.put((wrapped_task))
103+
else:
104+
logger.debug(
105+
f"Prompt cache refresh task already submitted for key: {key}"
106+
)
102107

103108
def active_tasks(self) -> int:
104-
return len(self._processing_keys)
109+
with self._lock:
110+
return len(self._processing_keys)
105111

106112
def wait_for_idle(self) -> None:
107113
self._queue.join()
@@ -112,7 +118,8 @@ def wrapped() -> None:
112118
try:
113119
task()
114120
finally:
115-
self._processing_keys.remove(key)
121+
with self._lock:
122+
self._processing_keys.remove(key)
116123
logger.debug(f"Refreshed prompt cache for key: {key}")
117124

118125
return wrapped
@@ -139,6 +146,7 @@ def shutdown(self) -> None:
139146

140147
class PromptCache:
141148
_cache: Dict[str, PromptCacheItem]
149+
_lock: RLock
142150

143151
_task_manager: PromptCacheTaskManager
144152
"""Task manager for refreshing cache"""
@@ -147,34 +155,60 @@ def __init__(
147155
self, max_prompt_refresh_workers: int = DEFAULT_PROMPT_CACHE_REFRESH_WORKERS
148156
):
149157
self._cache = {}
158+
self._lock = RLock()
150159
self._task_manager = PromptCacheTaskManager(threads=max_prompt_refresh_workers)
151160
logger.debug("Prompt cache initialized.")
152161

153162
def get(self, key: str) -> Optional[PromptCacheItem]:
154-
return self._cache.get(key, None)
163+
with self._lock:
164+
return self._cache.get(key, None)
155165

156166
def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]) -> None:
157167
if ttl_seconds is None:
158168
ttl_seconds = DEFAULT_PROMPT_CACHE_TTL_SECONDS
159169

160-
self._cache[key] = PromptCacheItem(value, ttl_seconds)
170+
with self._lock:
171+
self._cache[key] = PromptCacheItem(value, ttl_seconds)
161172

162173
def delete(self, key: str) -> None:
163-
self._cache.pop(key, None)
174+
with self._lock:
175+
self._cache.pop(key, None)
164176

165177
def invalidate(self, prompt_name: str) -> None:
166178
"""Invalidate all cached prompts with the given prompt name."""
167-
for key in list(self._cache):
168-
if key.startswith(prompt_name):
169-
del self._cache[key]
179+
with self._lock:
180+
for key in list(self._cache):
181+
if key.startswith(prompt_name):
182+
del self._cache[key]
170183

171184
def add_refresh_prompt_task(self, key: str, fetch_func: Callable[[], None]) -> None:
172185
logger.debug(f"Submitting refresh task for key: {key}")
173186
self._task_manager.add_task(key, fetch_func)
174187

188+
def add_refresh_prompt_task_if_current(
189+
self,
190+
key: str,
191+
expected_item: PromptCacheItem,
192+
fetch_func: Callable[[], None],
193+
) -> None:
194+
with self._lock:
195+
current_item = self._cache.get(key)
196+
if (
197+
current_item is not None
198+
and current_item is not expected_item
199+
and not current_item.is_expired()
200+
):
201+
logger.debug(
202+
f"Skipping refresh task for key: {key} because cache is already fresh."
203+
)
204+
return
205+
206+
self.add_refresh_prompt_task(key, fetch_func)
207+
175208
def clear(self) -> None:
176209
"""Clear the entire prompt cache, removing all cached prompts."""
177-
self._cache.clear()
210+
with self._lock:
211+
self._cache.clear()
178212

179213
@staticmethod
180214
def generate_cache_key(

tests/unit/test_prompt.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,58 @@ def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: La
492492
assert updated_result == TextPromptClient(updated_prompt)
493493

494494

495+
@patch.object(PromptCacheItem, "get_epoch_seconds")
496+
def test_skip_redundant_refresh_when_cache_already_updated(
497+
mock_time, langfuse: Langfuse
498+
) -> None:
499+
prompt_name = "test_skip_redundant_refresh_when_cache_already_updated"
500+
cache_key = PromptCache.generate_cache_key(prompt_name, version=None, label=None)
501+
502+
mock_time.return_value = 0
503+
504+
initial_prompt = Prompt_Text(
505+
name=prompt_name,
506+
version=1,
507+
prompt="Make me laugh",
508+
labels=[],
509+
type="text",
510+
config={},
511+
tags=[],
512+
)
513+
updated_prompt = Prompt_Text(
514+
name=prompt_name,
515+
version=2,
516+
prompt="Make me laugh",
517+
labels=[],
518+
type="text",
519+
config={},
520+
tags=[],
521+
)
522+
523+
stale_result = TextPromptClient(initial_prompt)
524+
fresh_result = TextPromptClient(updated_prompt)
525+
526+
langfuse._resources.prompt_cache.set(cache_key, stale_result, None)
527+
stale_item = langfuse._resources.prompt_cache.get(cache_key)
528+
assert stale_item is not None
529+
530+
mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1
531+
assert stale_item.is_expired()
532+
533+
langfuse._resources.prompt_cache.set(cache_key, fresh_result, None)
534+
535+
add_task_mock = Mock()
536+
langfuse._resources.prompt_cache._task_manager.add_task = add_task_mock
537+
538+
langfuse._resources.prompt_cache.add_refresh_prompt_task_if_current(
539+
cache_key,
540+
stale_item,
541+
Mock(),
542+
)
543+
544+
add_task_mock.assert_not_called()
545+
546+
495547
@patch.object(PromptCacheItem, "get_epoch_seconds")
496548
def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse):
497549
mock_time.return_value = 0

0 commit comments

Comments
 (0)