Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions src/utils/llama_stack_version.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Check if the Llama Stack version is supported by the LCS."""

import asyncio
import re

from llama_stack_client._client import AsyncLlamaStackClient
from llama_stack_client import APIConnectionError, AsyncLlamaStackClient
from semver import Version

from constants import (
Expand All @@ -13,33 +14,64 @@

logger = get_logger(__name__)

# Retry settings for waiting on Llama Stack readiness during startup.
# When LCS runs as a sidecar alongside Llama Stack, both containers start
# concurrently and Llama Stack may not be ready when LCS attempts its
# first version check.
_DEFAULT_MAX_RETRIES = 5
_DEFAULT_RETRY_DELAY = 2


class InvalidLlamaStackVersionException(Exception):
"""Llama Stack version is not valid."""


async def check_llama_stack_version(
client: AsyncLlamaStackClient,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: int = _DEFAULT_RETRY_DELAY,
) -> None:
Comment thread
savitojs marked this conversation as resolved.
"""
Verify the connected Llama Stack's version is within the supported range.

This coroutine fetches the Llama Stack version from the
provided client and validates it against the configured minimal
and maximal supported versions. Raises
InvalidLlamaStackVersionException if the detected version is
outside the supported range.
This coroutine fetches the Llama Stack version from the provided client
and validates it against the configured minimal and maximal supported
versions. Connection attempts are retried with a fixed delay to handle
the case where Llama Stack is still starting up (e.g., when running as
a sidecar in the same pod).

Args:
client: The async Llama Stack client.
max_retries: Maximum number of connection attempts before giving up.
retry_delay: Delay in seconds between retry attempts.

Raises:
APIConnectionError: If Llama Stack is unreachable after all retries.
InvalidLlamaStackVersionException: If the detected version is outside
the supported range or cannot be parsed.
"""
version_info = await client.inspect.version()
compare_versions(
version_info.version,
MINIMAL_SUPPORTED_LLAMA_STACK_VERSION,
MAXIMAL_SUPPORTED_LLAMA_STACK_VERSION,
)
if max_retries < 1:
raise ValueError("max_retries must be >= 1")

for attempt in range(max_retries):
try:
version_info = await client.inspect.version()
compare_versions(
version_info.version,
MINIMAL_SUPPORTED_LLAMA_STACK_VERSION,
MAXIMAL_SUPPORTED_LLAMA_STACK_VERSION,
)
return
except APIConnectionError:
if attempt == max_retries - 1:
raise
logger.warning(
"Llama Stack not ready (attempt %d/%d), retrying in %ds...",
attempt + 1,
max_retries,
retry_delay,
)
await asyncio.sleep(retry_delay)


def compare_versions(version_info: str, minimal: str, maximal: str) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/utils/test_llama_stack_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import pytest
from llama_stack_client import APIConnectionError
from llama_stack_client.types import VersionInfo
from pytest_mock import MockerFixture
from pytest_subtests import SubTests
Expand Down Expand Up @@ -115,3 +116,43 @@ async def test_check_llama_stack_version_too_big_version(
with subtests.test(msg="Increased all numbers"):
bigger_version = max_version.bump_major().bump_minor().bump_patch()
await _check_version_must_fail(mock_client, bigger_version)


@pytest.mark.asyncio
async def test_check_llama_stack_version_retries_on_connection_error(
mocker: MockerFixture,
) -> None:
"""Test that check_llama_stack_version retries on APIConnectionError."""
mock_client = mocker.AsyncMock()
mock_sleep = mocker.patch("utils.llama_stack_version.asyncio.sleep")

# Fail twice with connection error, then succeed
mock_client.inspect.version.side_effect = [
APIConnectionError(request=mocker.MagicMock()),
APIConnectionError(request=mocker.MagicMock()),
VersionInfo(version=MINIMAL_SUPPORTED_LLAMA_STACK_VERSION),
]

await check_llama_stack_version(mock_client, max_retries=5, retry_delay=1)

assert mock_client.inspect.version.call_count == 3
assert mock_sleep.call_count == 2


@pytest.mark.asyncio
async def test_check_llama_stack_version_raises_after_max_retries(
mocker: MockerFixture,
) -> None:
"""Test that check_llama_stack_version raises after all retries are exhausted."""
mock_client = mocker.AsyncMock()
mock_sleep = mocker.patch("utils.llama_stack_version.asyncio.sleep")

mock_client.inspect.version.side_effect = APIConnectionError(
request=mocker.MagicMock()
)

with pytest.raises(APIConnectionError):
await check_llama_stack_version(mock_client, max_retries=3, retry_delay=1)

assert mock_client.inspect.version.call_count == 3
assert mock_sleep.call_count == 2
Loading