Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/wordlist.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
QueryWeaver
FalkorDB
OAuth
DDL
DML
AGPL
Affero
nullability
Expand Down
5 changes: 1 addition & 4 deletions api/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import secrets

from dotenv import load_dotenv
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -24,8 +23,6 @@
from api.routes.tokens import tokens_router
from api.routes.settings import settings_router

# Load environment variables from .env file
load_dotenv()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
Expand Down Expand Up @@ -147,7 +144,7 @@ def create_app(): # pylint: disable=too-many-statements
app.include_router(graphs_router, prefix="/graphs")
app.include_router(database_router)
app.include_router(tokens_router, prefix="/tokens")
app.include_router(settings_router, prefix="/api")
app.include_router(settings_router, prefix="/settings")



Expand Down
10 changes: 9 additions & 1 deletion api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
import logging
import dataclasses
from typing import Union

from dotenv import load_dotenv
from litellm import embedding

# Ensure .env is loaded before Config reads os.getenv() at class definition time
load_dotenv()

# Configure litellm logging to prevent sensitive data leakage
def configure_litellm_logging():
"""Configure litellm to suppress completion logs."""
Expand Down Expand Up @@ -64,6 +69,9 @@ def _with_prefix(model: str, provider: str) -> str:
return prefix + model.removeprefix(prefix)


SUPPORTED_VENDORS = ("openai", "anthropic", "gemini", "azure", "ollama", "cohere")


@dataclasses.dataclass
class Config:
"""
Expand Down Expand Up @@ -103,7 +111,7 @@ class Config:
EMBEDDING_MODEL_NAME = "voyage/voyage-3"
else:
raise ValueError(
"Anthropic has no native embeddings. "
"ANTHROPIC_API_KEY is set, but Anthropic has no native embeddings. "
"Set EMBEDDING_MODEL or VOYAGE_API_KEY for embeddings."
)
elif os.getenv("COHERE_API_KEY"):
Expand Down
4 changes: 4 additions & 0 deletions api/core/schema_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from api.loaders.base_loader import BaseLoader
from api.loaders.postgres_loader import PostgresLoader
from api.loaders.mysql_loader import MySQLLoader
from api.loaders.snowflake_loader import SnowflakeLoader

# Use the same delimiter as in the JavaScript frontend for streaming chunks
MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||"
Expand Down Expand Up @@ -44,6 +45,9 @@ def _step_detect_db_type(steps_counter: int, url: str) -> tuple[type[BaseLoader]
elif url.startswith("mysql://"):
db_type = "mysql"
loader = MySQLLoader
elif url.startswith("snowflake://"):
db_type = "snowflake"
loader = SnowflakeLoader
else:
raise InvalidArgumentError("Invalid database URL format")

Expand Down
12 changes: 6 additions & 6 deletions api/core/text2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent, FollowUpAgent
from api.agents.healer_agent import HealerAgent
from api.config import Config
from api.config import SUPPORTED_VENDORS
from api.extensions import db
from api.graph import find, get_db_description, get_user_rules
from api.loaders.postgres_loader import PostgresLoader
from api.loaders.mysql_loader import MySQLLoader
from api.loaders.snowflake_loader import SnowflakeLoader
from api.memory.graphiti_tool import MemoryTool
from api.sql_utils import SQLIdentifierQuoter, DatabaseSpecificQuoter

Expand Down Expand Up @@ -83,6 +85,8 @@ def get_database_type_and_loader(db_url: str):
return 'postgresql', PostgresLoader
if db_url_lower.startswith('mysql://'):
return 'mysql', MySQLLoader
if db_url_lower.startswith('snowflake://'):
return 'snowflake', SnowflakeLoader

# Default to PostgresLoader for backward compatibility
return 'postgresql', PostgresLoader
Expand Down Expand Up @@ -257,21 +261,17 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
custom_model = chat_data.custom_model

# Validate custom model format (vendor/model)
supported_vendors = ("openai", "anthropic", "gemini", "azure", "ollama", "cohere")
if custom_model:
parts = custom_model.split("/", 1)
if len(parts) != 2 or not parts[0] or not parts[1]:
raise InvalidArgumentError(
"Invalid model format. Expected 'vendor/model' (e.g. 'openai/gpt-4.1')"
)
if parts[0] not in supported_vendors:
if parts[0] not in SUPPORTED_VENDORS:
raise InvalidArgumentError(
f"Unsupported vendor '{parts[0]}'. Supported: {', '.join(supported_vendors)}"
f"Unsupported vendor '{parts[0]}'. Supported: {', '.join(SUPPORTED_VENDORS)}"
)

if custom_api_key is not None and len(custom_api_key.strip()) < 10:
raise InvalidArgumentError("API key is too short")

agent_rel = RelevancyAgent(queries_history, result_history, custom_api_key, custom_model)
agent_an = AnalysisAgent(queries_history, result_history, custom_api_key, custom_model)
follow_up_agent = FollowUpAgent(queries_history, result_history, custom_api_key, custom_model)
Expand Down
6 changes: 5 additions & 1 deletion api/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Main entry point for the text2sql API."""

from api.app_factory import create_app
# Load .env before any app imports that read os.getenv at module level
from dotenv import load_dotenv
load_dotenv()

from api.app_factory import create_app # pylint: disable=wrong-import-position

app = create_app()

Expand Down
Loading
Loading