Skip to content

Commit 7d11856

Browse files
committed
feat(credits): add per-user rank-based credit limits, global monthly cap, and pre-authorization to prevent overspend
1 parent 69f0fd6 commit 7d11856

3 files changed

Lines changed: 241 additions & 0 deletions

File tree

src/discord_rag_bot/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,22 @@ class Settings(BaseSettings):
8787
pg_password: Optional[str] = None
8888
pg_database: Optional[str] = None
8989

90+
# Credits & Budgeting
91+
credit_enabled: bool = False
92+
credit_period: str = "month" # month|rolling
93+
credit_global_cap: int = 100000 # total credits per period across all users
94+
credit_default_limit: int = 1000 # per-user default credits per period
95+
# JSON maps, e.g.: {"gold": 5000, "silver": 2000}
96+
credit_rank_limits: dict[str, int] = {}
97+
# Map role name -> rank (JSON), e.g.: {"Gold": "gold", "VIP": "gold"}
98+
credit_role_ranks_by_name: dict[str, str] = {}
99+
# Map role ID (as string) -> rank (JSON), e.g.: {"123456": "gold"}
100+
credit_role_ranks_by_id: dict[str, str] = {}
101+
# Estimation: ~tokens per char and expected output tokens; 1 credit per 1k tokens by default
102+
credit_tokens_per_char: float = 0.25
103+
credit_est_output_tokens: int = 600
104+
credit_per_1k_tokens: float = 1.0
105+
90106
@property
91107
def db(self) -> Db:
92108
if not all([self.pg_host, self.pg_user, self.pg_password, self.pg_database]):
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from datetime import datetime, timezone
5+
from typing import Optional, Tuple
6+
7+
import asyncpg
8+
9+
from ..config import settings
10+
11+
12+
def _dsn() -> str:
13+
db = settings.db
14+
return f"postgresql://{db.user}:{db.password}@{db.host}:{db.port}/{db.database}"
15+
16+
17+
def _period_start(dt: Optional[datetime] = None) -> datetime:
18+
dt = dt or datetime.now(timezone.utc)
19+
# Monthly period boundary (first day of month at 00:00 UTC)
20+
return datetime(dt.year, dt.month, 1, tzinfo=timezone.utc)
21+
22+
23+
async def _ensure_async(conn: asyncpg.Connection) -> None:
24+
await conn.execute(
25+
"""
26+
CREATE TABLE IF NOT EXISTS bot_credits_user (
27+
user_id BIGINT NOT NULL,
28+
period_start TIMESTAMPTZ NOT NULL,
29+
used_credits INTEGER NOT NULL DEFAULT 0,
30+
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
31+
PRIMARY KEY (user_id, period_start)
32+
);
33+
CREATE INDEX IF NOT EXISTS idx_bot_credits_user_period ON bot_credits_user(period_start);
34+
CREATE TABLE IF NOT EXISTS bot_credits_global (
35+
period_start TIMESTAMPTZ PRIMARY KEY,
36+
used_credits INTEGER NOT NULL DEFAULT 0,
37+
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
38+
);
39+
"""
40+
)
41+
42+
43+
def ensure_store() -> None:
44+
async def run():
45+
conn = await asyncpg.connect(_dsn())
46+
try:
47+
await _ensure_async(conn)
48+
finally:
49+
await conn.close()
50+
51+
asyncio.run(run())
52+
53+
54+
def _credits_for_text_chars(in_chars: int, out_est_tokens: int) -> int:
55+
tokens_in = int(round(in_chars * float(settings.credit_tokens_per_char)))
56+
tokens_total = tokens_in + int(out_est_tokens)
57+
per_k = float(settings.credit_per_1k_tokens) or 1.0
58+
# credits ~ tokens_total / 1000 * per_k
59+
credits = int((tokens_total + 999) // 1000 * per_k)
60+
return max(1, credits)
61+
62+
63+
def estimate_credits_for_question(question: str) -> int:
64+
return _credits_for_text_chars(len(question or ""), int(settings.credit_est_output_tokens))
65+
66+
67+
def pre_authorize(user_id: int, est_credits: int, *, now: Optional[datetime] = None, user_limit_override: Optional[int] = None) -> Tuple[bool, int, int]:
68+
"""Reserve credits if within per-user limit and global cap.
69+
70+
Returns (ok, user_used_after, global_used_after). If ok is False, no change.
71+
"""
72+
period = _period_start(now)
73+
74+
async def run() -> Tuple[bool, int, int]:
75+
conn = await asyncpg.connect(_dsn())
76+
try:
77+
await _ensure_async(conn)
78+
async with conn.transaction():
79+
# Load current usage
80+
row_u = await conn.fetchrow(
81+
"SELECT used_credits FROM bot_credits_user WHERE user_id=$1 AND period_start=$2",
82+
int(user_id),
83+
period,
84+
)
85+
user_used = int(row_u[0]) if row_u else 0
86+
row_g = await conn.fetchrow(
87+
"SELECT used_credits FROM bot_credits_global WHERE period_start=$1",
88+
period,
89+
)
90+
global_used = int(row_g[0]) if row_g else 0
91+
92+
user_limit = int(user_limit_override or settings.credit_default_limit)
93+
# Limit can be overridden by ranks; resolved outside and passed? For now use default; the caller can pass a higher est or pre-check limit.
94+
# Enforce limits
95+
if settings.credit_enabled:
96+
if user_used + est_credits > user_limit:
97+
return False, user_used, global_used
98+
if global_used + est_credits > int(settings.credit_global_cap):
99+
return False, user_used, global_used
100+
101+
# Upsert increments
102+
await conn.execute(
103+
"""
104+
INSERT INTO bot_credits_user(user_id, period_start, used_credits)
105+
VALUES ($1, $2, $3)
106+
ON CONFLICT (user_id, period_start)
107+
DO UPDATE SET used_credits = bot_credits_user.used_credits + EXCLUDED.used_credits, updated_at=NOW()
108+
""",
109+
int(user_id),
110+
period,
111+
int(est_credits),
112+
)
113+
await conn.execute(
114+
"""
115+
INSERT INTO bot_credits_global(period_start, used_credits)
116+
VALUES ($1, $2)
117+
ON CONFLICT (period_start)
118+
DO UPDATE SET used_credits = bot_credits_global.used_credits + EXCLUDED.used_credits, updated_at=NOW()
119+
""",
120+
period,
121+
int(est_credits),
122+
)
123+
return True, user_used + est_credits, global_used + est_credits
124+
finally:
125+
await conn.close()
126+
127+
return asyncio.run(run())
128+
129+
130+
def adjust_usage(user_id: int, delta: int, *, now: Optional[datetime] = None) -> None:
131+
"""Adjust usage by delta (can be negative). Best-effort."""
132+
if delta == 0:
133+
return
134+
period = _period_start(now)
135+
136+
async def run():
137+
conn = await asyncpg.connect(_dsn())
138+
try:
139+
await _ensure_async(conn)
140+
async with conn.transaction():
141+
await conn.execute(
142+
"""
143+
INSERT INTO bot_credits_user(user_id, period_start, used_credits)
144+
VALUES ($1, $2, $3)
145+
ON CONFLICT (user_id, period_start)
146+
DO UPDATE SET used_credits = GREATEST(0, bot_credits_user.used_credits + EXCLUDED.used_credits), updated_at=NOW()
147+
""",
148+
int(user_id), period, int(delta)
149+
)
150+
await conn.execute(
151+
"""
152+
INSERT INTO bot_credits_global(period_start, used_credits)
153+
VALUES ($1, $2)
154+
ON CONFLICT (period_start)
155+
DO UPDATE SET used_credits = GREATEST(0, bot_credits_global.used_credits + EXCLUDED.used_credits), updated_at=NOW()
156+
""",
157+
period, int(delta)
158+
)
159+
finally:
160+
await conn.close()
161+
162+
asyncio.run(run())
163+
164+
165+
def resolve_user_limit_from_roles(*, member_roles: list[tuple[int, str]]) -> int:
166+
"""Compute per-user credit limit based on configured rank mappings and role names/ids.
167+
168+
member_roles: list of (role_id, role_name)
169+
"""
170+
# Determine rank candidates from roles
171+
by_name = settings.credit_role_ranks_by_name or {}
172+
by_id = settings.credit_role_ranks_by_id or {}
173+
ranks: set[str] = set()
174+
for rid, name in member_roles:
175+
if name and name in by_name:
176+
ranks.add(by_name[name])
177+
s_rid = str(rid)
178+
if s_rid in by_id:
179+
ranks.add(by_id[s_rid])
180+
# Map ranks to limits and take max
181+
rank_limits = settings.credit_rank_limits or {}
182+
max_limit = int(settings.credit_default_limit)
183+
for r in ranks:
184+
lim = rank_limits.get(r)
185+
if isinstance(lim, int) and lim > max_limit:
186+
max_limit = lim
187+
return max_limit

src/discord_rag_bot/listeners/chat.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..infrastructure.gating import should_use_rag
1313
from ..infrastructure.language import get_language_hint
1414
from rag_core.metrics import discord_messages_processed_total, rag_queries_total
15+
from ..infrastructure.credits import estimate_credits_for_question, pre_authorize, adjust_usage, resolve_user_limit_from_roles
1516

1617

1718
class ChatListenerCog(commands.Cog):
@@ -145,6 +146,30 @@ def run_query() -> tuple[str, list[str]]:
145146

146147
# Send friendly placeholder reply and then edit when ready
147148
placeholder_msg = await message.reply("🧠 Einen kleinen Moment – ich suche passende Informationen und schreibe die Antwort …")
149+
# Credits: pre-authorize based on estimate (if enabled)
150+
est_credits = 0
151+
reserved = 0
152+
user_limit = None
153+
if getattr(self.bot, "services", None) and getattr(self.bot.services, "rag", None): # basic sanity
154+
try:
155+
from ..config import settings as _settings
156+
if getattr(_settings, "credit_enabled", False):
157+
est_credits = estimate_credits_for_question(question)
158+
# Resolve user limit by roles
159+
roles = []
160+
if isinstance(message.author, discord.Member):
161+
for r in message.author.roles:
162+
roles.append((int(getattr(r, "id", 0) or 0), str(getattr(r, "name", "") or "")))
163+
user_limit = resolve_user_limit_from_roles(member_roles=roles)
164+
# Temporarily set default limit to computed user_limit for reservation
165+
# Pre-authorize in a thread to avoid event-loop blocking
166+
ok, _, _ = await asyncio.to_thread(pre_authorize, int(message.author.id), int(est_credits), user_limit_override=int(user_limit))
167+
if not ok:
168+
await placeholder_msg.edit(content="❌ Keine Credits mehr verfügbar (Limit oder globales Budget erreicht). Bitte später erneut versuchen.")
169+
return
170+
reserved = est_credits
171+
except Exception:
172+
pass
148173
# Save the incoming user message into memory (best-effort)
149174
try:
150175
self.bot.services.memory.record_user_message( # type: ignore[attr-defined]
@@ -170,6 +195,19 @@ def run_query() -> tuple[str, list[str]]:
170195
except Exception:
171196
# Fallback: send a fresh reply if edit fails
172197
await message.reply(clip_discord_message(text))
198+
# Adjust credits after answer based on actual output length (best-effort)
199+
try:
200+
if reserved > 0:
201+
from ..config import settings as _settings
202+
if getattr(_settings, "credit_enabled", False):
203+
# crude estimate: input + actual output
204+
out_tokens = int(len(text or "") * float(getattr(_settings, "credit_tokens_per_char", 0.25)))
205+
final = int((int(len(question) * float(_settings.credit_tokens_per_char)) + out_tokens + 999) // 1000 * float(_settings.credit_per_1k_tokens))
206+
delta = max(0, int(final) - int(reserved))
207+
if delta != 0:
208+
await asyncio.to_thread(pre_authorize, int(message.author.id), int(delta))
209+
except Exception:
210+
pass
173211
# Save bot answer and update summary in background (best-effort)
174212
try:
175213
self.bot.services.memory.record_assistant_message( # type: ignore[attr-defined]

0 commit comments

Comments
 (0)