Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ FASTAPI_SECRET_KEY=your_super_secret_key_here
# Example: redis://localhost:6379/0
FALKORDB_URL=redis://localhost:6379/0 # REQUIRED - change to your FalkorDB URL

# Optional: name of the central user-management graph (User/Identity/Token/UsageEvent).
# Defaults to "Organizations"; override to share or isolate that graph.
# ORGANIZATIONS_GRAPH=Organizations

# Optional: separate host/port settings for local testing (only used if FALKORDB_URL is not set)
# FALKORDB_HOST=localhost
# FALKORDB_PORT=6379
Expand Down
9 changes: 5 additions & 4 deletions api/auth/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fastapi import Request, HTTPException, status
from pydantic import BaseModel
from api.config import ORGANIZATIONS_GRAPH
from api.extensions import db

# Get secret key for sessions
Expand Down Expand Up @@ -45,7 +46,7 @@ async def _get_user_info(api_token: str) -> Optional[Dict[str, Any]]:

try:
# Select the Organizations graph
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)

result = await organizations_graph.query(
query,
Expand Down Expand Up @@ -83,7 +84,7 @@ async def delete_user_token(api_token: str):
"""
try:
# Select the Organizations graph
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)

await organizations_graph.query(
query,
Expand Down Expand Up @@ -117,7 +118,7 @@ async def ensure_user_in_organizations( # pylint: disable=too-many-arguments, d
return validation_result

try:
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)
first_name, last_name = _extract_name_parts(name)

merge_query = _build_user_merge_query()
Expand Down Expand Up @@ -166,7 +167,7 @@ async def update_identity_last_login(provider, provider_user_id):
return

try:
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)
update_query = """
MATCH (identity:Identity {provider: $provider, provider_user_id: $provider_user_id})
SET identity.last_login = timestamp()
Expand Down
5 changes: 5 additions & 0 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# Ensure .env is loaded before Config reads os.getenv() at class definition time
load_dotenv()

# Central user-management graph holding User/Identity/Token (and UsageEvent)
# nodes. Single source of truth for the name used across auth, tokens, and
# usage tracking — override with the ORGANIZATIONS_GRAPH env var.
ORGANIZATIONS_GRAPH = os.getenv("ORGANIZATIONS_GRAPH") or "Organizations"

# Configure litellm logging to prevent sensitive data leakage
def configure_litellm_logging():
"""Configure litellm to suppress completion logs."""
Expand Down
7 changes: 4 additions & 3 deletions api/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydantic import BaseModel

from api.auth.user_management import delete_user_token, ensure_user_in_organizations, validate_user
from api.config import ORGANIZATIONS_GRAPH
from api.extensions import db

# Import GENERAL_PREFIX from graphs route
Expand Down Expand Up @@ -136,7 +137,7 @@ def _validate_email(email: str) -> bool:
async def _set_mail_hash(email: str, password_hash: str) -> bool:
"""Set email hash for the user in the database."""
try:
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)

# Sanitize inputs for logging
safe_email = _sanitize_for_log(email)
Expand Down Expand Up @@ -178,7 +179,7 @@ async def _email_account_exists(email: str) -> bool:
Exceptions are intentionally not swallowed so callers fail closed (treat the
account as existing / abort the signup) rather than issuing a session token.
"""
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)
# Use a UNION of two label-scoped lookups so each side hits the (label, email)
# index and short-circuits with LIMIT 1. This avoids both a full-graph scan and
# the Cartesian product that two chained OPTIONAL MATCH clauses would produce.
Expand All @@ -204,7 +205,7 @@ def _is_request_secure(request: Request) -> bool:
async def _authenticate_email_user(email: str, password: str):
"""Authenticate an email user."""
try:
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)

# Find user by email
query = """
Expand Down
38 changes: 33 additions & 5 deletions api/routes/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,39 @@
from api.graph import get_user_rules, set_user_rules
from api.auth.user_management import token_required
from api.routes.tokens import UNAUTHORIZED_RESPONSE
from api.routes.usage_tracking import record_query_usage_background

graphs_router = APIRouter(tags=["Graphs & Databases"])


async def _serialize_pipeline(gen):
async def _serialize_pipeline(gen, *, user_id, namespaced):
Comment thread
galshubeli marked this conversation as resolved.
"""Serialize pipeline events to the wire format and stop on ``_Final``.

Pure encoding loop — no exception handling here. Each route handler
wraps iteration in its own ``try/except`` so the broad-except (which
emits a generic error event without leaking stack data) lives in the
route function CodeQL already accepts, not in a shared helper.

Always-on usage tracking lives here in the route layer (not in
``api/core``) so it ships with the hosted app, never the PyPI SDK. Exactly
one event is recorded per query, derived from the final ``QueryResult`` —
skipping the destructive-confirmation prompt, which has no outcome yet (the
``/confirm`` call records that query).
"""
final = None
async for event in gen:
if isinstance(event, _Final):
return
final = event.value
break
yield json.dumps(event) + MESSAGE_DELIMITER
if final is not None and not final.requires_confirmation:
# "Success" = a valid query that ran without error. error_message is
# None alone isn't enough: off-topic / not-SQL-translatable results
# carry is_valid=False with no error, and must not inflate success_count.
record_query_usage_background(
user_id, namespaced,
success=final.is_valid and final.error_message is None,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
galshubeli marked this conversation as resolved.


class GraphData(BaseModel):
Expand Down Expand Up @@ -170,7 +187,7 @@ async def query_graph(
# the StreamingResponse is iterated. Surfacing client errors as HTTP 400
# requires a synchronous check before we hand the stream to the response.
try:
graph_name(request.state.user_id, graph_id)
namespaced = graph_name(request.state.user_id, graph_id)
validate_and_truncate_chat(chat_data)
validate_custom_model(getattr(chat_data, "custom_model", None))
except InvalidArgumentError as iae:
Expand All @@ -180,13 +197,19 @@ async def query_graph(
async def stream():
try:
async for chunk in _serialize_pipeline(
run_query(request.state.user_id, graph_id, chat_data)
run_query(request.state.user_id, graph_id, chat_data),
user_id=request.state.user_id, namespaced=namespaced,
):
yield chunk
except Exception: # pylint: disable=broad-exception-caught
# Don't leak stack traces (CodeQL: information exposure through
# exception). Log internally; emit a generic error event.
logging.exception("Streaming query failed")
# Pipeline crashed before _Final, so _serialize_pipeline didn't
# record — count this attempt as a failure here.
record_query_usage_background(
request.state.user_id, namespaced, success=False
)
yield json.dumps({
"type": "error",
"final_response": True,
Expand Down Expand Up @@ -225,12 +248,17 @@ async def confirm_destructive_operation(
async def stream():
try:
async for chunk in _serialize_pipeline(
run_confirmed(request.state.user_id, graph_id, confirm_data)
run_confirmed(request.state.user_id, graph_id, confirm_data),
user_id=request.state.user_id, namespaced=namespaced,
):
yield chunk
except Exception: # pylint: disable=broad-exception-caught
# See note on the query endpoint above (CodeQL).
logging.exception("Streaming confirmed-destructive query failed")
# Pipeline crashed before _Final — record the failed attempt here.
record_query_usage_background(
request.state.user_id, namespaced, success=False
)
yield json.dumps({
"type": "error",
"final_response": True,
Expand Down
5 changes: 3 additions & 2 deletions api/routes/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel

from api.auth.user_management import token_required
from api.config import ORGANIZATIONS_GRAPH
from api.extensions import db

UNAUTHORIZED_RESPONSE = {"description": "Unauthorized - Please log in or provide a valid API token"}
Expand Down Expand Up @@ -80,7 +81,7 @@ async def list_tokens(request: Request) -> TokenListResponse:
user_email = request.state.user_email

# Get tokens from Organizations graph
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)

# Get user information by API token and then get all associated tokens that connected
# to the Identity of provider='api'
Expand Down Expand Up @@ -120,7 +121,7 @@ async def delete_token(request: Request, token_id: str) -> JSONResponse:
user_email = request.state.user_email

# Delete token from Organizations graph
organizations_graph = db.select_graph("Organizations")
organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH)

# Delete the token
delete_query = """
Expand Down
148 changes: 148 additions & 0 deletions api/routes/usage_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Always-on, provider-agnostic per-query usage tracking.

This is deliberately independent of the optional LLM conversational-memory
feature (``api/memory/graphiti_tool.py``). Memory writes are opt-in
(``use_memory``), gated to OpenAI/Azure providers, and lazily created — so
they cannot be used to measure adoption. This module records *every* query,
regardless of provider or the ``use_memory`` flag, onto the central
``Organizations`` graph (which already holds ``User``/``Identity``/``Token``).

For each query we maintain, fire-and-forget:

* Denormalized counters + activity timestamps on the ``User`` node
(``query_count``/``success_count``/``error_count``/``last_active``/
``first_query_at``) for cheap reads.
* A per-query ``(:UsageEvent)`` node linked ``(User)-[:PERFORMED]->`` carrying
``graph_id``/``is_demo``/``success``/``timestamp`` for time-series, per-DB
and success-rate analytics.

Writes never block or fail a request: they run as background tasks whose
exceptions are logged and swallowed, mirroring
``api.core.pipeline.save_memory_background``.
"""

import asyncio
import base64
import binascii
import hashlib
import logging
from typing import Optional

from api.config import ORGANIZATIONS_GRAPH
from api.core.db_resolver import resolve_db
from api.core.pipeline import background_tasks_var, is_general_graph

# Single round-trip: bump the User counters/timestamps and append a UsageEvent.
# Uses MATCH (not MERGE) on User so an unknown email is a silent no-op rather
# than creating a phantom user from the query path. ``timestamp()`` is FalkorDB
# epoch-millis, matching every other timestamp in the Organizations graph.
_RECORD_USAGE_CYPHER = """
MATCH (u:User {email: $email})
SET u.query_count = coalesce(u.query_count, 0) + 1,
u.success_count = coalesce(u.success_count, 0) + (CASE WHEN $success THEN 1 ELSE 0 END),
u.error_count = coalesce(u.error_count, 0) + (CASE WHEN $success THEN 0 ELSE 1 END),
u.last_active = timestamp(),
u.first_query_at = coalesce(u.first_query_at, timestamp())
CREATE (u)-[:PERFORMED]->(e:UsageEvent {
graph_id: $graph_id,
is_demo: $is_demo,
success: $success,
timestamp: timestamp()
})
"""


def _decode_email(user_id: str) -> Optional[str]:
"""Recover the user's email from the base64 ``user_id``.

Inverse of ``base64.b64encode(email.encode())`` in
``api/auth/user_management.py``. Returns ``None`` on malformed input so the
caller can skip tracking instead of raising.
"""
if not user_id:
return None
try:
email = base64.b64decode(user_id, validate=True).decode("utf-8")
except (binascii.Error, ValueError, UnicodeDecodeError):
logging.warning("Usage tracking: could not decode user_id to email")
return None
# b64decode is lenient about padding/length; require an email-shaped result
# so a malformed id can't trigger a phantom DB write (matches the docstring).
if "@" not in email:
logging.warning("Usage tracking: decoded user_id is not a valid email")
return None
return email


async def _write_usage(email: str, graph_id: str, is_demo: bool, success: bool, db) -> None:
"""Perform the single Cypher write against the Organizations graph."""
organizations_graph = resolve_db(db).select_graph(ORGANIZATIONS_GRAPH)
await organizations_graph.query(
_RECORD_USAGE_CYPHER,
{
"email": email,
"graph_id": graph_id,
"is_demo": is_demo,
"success": success,
},
)
# Structured-ish log line so usage is visible to log aggregators even
# before any read API exists. graph_id is the namespaced name
# ({base64(email)}_{db}) and base64 email is reversible, so log a short
# stable hash instead of the raw value — this keeps user identity out of
# logs and also neutralizes the CodeQL log-injection vector.
graph_ref = hashlib.sha256(graph_id.encode()).hexdigest()[:12]
logging.info(
"usage_event graph=%s is_demo=%s success=%s",
graph_ref, is_demo, success,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def record_query_usage_background(
user_id: str,
namespaced: str,
success: bool,
*,
db=None,
task_sink: Optional[set] = None,
) -> None:
"""Schedule fire-and-forget usage tracking for one query.

Returns immediately. The write runs as a background task whose failure is
logged but never propagated, so tracking can never break or delay a query
response. Called unconditionally at pipeline completion — independent of
``use_memory`` and the LLM provider.

Args:
user_id: Base64-encoded email (the namespacing id used by the routes).
namespaced: The fully-namespaced graph name the query ran against;
already demo-aware, so it doubles as the recorded ``graph_id``.
success: Whether SQL execution succeeded (no execution error).
db: Optional FalkorDB handle; resolves to the server singleton when None.
task_sink: Optional set the scheduled task is added to (and auto-removed
from on completion) so callers can await any in-flight tracking
writes before shutdown.
"""
email = _decode_email(user_id)
if email is None:
return

is_demo = is_general_graph(namespaced)
sink = task_sink if task_sink is not None else background_tasks_var.get()

task = asyncio.create_task(
_write_usage(email, namespaced, is_demo, success, db)
)

if sink is not None:
sink.add(task)
task.add_done_callback(sink.discard)

def _log_done(t: "asyncio.Task") -> None:
if t.cancelled():
return
exc = t.exception()
if exc is not None:
logging.error("Usage tracking save failed: %s", exc) # nosemgrep

task.add_done_callback(_log_done)
Loading