Skip to content

Commit 22008db

Browse files
Copilotfriggeri
andcommitted
Fix race condition in list_models caching and add missing test decorator
Co-authored-by: friggeri <106686+friggeri@users.noreply.github.com>
1 parent 0c49ace commit 22008db

2 files changed

Lines changed: 13 additions & 12 deletions

File tree

python/copilot/client.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __init__(self, options: Optional[CopilotClientOptions] = None):
158158
self._sessions: dict[str, CopilotSession] = {}
159159
self._sessions_lock = threading.Lock()
160160
self._models_cache: Optional[list[ModelInfo]] = None
161-
self._models_cache_lock = threading.Lock()
161+
self._models_cache_lock = asyncio.Lock()
162162

163163
def _parse_cli_url(self, url: str) -> tuple[str, int]:
164164
"""
@@ -284,7 +284,7 @@ async def stop(self) -> list["StopError"]:
284284
self._client = None
285285

286286
# Clear models cache
287-
with self._models_cache_lock:
287+
async with self._models_cache_lock:
288288
self._models_cache = None
289289

290290
# Kill CLI process
@@ -332,7 +332,7 @@ async def force_stop(self) -> None:
332332
self._client = None
333333

334334
# Clear models cache
335-
with self._models_cache_lock:
335+
async with self._models_cache_lock:
336336
self._models_cache = None
337337

338338
# Kill CLI process immediately
@@ -733,21 +733,21 @@ async def list_models(self) -> list["ModelInfo"]:
733733
if not self._client:
734734
raise RuntimeError("Client not connected")
735735

736-
# Check cache first (thread-safe)
737-
with self._models_cache_lock:
736+
# Use asyncio lock to prevent race condition with concurrent calls
737+
async with self._models_cache_lock:
738+
# Check cache (already inside lock)
738739
if self._models_cache is not None:
739740
return self._models_cache
740741

741-
# Cache miss - fetch from backend
742-
response = await self._client.request("models.list", {})
743-
models_data = response.get("models", [])
744-
models = [ModelInfo.from_dict(model) for model in models_data]
742+
# Cache miss - fetch from backend while holding lock
743+
response = await self._client.request("models.list", {})
744+
models_data = response.get("models", [])
745+
models = [ModelInfo.from_dict(model) for model in models_data]
745746

746-
# Update cache (thread-safe)
747-
with self._models_cache_lock:
747+
# Update cache before releasing lock
748748
self._models_cache = models
749749

750-
return models
750+
return models
751751

752752
async def list_sessions(self) -> list["SessionMetadata"]:
753753
"""

python/e2e/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ async def test_should_list_models_when_authenticated(self):
136136
finally:
137137
await client.force_stop()
138138

139+
@pytest.mark.asyncio
139140
async def test_should_cache_models_list(self):
140141
"""Test that list_models caches results to avoid rate limiting"""
141142
client = CopilotClient({"cli_path": CLI_PATH, "use_stdio": True})

0 commit comments

Comments
 (0)