Skip to content

Commit 835bfc1

Browse files
authored
Merge pull request lightspeed-core#722 from tisnik/lcore-741-consume-quota-mechanism
LCORE-741: consume quota mechanism
2 parents 8e414a2 + 5af4dad commit 835bfc1

3 files changed

Lines changed: 121 additions & 1 deletion

File tree

src/app/endpoints/query.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353
validate_conversation_ownership,
5454
validate_model_provider_override,
5555
)
56+
from utils.quota import (
57+
get_available_quotas,
58+
check_tokens_available,
59+
consume_tokens,
60+
)
5661
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
5762
from utils.transcripts import store_transcript
5863
from utils.types import TurnSummary
@@ -273,6 +278,7 @@ async def query_endpoint_handler( # pylint: disable=R0914
273278
logger.debug("Query does not contain conversation ID")
274279

275280
try:
281+
check_tokens_available(configuration.quota_limiters, user_id)
276282
# try to get Llama Stack client
277283
client = AsyncLlamaStackClientHolder().get_client()
278284
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
@@ -344,6 +350,13 @@ async def query_endpoint_handler( # pylint: disable=R0914
344350
referenced_documents=referenced_documents if referenced_documents else None,
345351
)
346352

353+
consume_tokens(
354+
configuration.quota_limiters,
355+
user_id,
356+
input_tokens=token_usage.input_tokens,
357+
output_tokens=token_usage.output_tokens,
358+
)
359+
347360
store_conversation_into_cache(
348361
configuration,
349362
user_id,
@@ -372,6 +385,8 @@ async def query_endpoint_handler( # pylint: disable=R0914
372385

373386
logger.info("Using referenced documents from response...")
374387

388+
available_quotas = get_available_quotas(configuration.quota_limiters, user_id)
389+
375390
logger.info("Building final response...")
376391
response = QueryResponse(
377392
conversation_id=conversation_id,
@@ -382,7 +397,7 @@ async def query_endpoint_handler( # pylint: disable=R0914
382397
truncated=False, # TODO: implement truncation detection
383398
input_tokens=token_usage.input_tokens,
384399
output_tokens=token_usage.output_tokens,
385-
available_quotas={}, # TODO: implement quota tracking
400+
available_quotas=available_quotas,
386401
)
387402
logger.info("Query processing completed successfully!")
388403
return response

src/utils/quota.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Quota handling helper functions."""
2+
3+
import psycopg2
4+
5+
from fastapi import HTTPException, status
6+
7+
from quota.quota_limiter import QuotaLimiter
8+
from quota.quota_exceed_error import QuotaExceedError
9+
10+
from log import get_logger
11+
12+
logger = get_logger(__name__)
13+
14+
15+
def consume_tokens(
16+
quota_limiters: list[QuotaLimiter],
17+
user_id: str,
18+
input_tokens: int,
19+
output_tokens: int,
20+
) -> None:
21+
"""Consume tokens from cluster and/or user quotas.
22+
23+
Args:
24+
quota_limiters: List of quota limiter instances to consume tokens from.
25+
user_id: Identifier of the user consuming tokens.
26+
input_tokens: Number of input tokens to consume.
27+
output_tokens: Number of output tokens to consume.
28+
29+
Returns:
30+
None
31+
"""
32+
# consume tokens all configured quota limiters
33+
for quota_limiter in quota_limiters:
34+
quota_limiter.consume_tokens(
35+
input_tokens=input_tokens,
36+
output_tokens=output_tokens,
37+
subject_id=user_id,
38+
)
39+
40+
41+
def check_tokens_available(quota_limiters: list[QuotaLimiter], user_id: str) -> None:
42+
"""Check if tokens are available for user.
43+
44+
Args:
45+
quota_limiters: List of quota limiter instances to check.
46+
user_id: Identifier of the user to check quota for.
47+
48+
Returns:
49+
None
50+
51+
Raises:
52+
HTTPException: With status 500 if database communication fails,
53+
or status 429 if quota is exceeded.
54+
"""
55+
try:
56+
# check available tokens using all configured quota limiters
57+
for quota_limiter in quota_limiters:
58+
quota_limiter.ensure_available_quota(subject_id=user_id)
59+
except psycopg2.Error as pg_error:
60+
message = "Error communicating with quota database backend"
61+
logger.error(message)
62+
raise HTTPException(
63+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
64+
detail={
65+
"response": message,
66+
"cause": str(pg_error),
67+
},
68+
) from pg_error
69+
except QuotaExceedError as quota_exceed_error:
70+
message = "The quota has been exceeded"
71+
logger.error(message)
72+
raise HTTPException(
73+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
74+
detail={
75+
"response": message,
76+
"cause": str(quota_exceed_error),
77+
},
78+
) from quota_exceed_error
79+
80+
81+
def get_available_quotas(
82+
quota_limiters: list[QuotaLimiter],
83+
user_id: str,
84+
) -> dict[str, int]:
85+
"""Get quota available from all quota limiters.
86+
87+
Args:
88+
quota_limiters: List of quota limiter instances to query.
89+
user_id: Identifier of the user to get quotas for.
90+
91+
Returns:
92+
Dictionary mapping quota limiter class names to available token counts.
93+
"""
94+
available_quotas: dict[str, int] = {}
95+
96+
# retrieve available tokens using all configured quota limiters
97+
for quota_limiter in quota_limiters:
98+
name = quota_limiter.__class__.__name__
99+
available_quota = quota_limiter.available_quota(user_id)
100+
available_quotas[name] = available_quota
101+
return available_quotas

tests/unit/app/endpoints/test_query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ async def _test_query_endpoint_handler(
184184
mock_config.user_data_collection_configuration.transcripts_enabled = (
185185
store_transcript_to_file
186186
)
187+
mock_config.quota_limiters = []
187188
mocker.patch("app.endpoints.query.configuration", mock_config)
188189

189190
mock_store_in_cache = mocker.patch(
@@ -1434,6 +1435,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(
14341435
# Mock dependencies
14351436
mock_config = mocker.Mock()
14361437
mock_config.llama_stack_configuration = mocker.Mock()
1438+
mock_config.quota_limiters = []
14371439
mocker.patch("app.endpoints.query.configuration", mock_config)
14381440

14391441
mock_client = mocker.AsyncMock()
@@ -1499,6 +1501,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request) -> No
14991501

15001502
mock_config = mocker.Mock()
15011503
mock_config.user_data_collection_configuration.transcripts_disabled = True
1504+
mock_config.quota_limiters = []
15021505
mocker.patch("app.endpoints.query.configuration", mock_config)
15031506

15041507
summary = TurnSummary(
@@ -1555,6 +1558,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request) -> N
15551558

15561559
mock_config = mocker.Mock()
15571560
mock_config.user_data_collection_configuration.transcripts_disabled = True
1561+
mock_config.quota_limiters = []
15581562
mocker.patch("app.endpoints.query.configuration", mock_config)
15591563

15601564
summary = TurnSummary(

0 commit comments

Comments
 (0)