From 85d83b2a58c424e3d9002ce55fd388d222bc50f5 Mon Sep 17 00:00:00 2001 From: Aleksandr Kukharenko Date: Tue, 10 Mar 2026 15:53:26 +0200 Subject: [PATCH] feat: per-user rate limit multiplier with admin endpoints --- ...0001_add_rate_limit_multiplier_to_users.py | 31 ++ src/api/v1/__init__.py | 2 +- src/api/v1/billing/admin.py | 409 ++++++++++++++++++ src/api/v1/billing/index.py | 297 +------------ src/api/v1/chat/index.py | 1 + src/api/v1/embeddings/index.py | 3 + src/db/models/user.py | 3 +- src/dependencies.py | 5 + src/schemas/billing.py | 21 + .../rate_limiting/rate_limit_service.py | 17 +- 10 files changed, 494 insertions(+), 295 deletions(-) create mode 100644 alembic/versions/2026_03_10_0001_add_rate_limit_multiplier_to_users.py create mode 100644 src/api/v1/billing/admin.py diff --git a/alembic/versions/2026_03_10_0001_add_rate_limit_multiplier_to_users.py b/alembic/versions/2026_03_10_0001_add_rate_limit_multiplier_to_users.py new file mode 100644 index 00000000..fe561802 --- /dev/null +++ b/alembic/versions/2026_03_10_0001_add_rate_limit_multiplier_to_users.py @@ -0,0 +1,31 @@ +"""Add rate_limit_multiplier to users table + +Per-user scaling factor for rate limits (RPM/TPM). +Default 1.0 means no change; 2.0 doubles the limits, 0.5 halves them. +Managed exclusively by admin endpoints — never exposed to end users. + +Revision ID: add_rate_limit_mult +Revises: drop_email_name_2026 +Create Date: 2026-03-10 00:01:00.000000 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = 'add_rate_limit_mult' +down_revision: Union[str, None] = 'drop_email_name_2026' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + 'users', + sa.Column('rate_limit_multiplier', sa.Float(), nullable=False, server_default='1.0'), + ) + + +def downgrade() -> None: + op.drop_column('users', 'rate_limit_multiplier') diff --git a/src/api/v1/__init__.py b/src/api/v1/__init__.py index 52b37fe0..92549a55 100644 --- a/src/api/v1/__init__.py +++ b/src/api/v1/__init__.py @@ -8,7 +8,7 @@ from .embeddings.index import router as embeddings_router from .audio.index import router as audio_router from .billing.index import router as billing_router -from .billing.index import admin_router as billing_admin_router +from .billing.admin import admin_router as billing_admin_router from .webhooks.stripe import stripe_webhook_router from .webhooks.coinbase import coinbase_webhook_router from .wallet.index import router as wallet_router diff --git a/src/api/v1/billing/admin.py b/src/api/v1/billing/admin.py new file mode 100644 index 00000000..afef28d3 --- /dev/null +++ b/src/api/v1/billing/admin.py @@ -0,0 +1,409 @@ +""" +Admin billing API endpoints. +Protected by X-Admin-Secret header. Served on /admin/docs Swagger page. +""" +from fastapi import APIRouter, Depends, HTTPException, status, Query, Header +from sqlalchemy.ext.asyncio import AsyncSession +from typing import Optional +import secrets + +from ....db.database import get_db_session +from ....db.models import User +from ....dependencies import get_current_user +from ....services.billing_service import billing_service +from ....services.staking_service import staking_service +from ....crud import credits as credits_crud +from ....crud import user as user_crud +from ....schemas.billing import ( + BalanceResponse, + StakingSettingsRequest, + StakingSettingsResponse, + StakingRefreshResponse, + ManualTopupRequest, + ManualTopupResponse, + RateLimitMultiplierRequest, + RateLimitMultiplierResponse, +) +from ....core.logging_config import get_api_logger +from ....core.config import settings +from ....services.cache_service import cache_service + +logger = get_api_logger() + +admin_router = APIRouter(tags=["Billing Admin"]) + + +# === Admin Authentication === + +async def verify_billing_admin_secret( + x_admin_secret: Optional[str] = Header(None, alias="X-Admin-Secret") +) -> bool: + """ + Verify the admin secret for protected billing endpoints. + + Requires the X-Admin-Secret header to match BILLING_ADMIN_SECRET env variable. + """ + if not settings.BILLING_ADMIN_SECRET: + logger.warning( + "Billing admin endpoint called but BILLING_ADMIN_SECRET is not configured", + event_type="billing_admin_not_configured" + ) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Admin billing endpoints are not configured. Set BILLING_ADMIN_SECRET environment variable." + ) + + if not x_admin_secret: + logger.warning( + "Billing admin endpoint called without X-Admin-Secret header", + event_type="billing_admin_missing_secret" + ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing X-Admin-Secret header" + ) + + if not secrets.compare_digest(x_admin_secret, settings.BILLING_ADMIN_SECRET): + logger.warning( + "Billing admin endpoint called with invalid secret", + event_type="billing_admin_invalid_secret" + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid admin secret" + ) + + return True + + +# === Staking Settings Endpoints === + +@admin_router.post("/staking/settings", response_model=StakingSettingsResponse) +async def set_staking_settings( + staking_request: StakingSettingsRequest, + db: AsyncSession = Depends(get_db_session), + current_user: User = Depends(get_current_user), + _admin_verified: bool = Depends(verify_billing_admin_secret), +): + """ + Set the daily staking allowance amount. + + **Admin/Dev endpoint** - Requires X-Admin-Secret header. + + This updates the configured daily amount but does NOT trigger an immediate refresh. + The new amount will take effect on the next daily refresh. + """ + try: + balance = await credits_crud.set_staking_daily_amount( + db=db, + user_id=current_user.id, + amount=staking_request.daily_amount, + ) + + logger.info( + "Staking settings updated by admin", + user_id=current_user.id, + daily_amount=str(staking_request.daily_amount), + event_type="billing_admin_staking_settings" + ) + + return StakingSettingsResponse( + daily_amount=balance.staking_daily_amount, + message="Staking daily amount updated", + ) + except Exception as e: + logger.error( + "Error in set_staking_settings endpoint", + user_id=current_user.id, + error=str(e), + error_type=type(e).__name__, + event_type="billing_staking_settings_error" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error updating staking settings: {str(e)}" + ) + + +@admin_router.post("/staking/refresh", response_model=StakingRefreshResponse) +async def trigger_staking_refresh( + db: AsyncSession = Depends(get_db_session), + current_user: User = Depends(get_current_user), + _admin_verified: bool = Depends(verify_billing_admin_secret), +): + """ + Trigger the daily staking sync from Builders API. + + **Admin/Dev endpoint** - Requires X-Admin-Secret header. + + This operation: + 1. Fetches all stakers from Builders API + 2. Updates staked_amount for all linked wallets + 3. Calculates daily credits for each user (total_staked / 100) + 4. Creates ledger entries (transactions) for each user refresh + 5. Updates user balances + + Idempotent: Users already refreshed today will be skipped. + The staking bucket resets to the calculated daily amount (does not accumulate). + """ + logger.info( + "Staking sync triggered by admin", + user_id=current_user.id, + event_type="billing_admin_staking_sync" + ) + + try: + summary = await staking_service.run_daily_sync(db) + + return StakingRefreshResponse( + success=summary.get("success", True), + message="Staking sync completed successfully", + stakers_fetched=summary.get("stakers_fetched"), + total_wallets=summary.get("total_wallets"), + wallets_updated=summary.get("wallets_updated"), + users_processed=summary.get("users_processed"), + users_skipped=summary.get("users_skipped"), + users_failed=summary.get("users_failed"), + duration_seconds=summary.get("duration_seconds"), + ) + except Exception as e: + logger.error( + "Staking sync failed", + user_id=current_user.id, + error=str(e), + event_type="billing_admin_staking_sync_failed" + ) + return StakingRefreshResponse( + success=False, + message=f"Staking sync failed: {str(e)}", + ) + + +# === Manual Credit Adjustment === + +@admin_router.post("/credits/adjust", response_model=ManualTopupResponse) +async def adjust_credits( + request: ManualTopupRequest, + db: AsyncSession = Depends(get_db_session), + current_user: User = Depends(get_current_user), + _admin_verified: bool = Depends(verify_billing_admin_secret), +): + """ + Manually adjust credits for an account (add or subtract). + + **Admin/Dev endpoint** - Requires X-Admin-Secret header. + + - Positive amount: Adds credits (simulates a purchase) + - Negative amount: Subtracts credits (admin correction/chargeback) + - user_id (optional): Target user ID (database primary key integer) + - cognito_user_id (optional): Target Cognito user ID (UUID) + - If neither provided, adjusts current user's credits. + + This endpoint is for development/admin purposes to manage credits + without integrating with payment providers. + """ + try: + target_user_id = current_user.id + + if request.user_id is not None: + target_user_id = request.user_id + elif request.cognito_user_id is not None: + target_user = await user_crud.get_user_by_cognito_id(db, str(request.cognito_user_id)) + if not target_user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with cognito_user_id {request.cognito_user_id} not found" + ) + target_user_id = target_user.id + + entry, new_balance = await billing_service.adjust_credits( + db=db, + user_id=target_user_id, + amount=request.amount_usd, + description=request.description, + ) + + action = "added" if request.amount_usd >= 0 else "subtracted" + logger.info( + f"Manual credit adjustment by admin: {action}", + user_id=str(target_user_id), + admin_user_id=str(current_user.id), + amount=str(request.amount_usd), + new_balance=str(new_balance), + event_type="billing_admin_credit_adjust" + ) + + return ManualTopupResponse( + ledger_entry_id=entry.id, + amount_added=request.amount_usd, + new_paid_balance=new_balance, + message=f"Credits {action} successfully", + ) + except HTTPException: + raise + except Exception as e: + logger.error( + "Error in adjust_credits endpoint", + user_id=current_user.id, + error=str(e), + error_type=type(e).__name__, + event_type="billing_credit_adjust_error" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error adjusting credits: {str(e)}" + ) + + +# === Balance Reconciliation === + +@admin_router.post("/balance/reconcile", response_model=BalanceResponse) +async def reconcile_balance( + user_id: Optional[int] = Query(default=None, description="Target user ID (defaults to current user)"), + db: AsyncSession = Depends(get_db_session), + current_user: User = Depends(get_current_user), + _admin_verified: bool = Depends(verify_billing_admin_secret), +): + """ + Reconcile the cached balance against the ledger (source of truth). + + **Admin/Dev endpoint** - Requires X-Admin-Secret header. + + Fixes drift in `paid_pending_holds` caused by partial transaction failures + where the ledger entry was updated but the balance cache was not. + + Recomputes `paid_pending_holds` from the sum of all pending usage_hold entries in the ledger. + + Parameters: + - user_id: Target user ID (defaults to current user) + """ + try: + target_user_id = user_id if user_id is not None else current_user.id + + logger.info( + "Balance reconciliation triggered by admin", + target_user_id=target_user_id, + admin_user_id=current_user.id, + event_type="billing_admin_reconcile", + ) + + result = await billing_service.reconcile_balance(db, target_user_id) + return result + except Exception as e: + logger.error( + "Error in reconcile_balance endpoint", + user_id=current_user.id, + error=str(e), + error_type=type(e).__name__, + event_type="billing_reconcile_error", + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error reconciling balance: {str(e)}", + ) + + +# === Rate Limit Multiplier === + +@admin_router.post("/rate-limit/multiplier", response_model=RateLimitMultiplierResponse) +async def set_rate_limit_multiplier( + request: RateLimitMultiplierRequest, + db: AsyncSession = Depends(get_db_session), + current_user: User = Depends(get_current_user), + _admin_verified: bool = Depends(verify_billing_admin_secret), +): + """ + Set the rate limit multiplier for a user. + + **Admin endpoint** - Requires X-Admin-Secret header. + + The multiplier scales all RPM/TPM limits for the target user: + - 1.0 = default limits + - 2.0 = double the limits + - 0.5 = half the limits + + Applies to all models. Takes effect on the next request. + """ + try: + target_user = await user_crud.get_user_by_cognito_id(db, request.cognito_user_id) + if not target_user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with cognito_user_id {request.cognito_user_id} not found", + ) + + target_user.rate_limit_multiplier = request.multiplier + await db.commit() + await db.refresh(target_user) + + await cache_service.delete("user", request.cognito_user_id) + + logger.info( + "Rate limit multiplier updated by admin", + target_user_id=target_user.id, + target_cognito_id=request.cognito_user_id, + admin_user_id=current_user.id, + multiplier=request.multiplier, + event_type="admin_rate_limit_multiplier_set", + ) + + return RateLimitMultiplierResponse( + cognito_user_id=target_user.cognito_user_id, + user_id=target_user.id, + multiplier=target_user.rate_limit_multiplier, + message=f"Rate limit multiplier set to {request.multiplier}", + ) + except HTTPException: + raise + except Exception as e: + logger.error( + "Error setting rate limit multiplier", + error=str(e), + error_type=type(e).__name__, + event_type="admin_rate_limit_multiplier_error", + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error setting rate limit multiplier: {str(e)}", + ) + + +@admin_router.get("/rate-limit/multiplier/{cognito_user_id}", response_model=RateLimitMultiplierResponse) +async def get_rate_limit_multiplier( + cognito_user_id: str, + db: AsyncSession = Depends(get_db_session), + current_user: User = Depends(get_current_user), + _admin_verified: bool = Depends(verify_billing_admin_secret), +): + """ + Get the current rate limit multiplier for a user. + + **Admin endpoint** - Requires X-Admin-Secret header. + """ + try: + target_user = await user_crud.get_user_by_cognito_id(db, cognito_user_id) + if not target_user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with cognito_user_id {cognito_user_id} not found", + ) + + return RateLimitMultiplierResponse( + cognito_user_id=target_user.cognito_user_id, + user_id=target_user.id, + multiplier=target_user.rate_limit_multiplier, + message="Current rate limit multiplier", + ) + except HTTPException: + raise + except Exception as e: + logger.error( + "Error getting rate limit multiplier", + error=str(e), + error_type=type(e).__name__, + event_type="admin_rate_limit_multiplier_get_error", + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error getting rate limit multiplier: {str(e)}", + ) diff --git a/src/api/v1/billing/index.py b/src/api/v1/billing/index.py index f9ee6e0d..2f590669 100644 --- a/src/api/v1/billing/index.py +++ b/src/api/v1/billing/index.py @@ -1,22 +1,18 @@ """ Billing API endpoints for credits management. -Provides REST API for viewing balance, transactions, spending metrics, and staking settings. +Provides REST API for viewing balance, transactions, spending metrics, and overage settings. """ -from fastapi import APIRouter, Depends, HTTPException, Request, status, Query, Header -from fastapi.security import APIKeyHeader +from fastapi import APIRouter, Depends, HTTPException, Request, status, Query from sqlalchemy.ext.asyncio import AsyncSession -from typing import Optional, List -from datetime import datetime, date +from typing import Optional +from datetime import datetime from decimal import Decimal -import secrets from ....db.database import get_db_session from ....db.models import User, LedgerEntryType from ....dependencies import get_current_user, get_api_key_user from ....services.billing_service import billing_service -from ....services.staking_service import staking_service from ....crud import credits as credits_crud -from ....crud import user as user_crud from ....schemas.billing import ( BalanceResponse, LedgerEntryResponse, @@ -26,67 +22,16 @@ SpendingModeEnum, UsageListResponse, UsageEntryResponse, - StakingSettingsRequest, - StakingSettingsResponse, - StakingRefreshResponse, - ManualTopupRequest, - ManualTopupResponse, OverageSettingsRequest, OverageSettingsResponse, LedgerStatusEnum, LedgerEntryTypeEnum, ) from ....core.logging_config import get_api_logger -from ....core.config import settings logger = get_api_logger() router = APIRouter(tags=["Billing"]) -admin_router = APIRouter(tags=["Billing Admin"]) - - -# === Admin Authentication === - -async def verify_billing_admin_secret( - x_admin_secret: Optional[str] = Header(None, alias="X-Admin-Secret") -) -> bool: - """ - Verify the admin secret for protected billing endpoints. - - Requires the X-Admin-Secret header to match BILLING_ADMIN_SECRET env variable. - """ - if not settings.BILLING_ADMIN_SECRET: - logger.warning( - "Billing admin endpoint called but BILLING_ADMIN_SECRET is not configured", - event_type="billing_admin_not_configured" - ) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Admin billing endpoints are not configured. Set BILLING_ADMIN_SECRET environment variable." - ) - - if not x_admin_secret: - logger.warning( - "Billing admin endpoint called without X-Admin-Secret header", - event_type="billing_admin_missing_secret" - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Missing X-Admin-Secret header" - ) - - # Use constant-time comparison to prevent timing attacks - if not secrets.compare_digest(x_admin_secret, settings.BILLING_ADMIN_SECRET): - logger.warning( - "Billing admin endpoint called with invalid secret", - event_type="billing_admin_invalid_secret" - ) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Invalid admin secret" - ) - - return True # === Balance Endpoint === @@ -501,239 +446,7 @@ async def list_usage_for_month( ) -# === Staking Settings Endpoints (Admin Protected) === - -@admin_router.post("/staking/settings", response_model=StakingSettingsResponse) -async def set_staking_settings( - staking_request: StakingSettingsRequest, - db: AsyncSession = Depends(get_db_session), - current_user: User = Depends(get_current_user), - _admin_verified: bool = Depends(verify_billing_admin_secret), -): - """ - Set the daily staking allowance amount. - - **Admin/Dev endpoint** - Requires X-Admin-Secret header. - - This updates the configured daily amount but does NOT trigger an immediate refresh. - The new amount will take effect on the next daily refresh. - """ - try: - balance = await credits_crud.set_staking_daily_amount( - db=db, - user_id=current_user.id, - amount=staking_request.daily_amount, - ) - - logger.info( - "Staking settings updated by admin", - user_id=current_user.id, - daily_amount=str(staking_request.daily_amount), - event_type="billing_admin_staking_settings" - ) - - return StakingSettingsResponse( - daily_amount=balance.staking_daily_amount, - message="Staking daily amount updated", - ) - except Exception as e: - logger.error( - "Error in set_staking_settings endpoint", - user_id=current_user.id, - error=str(e), - error_type=type(e).__name__, - event_type="billing_staking_settings_error" - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating staking settings: {str(e)}" - ) - - -@admin_router.post("/staking/refresh", response_model=StakingRefreshResponse) -async def trigger_staking_refresh( - db: AsyncSession = Depends(get_db_session), - current_user: User = Depends(get_current_user), - _admin_verified: bool = Depends(verify_billing_admin_secret), -): - """ - Trigger the daily staking sync from Builders API. - - **Admin/Dev endpoint** - Requires X-Admin-Secret header. - - This operation: - 1. Fetches all stakers from Builders API - 2. Updates staked_amount for all linked wallets - 3. Calculates daily credits for each user (total_staked / 100) - 4. Creates ledger entries (transactions) for each user refresh - 5. Updates user balances - - Idempotent: Users already refreshed today will be skipped. - The staking bucket resets to the calculated daily amount (does not accumulate). - """ - logger.info( - "Staking sync triggered by admin", - user_id=current_user.id, - event_type="billing_admin_staking_sync" - ) - - try: - summary = await staking_service.run_daily_sync(db) - - return StakingRefreshResponse( - success=summary.get("success", True), - message="Staking sync completed successfully", - stakers_fetched=summary.get("stakers_fetched"), - total_wallets=summary.get("total_wallets"), - wallets_updated=summary.get("wallets_updated"), - users_processed=summary.get("users_processed"), - users_skipped=summary.get("users_skipped"), - users_failed=summary.get("users_failed"), - duration_seconds=summary.get("duration_seconds"), - ) - except Exception as e: - logger.error( - "Staking sync failed", - user_id=current_user.id, - error=str(e), - event_type="billing_admin_staking_sync_failed" - ) - return StakingRefreshResponse( - success=False, - message=f"Staking sync failed: {str(e)}", - ) - - -# === Manual Credit Top-up (Admin/Dev endpoint) === - -@admin_router.post("/credits/adjust", response_model=ManualTopupResponse) -async def adjust_credits( - request: ManualTopupRequest, - db: AsyncSession = Depends(get_db_session), - current_user: User = Depends(get_current_user), - _admin_verified: bool = Depends(verify_billing_admin_secret), -): - """ - Manually adjust credits for an account (add or subtract). - - **Admin/Dev endpoint** - Requires X-Admin-Secret header. - - - Positive amount: Adds credits (simulates a purchase) - - Negative amount: Subtracts credits (admin correction/chargeback) - - user_id (optional): Target user ID (database primary key integer) - - cognito_user_id (optional): Target Cognito user ID (UUID) - - If neither provided, adjusts current user's credits. - - This endpoint is for development/admin purposes to manage credits - without integrating with payment providers. - """ - try: - # Determine target user ID - target_user_id = current_user.id - - if request.user_id is not None: - target_user_id = request.user_id - elif request.cognito_user_id is not None: - # Look up user by cognito_user_id - target_user = await user_crud.get_user_by_cognito_id(db, str(request.cognito_user_id)) - if not target_user: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"User with cognito_user_id {request.cognito_user_id} not found" - ) - target_user_id = target_user.id - - entry, new_balance = await billing_service.adjust_credits( - db=db, - user_id=target_user_id, - amount=request.amount_usd, - description=request.description, - ) - - action = "added" if request.amount_usd >= 0 else "subtracted" - logger.info( - f"Manual credit adjustment by admin: {action}", - user_id=str(target_user_id), - admin_user_id=str(current_user.id), - amount=str(request.amount_usd), - new_balance=str(new_balance), - event_type="billing_admin_credit_adjust" - ) - - message = f"Credits {action} successfully" - - return ManualTopupResponse( - ledger_entry_id=entry.id, - amount_added=request.amount_usd, - new_paid_balance=new_balance, - message=message, - ) - except HTTPException: - # Re-raise HTTP exceptions as-is - raise - except Exception as e: - logger.error( - "Error in adjust_credits endpoint", - user_id=current_user.id, - error=str(e), - error_type=type(e).__name__, - event_type="billing_credit_adjust_error" - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error adjusting credits: {str(e)}" - ) - - -# === Balance Reconciliation (Admin endpoint) === - -@admin_router.post("/balance/reconcile", response_model=BalanceResponse) -async def reconcile_balance( - user_id: Optional[int] = Query(default=None, description="Target user ID (defaults to current user)"), - db: AsyncSession = Depends(get_db_session), - current_user: User = Depends(get_current_user), - _admin_verified: bool = Depends(verify_billing_admin_secret), -): - """ - Reconcile the cached balance against the ledger (source of truth). - - **Admin/Dev endpoint** - Requires X-Admin-Secret header. - - Fixes drift in `paid_pending_holds` caused by partial transaction failures - where the ledger entry was updated but the balance cache was not. - - Recomputes `paid_pending_holds` from the sum of all pending usage_hold entries in the ledger. - - Parameters: - - user_id: Target user ID (defaults to current user) - """ - try: - target_user_id = user_id if user_id is not None else current_user.id - - logger.info( - "Balance reconciliation triggered by admin", - target_user_id=target_user_id, - admin_user_id=current_user.id, - event_type="billing_admin_reconcile", - ) - - result = await billing_service.reconcile_balance(db, target_user_id) - return result - except Exception as e: - logger.error( - "Error in reconcile_balance endpoint", - user_id=current_user.id, - error=str(e), - error_type=type(e).__name__, - event_type="billing_reconcile_error", - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error reconciling balance: {str(e)}", - ) - -# Export routers +# Export router billing_router = router -billing_admin_router = admin_router diff --git a/src/api/v1/chat/index.py b/src/api/v1/chat/index.py index 5aa7abfe..3d25b22a 100644 --- a/src/api/v1/chat/index.py +++ b/src/api/v1/chat/index.py @@ -273,6 +273,7 @@ async def _check_rate_limits( model=model, estimated_tokens=estimated_tokens, request_id=request_id, + multiplier=getattr(user, "rate_limit_multiplier", 1.0) or 1.0, ) if not result.allowed and result.status != RateLimitStatus.ERROR: diff --git a/src/api/v1/embeddings/index.py b/src/api/v1/embeddings/index.py index f8e24da0..e3a5506d 100644 --- a/src/api/v1/embeddings/index.py +++ b/src/api/v1/embeddings/index.py @@ -76,6 +76,7 @@ async def create_embeddings( db_api_key=db_api_key, requested_model=requested_model, request_data=request_data, + user=user, ) # Create billing hold @@ -259,6 +260,7 @@ async def _check_rate_limits( db_api_key: APIKey, requested_model: Optional[str], request_data: EmbeddingRequest, + user=None, ) -> RateLimitResult: """ Check rate limits (RPM and TPM) before processing the request. @@ -296,6 +298,7 @@ async def _check_rate_limits( model=requested_model, estimated_tokens=estimated_tokens, request_id=request_id, + multiplier=getattr(user, "rate_limit_multiplier", 1.0) or 1.0, ) if not result.allowed: diff --git a/src/db/models/user.py b/src/db/models/user.py index c2509874..4808261e 100644 --- a/src/db/models/user.py +++ b/src/db/models/user.py @@ -4,7 +4,7 @@ PII (email, name) lives exclusively in Cognito. The database only stores cognito_user_id as the identity key and application-level fields. """ -from sqlalchemy import Column, Integer, String, Boolean, DateTime +from sqlalchemy import Column, Integer, String, Boolean, DateTime, Float from sqlalchemy.orm import relationship from datetime import datetime @@ -20,6 +20,7 @@ class User(Base): is_active = Column(Boolean, default=True) age_verified = Column(Boolean, default=False, nullable=False) age_verified_at = Column(DateTime, nullable=True) + rate_limit_multiplier = Column(Float, nullable=False, default=1.0, server_default="1.0") created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) diff --git a/src/dependencies.py b/src/dependencies.py index f80db152..b4dd320e 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -178,6 +178,7 @@ async def get_current_user( 'cognito_user_id': user.cognito_user_id, 'created_at': user.created_at.isoformat() if user.created_at else None, 'updated_at': user.updated_at.isoformat() if user.updated_at else None, + 'rate_limit_multiplier': user.rate_limit_multiplier, } await cache_service.set("user", cognito_user_id, user_cache_data, ttl_seconds=600) @@ -195,6 +196,7 @@ async def get_current_user( 'cognito_user_id': user.cognito_user_id, 'created_at': user.created_at.isoformat() if user.created_at else None, 'updated_at': user.updated_at.isoformat() if user.updated_at else None, + 'rate_limit_multiplier': user.rate_limit_multiplier, } await cache_service.set("user", cognito_user_id, user_cache_data, ttl_seconds=600) @@ -285,6 +287,7 @@ async def get_api_key_auth( 'cognito_user_id': test_user.cognito_user_id, 'created_at': test_user.created_at, 'updated_at': test_user.updated_at, + 'rate_limit_multiplier': test_user.rate_limit_multiplier, } # Fetch the test user's first active API key (if any) result = await db.execute( @@ -377,6 +380,7 @@ async def _build_auth_from_cache( cognito_user_id=ud.get("cognito_user_id"), created_at=datetime.fromisoformat(ud["created_at"]) if ud.get("created_at") else None, updated_at=datetime.fromisoformat(ud["updated_at"]) if ud.get("updated_at") else None, + rate_limit_multiplier=ud.get("rate_limit_multiplier", 1.0), ) # ── Deserialize APIKey ────────────────────────────────────────────── @@ -459,6 +463,7 @@ async def _build_auth_from_db( 'cognito_user_id': db_user.cognito_user_id, 'created_at': db_user.created_at, 'updated_at': db_user.updated_at, + 'rate_limit_multiplier': db_user.rate_limit_multiplier, } api_key_dict = { 'id': db_api_key.id, diff --git a/src/schemas/billing.py b/src/schemas/billing.py index 1e096fa0..d0b6da6c 100644 --- a/src/schemas/billing.py +++ b/src/schemas/billing.py @@ -343,3 +343,24 @@ class RefundResponse(BaseModel): amount_refunded: DecimalStr success: bool error: Optional[str] = None + + +# === Rate Limit Multiplier (Admin) === + +class RateLimitMultiplierRequest(BaseModel): + """Request for POST /billing/rate-limit/multiplier.""" + cognito_user_id: str = Field(..., description="Target user's Cognito ID (UUID)") + multiplier: float = Field( + ..., + gt=0, + le=100, + description="Rate limit multiplier (1.0 = default, 2.0 = double limits, 0.5 = half limits)", + ) + + +class RateLimitMultiplierResponse(BaseModel): + """Response for rate limit multiplier operations.""" + cognito_user_id: str + user_id: int + multiplier: float + message: str diff --git a/src/services/rate_limiting/rate_limit_service.py b/src/services/rate_limiting/rate_limit_service.py index 42c5eabc..0fa41bf1 100644 --- a/src/services/rate_limiting/rate_limit_service.py +++ b/src/services/rate_limiting/rate_limit_service.py @@ -14,6 +14,7 @@ from src.core.logging_config import get_core_logger from .types import ( + RateLimitConfig, RateLimitResult, RateLimitStatus, RateLimitHeaders, @@ -85,6 +86,7 @@ async def check_rate_limit( model: Optional[str] = None, estimated_tokens: int = 0, request_id: Optional[str] = None, + multiplier: float = 1.0, ) -> RateLimitResult: """ Check if a request is within rate limits and increment RPM only. @@ -98,6 +100,7 @@ async def check_rate_limit( model: The model being requested estimated_tokens: Estimated input tokens for TPM check (not incremented) request_id: Unique identifier for this request + multiplier: Per-user scaling factor for RPM/TPM limits (default 1.0) Returns: RateLimitResult with the check result @@ -112,7 +115,17 @@ async def check_rate_limit( await self.initialize() # Get rate limit config for this model - config, model_group = self._rules.get_config_for_model(model) + base_config, model_group = self._rules.get_config_for_model(model) + + # Apply per-user multiplier to scale limits + if multiplier != 1.0: + config = RateLimitConfig( + rpm=max(1, int(base_config.rpm * multiplier)), + tpm=max(1, int(base_config.tpm * multiplier)) if base_config.tpm > 0 else 0, + window_seconds=base_config.window_seconds, + ) + else: + config = base_config # Calculate window boundary and reset time # Window is aligned to fixed intervals (e.g., each minute boundary) @@ -147,6 +160,7 @@ async def check_rate_limit( rpm_limit=rpm_limit, retry_after=retry_after, reset_at=reset_at, + multiplier=multiplier, event_type="rate_limit_exceeded_rpm", ) @@ -181,6 +195,7 @@ async def check_rate_limit( estimated_tokens=estimated_tokens, retry_after=retry_after, reset_at=reset_at, + multiplier=multiplier, event_type="rate_limit_exceeded_tpm", )