|
| 1 | +# SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | +"""Per-user monthly API-call quota — write-side and gate-side. |
| 3 | +
|
| 4 | +The pricing page advertises 500/month (Free), 10 000/month (Pro), |
| 5 | +100 000/month (Business). Until now the limits in |
| 6 | +``app/core/quotas.py`` were informational; this module wires them up |
| 7 | +so the system actually enforces what the pricing page promises. |
| 8 | +
|
| 9 | +Two responsibilities: |
| 10 | +
|
| 11 | +1. **Writer** — :func:`record_usage` inserts one ``UsageRecord`` row |
| 12 | + per successful conversion / compression. Called from the success |
| 13 | + branch of ``_do_convert`` / ``_do_compress`` (single + batch). |
| 14 | +2. **Gate** — :func:`enforce_monthly_quota` counts the rows for the |
| 15 | + current calendar month and raises ``HTTPException(429)`` when the |
| 16 | + user is at or above their tier limit. Called *after* the |
| 17 | + concurrency slot is acquired and *before* file I/O begins, so a |
| 18 | + refused request never touches the temp dir. |
| 19 | +
|
| 20 | +Session ownership mirrors :mod:`app.core.audit` and |
| 21 | +:mod:`app.core.metrics`: each helper opens its own |
| 22 | +``AsyncSession`` from ``AsyncSessionLocal``. The route does not need |
| 23 | +to thread a ``db=`` parameter through. Tests pass an explicit |
| 24 | +``db=`` for the in-memory SQLite engine. |
| 25 | +
|
| 26 | +Time window |
| 27 | +----------- |
| 28 | +Calendar month, UTC. Picked because: |
| 29 | +
|
| 30 | +* It matches how the pricing page is read ("you get 10 k per month"). |
| 31 | +* Users see their reset boundary in their own calendar (1st of the |
| 32 | + next month at 00:00 UTC) — cheap to display, easy to remember. |
| 33 | +* A rolling 30-day window is smoother under load but harder to |
| 34 | + communicate ("when does my quota reset?" → "depends which calls |
| 35 | + you made"). Not worth the cognitive cost for an MVP. |
| 36 | +
|
| 37 | +Anonymous tier (no ``user_id``) skips both the writer and the gate — |
| 38 | +the per-IP rate-limiter (10/min) is the only constraint. |
| 39 | +``Enterprise`` (``api_calls_per_month=None``) is unlimited and is |
| 40 | +also exempt from the gate; ``record_usage`` still writes its row so |
| 41 | +the cockpit gets accurate counts. |
| 42 | +""" |
| 43 | + |
| 44 | +from __future__ import annotations |
| 45 | + |
| 46 | +import logging |
| 47 | +import uuid |
| 48 | +from datetime import datetime, timezone |
| 49 | + |
| 50 | +from fastapi import HTTPException, status |
| 51 | +from sqlalchemy import func, select |
| 52 | +from sqlalchemy.ext.asyncio import AsyncSession |
| 53 | + |
| 54 | +from app.core.quotas import get_quota |
| 55 | +from app.db.base import AsyncSessionLocal |
| 56 | +from app.db.models import UsageRecord, User |
| 57 | + |
| 58 | +logger = logging.getLogger(__name__) |
| 59 | + |
| 60 | + |
| 61 | +def _month_start(now: datetime) -> datetime: |
| 62 | + """Return the UTC timestamp at the start of the given moment's calendar month.""" |
| 63 | + return now.astimezone(timezone.utc).replace(day=1, hour=0, minute=0, second=0, microsecond=0) |
| 64 | + |
| 65 | + |
| 66 | +def _next_month_start(now: datetime) -> datetime: |
| 67 | + """Return the UTC timestamp at the start of the *following* calendar month. |
| 68 | +
|
| 69 | + Used for the ``Retry-After`` header so a refused caller knows when their |
| 70 | + quota resets. Computed as "1st of (this month + 1)" — December rolls |
| 71 | + forward to January of next year. |
| 72 | + """ |
| 73 | + month_start = _month_start(now) |
| 74 | + if month_start.month == 12: |
| 75 | + return month_start.replace(year=month_start.year + 1, month=1) |
| 76 | + return month_start.replace(month=month_start.month + 1) |
| 77 | + |
| 78 | + |
| 79 | +async def monthly_call_count( |
| 80 | + db: AsyncSession, |
| 81 | + user_id: uuid.UUID, |
| 82 | + *, |
| 83 | + now: datetime | None = None, |
| 84 | +) -> int: |
| 85 | + """Count this user's ``UsageRecord`` rows for the current calendar month. |
| 86 | +
|
| 87 | + The index on ``(user_id, timestamp)`` (migration 007) makes this a fast |
| 88 | + range scan even at 100 k rows/user/month for the Business tier. |
| 89 | + """ |
| 90 | + if now is None: |
| 91 | + now = datetime.now(timezone.utc) |
| 92 | + start = _month_start(now) |
| 93 | + stmt = ( |
| 94 | + select(func.count()) |
| 95 | + .select_from(UsageRecord) |
| 96 | + .where( |
| 97 | + UsageRecord.user_id == user_id, |
| 98 | + UsageRecord.timestamp >= start, |
| 99 | + ) |
| 100 | + ) |
| 101 | + result = await db.execute(stmt) |
| 102 | + return int(result.scalar() or 0) |
| 103 | + |
| 104 | + |
| 105 | +async def enforce_monthly_quota( |
| 106 | + user: User | None, |
| 107 | + *, |
| 108 | + db: AsyncSession | None = None, |
| 109 | + now: datetime | None = None, |
| 110 | +) -> None: |
| 111 | + """Raise ``HTTPException(429)`` if the user is at or above their monthly limit. |
| 112 | +
|
| 113 | + No-op when: |
| 114 | +
|
| 115 | + * ``user is None`` — anonymous tier; per-IP rate-limiter is the |
| 116 | + only gate. |
| 117 | + * ``user.tier`` is ``enterprise`` or otherwise has |
| 118 | + ``api_calls_per_month=None`` — unlimited tier. |
| 119 | + * ``AsyncSessionLocal is None`` and no ``db=`` passed — Community |
| 120 | + Edition without ``DATABASE_URL``; nothing to count against. |
| 121 | + """ |
| 122 | + if user is None: |
| 123 | + return |
| 124 | + |
| 125 | + tier = user.tier.value if hasattr(user.tier, "value") else str(user.tier) |
| 126 | + quota = get_quota(tier) |
| 127 | + if quota.api_calls_per_month is None: |
| 128 | + return |
| 129 | + |
| 130 | + if now is None: |
| 131 | + now = datetime.now(timezone.utc) |
| 132 | + |
| 133 | + if db is not None: |
| 134 | + used = await monthly_call_count(db, user.id, now=now) |
| 135 | + else: |
| 136 | + if AsyncSessionLocal is None: |
| 137 | + return |
| 138 | + async with AsyncSessionLocal() as session: |
| 139 | + used = await monthly_call_count(session, user.id, now=now) |
| 140 | + |
| 141 | + if used >= quota.api_calls_per_month: |
| 142 | + retry_at = _next_month_start(now) |
| 143 | + retry_after_seconds = max(int((retry_at - now).total_seconds()), 1) |
| 144 | + detail = ( |
| 145 | + f"Monthly API call limit reached ({quota.api_calls_per_month} per month " |
| 146 | + f"for tier '{tier}'). Quota resets {retry_at.isoformat()}. Upgrade your plan " |
| 147 | + "or wait until the reset to continue." |
| 148 | + ) |
| 149 | + # 429 is the conventional rate-limit code; Retry-After is in |
| 150 | + # seconds per RFC 9110 § 10.2.3. |
| 151 | + raise HTTPException( |
| 152 | + status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
| 153 | + detail=detail, |
| 154 | + headers={"Retry-After": str(retry_after_seconds)}, |
| 155 | + ) |
| 156 | + |
| 157 | + |
| 158 | +async def record_usage( |
| 159 | + *, |
| 160 | + user_id: uuid.UUID | None, |
| 161 | + api_key_id: uuid.UUID | None, |
| 162 | + endpoint: str, |
| 163 | + file_size_bytes: int, |
| 164 | + duration_ms: int, |
| 165 | + db: AsyncSession | None = None, |
| 166 | +) -> None: |
| 167 | + """Append one ``UsageRecord`` for a successful conversion / compression. |
| 168 | +
|
| 169 | + Fire-and-forget by design — failures are logged at ``WARNING`` and |
| 170 | + never bubble into the request path. The audit log |
| 171 | + (:mod:`app.core.audit`) is the source-of-truth for compliance |
| 172 | + purposes; ``UsageRecord`` is the lightweight per-user counter that |
| 173 | + powers the monthly-quota gate and the dashboard usage display. |
| 174 | +
|
| 175 | + Anonymous tier (``user_id is None`` and ``api_key_id is None``) |
| 176 | + is a no-op — there is no caller identity to attribute the row to. |
| 177 | + """ |
| 178 | + if user_id is None and api_key_id is None: |
| 179 | + return |
| 180 | + |
| 181 | + if db is not None: |
| 182 | + await _insert(db, user_id, api_key_id, endpoint, file_size_bytes, duration_ms) |
| 183 | + return |
| 184 | + |
| 185 | + if AsyncSessionLocal is None: |
| 186 | + return |
| 187 | + |
| 188 | + try: |
| 189 | + async with AsyncSessionLocal() as session: |
| 190 | + await _insert(session, user_id, api_key_id, endpoint, file_size_bytes, duration_ms) |
| 191 | + except Exception: |
| 192 | + logger.warning( |
| 193 | + "record_usage failed for endpoint=%s user_id=%s", |
| 194 | + endpoint, |
| 195 | + user_id, |
| 196 | + exc_info=True, |
| 197 | + ) |
| 198 | + |
| 199 | + |
| 200 | +async def _insert( |
| 201 | + db: AsyncSession, |
| 202 | + user_id: uuid.UUID | None, |
| 203 | + api_key_id: uuid.UUID | None, |
| 204 | + endpoint: str, |
| 205 | + file_size_bytes: int, |
| 206 | + duration_ms: int, |
| 207 | +) -> None: |
| 208 | + """Single INSERT, owned-session caller commits.""" |
| 209 | + row = UsageRecord( |
| 210 | + user_id=user_id, |
| 211 | + api_key_id=api_key_id, |
| 212 | + endpoint=endpoint, |
| 213 | + file_size_bytes=file_size_bytes, |
| 214 | + duration_ms=duration_ms, |
| 215 | + ) |
| 216 | + db.add(row) |
| 217 | + await db.commit() |
0 commit comments