Skip to content

Commit 8e34104

Browse files
committed
Requested changes
1 parent 84fdf4b commit 8e34104

10 files changed

Lines changed: 84 additions & 110 deletions

File tree

docs/providers.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ Example of LCS config section:
8484

8585
```yaml
8686
azure_entra_id:
87-
tenant_id: ${env.AZURE_TENANT_ID}
88-
client_id: ${env.AZURE_CLIENT_ID}
89-
client_secret: ${env.AZURE_CLIENT_SECRET}
87+
tenant_id: ${env.TENANT_ID}
88+
client_id: ${env.CLIENT_ID}
89+
client_secret: ${env.CLIENT_SECRET}
9090
# scope: "https://cognitiveservices.azure.com/.default" # optional, this is the default
9191
```
9292

scripts/llama-stack-entrypoint.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@ ENV_FILE="/opt/app-root/.env"
1212
# Enrich config if lightspeed config exists
1313
if [ -f "$LIGHTSPEED_CONFIG" ]; then
1414
echo "Enriching llama-stack config..."
15+
ENRICHMENT_FAILED=0
1516
python3 /opt/app-root/llama_stack_configuration.py \
1617
-c "$LIGHTSPEED_CONFIG" \
1718
-i "$INPUT_CONFIG" \
1819
-o "$ENRICHED_CONFIG" \
19-
-e "$ENV_FILE" 2>&1 || true
20+
-e "$ENV_FILE" 2>&1 || ENRICHMENT_FAILED=1
2021

2122
# Source .env if generated (contains AZURE_API_KEY)
2223
if [ -f "$ENV_FILE" ]; then
2324
# shellcheck source=/dev/null
2425
set -a && . "$ENV_FILE" && set +a
2526
fi
2627

27-
if [ -f "$ENRICHED_CONFIG" ]; then
28+
if [ -f "$ENRICHED_CONFIG" ] && [ "$ENRICHMENT_FAILED" -eq 0 ]; then
2829
echo "Using enriched config: $ENRICHED_CONFIG"
2930
exec llama stack run "$ENRICHED_CONFIG"
3031
fi
3132
fi
3233

3334
echo "Using original config: $INPUT_CONFIG"
3435
exec llama stack run "$INPUT_CONFIG"
35-

src/app/endpoints/query.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -307,36 +307,34 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
307307
try:
308308
check_tokens_available(configuration.quota_limiters, user_id)
309309
# try to get Llama Stack client
310-
client_holder = AsyncLlamaStackClientHolder()
311-
client = client_holder.get_client()
310+
client = AsyncLlamaStackClientHolder().get_client()
312311
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
313312
await client.models.list(),
314313
*evaluate_model_hints(
315314
user_conversation=user_conversation, query_request=query_request
316315
),
317316
)
318317

319-
if provider_id == "azure":
320-
if (
321-
AzureEntraIDManager().is_entra_id_configured
322-
and AzureEntraIDManager().is_token_expired
323-
):
324-
await AzureEntraIDManager().refresh_token()
325-
326-
if client_holder.is_library_client:
327-
client = await client_holder.reload_library_client()
328-
else:
329-
azure_config = next(
330-
p.config
331-
for p in await client.providers.list()
332-
if p.provider_type == "remote::azure"
333-
)
334-
client = client_holder.update_provider_data(
335-
{
336-
"azure_api_key": AzureEntraIDManager().access_token.get_secret_value(),
337-
"azure_api_base": str(azure_config.get("api_base")),
338-
}
339-
)
318+
if (
319+
provider_id == "azure"
320+
and AzureEntraIDManager().is_entra_id_configured
321+
and AzureEntraIDManager().is_token_expired
322+
and AzureEntraIDManager().refresh_token()
323+
):
324+
if AsyncLlamaStackClientHolder().is_library_client:
325+
client = await AsyncLlamaStackClientHolder().reload_library_client()
326+
else:
327+
azure_config = next(
328+
p.config
329+
for p in await client.providers.list()
330+
if p.provider_type == "remote::azure"
331+
)
332+
client = AsyncLlamaStackClientHolder().update_provider_data(
333+
{
334+
"azure_api_key": AzureEntraIDManager().access_token.get_secret_value(),
335+
"azure_api_base": str(azure_config.get("api_base")),
336+
}
337+
)
340338

341339
summary, conversation_id, referenced_documents, token_usage = (
342340
await retrieve_response_func(

src/app/endpoints/streaming_query.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -884,36 +884,34 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
884884

885885
try:
886886
# try to get Llama Stack client
887-
client_holder = AsyncLlamaStackClientHolder()
888-
client = client_holder.get_client()
887+
client = AsyncLlamaStackClientHolder().get_client()
889888
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
890889
await client.models.list(),
891890
*evaluate_model_hints(
892891
user_conversation=user_conversation, query_request=query_request
893892
),
894893
)
895894

896-
if provider_id == "azure":
897-
if (
898-
AzureEntraIDManager().is_entra_id_configured
899-
and AzureEntraIDManager().is_token_expired
900-
):
901-
await AzureEntraIDManager().refresh_token()
902-
903-
if client_holder.is_library_client:
904-
client = await client_holder.reload_library_client()
905-
else:
906-
azure_config = next(
907-
p.config
908-
for p in await client.providers.list()
909-
if p.provider_type == "remote::azure"
910-
)
911-
client = client_holder.update_provider_data(
912-
{
913-
"azure_api_key": AzureEntraIDManager().access_token.get_secret_value(),
914-
"azure_api_base": str(azure_config.get("api_base")),
915-
}
916-
)
895+
if (
896+
provider_id == "azure"
897+
and AzureEntraIDManager().is_entra_id_configured
898+
and AzureEntraIDManager().is_token_expired
899+
and AzureEntraIDManager().refresh_token()
900+
):
901+
if AsyncLlamaStackClientHolder().is_library_client:
902+
client = await AsyncLlamaStackClientHolder().reload_library_client()
903+
else:
904+
azure_config = next(
905+
p.config
906+
for p in await client.providers.list()
907+
if p.provider_type == "remote::azure"
908+
)
909+
client = AsyncLlamaStackClientHolder().update_provider_data(
910+
{
911+
"azure_api_key": AzureEntraIDManager().access_token.get_secret_value(),
912+
"azure_api_base": str(azure_config.get("api_base")),
913+
}
914+
)
917915

918916
response, conversation_id = await retrieve_response_func(
919917
client,

src/app/main.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,11 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
4444
azure_config = configuration.configuration.azure_entra_id
4545
if azure_config is not None:
4646
AzureEntraIDManager().set_config(azure_config)
47-
try:
48-
await AzureEntraIDManager().refresh_token()
49-
os.environ["AZURE_API_KEY"] = (
50-
AzureEntraIDManager().access_token.get_secret_value()
47+
if not AzureEntraIDManager().refresh_token():
48+
logger.warning(
49+
"Failed to refresh Azure token at startup. "
50+
"Token refresh will be retried on next Azure request."
5151
)
52-
logger.info("Azure Entra ID token set in environment")
53-
except ValueError as e:
54-
logger.error("Failed to refresh Azure token: %s", e)
5552

5653
await AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack)
5754
client = AsyncLlamaStackClientHolder().get_client()

src/authorization/azure_token_manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def access_token(self) -> SecretStr:
5656
"""Return the access token from environment variable as SecretStr."""
5757
return SecretStr(os.environ.get("AZURE_API_KEY", ""))
5858

59-
async def refresh_token(self) -> None:
59+
def refresh_token(self) -> bool:
6060
"""Refresh the cached Azure access token.
6161
62-
This is async to enforce proper ordering in the event loop -
63-
callers must await this before using the refreshed token.
62+
Returns:
63+
bool: True if token was successfully refreshed, False otherwise.
6464
6565
Raises:
6666
ValueError: If Entra ID configuration has not been set.
@@ -72,6 +72,8 @@ async def refresh_token(self) -> None:
7272
token_obj = self._retrieve_access_token()
7373
if token_obj:
7474
self._update_access_token(token_obj.token, token_obj.expires_on)
75+
return True
76+
return False
7577

7678
def _update_access_token(self, token: str, expires_on: int) -> None:
7779
"""Update the token in env var and track expiration time."""
@@ -96,5 +98,5 @@ def _retrieve_access_token(self) -> Optional[AccessToken]:
9698
return credential.get_token(self._entra_id_config.scope)
9799

98100
except (ClientAuthenticationError, CredentialUnavailableError):
99-
logger.error("Failed to retrieve Azure access token")
101+
logger.warning("Failed to retrieve Azure access token")
100102
return None

src/llama_stack_configuration.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
import os
1010
from argparse import ArgumentParser
11+
from pathlib import Path
1112
from typing import Any
1213

1314
from azure.core.exceptions import ClientAuthenticationError
@@ -62,14 +63,14 @@ def setup_azure_entra_id_token(
6263
tenant_id = azure_config.get("tenant_id")
6364
client_id = azure_config.get("client_id")
6465
client_secret = azure_config.get("client_secret")
66+
scope = azure_config.get("scope", "https://cognitiveservices.azure.com/.default")
6567

6668
if not all([tenant_id, client_id, client_secret]):
6769
logger.warning(
6870
"Azure Entra ID: Missing required fields (tenant_id, client_id, client_secret)"
6971
)
7072
return
7173

72-
scope = "https://cognitiveservices.azure.com/.default"
7374
try:
7475
credential = ClientSecretCredential(
7576
tenant_id=str(tenant_id),
@@ -80,10 +81,12 @@ def setup_azure_entra_id_token(
8081
token = credential.get_token(scope)
8182

8283
# Write to .env file
84+
# Create file if it doesn't exist
85+
Path(env_file).touch()
86+
8387
lines = []
84-
if os.path.exists(env_file):
85-
with open(env_file, "r", encoding="utf-8") as f:
86-
lines = f.readlines()
88+
with open(env_file, "r", encoding="utf-8") as f:
89+
lines = f.readlines()
8790

8891
# Update or add AZURE_API_KEY
8992
key_found = False
@@ -123,7 +126,7 @@ def construct_vector_dbs_section(
123126
ls_config (dict[str, Any]): Existing Llama Stack configuration mapping
124127
used as the base; existing `vector_dbs` entries are preserved if
125128
present.
126-
byok_rag (list[ByokRag]): List of BYOK RAG definitions to be added to
129+
byok_rag (list[dict[str, Any]]): List of BYOK RAG definitions to be added to
127130
the `vector_dbs` section.
128131
129132
Returns:
@@ -143,10 +146,10 @@ def construct_vector_dbs_section(
143146
for brag in byok_rag:
144147
output.append(
145148
{
146-
"vector_db_id": brag.get("vector_db_id"),
149+
"vector_db_id": brag.get("vector_db_id", ""),
147150
"provider_id": "byok_" + brag.get("vector_db_id", ""),
148-
"embedding_model": brag.get("embedding_model", "all-MiniLM-L6-v2"),
149-
"embedding_dimension": brag.get("embedding_dimension", 384),
151+
"embedding_model": brag.get("embedding_model", ""),
152+
"embedding_dimension": brag.get("embedding_dimension"),
150153
}
151154
)
152155
logger.info(
@@ -170,7 +173,7 @@ def construct_vector_io_providers_section(
170173
ls_config (dict[str, Any]): Existing Llama Stack configuration
171174
dictionary; if it contains providers.vector_io, those entries are used
172175
as the starting list.
173-
byok_rag (list[ByokRag]): List of BYOK RAG specifications to convert
176+
byok_rag (list[dict[str, Any]]): List of BYOK RAG specifications to convert
174177
into provider entries.
175178
176179
Returns:
@@ -303,13 +306,9 @@ def main() -> None:
303306
)
304307
args = parser.parse_args()
305308

306-
try:
307-
with open(args.config, "r", encoding="utf-8") as f:
308-
config = yaml.safe_load(f)
309-
config = replace_env_vars(config)
310-
except FileNotFoundError:
311-
logger.error("Config not found: %s", args.config)
312-
return
309+
with open(args.config, "r", encoding="utf-8") as f:
310+
config = yaml.safe_load(f)
311+
config = replace_env_vars(config)
313312

314313
generate_configuration(args.input, args.output, config, args.env_file)
315314

test.containerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ RUN pip install faiss-cpu==1.11.0 azure-identity && \
1313
# Copy enrichment scripts for runtime config enrichment
1414
COPY src/llama_stack_configuration.py /opt/app-root/llama_stack_configuration.py
1515
COPY scripts/llama-stack-entrypoint.sh /opt/app-root/enrich-entrypoint.sh
16-
RUN chmod +x /opt/app-root/enrich-entrypoint.sh /opt/app-root/llama_stack_configuration.py && \
16+
RUN chmod +x /opt/app-root/enrich-entrypoint.sh && \
1717
chown 1001:0 /opt/app-root/enrich-entrypoint.sh /opt/app-root/llama_stack_configuration.py
1818

1919
# Switch back to the original user

tests/unit/authorization/test_azure_token_manager.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,12 @@ def test_token_expiration_logic(self, token_manager: AzureEntraIDManager) -> Non
7575
token_manager._expires_on = 0
7676
assert token_manager.is_token_expired
7777

78-
@pytest.mark.asyncio
79-
async def test_refresh_token_raises_without_config(
78+
def test_refresh_token_raises_without_config(
8079
self, token_manager: AzureEntraIDManager
8180
) -> None:
8281
"""Raise ValueError when refresh_token is called without config."""
8382
with pytest.raises(ValueError, match="Azure Entra ID configuration not set"):
84-
await token_manager.refresh_token()
83+
token_manager.refresh_token()
8584

8685
def test_update_access_token_sets_token_and_expiration(
8786
self, token_manager: AzureEntraIDManager
@@ -92,8 +91,7 @@ def test_update_access_token_sets_token_and_expiration(
9291
assert token_manager.access_token.get_secret_value() == "test-token"
9392
assert token_manager._expires_on == expires_on - TOKEN_EXPIRATION_LEEWAY
9493

95-
@pytest.mark.asyncio
96-
async def test_refresh_token_success(
94+
def test_refresh_token_success(
9795
self,
9896
token_manager: AzureEntraIDManager,
9997
dummy_config: AzureEntraIdConfiguration,
@@ -111,16 +109,14 @@ async def test_refresh_token_success(
111109
return_value=mock_credential_instance,
112110
)
113111

114-
await token_manager.refresh_token()
112+
result = token_manager.refresh_token()
115113

114+
assert result is True
116115
assert token_manager.access_token.get_secret_value() == "token_value"
117116
assert not token_manager.is_token_expired
118-
mock_credential_instance.get_token.assert_called_once_with(
119-
"https://cognitiveservices.azure.com/.default"
120-
)
117+
mock_credential_instance.get_token.assert_called_once_with(dummy_config.scope)
121118

122-
@pytest.mark.asyncio
123-
async def test_refresh_token_failure_logs_error(
119+
def test_refresh_token_failure_logs_error(
124120
self,
125121
token_manager: AzureEntraIDManager,
126122
dummy_config: AzureEntraIdConfiguration,
@@ -138,8 +134,9 @@ async def test_refresh_token_failure_logs_error(
138134
return_value=mock_credential_instance,
139135
)
140136

141-
with caplog.at_level("ERROR"):
142-
await token_manager.refresh_token()
137+
with caplog.at_level("WARNING"):
138+
result = token_manager.refresh_token()
139+
assert result is False
143140
assert "Failed to retrieve Azure access token" in caplog.text
144141

145142
def test_token_expired_property_dynamic(

tests/unit/test_llama_stack_configuration.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,6 @@ def test_construct_vector_dbs_section_adds_new() -> None:
6565
assert output[0]["embedding_dimension"] == 512
6666

6767

68-
def test_construct_vector_dbs_section_uses_defaults() -> None:
69-
"""Test uses default values when not specified."""
70-
ls_config: dict[str, Any] = {}
71-
byok_rag = [{"vector_db_id": "db1"}]
72-
output = construct_vector_dbs_section(ls_config, byok_rag)
73-
assert output[0]["embedding_model"] == "all-MiniLM-L6-v2"
74-
assert output[0]["embedding_dimension"] == 384
75-
76-
7768
def test_construct_vector_dbs_section_merge() -> None:
7869
"""Test merges existing and new entries."""
7970
ls_config = {"vector_dbs": [{"vector_db_id": "existing"}]}
@@ -119,14 +110,6 @@ def test_construct_vector_io_providers_section_adds_new() -> None:
119110
assert output[0]["provider_type"] == "inline::faiss"
120111

121112

122-
def test_construct_vector_io_providers_section_uses_defaults() -> None:
123-
"""Test uses default values when not specified."""
124-
ls_config: dict[str, Any] = {"providers": {}}
125-
byok_rag = [{"vector_db_id": "db1"}]
126-
output = construct_vector_io_providers_section(ls_config, byok_rag)
127-
assert output[0]["provider_type"] == "inline::faiss"
128-
129-
130113
# =============================================================================
131114
# Test generate_configuration
132115
# =============================================================================

0 commit comments

Comments
 (0)