diff --git a/.env.example b/.env.example index c37f9772..eeecd753 100644 --- a/.env.example +++ b/.env.example @@ -37,6 +37,10 @@ FASTAPI_SECRET_KEY=your_super_secret_key_here # Example: redis://localhost:6379/0 FALKORDB_URL=redis://localhost:6379/0 # REQUIRED - change to your FalkorDB URL +# Optional: name of the central user-management graph (User/Identity/Token/UsageEvent). +# Defaults to "Organizations"; override to share or isolate that graph. +# ORGANIZATIONS_GRAPH=Organizations + # Optional: separate host/port settings for local testing (only used if FALKORDB_URL is not set) # FALKORDB_HOST=localhost # FALKORDB_PORT=6379 diff --git a/api/auth/user_management.py b/api/auth/user_management.py index 3b5cd00f..7272b071 100644 --- a/api/auth/user_management.py +++ b/api/auth/user_management.py @@ -9,6 +9,7 @@ from fastapi import Request, HTTPException, status from pydantic import BaseModel +from api.config import ORGANIZATIONS_GRAPH from api.extensions import db # Get secret key for sessions @@ -45,7 +46,7 @@ async def _get_user_info(api_token: str) -> Optional[Dict[str, Any]]: try: # Select the Organizations graph - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) result = await organizations_graph.query( query, @@ -83,7 +84,7 @@ async def delete_user_token(api_token: str): """ try: # Select the Organizations graph - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) await organizations_graph.query( query, @@ -117,7 +118,7 @@ async def ensure_user_in_organizations( # pylint: disable=too-many-arguments, d return validation_result try: - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) first_name, last_name = _extract_name_parts(name) merge_query = _build_user_merge_query() @@ -166,7 +167,7 @@ async def update_identity_last_login(provider, provider_user_id): return try: - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) update_query = """ MATCH (identity:Identity {provider: $provider, provider_user_id: $provider_user_id}) SET identity.last_login = timestamp() diff --git a/api/config.py b/api/config.py index b029970c..dce8a34c 100644 --- a/api/config.py +++ b/api/config.py @@ -14,6 +14,11 @@ # Ensure .env is loaded before Config reads os.getenv() at class definition time load_dotenv() +# Central user-management graph holding User/Identity/Token (and UsageEvent) +# nodes. Single source of truth for the name used across auth, tokens, and +# usage tracking — override with the ORGANIZATIONS_GRAPH env var. +ORGANIZATIONS_GRAPH = os.getenv("ORGANIZATIONS_GRAPH") or "Organizations" + # Configure litellm logging to prevent sensitive data leakage def configure_litellm_logging(): """Configure litellm to suppress completion logs.""" diff --git a/api/routes/auth.py b/api/routes/auth.py index 680a0c59..e55e61b3 100644 --- a/api/routes/auth.py +++ b/api/routes/auth.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from api.auth.user_management import delete_user_token, ensure_user_in_organizations, validate_user +from api.config import ORGANIZATIONS_GRAPH from api.extensions import db # Import GENERAL_PREFIX from graphs route @@ -136,7 +137,7 @@ def _validate_email(email: str) -> bool: async def _set_mail_hash(email: str, password_hash: str) -> bool: """Set email hash for the user in the database.""" try: - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) # Sanitize inputs for logging safe_email = _sanitize_for_log(email) @@ -178,7 +179,7 @@ async def _email_account_exists(email: str) -> bool: Exceptions are intentionally not swallowed so callers fail closed (treat the account as existing / abort the signup) rather than issuing a session token. """ - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) # Use a UNION of two label-scoped lookups so each side hits the (label, email) # index and short-circuits with LIMIT 1. This avoids both a full-graph scan and # the Cartesian product that two chained OPTIONAL MATCH clauses would produce. @@ -204,7 +205,7 @@ def _is_request_secure(request: Request) -> bool: async def _authenticate_email_user(email: str, password: str): """Authenticate an email user.""" try: - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) # Find user by email query = """ diff --git a/api/routes/graphs.py b/api/routes/graphs.py index 95b6c08c..7c1d1c54 100644 --- a/api/routes/graphs.py +++ b/api/routes/graphs.py @@ -29,22 +29,39 @@ from api.graph import get_user_rules, set_user_rules from api.auth.user_management import token_required from api.routes.tokens import UNAUTHORIZED_RESPONSE +from api.routes.usage_tracking import record_query_usage_background graphs_router = APIRouter(tags=["Graphs & Databases"]) -async def _serialize_pipeline(gen): +async def _serialize_pipeline(gen, *, user_id, namespaced): """Serialize pipeline events to the wire format and stop on ``_Final``. Pure encoding loop — no exception handling here. Each route handler wraps iteration in its own ``try/except`` so the broad-except (which emits a generic error event without leaking stack data) lives in the route function CodeQL already accepts, not in a shared helper. + + Always-on usage tracking lives here in the route layer (not in + ``api/core``) so it ships with the hosted app, never the PyPI SDK. Exactly + one event is recorded per query, derived from the final ``QueryResult`` — + skipping the destructive-confirmation prompt, which has no outcome yet (the + ``/confirm`` call records that query). """ + final = None async for event in gen: if isinstance(event, _Final): - return + final = event.value + break yield json.dumps(event) + MESSAGE_DELIMITER + if final is not None and not final.requires_confirmation: + # "Success" = a valid query that ran without error. error_message is + # None alone isn't enough: off-topic / not-SQL-translatable results + # carry is_valid=False with no error, and must not inflate success_count. + record_query_usage_background( + user_id, namespaced, + success=final.is_valid and final.error_message is None, + ) class GraphData(BaseModel): @@ -170,7 +187,7 @@ async def query_graph( # the StreamingResponse is iterated. Surfacing client errors as HTTP 400 # requires a synchronous check before we hand the stream to the response. try: - graph_name(request.state.user_id, graph_id) + namespaced = graph_name(request.state.user_id, graph_id) validate_and_truncate_chat(chat_data) validate_custom_model(getattr(chat_data, "custom_model", None)) except InvalidArgumentError as iae: @@ -180,13 +197,19 @@ async def query_graph( async def stream(): try: async for chunk in _serialize_pipeline( - run_query(request.state.user_id, graph_id, chat_data) + run_query(request.state.user_id, graph_id, chat_data), + user_id=request.state.user_id, namespaced=namespaced, ): yield chunk except Exception: # pylint: disable=broad-exception-caught # Don't leak stack traces (CodeQL: information exposure through # exception). Log internally; emit a generic error event. logging.exception("Streaming query failed") + # Pipeline crashed before _Final, so _serialize_pipeline didn't + # record — count this attempt as a failure here. + record_query_usage_background( + request.state.user_id, namespaced, success=False + ) yield json.dumps({ "type": "error", "final_response": True, @@ -225,12 +248,17 @@ async def confirm_destructive_operation( async def stream(): try: async for chunk in _serialize_pipeline( - run_confirmed(request.state.user_id, graph_id, confirm_data) + run_confirmed(request.state.user_id, graph_id, confirm_data), + user_id=request.state.user_id, namespaced=namespaced, ): yield chunk except Exception: # pylint: disable=broad-exception-caught # See note on the query endpoint above (CodeQL). logging.exception("Streaming confirmed-destructive query failed") + # Pipeline crashed before _Final — record the failed attempt here. + record_query_usage_background( + request.state.user_id, namespaced, success=False + ) yield json.dumps({ "type": "error", "final_response": True, diff --git a/api/routes/tokens.py b/api/routes/tokens.py index 03fa0f5a..7df4a492 100644 --- a/api/routes/tokens.py +++ b/api/routes/tokens.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from api.auth.user_management import token_required +from api.config import ORGANIZATIONS_GRAPH from api.extensions import db UNAUTHORIZED_RESPONSE = {"description": "Unauthorized - Please log in or provide a valid API token"} @@ -80,7 +81,7 @@ async def list_tokens(request: Request) -> TokenListResponse: user_email = request.state.user_email # Get tokens from Organizations graph - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) # Get user information by API token and then get all associated tokens that connected # to the Identity of provider='api' @@ -120,7 +121,7 @@ async def delete_token(request: Request, token_id: str) -> JSONResponse: user_email = request.state.user_email # Delete token from Organizations graph - organizations_graph = db.select_graph("Organizations") + organizations_graph = db.select_graph(ORGANIZATIONS_GRAPH) # Delete the token delete_query = """ diff --git a/api/routes/usage_tracking.py b/api/routes/usage_tracking.py new file mode 100644 index 00000000..83ded999 --- /dev/null +++ b/api/routes/usage_tracking.py @@ -0,0 +1,148 @@ +"""Always-on, provider-agnostic per-query usage tracking. + +This is deliberately independent of the optional LLM conversational-memory +feature (``api/memory/graphiti_tool.py``). Memory writes are opt-in +(``use_memory``), gated to OpenAI/Azure providers, and lazily created — so +they cannot be used to measure adoption. This module records *every* query, +regardless of provider or the ``use_memory`` flag, onto the central +``Organizations`` graph (which already holds ``User``/``Identity``/``Token``). + +For each query we maintain, fire-and-forget: + +* Denormalized counters + activity timestamps on the ``User`` node + (``query_count``/``success_count``/``error_count``/``last_active``/ + ``first_query_at``) for cheap reads. +* A per-query ``(:UsageEvent)`` node linked ``(User)-[:PERFORMED]->`` carrying + ``graph_id``/``is_demo``/``success``/``timestamp`` for time-series, per-DB + and success-rate analytics. + +Writes never block or fail a request: they run as background tasks whose +exceptions are logged and swallowed, mirroring +``api.core.pipeline.save_memory_background``. +""" + +import asyncio +import base64 +import binascii +import hashlib +import logging +from typing import Optional + +from api.config import ORGANIZATIONS_GRAPH +from api.core.db_resolver import resolve_db +from api.core.pipeline import background_tasks_var, is_general_graph + +# Single round-trip: bump the User counters/timestamps and append a UsageEvent. +# Uses MATCH (not MERGE) on User so an unknown email is a silent no-op rather +# than creating a phantom user from the query path. ``timestamp()`` is FalkorDB +# epoch-millis, matching every other timestamp in the Organizations graph. +_RECORD_USAGE_CYPHER = """ +MATCH (u:User {email: $email}) +SET u.query_count = coalesce(u.query_count, 0) + 1, + u.success_count = coalesce(u.success_count, 0) + (CASE WHEN $success THEN 1 ELSE 0 END), + u.error_count = coalesce(u.error_count, 0) + (CASE WHEN $success THEN 0 ELSE 1 END), + u.last_active = timestamp(), + u.first_query_at = coalesce(u.first_query_at, timestamp()) +CREATE (u)-[:PERFORMED]->(e:UsageEvent { + graph_id: $graph_id, + is_demo: $is_demo, + success: $success, + timestamp: timestamp() +}) +""" + + +def _decode_email(user_id: str) -> Optional[str]: + """Recover the user's email from the base64 ``user_id``. + + Inverse of ``base64.b64encode(email.encode())`` in + ``api/auth/user_management.py``. Returns ``None`` on malformed input so the + caller can skip tracking instead of raising. + """ + if not user_id: + return None + try: + email = base64.b64decode(user_id, validate=True).decode("utf-8") + except (binascii.Error, ValueError, UnicodeDecodeError): + logging.warning("Usage tracking: could not decode user_id to email") + return None + # b64decode is lenient about padding/length; require an email-shaped result + # so a malformed id can't trigger a phantom DB write (matches the docstring). + if "@" not in email: + logging.warning("Usage tracking: decoded user_id is not a valid email") + return None + return email + + +async def _write_usage(email: str, graph_id: str, is_demo: bool, success: bool, db) -> None: + """Perform the single Cypher write against the Organizations graph.""" + organizations_graph = resolve_db(db).select_graph(ORGANIZATIONS_GRAPH) + await organizations_graph.query( + _RECORD_USAGE_CYPHER, + { + "email": email, + "graph_id": graph_id, + "is_demo": is_demo, + "success": success, + }, + ) + # Structured-ish log line so usage is visible to log aggregators even + # before any read API exists. graph_id is the namespaced name + # ({base64(email)}_{db}) and base64 email is reversible, so log a short + # stable hash instead of the raw value — this keeps user identity out of + # logs and also neutralizes the CodeQL log-injection vector. + graph_ref = hashlib.sha256(graph_id.encode()).hexdigest()[:12] + logging.info( + "usage_event graph=%s is_demo=%s success=%s", + graph_ref, is_demo, success, + ) + + +def record_query_usage_background( + user_id: str, + namespaced: str, + success: bool, + *, + db=None, + task_sink: Optional[set] = None, +) -> None: + """Schedule fire-and-forget usage tracking for one query. + + Returns immediately. The write runs as a background task whose failure is + logged but never propagated, so tracking can never break or delay a query + response. Called unconditionally at pipeline completion — independent of + ``use_memory`` and the LLM provider. + + Args: + user_id: Base64-encoded email (the namespacing id used by the routes). + namespaced: The fully-namespaced graph name the query ran against; + already demo-aware, so it doubles as the recorded ``graph_id``. + success: Whether SQL execution succeeded (no execution error). + db: Optional FalkorDB handle; resolves to the server singleton when None. + task_sink: Optional set the scheduled task is added to (and auto-removed + from on completion) so callers can await any in-flight tracking + writes before shutdown. + """ + email = _decode_email(user_id) + if email is None: + return + + is_demo = is_general_graph(namespaced) + sink = task_sink if task_sink is not None else background_tasks_var.get() + + task = asyncio.create_task( + _write_usage(email, namespaced, is_demo, success, db) + ) + + if sink is not None: + sink.add(task) + task.add_done_callback(sink.discard) + + def _log_done(t: "asyncio.Task") -> None: + if t.cancelled(): + return + exc = t.exception() + if exc is not None: + logging.error("Usage tracking save failed: %s", exc) # nosemgrep + + task.add_done_callback(_log_done) diff --git a/tests/test_usage_tracking.py b/tests/test_usage_tracking.py new file mode 100644 index 00000000..2baf0435 --- /dev/null +++ b/tests/test_usage_tracking.py @@ -0,0 +1,155 @@ +"""Tests for always-on per-query usage tracking. + +Usage tracking (``api/routes/usage_tracking.py``) records every query onto the +``Organizations`` graph, independent of the optional ``use_memory`` feature and +the LLM provider. These tests assert the write content, the ungated design, +and that failures never propagate to the caller. +""" + +import asyncio +import base64 +import inspect +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from api.routes import usage_tracking +from api.routes.usage_tracking import ( + _decode_email, + record_query_usage_background, +) + +pytestmark = [pytest.mark.unit] + +EMAIL = "gal.shubeli@falkordb.com" +USER_ID = base64.b64encode(EMAIL.encode()).decode() + + +def _mock_db(): + """A FalkorDB-like mock whose ``select_graph(...).query`` is awaitable.""" + graph = MagicMock() + graph.query = AsyncMock(return_value=MagicMock(result_set=[])) + db = MagicMock() + db.select_graph.return_value = graph + return db, graph + + +async def _drain(sink): + """Await every background task the recorder scheduled into ``sink``.""" + await asyncio.gather(*list(sink), return_exceptions=True) + + +class TestDecodeEmail: + def test_decodes_valid_user_id(self): + assert _decode_email(USER_ID) == EMAIL + + def test_returns_none_for_empty(self): + assert _decode_email("") is None + + def test_returns_none_for_garbage(self): + # Malformed base64 must yield None, not raise. + assert _decode_email("!!!not-base64!!!") is None + + def test_returns_none_for_valid_base64_that_is_not_an_email(self): + # Decodes cleanly but isn't email-shaped -> skip (no phantom write). + not_email = base64.b64encode(b"notanemail").decode() + assert _decode_email(not_email) is None + + +class TestRecordQueryUsage: + @pytest.mark.asyncio + async def test_records_successful_query_event(self): + db, graph = _mock_db() + sink: set = set() + with patch.object(usage_tracking, "resolve_db", return_value=db), \ + patch.object(usage_tracking, "is_general_graph", return_value=False): + record_query_usage_background( + USER_ID, f"{USER_ID}_mydb", success=True, db=db, task_sink=sink + ) + await _drain(sink) + + db.select_graph.assert_called_once_with(usage_tracking.ORGANIZATIONS_GRAPH) + graph.query.assert_awaited_once() + cypher, params = graph.query.await_args.args + assert "MATCH (u:User {email: $email})" in cypher + assert ":UsageEvent" in cypher + assert params == { + "email": EMAIL, + "graph_id": f"{USER_ID}_mydb", + "is_demo": False, + "success": True, + } + + @pytest.mark.asyncio + async def test_records_failed_query_event(self): + db, graph = _mock_db() + sink: set = set() + with patch.object(usage_tracking, "resolve_db", return_value=db), \ + patch.object(usage_tracking, "is_general_graph", return_value=False): + record_query_usage_background( + USER_ID, f"{USER_ID}_mydb", success=False, db=db, task_sink=sink + ) + await _drain(sink) + + _cypher, params = graph.query.await_args.args + assert params["success"] is False + + @pytest.mark.asyncio + async def test_demo_graph_is_flagged(self): + db, graph = _mock_db() + sink: set = set() + with patch.object(usage_tracking, "resolve_db", return_value=db), \ + patch.object(usage_tracking, "is_general_graph", return_value=True): + record_query_usage_background( + USER_ID, "DEMO_CRM", success=True, db=db, task_sink=sink + ) + await _drain(sink) + + _cypher, params = graph.query.await_args.args + assert params["is_demo"] is True + assert params["graph_id"] == "DEMO_CRM" + + @pytest.mark.asyncio + async def test_invalid_user_id_skips_write(self): + db, graph = _mock_db() + sink: set = set() + with patch.object(usage_tracking, "resolve_db", return_value=db): + record_query_usage_background( + "!!!bad!!!", "x_y", success=True, db=db, task_sink=sink + ) + await _drain(sink) + + # No task scheduled, no graph touched. + assert not sink + graph.query.assert_not_awaited() + db.select_graph.assert_not_called() + + @pytest.mark.asyncio + async def test_write_failure_is_swallowed(self): + db, graph = _mock_db() + graph.query.side_effect = RuntimeError("falkordb down") + sink: set = set() + with patch.object(usage_tracking, "resolve_db", return_value=db), \ + patch.object(usage_tracking, "is_general_graph", return_value=False), \ + patch.object(usage_tracking.logging, "error") as mock_log_error: + # The synchronous call must not raise despite the write failing. + record_query_usage_background( + USER_ID, f"{USER_ID}_mydb", success=True, db=db, task_sink=sink + ) + await _drain(sink) + + # Failure was logged by the done-callback, not propagated. + assert any( + "Usage tracking save failed" in str(call.args[0]) + for call in mock_log_error.call_args_list + ) + + +class TestUngatedDesign: + def test_recorder_has_no_memory_or_provider_parameter(self): + """Tracking cannot be gated by ``use_memory`` or the LLM provider: + the recorder simply has no such inputs.""" + params = set(inspect.signature(record_query_usage_background).parameters) + assert params == {"user_id", "namespaced", "success", "db", "task_sink"} + assert "use_memory" not in params + assert "provider" not in params