44import os
55from datetime import datetime
66from queue import Queue
7- from threading import Thread
7+ from threading import RLock , Thread
88from typing import Callable , Dict , List , Optional , Set
99
1010from langfuse ._client .environment_variables import (
@@ -77,12 +77,14 @@ class PromptCacheTaskManager(object):
7777 _threads : int
7878 _queue : Queue
7979 _processing_keys : Set [str ]
80+ _lock : RLock
8081
8182 def __init__ (self , threads : int = 1 ):
8283 self ._queue = Queue ()
8384 self ._consumers = []
8485 self ._threads = threads
8586 self ._processing_keys = set ()
87+ self ._lock = RLock ()
8688
8789 for i in range (self ._threads ):
8890 consumer = PromptCacheRefreshConsumer (self ._queue , i )
@@ -92,16 +94,20 @@ def __init__(self, threads: int = 1):
9294 atexit .register (self .shutdown )
9395
9496 def add_task (self , key : str , task : Callable [[], None ]) -> None :
95- if key not in self ._processing_keys :
96- logger .debug (f"Adding prompt cache refresh task for key: { key } " )
97- self ._processing_keys .add (key )
98- wrapped_task = self ._wrap_task (key , task )
99- self ._queue .put ((wrapped_task ))
100- else :
101- logger .debug (f"Prompt cache refresh task already submitted for key: { key } " )
97+ with self ._lock :
98+ if key not in self ._processing_keys :
99+ logger .debug (f"Adding prompt cache refresh task for key: { key } " )
100+ self ._processing_keys .add (key )
101+ wrapped_task = self ._wrap_task (key , task )
102+ self ._queue .put ((wrapped_task ))
103+ else :
104+ logger .debug (
105+ f"Prompt cache refresh task already submitted for key: { key } "
106+ )
102107
103108 def active_tasks (self ) -> int :
104- return len (self ._processing_keys )
109+ with self ._lock :
110+ return len (self ._processing_keys )
105111
106112 def wait_for_idle (self ) -> None :
107113 self ._queue .join ()
@@ -112,7 +118,8 @@ def wrapped() -> None:
112118 try :
113119 task ()
114120 finally :
115- self ._processing_keys .remove (key )
121+ with self ._lock :
122+ self ._processing_keys .remove (key )
116123 logger .debug (f"Refreshed prompt cache for key: { key } " )
117124
118125 return wrapped
@@ -139,6 +146,7 @@ def shutdown(self) -> None:
139146
140147class PromptCache :
141148 _cache : Dict [str , PromptCacheItem ]
149+ _lock : RLock
142150
143151 _task_manager : PromptCacheTaskManager
144152 """Task manager for refreshing cache"""
@@ -147,34 +155,60 @@ def __init__(
147155 self , max_prompt_refresh_workers : int = DEFAULT_PROMPT_CACHE_REFRESH_WORKERS
148156 ):
149157 self ._cache = {}
158+ self ._lock = RLock ()
150159 self ._task_manager = PromptCacheTaskManager (threads = max_prompt_refresh_workers )
151160 logger .debug ("Prompt cache initialized." )
152161
153162 def get (self , key : str ) -> Optional [PromptCacheItem ]:
154- return self ._cache .get (key , None )
163+ with self ._lock :
164+ return self ._cache .get (key , None )
155165
156166 def set (self , key : str , value : PromptClient , ttl_seconds : Optional [int ]) -> None :
157167 if ttl_seconds is None :
158168 ttl_seconds = DEFAULT_PROMPT_CACHE_TTL_SECONDS
159169
160- self ._cache [key ] = PromptCacheItem (value , ttl_seconds )
170+ with self ._lock :
171+ self ._cache [key ] = PromptCacheItem (value , ttl_seconds )
161172
162173 def delete (self , key : str ) -> None :
163- self ._cache .pop (key , None )
174+ with self ._lock :
175+ self ._cache .pop (key , None )
164176
165177 def invalidate (self , prompt_name : str ) -> None :
166178 """Invalidate all cached prompts with the given prompt name."""
167- for key in list (self ._cache ):
168- if key .startswith (prompt_name ):
169- del self ._cache [key ]
179+ with self ._lock :
180+ for key in list (self ._cache ):
181+ if key .startswith (prompt_name ):
182+ del self ._cache [key ]
170183
171184 def add_refresh_prompt_task (self , key : str , fetch_func : Callable [[], None ]) -> None :
172185 logger .debug (f"Submitting refresh task for key: { key } " )
173186 self ._task_manager .add_task (key , fetch_func )
174187
188+ def add_refresh_prompt_task_if_current (
189+ self ,
190+ key : str ,
191+ expected_item : PromptCacheItem ,
192+ fetch_func : Callable [[], None ],
193+ ) -> None :
194+ with self ._lock :
195+ current_item = self ._cache .get (key )
196+ if (
197+ current_item is not None
198+ and current_item is not expected_item
199+ and not current_item .is_expired ()
200+ ):
201+ logger .debug (
202+ f"Skipping refresh task for key: { key } because cache is already fresh."
203+ )
204+ return
205+
206+ self .add_refresh_prompt_task (key , fetch_func )
207+
175208 def clear (self ) -> None :
176209 """Clear the entire prompt cache, removing all cached prompts."""
177- self ._cache .clear ()
210+ with self ._lock :
211+ self ._cache .clear ()
178212
179213 @staticmethod
180214 def generate_cache_key (
0 commit comments