diff --git a/.github/wordlist.txt b/.github/wordlist.txt index c726f155..5eacb0df 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -1,6 +1,8 @@ QueryWeaver FalkorDB OAuth +DDL +DML AGPL Affero nullability diff --git a/api/app_factory.py b/api/app_factory.py index 80b7b713..3b95b5bf 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -5,7 +5,6 @@ import os import secrets -from dotenv import load_dotenv from fastapi import FastAPI, Request, HTTPException from fastapi.responses import RedirectResponse, JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles @@ -24,8 +23,6 @@ from api.routes.tokens import tokens_router from api.routes.settings import settings_router -# Load environment variables from .env file -load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) @@ -147,7 +144,7 @@ def create_app(): # pylint: disable=too-many-statements app.include_router(graphs_router, prefix="/graphs") app.include_router(database_router) app.include_router(tokens_router, prefix="/tokens") - app.include_router(settings_router, prefix="/api") + app.include_router(settings_router, prefix="/settings") diff --git a/api/config.py b/api/config.py index 9f517ff1..b029970c 100644 --- a/api/config.py +++ b/api/config.py @@ -7,8 +7,13 @@ import logging import dataclasses from typing import Union + +from dotenv import load_dotenv from litellm import embedding +# Ensure .env is loaded before Config reads os.getenv() at class definition time +load_dotenv() + # Configure litellm logging to prevent sensitive data leakage def configure_litellm_logging(): """Configure litellm to suppress completion logs.""" @@ -64,6 +69,9 @@ def _with_prefix(model: str, provider: str) -> str: return prefix + model.removeprefix(prefix) +SUPPORTED_VENDORS = ("openai", "anthropic", "gemini", "azure", "ollama", "cohere") + + @dataclasses.dataclass class Config: """ @@ -103,7 +111,7 @@ class Config: EMBEDDING_MODEL_NAME = "voyage/voyage-3" else: raise ValueError( - "Anthropic has no native embeddings. " + "ANTHROPIC_API_KEY is set, but Anthropic has no native embeddings. " "Set EMBEDDING_MODEL or VOYAGE_API_KEY for embeddings." ) elif os.getenv("COHERE_API_KEY"): diff --git a/api/core/schema_loader.py b/api/core/schema_loader.py index bb4dcedb..b1568514 100644 --- a/api/core/schema_loader.py +++ b/api/core/schema_loader.py @@ -13,6 +13,7 @@ from api.loaders.base_loader import BaseLoader from api.loaders.postgres_loader import PostgresLoader from api.loaders.mysql_loader import MySQLLoader +from api.loaders.snowflake_loader import SnowflakeLoader # Use the same delimiter as in the JavaScript frontend for streaming chunks MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" @@ -44,6 +45,9 @@ def _step_detect_db_type(steps_counter: int, url: str) -> tuple[type[BaseLoader] elif url.startswith("mysql://"): db_type = "mysql" loader = MySQLLoader + elif url.startswith("snowflake://"): + db_type = "snowflake" + loader = SnowflakeLoader else: raise InvalidArgumentError("Invalid database URL format") diff --git a/api/core/text2sql.py b/api/core/text2sql.py index bd0f20b9..65bc0b4f 100644 --- a/api/core/text2sql.py +++ b/api/core/text2sql.py @@ -15,10 +15,12 @@ from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent, FollowUpAgent from api.agents.healer_agent import HealerAgent from api.config import Config +from api.config import SUPPORTED_VENDORS from api.extensions import db from api.graph import find, get_db_description, get_user_rules from api.loaders.postgres_loader import PostgresLoader from api.loaders.mysql_loader import MySQLLoader +from api.loaders.snowflake_loader import SnowflakeLoader from api.memory.graphiti_tool import MemoryTool from api.sql_utils import SQLIdentifierQuoter, DatabaseSpecificQuoter @@ -83,6 +85,8 @@ def get_database_type_and_loader(db_url: str): return 'postgresql', PostgresLoader if db_url_lower.startswith('mysql://'): return 'mysql', MySQLLoader + if db_url_lower.startswith('snowflake://'): + return 'snowflake', SnowflakeLoader # Default to PostgresLoader for backward compatibility return 'postgresql', PostgresLoader @@ -257,21 +261,17 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m custom_model = chat_data.custom_model # Validate custom model format (vendor/model) - supported_vendors = ("openai", "anthropic", "gemini", "azure", "ollama", "cohere") if custom_model: parts = custom_model.split("/", 1) if len(parts) != 2 or not parts[0] or not parts[1]: raise InvalidArgumentError( "Invalid model format. Expected 'vendor/model' (e.g. 'openai/gpt-4.1')" ) - if parts[0] not in supported_vendors: + if parts[0] not in SUPPORTED_VENDORS: raise InvalidArgumentError( - f"Unsupported vendor '{parts[0]}'. Supported: {', '.join(supported_vendors)}" + f"Unsupported vendor '{parts[0]}'. Supported: {', '.join(SUPPORTED_VENDORS)}" ) - if custom_api_key is not None and len(custom_api_key.strip()) < 10: - raise InvalidArgumentError("API key is too short") - agent_rel = RelevancyAgent(queries_history, result_history, custom_api_key, custom_model) agent_an = AnalysisAgent(queries_history, result_history, custom_api_key, custom_model) follow_up_agent = FollowUpAgent(queries_history, result_history, custom_api_key, custom_model) diff --git a/api/index.py b/api/index.py index 829e3e0a..1bb7d061 100644 --- a/api/index.py +++ b/api/index.py @@ -1,6 +1,10 @@ """Main entry point for the text2sql API.""" -from api.app_factory import create_app +# Load .env before any app imports that read os.getenv at module level +from dotenv import load_dotenv +load_dotenv() + +from api.app_factory import create_app # pylint: disable=wrong-import-position app = create_app() diff --git a/api/loaders/snowflake_loader.py b/api/loaders/snowflake_loader.py new file mode 100644 index 00000000..7685daa9 --- /dev/null +++ b/api/loaders/snowflake_loader.py @@ -0,0 +1,711 @@ +"""Snowflake loader for loading database schemas into FalkorDB graphs.""" + +import base64 +import datetime +import decimal +import logging +import re +from typing import AsyncGenerator, Dict, Any, List, Tuple +from urllib.parse import urlparse, parse_qs + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + +import tqdm +import snowflake.connector +from snowflake.connector import DictCursor + +from api.loaders.base_loader import BaseLoader +from api.loaders.graph_loader import load_to_graph + + +class SnowflakeQueryError(Exception): + """Exception raised for Snowflake query execution errors.""" + + +class SnowflakeConnectionError(Exception): + """Exception raised for Snowflake connection errors.""" + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class SnowflakeLoader(BaseLoader): + """ + Loader for Snowflake databases that connects and extracts schema information. + """ + + # DDL operations that modify database schema # pylint: disable=duplicate-code + SCHEMA_MODIFYING_OPERATIONS = { + 'CREATE', 'ALTER', 'DROP', 'RENAME', 'TRUNCATE' + } + + # More specific patterns for schema-affecting operations + SCHEMA_PATTERNS = [ # pylint: disable=duplicate-code + r'^\s*CREATE\s+TABLE', + r'^\s*CREATE\s+INDEX', + r'^\s*CREATE\s+UNIQUE\s+INDEX', + r'^\s*ALTER\s+TABLE', + r'^\s*DROP\s+TABLE', + r'^\s*DROP\s+INDEX', + r'^\s*RENAME\s+TABLE', + r'^\s*TRUNCATE\s+TABLE', + r'^\s*CREATE\s+VIEW', + r'^\s*DROP\s+VIEW', + r'^\s*CREATE\s+DATABASE', + r'^\s*DROP\s+DATABASE', + r'^\s*CREATE\s+SCHEMA', + r'^\s*DROP\s+SCHEMA', + ] + + @staticmethod + def _validate_identifier(identifier: str, identifier_type: str = "identifier") -> None: + """ + Validate that an identifier (table, column, database, schema name) is safe. + + Args: + identifier: The identifier to validate + identifier_type: Type of identifier for error messages + + Raises: + ValueError: If identifier contains invalid characters + """ + # Allow alphanumeric, underscore, dollar sign, and limit to reasonable length + # Snowflake identifiers can contain these characters when quoted + if not re.match(r'^[A-Za-z0-9_$]+$', identifier): + raise ValueError( + f"Invalid {identifier_type}: {identifier!r}. " + "Only alphanumeric characters, underscore, and dollar sign are allowed." + ) + if len(identifier) > 255: + raise ValueError(f"Invalid {identifier_type}: exceeds maximum length of 255") + + @staticmethod + def _quote_identifier(identifier: str) -> str: + """ + Safely quote a Snowflake identifier by escaping double quotes. + + Args: + identifier: The identifier to quote + + Returns: + Quoted identifier safe for SQL interpolation + """ + # Escape any existing double quotes by doubling them + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + + @staticmethod + def _execute_sample_query( + cursor, table_name: str, col_name: str, sample_size: int = 3 + ) -> List[Any]: + """ + Execute query to get random sample values for a column. + Snowflake implementation using SAMPLE for random sampling. + """ + # Validate identifiers to prevent SQL injection + SnowflakeLoader._validate_identifier(table_name, "table name") + SnowflakeLoader._validate_identifier(col_name, "column name") + + # Validate sample_size is a positive integer + if not isinstance(sample_size, int) or sample_size <= 0: + raise ValueError(f"sample_size must be a positive integer, got {sample_size!r}") + + # Quote identifiers safely + quoted_table = SnowflakeLoader._quote_identifier(table_name) + quoted_col = SnowflakeLoader._quote_identifier(col_name) + + # Oversample by 10x to increase the chance of getting sample_size + # distinct non-null values after filtering (Snowflake's SAMPLE clause + # returns approximate row counts, and rows may contain NULLs or duplicates) + sample_rows = sample_size * 10 + + query = f""" + SELECT DISTINCT {quoted_col} + FROM {quoted_table} SAMPLE ({sample_rows} ROWS) + WHERE {quoted_col} IS NOT NULL + LIMIT %s; + """ + cursor.execute(query, (sample_size,)) + + sample_results = cursor.fetchall() + # DictCursor returns dicts; extract the column value by name + return [row[col_name] for row in sample_results if row[col_name] is not None] + + @staticmethod + def _serialize_value(value): + """ + Convert non-JSON serializable values to JSON serializable format. + + Args: + value: The value to serialize + + Returns: + JSON serializable version of the value + """ + if isinstance(value, (datetime.date, datetime.datetime)): + return value.isoformat() + if isinstance(value, datetime.time): + return value.isoformat() + if isinstance(value, decimal.Decimal): + return float(value) + if value is None: + return None + return value + + @staticmethod + def _parse_snowflake_url(connection_url: str) -> Dict[str, Any]: # pylint: disable=too-many-locals + """ + Parse Snowflake connection URL into components. + + Supports two authentication modes: + - Password: snowflake://user:pass@account/db/schema?warehouse=WH + - Key-pair: snowflake://user@account/db/schema?warehouse=WH&private_key=BASE64_PEM + (optionally with &private_key_passphrase=PASSPHRASE) + + Args: + connection_url: Snowflake connection URL + + Returns: + Dict with connection parameters for snowflake.connector.connect() + """ + if not connection_url.startswith('snowflake://'): + raise ValueError( + "Invalid Snowflake URL format. Expected " + "snowflake://username:password@account/database/schema?warehouse=warehouse_name" + ) + + parsed = urlparse(connection_url) + + if not parsed.username: + raise ValueError("Snowflake URL must include username") + + username = parsed.username + password = parsed.password or "" + + if not parsed.hostname: + raise ValueError("Snowflake URL must include account") + account = parsed.hostname + + path_parts = [p for p in parsed.path.split('/') if p] + if len(path_parts) < 1: + raise ValueError("Snowflake URL must include database name") + + database = path_parts[0] + schema = path_parts[1] if len(path_parts) > 1 else "PUBLIC" + + query_params = parse_qs(parsed.query) + warehouse = query_params.get('warehouse', ['COMPUTE_WH'])[0] + + # Validate all identifiers + SnowflakeLoader._validate_identifier(database, "database") + SnowflakeLoader._validate_identifier(schema, "schema") + SnowflakeLoader._validate_identifier(warehouse, "warehouse") + + conn_params: Dict[str, Any] = { + 'user': username, + 'account': account, + 'database': database, + 'schema': schema, + 'warehouse': warehouse, + 'login_timeout': 30, + 'network_timeout': 60, + } + + # Check for key-pair authentication + private_key_b64 = query_params.get('private_key', [None])[0] + if private_key_b64: + passphrase = query_params.get('private_key_passphrase', [None])[0] + passphrase_bytes = passphrase.encode() if passphrase else None + + try: + # Handle both standard and URL-safe base64 (browsers may + # convert '+' to spaces when URL-encoding query params) + cleaned_b64 = private_key_b64.replace(' ', '+') + pem_bytes = base64.b64decode(cleaned_b64) + private_key = serialization.load_pem_private_key( + pem_bytes, + password=passphrase_bytes, + backend=default_backend(), + ) + conn_params['private_key'] = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + except Exception as e: + raise ValueError(f"Failed to load private key: {e}") from e + else: + conn_params['password'] = password + + return conn_params + + @staticmethod + async def load(prefix: str, connection_url: str) -> AsyncGenerator[ + tuple[bool, str], None + ]: + """ + Load the graph data from a Snowflake database into the graph database. + + Args: + connection_url: Snowflake connection URL in format: + snowflake://username:password@account/database/schema?warehouse=warehouse_name + + Returns: + Tuple[bool, str]: Success status and message + """ + try: + # Parse connection URL + conn_params = SnowflakeLoader._parse_snowflake_url(connection_url) + + # Connect to Snowflake database + conn = snowflake.connector.connect(**conn_params) + cursor = conn.cursor(DictCursor) + + # Get database and schema name + db_name = conn_params['database'] + # Snowflake stores unquoted identifiers in UPPERCASE; + # INFORMATION_SCHEMA lookups require the canonical form. + schema_name = conn_params['schema'].upper() + + # Get all table information + yield True, "Extracting table information..." + entities = SnowflakeLoader.extract_tables_info(cursor, db_name, schema_name) + + # Get all relationship information + yield True, "Extracting relationship information..." + relationships = SnowflakeLoader.extract_relationships(cursor, db_name, schema_name) + + # Close database connection + cursor.close() + conn.close() + + # Load data into graph + yield True, "Loading data into graph..." + await load_to_graph(f"{prefix}_{db_name}", entities, relationships, + db_name=db_name, db_url=connection_url) + + yield True, (f"Snowflake schema loaded successfully. " + f"Found {len(entities)} tables.") + + except snowflake.connector.Error as e: + logging.error("Snowflake error: %s", e) + yield False, f"Snowflake error: {e}" + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Error loading Snowflake schema: %s", e) + yield False, f"Failed to load Snowflake database schema: {e}" + + @staticmethod + def extract_tables_info(cursor, db_name: str, schema_name: str) -> Dict[str, Any]: + """ + Extract table and column information from Snowflake database. + + Args: + cursor: Database cursor + db_name: Database name + schema_name: Schema name + + Returns: + Dict containing table information + """ + # Validate identifiers to prevent SQL injection + SnowflakeLoader._validate_identifier(db_name, "database name") + SnowflakeLoader._validate_identifier(schema_name, "schema name") + + entities = {} + + # Get all tables in the schema + # Use quoted identifiers for database name, parameterize schema_name + quoted_db = SnowflakeLoader._quote_identifier(db_name) + cursor.execute(f""" + SELECT TABLE_NAME, COMMENT + FROM {quoted_db}.INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = %s + AND TABLE_TYPE = 'BASE TABLE' + ORDER BY TABLE_NAME; + """, (schema_name,)) + + tables = cursor.fetchall() + + for table_info in tqdm.tqdm(tables, desc="Extracting table information"): + table_name = table_info['TABLE_NAME'] + table_comment = table_info['COMMENT'] + + # Get column information for this table + columns_info = SnowflakeLoader.extract_columns_info( + cursor, db_name, schema_name, table_name + ) + + # Get foreign keys for this table + foreign_keys = SnowflakeLoader.extract_foreign_keys( + cursor, db_name, schema_name, table_name + ) + + # Generate table description + table_description = table_comment if table_comment else f"Table: {table_name}" + + # Get column descriptions for batch embedding + col_descriptions = [col_info['description'] for col_info in columns_info.values()] + + entities[table_name] = { + 'description': table_description, + 'columns': columns_info, + 'foreign_keys': foreign_keys, + 'col_descriptions': col_descriptions + } + + return entities + + @staticmethod + def extract_columns_info( # pylint: disable=too-many-locals + cursor, db_name: str, schema_name: str, table_name: str + ) -> Dict[str, Any]: + """ + Extract column information for a specific table. + + Args: + cursor: Database cursor + db_name: Database name + schema_name: Schema name + table_name: Name of the table + + Returns: + Dict containing column information + """ + # Validate identifiers to prevent SQL injection + SnowflakeLoader._validate_identifier(db_name, "database name") + SnowflakeLoader._validate_identifier(schema_name, "schema name") + SnowflakeLoader._validate_identifier(table_name, "table name") + + quoted_db = SnowflakeLoader._quote_identifier(db_name) + + cursor.execute(f""" + SELECT + COLUMN_NAME, + DATA_TYPE, + IS_NULLABLE, + COLUMN_DEFAULT, + COMMENT + FROM {quoted_db}.INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = %s + AND TABLE_NAME = %s + ORDER BY ORDINAL_POSITION; + """, (schema_name, table_name)) + + columns = cursor.fetchall() + columns_info = {} + + # Get primary key information using Snowflake's SHOW command + quoted_table = SnowflakeLoader._quote_identifier(table_name) + quoted_schema = SnowflakeLoader._quote_identifier(schema_name) + cursor.execute(f"SHOW PRIMARY KEYS IN TABLE {quoted_db}.{quoted_schema}.{quoted_table}") + primary_keys = {row['column_name'] for row in cursor.fetchall()} + + # Get foreign key columns using Snowflake's SHOW IMPORTED KEYS + cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {quoted_db}.{quoted_schema}.{quoted_table}") + foreign_keys_cols = {row['fk_column_name'] for row in cursor.fetchall()} + + for col_info in columns: + col_name = col_info['COLUMN_NAME'] + data_type = col_info['DATA_TYPE'] + is_nullable = col_info['IS_NULLABLE'] + column_default = col_info['COLUMN_DEFAULT'] + column_comment = col_info['COMMENT'] + + # Determine key type + if col_name in primary_keys: + key_type = 'PRIMARY KEY' + elif col_name in foreign_keys_cols: + key_type = 'FOREIGN KEY' + else: + key_type = 'NONE' + + # Generate column description + description_parts = [] + if column_comment: + description_parts.append(column_comment) + else: + description_parts.append(f"Column {col_name} of type {data_type}") + + if key_type != 'NONE': + description_parts.append(f"({key_type})") + + if is_nullable == 'NO': + description_parts.append("(NOT NULL)") + + if column_default is not None: + description_parts.append(f"(Default: {column_default})") + + # Extract sample values for the column (stored separately, not in description) + sample_values = SnowflakeLoader.extract_sample_values_for_column( + cursor, table_name, col_name + ) + + columns_info[col_name] = { + 'type': data_type, + 'null': is_nullable, + 'key': key_type, + 'description': ' '.join(description_parts), + 'default': column_default, + 'sample_values': sample_values + } + + return columns_info + + @staticmethod + def extract_foreign_keys( + cursor, db_name: str, schema_name: str, table_name: str + ) -> List[Dict[str, str]]: + """ + Extract foreign key information for a specific table. + + Args: + cursor: Database cursor + db_name: Database name + schema_name: Schema name + table_name: Name of the table + + Returns: + List of foreign key dictionaries + """ + # Validate identifiers to prevent SQL injection + SnowflakeLoader._validate_identifier(db_name, "database name") + SnowflakeLoader._validate_identifier(schema_name, "schema name") + SnowflakeLoader._validate_identifier(table_name, "table name") + + quoted_db = SnowflakeLoader._quote_identifier(db_name) + quoted_schema = SnowflakeLoader._quote_identifier(schema_name) + quoted_table = SnowflakeLoader._quote_identifier(table_name) + + # Use Snowflake's SHOW IMPORTED KEYS for foreign key information + cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {quoted_db}.{quoted_schema}.{quoted_table}") + + foreign_keys = [] + for fk_info in cursor.fetchall(): + foreign_keys.append({ + 'constraint_name': fk_info['fk_name'], + 'column': fk_info['fk_column_name'], + 'referenced_table': fk_info['pk_table_name'], + 'referenced_column': fk_info['pk_column_name'] + }) + + return foreign_keys + + @staticmethod + def extract_relationships( + cursor, db_name: str, schema_name: str + ) -> Dict[str, List[Dict[str, str]]]: + """ + Extract all relationship information from the database. + + Args: + cursor: Database cursor + db_name: Database name + schema_name: Schema name + + Returns: + Dict containing relationship information + """ + # Validate identifiers to prevent SQL injection + SnowflakeLoader._validate_identifier(db_name, "database name") + SnowflakeLoader._validate_identifier(schema_name, "schema name") + + quoted_db = SnowflakeLoader._quote_identifier(db_name) + quoted_schema = SnowflakeLoader._quote_identifier(schema_name) + + # Use Snowflake's SHOW IMPORTED KEYS for each table to get relationships + cursor.execute(f""" + SELECT TABLE_NAME + FROM {quoted_db}.INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = %s + AND TABLE_TYPE = 'BASE TABLE' + ORDER BY TABLE_NAME; + """, (schema_name,)) + tables = [row['TABLE_NAME'] for row in cursor.fetchall()] + + relationships = {} + for tbl in tables: + SnowflakeLoader._validate_identifier(tbl, "table name") + quoted_table = SnowflakeLoader._quote_identifier(tbl) + cursor.execute( + f"SHOW IMPORTED KEYS IN TABLE {quoted_db}.{quoted_schema}.{quoted_table}" + ) + for rel_info in cursor.fetchall(): + constraint_name = rel_info['fk_name'] + + if constraint_name not in relationships: + relationships[constraint_name] = [] + + relationships[constraint_name].append({ + 'from': rel_info['fk_table_name'], + 'to': rel_info['pk_table_name'], + 'source_column': rel_info['fk_column_name'], + 'target_column': rel_info['pk_column_name'], + 'note': f'Foreign key constraint: {constraint_name}' + }) + + return relationships + + @staticmethod + def is_schema_modifying_query(sql_query: str) -> Tuple[bool, str]: + """ + Check if a SQL query modifies the database schema. + + Args: + sql_query: The SQL query to check + + Returns: + Tuple of (is_schema_modifying, operation_type) + """ + if not sql_query or not sql_query.strip(): + return False, "" + + # Clean and normalize the query + normalized_query = sql_query.strip().upper() + + # Check for basic DDL operations + first_word = normalized_query.split()[0] if normalized_query.split() else "" + if first_word in SnowflakeLoader.SCHEMA_MODIFYING_OPERATIONS: + # Additional pattern matching for more precise detection + for pattern in SnowflakeLoader.SCHEMA_PATTERNS: + if re.match(pattern, normalized_query, re.IGNORECASE): + return True, first_word + + # If it's a known DDL operation but doesn't match specific patterns, + # still consider it schema-modifying (better safe than sorry) + return True, first_word + + return False, "" + + @staticmethod + async def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: + """ + Refresh the graph schema by clearing existing data and reloading from the database. + + Args: + graph_id: The graph ID to refresh + db_url: Database connection URL + + Returns: + Tuple of (success, message) + """ + try: + logging.info("Schema modification detected. Refreshing graph schema.") + + # Import here to avoid circular imports + from api.extensions import db # pylint: disable=import-error,import-outside-toplevel + + # Clear existing graph data + # Drop current graph before reloading + graph = db.select_graph(graph_id) + await graph.delete() + + # Extract prefix from graph_id (remove database name part) + # graph_id format is typically "prefix_database_name" + parts = graph_id.split('_') + if len(parts) >= 2: + # Reconstruct prefix by joining all parts except the last one + prefix = '_'.join(parts[:-1]) + else: + prefix = graph_id + + # Reuse the existing load method to reload the schema + success = False + message = "" + async for progress_tuple in SnowflakeLoader.load(prefix, db_url): + success, message = progress_tuple + + if success: + logging.info("Graph schema refreshed successfully.") + return True, message + + logging.error("Schema refresh failed") + return False, "Failed to reload schema" + + except Exception as e: # pylint: disable=broad-exception-caught + # Log the error and return failure + logging.error("Error refreshing graph schema: %s", str(e)) + error_msg = "Error refreshing graph schema" + logging.error(error_msg) + return False, error_msg + + @staticmethod + def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: + """ + Execute a SQL query on the Snowflake database and return the results. + + Args: + sql_query: The SQL query to execute + db_url: Snowflake connection URL in format: + snowflake://username:password@account/database/schema?warehouse=warehouse_name + + Returns: + List of dictionaries containing the query results + """ + try: + # Parse connection URL + conn_params = SnowflakeLoader._parse_snowflake_url(db_url) + + # Connect to Snowflake database + conn = snowflake.connector.connect(**conn_params) + cursor = conn.cursor(DictCursor) + + # Execute the SQL query + cursor.execute(sql_query) + + # Check if the query returns results (SELECT queries) + if cursor.description is not None: + # This is a SELECT query or similar that returns rows + results = cursor.fetchall() + result_list = [] + for row in results: + # Serialize each value to ensure JSON compatibility + serialized_row = { + key: SnowflakeLoader._serialize_value(value) + for key, value in row.items() + } + result_list.append(serialized_row) + else: + # This is an INSERT, UPDATE, DELETE, or other non-SELECT query + # Return information about the operation + affected_rows = cursor.rowcount + sql_type = sql_query.strip().split()[0].upper() + + if sql_type in ['INSERT', 'UPDATE', 'DELETE']: + result_list = [{ + "operation": sql_type, + "affected_rows": affected_rows, + "status": "success" + }] + else: + # For other types of queries (CREATE, DROP, etc.) + result_list = [{ + "operation": sql_type, + "status": "success" + }] + + # Commit the transaction for write operations + conn.commit() + + # Close database connection + cursor.close() + conn.close() + + return result_list + + except snowflake.connector.Error as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() + logging.error("Snowflake query execution error: %s", e) + raise SnowflakeQueryError(f"Snowflake query execution error: {str(e)}") from e + except Exception as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() + logging.error("Error executing SQL query: %s", e) + raise SnowflakeQueryError(f"Error executing SQL query: {str(e)}") from e diff --git a/api/memory/graphiti_tool.py b/api/memory/graphiti_tool.py index 3dc7279c..00a1b3e6 100644 --- a/api/memory/graphiti_tool.py +++ b/api/memory/graphiti_tool.py @@ -65,6 +65,7 @@ def __init__(self, user_id: str, graph_id: str): # Create Graphiti client with Azure OpenAI configuration self.graphiti_client = create_graphiti_client(falkor_driver) + self.memory_enabled = self.graphiti_client is not None self.user_id = user_id self.graph_id = graph_id @@ -82,6 +83,9 @@ async def create(cls, user_id: str, graph_id: str, use_direct_entities: bool = T """Async factory to construct and initialize the tool.""" self = cls(user_id, graph_id) + if not self.memory_enabled: + return self + await self._ensure_entity_nodes_direct(user_id, graph_id) @@ -301,16 +305,18 @@ async def add_new_memory(self, conversation: Dict[str, Any], history: Tuple[List async def save_query_memory(self, query: str, sql_query: str, success: bool, error: Optional[str] = None) -> bool: """ Save individual query memory directly to the database node. - + Args: query: The user's natural language query sql_query: The generated SQL query success: Whether the query execution was successful error: Error message if the query failed - + Returns: bool: True if memory was saved successfully, False otherwise """ + if not self.memory_enabled: + return False try: database_name = self.graph_id database_node_name = f"Database {database_name}" @@ -396,6 +402,8 @@ async def retrieve_similar_queries(self, query: str, limit: int = 5) -> List[Dic Returns: A list of similar query metadata. """ + if not self.memory_enabled: + return [] try: database_name = self.graph_id @@ -457,10 +465,12 @@ async def search_user_summary(self, limit: int = 5) -> str: Args: query: Natural language query to search for limit: Maximum number of results to return - + Returns: List of user node summaries with metadata """ + if not self.memory_enabled: + return "" try: driver = self.graphiti_client.driver query = """ @@ -508,14 +518,16 @@ async def extract_episode_from_rel(self, rel_result): async def search_database_facts(self, query: str, limit: int = 5, episode_limit: int = 3) -> str: """ Search for database-specific facts and interaction history using database node as center. - + Args: query: Natural language query to search for database facts limit: Maximum number of results to return - + Returns: String containing all relevant database facts with time relevancy information """ + if not self.memory_enabled: + return "" try: driver = self.graphiti_client.driver query = """ @@ -568,15 +580,18 @@ async def search_memories(self, query: str, user_limit: int = 5, database_limit: """ Run both user summary and database facts searches concurrently for better performance. Also builds a comprehensive memory context string for the analysis agent. - + Args: query: Natural language query to search for database facts user_limit: Maximum number of results for user summary search database_limit: Maximum number of results for database facts search - + Returns: - Dict containing user_summary, database_facts, similar_queries, and memory_context + A formatted memory context string combining user summary, database facts, + and similar query history, or empty string if memory is disabled. """ + if not self.memory_enabled: + return "" try: # Run both searches concurrently using asyncio.gather user_summary_task = self.search_user_summary(limit=user_limit) @@ -835,10 +850,13 @@ def create_graphiti_client(falkor_driver: FalkorDriver) -> Graphiti: else: # Non-OpenAI/Azure providers (Gemini, Anthropic, Ollama, Cohere): # Graphiti memory requires OpenAI-compatible embeddings. - # Use LiteLLM embeddings via Config instead. - graphiti_client = Graphiti( - graph_driver=falkor_driver, + # Memory is not supported for these providers. + logging.warning( + "Memory is only supported with Azure or OpenAI providers. " + "Current provider: %s. Memory will be disabled.", + getattr(Config, 'LLM_PROVIDER', 'unknown') ) + return None return graphiti_client diff --git a/api/routes/settings.py b/api/routes/settings.py index 5f6ebd42..554081d2 100644 --- a/api/routes/settings.py +++ b/api/routes/settings.py @@ -42,14 +42,12 @@ async def validate_api_key(request: Request, data: ValidateKeyRequest): # pylin status_code=400 ) - # Validate vendor is supported - supported_vendors = ( - "openai", "anthropic", "gemini", "azure", "ollama", "cohere", - ) - if vendor not in supported_vendors: - allowed = ", ".join(supported_vendors) + # Validate vendor — only key-based vendors can be validated via API call + validatable_vendors = ("openai", "anthropic", "gemini", "cohere") + if vendor not in validatable_vendors: + allowed = ", ".join(validatable_vendors) return JSONResponse( - content={"valid": False, "error": f"Unsupported vendor. Supported: {allowed}"}, + content={"valid": False, "error": f"Unsupported vendor for key validation. Supported: {allowed}"}, status_code=400 ) diff --git a/api/utils.py b/api/utils.py index 845ca604..e6979876 100644 --- a/api/utils.py +++ b/api/utils.py @@ -98,7 +98,8 @@ def create_combined_description( # pylint: disable=too-many-locals if isinstance(batch_response, Exception): table_info[table_name]["description"] = table_name else: - content = batch_response.choices[0].message["content"].strip() + msg_content = batch_response.choices[0].message["content"] + content = msg_content.strip() if msg_content else table_name table_info[table_name]["description"] = content return table_info diff --git a/app/src/components/modals/DatabaseModal.tsx b/app/src/components/modals/DatabaseModal.tsx index 32590e63..e3a7f47a 100644 --- a/app/src/components/modals/DatabaseModal.tsx +++ b/app/src/components/modals/DatabaseModal.tsx @@ -31,6 +31,13 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { const [password, setPassword] = useState(""); const [schema, setSchema] = useState(""); const [schemaError, setSchemaError] = useState(""); + // Snowflake-specific fields + const [account, setAccount] = useState(""); + const [snowflakeSchema, setSnowflakeSchema] = useState("PUBLIC"); + const [warehouse, setWarehouse] = useState("COMPUTE_WH"); + const [authMode, setAuthMode] = useState<'password' | 'keypair'>('password'); + const [privateKey, setPrivateKey] = useState(""); + const [privateKeyPassphrase, setPrivateKeyPassphrase] = useState(""); const [isConnecting, setIsConnecting] = useState(false); const [connectionSteps, setConnectionSteps] = useState([]); const { refreshGraphs } = useDatabase(); @@ -75,13 +82,32 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { return; } } else { - if (!selectedDatabase || !host || !port || !database || !username) { - toast({ - title: "Missing Information", - description: "Please fill in all required fields", - variant: "destructive", - }); - return; + if (selectedDatabase === 'snowflake') { + if (!account || !database || !username) { + toast({ + title: "Missing Information", + description: "Please fill in all required fields (account, database, username)", + variant: "destructive", + }); + return; + } + if (authMode === 'keypair' && !privateKey) { + toast({ + title: "Missing Information", + description: "Please paste your private key in PEM format", + variant: "destructive", + }); + return; + } + } else { + if (!selectedDatabase || !host || !port || !database || !username) { + toast({ + title: "Missing Information", + description: "Please fill in all required fields", + variant: "destructive", + }); + return; + } } } @@ -92,20 +118,37 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { // Build the connection URL let dbUrl = connectionUrl; if (connectionMode === 'manual') { - const protocol = selectedDatabase === 'mysql' ? 'mysql' : 'postgresql'; - const builtUrl = new URL(`${protocol}://${host}:${port}/${database}`); - builtUrl.username = username; - builtUrl.password = password; - - // Append schema option for PostgreSQL if provided - if (selectedDatabase === 'postgresql' && schema.trim()) { - if (/[^a-zA-Z0-9_]/.test(schema.trim())) { - throw new Error('Schema name can only contain letters, digits, and underscores'); + if (selectedDatabase === 'snowflake') { + // Build Snowflake URL: snowflake://user@account/database/schema?warehouse=WH + const builtUrl = new URL(`snowflake://${account}/${database}/${snowflakeSchema}`); + builtUrl.username = username; + if (authMode === 'keypair' && privateKey) { + // Base64-encode the PEM key for safe URL transport + builtUrl.searchParams.set('private_key', btoa(privateKey)); + if (privateKeyPassphrase) { + builtUrl.searchParams.set('private_key_passphrase', privateKeyPassphrase); + } + } else { + builtUrl.password = password; } - builtUrl.searchParams.set('options', `-csearch_path=${schema.trim()}`); - } + builtUrl.searchParams.set('warehouse', warehouse); + dbUrl = builtUrl.toString(); + } else { + const protocol = selectedDatabase === 'mysql' ? 'mysql' : 'postgresql'; + const builtUrl = new URL(`${protocol}://${host}:${port}/${database}`); + builtUrl.username = username; + builtUrl.password = password; - dbUrl = builtUrl.toString(); + // Append schema option for PostgreSQL if provided + if (selectedDatabase === 'postgresql' && schema.trim()) { + if (/[^a-zA-Z0-9_]/.test(schema.trim())) { + throw new Error('Schema name can only contain letters, digits, and underscores'); + } + builtUrl.searchParams.set('options', `-csearch_path=${schema.trim()}`); + } + + dbUrl = builtUrl.toString(); + } } // Make streaming request @@ -190,6 +233,12 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { setPassword(""); setSchema(""); setSchemaError(""); + setAccount(""); + setSnowflakeSchema("PUBLIC"); + setWarehouse("COMPUTE_WH"); + setAuthMode('password'); + setPrivateKey(""); + setPrivateKeyPassphrase(""); setConnectionSteps([]); }, 1000); } else { @@ -252,7 +301,7 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { Connect to Database - Connect to PostgreSQL or MySQL database using a connection URL or manual entry.{" "} + Connect to PostgreSQL, MySQL, or Snowflake database using a connection URL or manual entry.{" "} { MySQL + +
+
+ Snowflake +
+
@@ -328,106 +383,247 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { placeholder={ selectedDatabase === 'postgresql' ? 'postgresql://username:password@host:5432/database' - : 'mysql://username:password@host:3306/database' + : selectedDatabase === 'mysql' + ? 'mysql://username:password@host:3306/database' + : 'snowflake://username:password@account/database/schema?warehouse=warehouse_name' } value={connectionUrl} onChange={(e) => setConnectionUrl(e.target.value)} className="bg-muted border-border font-mono text-sm focus-visible:ring-purple-500" />

- Enter your database connection string + {selectedDatabase === 'snowflake' + ? 'Enter your Snowflake connection string (schema defaults to PUBLIC, warehouse to COMPUTE_WH)' + : 'Enter your database connection string'}

)} {selectedDatabase && connectionMode === 'manual' && ( <> -
- - setHost(e.target.value)} - className="bg-muted border-border focus-visible:ring-purple-500" - /> -
- -
- - setPort(e.target.value)} - className="bg-muted border-border focus-visible:ring-purple-500" - /> -
+ {selectedDatabase === 'snowflake' ? ( + <> +
+ + setAccount(e.target.value)} + className="bg-muted border-border focus-visible:ring-purple-500" + /> +

+ Your Snowflake account identifier (e.g., myorg-account) +

+
-
- - setDatabase(e.target.value)} - className="bg-muted border-border focus-visible:ring-purple-500" - /> -
+
+ + setDatabase(e.target.value)} + className="bg-muted border-border focus-visible:ring-purple-500" + /> +
-
- - setUsername(e.target.value)} - className="bg-muted border-border focus-visible:ring-purple-500" - /> -
+
+ + setSnowflakeSchema(e.target.value)} + className="bg-muted border-border focus-visible:ring-purple-500" + /> +

+ Defaults to PUBLIC if not specified +

+
-
- - setPassword(e.target.value)} - className="bg-muted border-border focus-visible:ring-purple-500" - /> -
- - {/* Schema field - PostgreSQL only */} - {selectedDatabase === 'postgresql' && ( -
- - { - const val = e.target.value; - setSchema(val); - if (val && /[^a-zA-Z0-9_]/.test(val)) { - setSchemaError('Schema name can only contain letters, digits, and underscores'); - } else { - setSchemaError(''); - } - }} - className={`bg-muted border-border ${schemaError ? 'border-red-500' : ''}`} - /> - {schemaError ? ( -

{schemaError}

- ) : ( +
+ + setWarehouse(e.target.value)} + className="bg-muted border-border focus-visible:ring-purple-500" + />

- Leave empty to use the default 'public' schema + Defaults to COMPUTE_WH if not specified

+
+ +
+ + setUsername(e.target.value)} + className="bg-muted border-border focus-visible:ring-purple-500" + /> +
+ + {/* Auth Mode Toggle */} +
+ +
+ + +
+
+ + {authMode === 'password' ? ( +
+ + setPassword(e.target.value)} + className="bg-muted border-border focus-visible:ring-purple-500" + /> +
+ ) : ( + <> +
+ +