diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b2efdf35..aa0fce55 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,7 +37,6 @@ jobs: PYTHON: ${{ matrix.python-version }} # Do not tear down Testcontainers TC_KEEPALIVE: true - # https://docs.github.com/en/actions/using-containerized-services/about-service-containers services: cratedb: diff --git a/.github/workflows/nlsql.yml b/.github/workflows/nlsql.yml new file mode 100644 index 00000000..010195d5 --- /dev/null +++ b/.github/workflows/nlsql.yml @@ -0,0 +1,83 @@ +--- +name: "Tests: NLSQL" + +on: + pull_request: + paths: + - '.github/workflows/nlsql.yml' + - 'cratedb_toolkit/query/nlsql/**' + - 'tests/query/*nlsql*' + - 'pyproject.toml' + push: + branches: [ main ] + paths: + - '.github/workflows/nlsql.yml' + - 'cratedb_toolkit/query/nlsql/**' + - 'tests/query/*nlsql*' + - 'pyproject.toml' + + # Allow job to be triggered manually. + workflow_dispatch: + + # Run the job each night after CrateDB nightly has been published. + schedule: + - cron: '0 3 * * *' + +# Cancel in-progress jobs when pushing to the same branch. +concurrency: + cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + +jobs: + + tests: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + python-version: [ + "3.10", + "3.14", + ] + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python-version }} + # Do not tear down Testcontainers + TC_KEEPALIVE: true + + name: Python ${{ matrix.python-version }} on OS ${{ matrix.os }} + steps: + + - name: Acquire sources + uses: actions/checkout@v6 + + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + activate-environment: 'true' + cache-suffix: ${{ matrix.python-version }} + enable-cache: true + python-version: ${{ matrix.python-version }} + + - name: Set up project + run: | + # Install package in editable mode. + uv pip install --editable='.[nlsql,test]' + + - name: Run software tests + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + run: | + pytest -m nlsql + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v6 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + fail_ci_if_error: true diff --git a/CHANGES.md b/CHANGES.md index 3df4b9c6..7b027e73 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,7 @@ - Kinesis: Added `ctk kinesis` CLI group with `list-checkpoints` and `prune-checkpoints` commands for checkpoint table maintenance - Dependencies: Permitted installation of click 8.3 +- DataQuery: Help agents turn natural language into SQL queries ## 2026/03/16 v0.0.46 - I/O: API improvements: `ctk {load,save} table` became `ctk {load,save}` diff --git a/README.md b/README.md index f610fdea..2c7067ff 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,10 @@ [![ci-main][ci-main-badge]][ci-main-workflow] [![ci-cloud][ci-cloud-badge]][ci-cloud-workflow] +[![ci-nlsql][ci-nlsql-badge]][ci-nlsql-workflow] + [![ci-dynamodb][ci-dynamodb-badge]][ci-dynamodb-workflow] [![ci-influxdb][ci-influxdb-badge]][ci-influxdb-workflow] - [![ci-kinesis][ci-kinesis-badge]][ci-kinesis-workflow] [![ci-mongodb][ci-mongodb-badge]][ci-mongodb-workflow] [![ci-postgresql][ci-postgresql-badge]][ci-postgresql-workflow] @@ -99,6 +100,8 @@ pip install 'cratedb-toolkit[full]==0.0.38' [ci-kinesis-workflow]: https://github.com/crate/cratedb-toolkit/actions/workflows/kinesis.yml [ci-mongodb-badge]: https://github.com/crate/cratedb-toolkit/actions/workflows/mongodb.yml/badge.svg [ci-mongodb-workflow]: https://github.com/crate/cratedb-toolkit/actions/workflows/mongodb.yml +[ci-nlsql-badge]: https://github.com/crate/cratedb-toolkit/actions/workflows/nlsql.yml/badge.svg +[ci-nlsql-workflow]: https://github.com/crate/cratedb-toolkit/actions/workflows/nlsql.yml [ci-postgresql-badge]: https://github.com/crate/cratedb-toolkit/actions/workflows/postgresql.yml/badge.svg [ci-postgresql-workflow]: https://github.com/crate/cratedb-toolkit/actions/workflows/postgresql.yml [ci-pymongo-badge]: https://github.com/crate/cratedb-toolkit/actions/workflows/pymongo.yml/badge.svg diff --git a/cratedb_toolkit/query/cli.py b/cratedb_toolkit/query/cli.py index 208dd09c..1bed0f0b 100644 --- a/cratedb_toolkit/query/cli.py +++ b/cratedb_toolkit/query/cli.py @@ -1,26 +1,9 @@ -import logging - -import click -from click_aliases import ClickAliasedGroup - -from ..util.cli import boot_click +from ..util.app import make_cli from .convert.cli import convert_query from .mcp.cli import cli as mcp_cli +from .nlsql.cli import llm_cli -logger = logging.getLogger(__name__) - - -@click.group(cls=ClickAliasedGroup) -@click.option("--verbose", is_flag=True, required=False, help="Turn on logging") -@click.option("--debug", is_flag=True, required=False, help="Turn on logging with debug level") -@click.version_option() -@click.pass_context -def cli(ctx: click.Context, verbose: bool, debug: bool): - """ - Query utilities. - """ - return boot_click(ctx, verbose, debug) - - +cli = make_cli() cli.add_command(convert_query, name="convert") +cli.add_command(llm_cli, name="nlsql") cli.add_command(mcp_cli, name="mcp") diff --git a/cratedb_toolkit/query/nlsql/__init__.py b/cratedb_toolkit/query/nlsql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cratedb_toolkit/query/nlsql/api.py b/cratedb_toolkit/query/nlsql/api.py new file mode 100644 index 00000000..ec09bd6e --- /dev/null +++ b/cratedb_toolkit/query/nlsql/api.py @@ -0,0 +1,99 @@ +""" +Use an LLM to query a database in human language using LlamaIndex' NLSQLTableQueryEngine. +""" + +import contextlib +import dataclasses +import logging +from typing import Optional + +from cratedb_toolkit.query.nlsql.model import DatabaseInfo, ModelInfo + +logger = logging.getLogger(__name__) + +llama_index_import_error: Optional[ImportError] = None + +try: + from llama_index.core.base.response.schema import RESPONSE_TYPE + from llama_index.core.llms import LLM + from llama_index.core.query_engine import NLSQLTableQueryEngine + from llama_index.core.utilities.sql_wrapper import SQLDatabase +except ImportError as exc: + llama_index_import_error = exc + + +@dataclasses.dataclass +class DataQuery: + """ + DataQuery helps agents turn natural language into SQL queries. + It's the little sister of Google's QueryData product. [1] + + We recommend evaluating the Text-to-SQL interface using the Gemma models if you are + looking at non-frontier variants that need less resources for inference. However, + depending on the complexity of your problem, you may also want to use cutting-edge + models with your provider of choice at the cost of higher resource usage. + + Attention: Any natural language SQL table query engine and Text-to-SQL application + should be aware that executing arbitrary SQL queries can be a security risk. + It is recommended to take precautions as needed, such as using restricted roles, + read-only databases, sandboxing, etc. + + [1] https://cloud.google.com/blog/products/databases/introducing-querydata-for-near-100-percent-accurate-data-agents + [2] https://github.com/kupp0/multi-db-property-search-data-agents + """ + + db: DatabaseInfo + model: ModelInfo + query_engine: Optional["NLSQLTableQueryEngine"] = None + permit_all_statements: bool = False + + def __post_init__(self): + """Initialize query engine.""" + if self.query_engine is None: + self.setup() + + def setup(self): + """Configure database connection and query engine.""" + if llama_index_import_error: + raise ImportError( + "NLSQL support requires installing `cratedb-toolkit[nlsql]`" + ) from llama_index_import_error + + from cratedb_toolkit.query.nlsql.util import configure_llm, disable_embeddings + + # Configure model. + logger.info("Configuring LLM: provider=%s, name=%s", self.model.provider.name, self.model.name) + llm: LLM = configure_llm(self.model) + logger.info("Selected LLM: %s", llm.metadata.model_dump_json()) + + # Configure database. + self.db.setup() + + # schema = quote_relation_name(self.db.schema) if self.db.schema else None # noqa: ERA001 + + # Configure NLSQL query engine. + logger.info("Creating query engine") + sql_database = SQLDatabase( + self.db.get_engine(), + schema=self.db.schema, + ignore_tables=self.db.ignore_tables, + include_tables=self.db.include_tables, + ) + with disable_embeddings(): + self.query_engine = NLSQLTableQueryEngine( + sql_database=sql_database, + llm=llm, + ) + + def ask(self, question: str) -> "RESPONSE_TYPE": + """Invoke an inquiry to the LLM.""" + from cratedb_toolkit.query.nlsql.sqlgate import enable_sql_gateway + + if not self.query_engine: + raise ValueError("Query engine not configured") + if self.permit_all_statements: + sql_gateway = contextlib.nullcontext + else: + sql_gateway = enable_sql_gateway + with sql_gateway(): + return self.query_engine.query(question) diff --git a/cratedb_toolkit/query/nlsql/cli.py b/cratedb_toolkit/query/nlsql/cli.py new file mode 100644 index 00000000..3ed10e96 --- /dev/null +++ b/cratedb_toolkit/query/nlsql/cli.py @@ -0,0 +1,112 @@ +import json +import logging +import os +import sys +from typing import Optional + +import click +from dotenv import load_dotenv + +from cratedb_toolkit.option import ( + option_cluster_id, + option_cluster_name, + option_cluster_url, + option_password, + option_schema, + option_username, +) +from cratedb_toolkit.query.nlsql.api import DataQuery +from cratedb_toolkit.query.nlsql.model import DatabaseInfo +from cratedb_toolkit.util.common import setup_logging +from cratedb_toolkit.util.data import asbool + +logger = logging.getLogger(__name__) + + +def help_llm(): + """ + Use an LLM to query the database in human language. + + Synopsis + ======== + + export CRATEDB_CLUSTER_URL=crate://localhost/ + ctk query nlsql "What is the average value for sensor 1?" + + """ # noqa: E501 + + +@click.command() +@click.argument("question") +@option_cluster_id +@option_cluster_name +@option_cluster_url +@option_username +@option_password +@option_schema +@click.option("--llm-provider", type=str, required=False, help="LLM provider name") +@click.option("--llm-endpoint", type=str, required=False, help="LLM endpoint URL") +@click.option( + "--llm-instance", type=str, required=False, help="LLM model resource name, e.g. with Azure OpenAI service" +) +@click.option("--llm-name", type=str, required=False, help="LLM model name for completions") +@click.option("--llm-api-key", type=str, required=False, help="LLM API key") +@click.option("--llm-api-version", type=str, required=False, help="LLM API version") +@click.pass_context +def llm_cli( + ctx: click.Context, + question: str, + cluster_id: str, + cluster_name: str, + cluster_url: str, + username: str, + password: str, + schema: str, + llm_provider: Optional[str], + llm_endpoint: Optional[str], + llm_instance: Optional[str], + llm_name: Optional[str], + llm_api_key: Optional[str], + llm_api_version: Optional[str], +): + """ + Use an LLM to query a database in human language. + """ + from cratedb_toolkit.query.nlsql.util import read_llm_options + + setup_logging() + load_dotenv() + + # Read question. + if question == "-": + question = sys.stdin.read().strip() + + schema = schema or "doc" + permit_all_statements = asbool(os.getenv("NLSQL_PERMIT_ALL_STATEMENTS")) + + # Connect to database and configure LLM. + dburi = ctx.meta["address"].cluster_url + + # Configure natural language query machinery. + dataquery = DataQuery( + db=DatabaseInfo( + dburi=dburi, + schema=schema, + ), + model=read_llm_options( + llm_provider=llm_provider, + llm_name=llm_name, + llm_endpoint=llm_endpoint, + llm_instance=llm_instance, + llm_api_key=llm_api_key, + llm_api_version=llm_api_version, + ), + permit_all_statements=permit_all_statements, + ) + + # Submit query. + response = dataquery.ask(question) + output = {"question": question, "answer": str(response)} + if response.metadata: + output.update(next(iter(response.metadata.values()))) + print(json.dumps(output, indent=2, default=str), file=sys.stdout) # noqa: T201 diff --git a/cratedb_toolkit/query/nlsql/model.py b/cratedb_toolkit/query/nlsql/model.py new file mode 100644 index 00000000..ca161661 --- /dev/null +++ b/cratedb_toolkit/query/nlsql/model.py @@ -0,0 +1,72 @@ +import dataclasses +from enum import Enum +from typing import List, Optional + +import sqlalchemy as sa +import sqlalchemy.event +from sqlalchemy_cratedb.support import quote_relation_name + + +class ModelProvider(Enum): + """Model provider choices.""" + + AMAZON_BEDROCK = "amazon_bedrock" + AMAZON_BEDROCK_CONVERSE = "amazon_bedrock_converse" + ANTHROPIC = "anthropic" + AZURE = "azure" + GOOGLE = "google" + HUGGINGFACE_SERVERLESS = "huggingface_serverless" + LLAMAFILE = "llamafile" + MISTRAL = "mistral" + OLLAMA = "ollama" + OPENAI = "openai" + OPENROUTER = "openrouter" + RUNPOD_SERVERLESS = "runpod_serverless" + + +@dataclasses.dataclass +class ModelInfo: + """Information about the model.""" + + provider: ModelProvider + name: Optional[str] = None + endpoint: Optional[str] = None + instance: Optional[str] = None + api_key: Optional[str] = None + api_version: Optional[str] = None + + +@dataclasses.dataclass +class DatabaseInfo: + """Information about the database.""" + + dburi: Optional[str] = None + engine: Optional[sa.engine.Engine] = None + schema: Optional[str] = None + ignore_tables: Optional[List[str]] = None + include_tables: Optional[List[str]] = None + _listener_registered: bool = dataclasses.field(default=False, init=False, repr=False) + + def setup(self): + """Set up SQLAlchemy engine and schema.""" + + if self.engine is None: + if self.dburi is None: + raise ValueError("Either SQLAlchemy connection URL or database engine object required") + self.engine = sa.create_engine(self.dburi, echo=False) + + def receive_engine_connect(conn): + """Configure search path.""" + if self.schema is not None: + conn.execute(sa.text(f"SET search_path={quote_relation_name(self.schema)};")) + conn.commit() + + if not self._listener_registered: + sqlalchemy.event.listen(self.engine, "engine_connect", receive_engine_connect) + self._listener_registered = True + + def get_engine(self) -> sa.engine.Engine: + """Return SQLAlchemy engine object.""" + if self.engine is None: + raise RuntimeError("Engine is not configured. Call setup() first.") + return self.engine diff --git a/cratedb_toolkit/query/nlsql/sqlgate.py b/cratedb_toolkit/query/nlsql/sqlgate.py new file mode 100644 index 00000000..b3b4ac1c --- /dev/null +++ b/cratedb_toolkit/query/nlsql/sqlgate.py @@ -0,0 +1,151 @@ +import contextlib +import dataclasses +import logging +import threading +from typing import Any, Callable, Dict, Tuple + +import sqlparse +from sqlalchemy.exc import ProgrammingError +from sqlparse.tokens import Keyword + +_protection_lock = threading.RLock() + +logger = logging.getLogger(__name__) + + +def make_protected_run_sql(original_run_sql: Callable) -> Callable: + """ + Replacement method for `SQLDatabase.run_sql` that only permits read-only queries. + """ + + def _protected_run_sql(self, command: str) -> Tuple[str, Dict]: + if not sql_is_permitted(command): + raise ProgrammingError(command, {}, ValueError("Rejected SQL command")) + return original_run_sql(self, command) + + return _protected_run_sql + + +@contextlib.contextmanager +def enable_sql_gateway(): + """ + Enable the SQL gateway for software-enforced read-only queries. + """ + from llama_index.core import SQLDatabase + + with _protection_lock: + original_run_sql = SQLDatabase.run_sql + try: + SQLDatabase.run_sql = make_protected_run_sql(original_run_sql) # ty: ignore[invalid-assignment] + yield + finally: + SQLDatabase.run_sql = original_run_sql # ty: ignore[invalid-assignment] + + +def sql_is_permitted(expression: str) -> bool: + """ + Validate the SQL expression, only permit read queries by default. + + NOTE: For serious protections, please use a dedicated read-only database user. + + FIXME: Revisit implementation, it might be too naive or weak. + Issue: https://github.com/crate/cratedb-mcp/issues/10 + Question: Does SQLAlchemy provide a solid read-only mode, or any other library? + """ + is_dql = SqlStatementClassifier(expression=expression).is_dql + if is_dql: + logger.info("Permitted SQL expression: %s", expression and expression[:50]) + else: + logger.warning("Denied SQL expression: %s", expression and expression[:50]) + return is_dql + + +@dataclasses.dataclass +class SqlStatementClassifier: + """ + Helper to classify an SQL statement. + + Here, most importantly: Provide the `is_dql` property that + signals truthfulness for read-only SQL SELECT statements only. + """ + + expression: str + + _parsed_sqlparse: Any = dataclasses.field(init=False, default=None) + + def __post_init__(self) -> None: + if self.expression is None: + self.expression = "" + if self.expression: + self.expression = self.expression.strip() + + def parse_sqlparse(self) -> Tuple[sqlparse.sql.Statement, ...]: + """ + Parse expression using traditional `sqlparse` library. + """ + if self._parsed_sqlparse is None: + self._parsed_sqlparse = sqlparse.parse(self.expression) + return self._parsed_sqlparse + + @property + def is_dql(self) -> bool: + """ + Is it a DQL statement, which effectively invokes read-only operations only? + """ + + if not self.expression: + return False + + # Check if the expression is valid and if it's a DQL/SELECT statement, + # also trying to consider `SELECT ... INTO ...` and evasive + # `SELECT * FROM users; \uff1b DROP TABLE users` statements. + return self.is_select and not self.is_camouflage + + @property + def is_select(self) -> bool: + """ + Whether the expression is an SQL SELECT statement. + """ + return self.operation == "SELECT" + + @property + def operation(self) -> str: + """ + The SQL operation: SELECT, INSERT, UPDATE, DELETE, CREATE, etc. + """ + parsed = self.parse_sqlparse() + return parsed[0].get_type().upper() + + @property + def is_camouflage(self) -> bool: + """ + Innocent-looking `SELECT` statements can evade filters. + """ + return self.is_select_into or self.is_evasive + + @property + def is_select_into(self) -> bool: + """ + Use traditional `sqlparse` for catching `SELECT ... INTO ...` statements. + Examples: + SELECT * INTO foobar FROM bazqux + SELECT * FROM bazqux INTO foobar + """ + # Flatten all tokens (including nested ones) and match on type+value. + statement = self.parse_sqlparse()[0] + return any(token.ttype is Keyword and token.value.upper() == "INTO" for token in statement.flatten()) + + @property + def is_evasive(self) -> bool: + """ + Use traditional `sqlparse` for catching evasive SQL statements. + + A practice picked up from CodeRabbit was to reject multiple statements + to prevent potential SQL injections. Is it a viable suggestion? + + Examples: + + SELECT * FROM users; \uff1b DROP TABLE users + """ + parsed = self.parse_sqlparse() + return len(parsed) > 1 diff --git a/cratedb_toolkit/query/nlsql/util.py b/cratedb_toolkit/query/nlsql/util.py new file mode 100644 index 00000000..28669a01 --- /dev/null +++ b/cratedb_toolkit/query/nlsql/util.py @@ -0,0 +1,328 @@ +import contextlib +import os +import threading +from typing import TYPE_CHECKING, Optional + +from cratedb_toolkit.query.nlsql.model import ModelInfo, ModelProvider + +if TYPE_CHECKING: + from llama_index.core.base.embeddings.base import BaseEmbedding + from llama_index.core.callbacks import CallbackManager + from llama_index.core.embeddings.utils import EmbedType + from llama_index.core.llms import LLM + +llama_index_import_error: Optional[ImportError] = None + +try: + from llama_index.core import MockEmbedding, set_global_handler, settings + from llama_index.core.embeddings import utils +except ImportError as exc: + llama_index_import_error = exc + + +def ensure_llama_index() -> None: + if llama_index_import_error is not None: + raise ImportError("NLSQL support requires installing `cratedb-toolkit[nlsql]`") from llama_index_import_error + + +_embedding_resolver_lock = threading.RLock() + + +def _mock_embed_model( + embed_model: Optional["EmbedType"] = None, + callback_manager: Optional["CallbackManager"] = None, +) -> "BaseEmbedding": + """Stub that suppresses embedding resolution without print/side effects.""" + return MockEmbedding(embed_dim=1) + + +@contextlib.contextmanager +def disable_embeddings(): + """ + Temporarily suppress LlamaIndex's embedding resolver. + + ``NLSQLTableQueryEngine`` does not require embeddings, but LlamaIndex may + still invoke ``resolve_embed_model`` during construction. This context + manager replaces both resolution hooks with a no-op stub and guarantees + the originals are restored on exit, even if an exception is raised. + """ + ensure_llama_index() + with _embedding_resolver_lock: + original_utils = utils.resolve_embed_model + original_settings = settings.resolve_embed_model + try: + utils.resolve_embed_model = _mock_embed_model # ty: ignore[invalid-assignment] + settings.resolve_embed_model = _mock_embed_model # ty: ignore[invalid-assignment] + yield + finally: + utils.resolve_embed_model = original_utils # ty: ignore[invalid-assignment] + settings.resolve_embed_model = original_settings # ty: ignore[invalid-assignment] + + +DEFAULT_MODEL_MAP = { + # https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + ModelProvider.AMAZON_BEDROCK: "global.anthropic.claude-haiku-4-5-20251001-v1:0", + # ModelProvider.AMAZON_BEDROCK_CONVERSE: "amazon.nova-micro-v1:0", # noqa: ERA001 + ModelProvider.AMAZON_BEDROCK_CONVERSE: "global.amazon.nova-2-lite-v1:0", + # ModelProvider.AMAZON_BEDROCK_CONVERSE: "global.anthropic.claude-haiku-4-5-20251001-v1:0", # noqa: ERA001 + ModelProvider.ANTHROPIC: "claude-haiku-4-5", + ModelProvider.AZURE: "gpt-4.1", # TODO: Not validated yet. + ModelProvider.GOOGLE: "gemini-2.5-flash", # TODO: Not validated yet. + ModelProvider.HUGGINGFACE_SERVERLESS: "HuggingFaceH4/zephyr-7b-alpha", # TODO: Not validated yet. + ModelProvider.OLLAMA: "gemma3:1b", + ModelProvider.OPENAI: "gpt-4o-mini", + ModelProvider.OPENROUTER: "gryphe/mythomax-l2-13b", + ModelProvider.LLAMAFILE: "n/a", # Only one model per process. + ModelProvider.MISTRAL: "mistral-medium-latest", # TODO: Not validated yet. + ModelProvider.RUNPOD_SERVERLESS: "gemma3:270m", +} + + +def read_llm_options( + llm_provider: Optional[str], + llm_endpoint: Optional[str], + llm_instance: Optional[str], + llm_name: Optional[str], + llm_api_key: Optional[str], + llm_api_version: Optional[str], +) -> ModelInfo: + """Read options and apply parameter sanity checks and heuristics.""" + + llm_provider = llm_provider or os.getenv("LLM_PROVIDER") + llm_endpoint = llm_endpoint or os.getenv("LLM_ENDPOINT") + llm_instance = llm_instance or os.getenv("LLM_INSTANCE") + llm_name = llm_name or os.getenv("LLM_NAME") + llm_api_key = llm_api_key or os.getenv("LLM_API_KEY") + llm_api_version = llm_api_version or os.getenv("LLM_API_VERSION") + if not llm_provider: + raise ValueError("LLM provider name is required") + + provider = ModelProvider(llm_provider) + + if not llm_name: + llm_name = DEFAULT_MODEL_MAP.get(provider) + + if provider is ModelProvider.ANTHROPIC: + llm_api_key = llm_api_key or os.getenv("ANTHROPIC_API_KEY") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or ANTHROPIC_API_KEY environment variable." + ) + elif provider is ModelProvider.AZURE: + llm_endpoint = llm_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") + llm_api_key = llm_api_key or os.getenv("AZURE_OPENAI_API_KEY") + llm_api_version = llm_api_version or os.getenv("OPENAI_API_VERSION") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or AZURE_OPENAI_API_KEY environment variable." + ) + if not llm_endpoint: + raise ValueError( + "Azure OpenAI endpoint not defined. Use either CLI/API parameter or " + "AZURE_OPENAI_ENDPOINT environment variable." + ) + if not llm_api_version: + raise ValueError( + "Azure OpenAI API version not defined. Use either CLI/API parameter or " + "OPENAI_API_VERSION environment variable." + ) + elif provider is ModelProvider.GOOGLE: + llm_api_key = llm_api_key or os.getenv("GOOGLE_API_KEY") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or GOOGLE_API_KEY environment variable." + ) + elif provider is ModelProvider.HUGGINGFACE_SERVERLESS: + llm_api_key = llm_api_key or os.getenv("HF_TOKEN") + if not llm_api_key: + raise ValueError( + "LLM API token not defined. Use either CLI/API parameter or HF_TOKEN environment variable." + ) + elif provider is ModelProvider.MISTRAL: + llm_api_key = llm_api_key or os.getenv("MISTRAL_API_KEY") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or MISTRAL_API_KEY environment variable." + ) + elif provider is ModelProvider.OPENAI: + llm_api_key = llm_api_key or os.getenv("OPENAI_API_KEY") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or OPENAI_API_KEY environment variable." + ) + elif provider is ModelProvider.OPENROUTER: + llm_api_key = llm_api_key or os.getenv("OPENROUTER_API_KEY") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or OPENROUTER_API_KEY environment variable." + ) + elif provider is ModelProvider.RUNPOD_SERVERLESS: + llm_api_key = llm_api_key or os.getenv("RUNPOD_API_KEY") + if not llm_api_key: + raise ValueError( + "LLM API key not defined. Use either CLI/API parameter or RUNPOD_API_KEY environment variable." + ) + if not llm_endpoint: + raise ValueError( + "Runpod serverless endpoint not defined. " + "Use either CLI/API parameter or LLM_ENDPOINT environment variable." + ) + return ModelInfo( + provider=provider, + endpoint=llm_endpoint, + instance=llm_instance, + name=llm_name, + api_key=llm_api_key, + api_version=llm_api_version, + ) + + +def configure_llm(info: ModelInfo, debug: bool = False) -> "LLM": + """ + Configure LLM inference, local or remote. + + Supports Amazon Bedrock (+ Converse), Anthropic, Azure OpenAI, Google Gemini, + Hugging Face Inference API, Llamafile, Mistral, Ollama, OpenAI, OpenRouter, + and Runpod Serverless (OpenAI-compatible). + """ + ensure_llama_index() + + completion_model = info.name + + if not info.provider: + raise ValueError("LLM model provider not defined") + if not completion_model: + raise ValueError("LLM model name not defined") + + # https://docs.llamaindex.ai/en/stable/understanding/tracing_and_debugging/tracing_and_debugging/ + if debug: + set_global_handler("simple") + + # Select completions model. + if info.provider is ModelProvider.AMAZON_BEDROCK: + from llama_index.llms.bedrock import Bedrock + from llama_index.llms.bedrock_converse.utils import bedrock_modelname_to_context_size + + llm = Bedrock( + model=completion_model, + temperature=0.0, + context_size=bedrock_modelname_to_context_size(completion_model), + ) + elif info.provider is ModelProvider.AMAZON_BEDROCK_CONVERSE: + from llama_index.llms.bedrock_converse import BedrockConverse + + region_name = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") + llm = BedrockConverse( + model=completion_model, + temperature=0.0, + region_name=region_name, + ) + elif info.provider is ModelProvider.ANTHROPIC: + from llama_index.llms.anthropic import Anthropic + from llama_index.llms.anthropic.utils import CLAUDE_MODELS + + # TODO: Add new model types to upstream `llama-index-llms-anthropic`. + CLAUDE_MODELS.update({"claude-opus-4-7": 200000, "claude-sonnet-4-6": 1000000, "claude-haiku-4-5": 200000}) + + llm = Anthropic( + model=completion_model, + temperature=0.0, + base_url=info.endpoint, + api_key=info.api_key, + ) + elif info.provider is ModelProvider.AZURE: + from llama_index.llms.azure_openai import AzureOpenAI + + if not info.instance: + raise ValueError("Azure OpenAI deployment/engine instance name not defined") + llm = AzureOpenAI( + model=completion_model, + temperature=0.0, + engine=info.instance, + azure_endpoint=info.endpoint, + api_key=info.api_key, + api_version=info.api_version, + ) + + elif info.provider is ModelProvider.GOOGLE: + from llama_index.llms.google_genai import GoogleGenAI + + llm = GoogleGenAI( + model=completion_model, + temperature=0.0, + api_key=info.api_key, + ) + elif info.provider is ModelProvider.HUGGINGFACE_SERVERLESS: + from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI + + llm = HuggingFaceInferenceAPI( + model=completion_model, + temperature=0.1, + base_url=info.endpoint, + token=info.api_key, + ) + + elif info.provider is ModelProvider.LLAMAFILE: + from llama_index.llms.llamafile import Llamafile + + llm = Llamafile( + base_url=info.endpoint or "http://localhost:8080", + temperature=0.0, + ) + elif info.provider is ModelProvider.MISTRAL: + from llama_index.llms.mistralai import MistralAI + + llm = MistralAI( + model=completion_model, + temperature=0.0, + endpoint=info.endpoint, + api_key=info.api_key, + ) + + elif info.provider is ModelProvider.OLLAMA: + # https://docs.llamaindex.ai/en/stable/api_reference/llms/ollama/ + from llama_index.llms.ollama import Ollama + + llm = Ollama( + base_url=info.endpoint or "http://localhost:11434", + model=completion_model, + temperature=0.0, + request_timeout=120.0, + keep_alive=-1, + ) + elif info.provider is ModelProvider.OPENAI: + from llama_index.llms.openai import OpenAI + + llm = OpenAI( + model=completion_model, + temperature=0.0, + api_key=info.api_key, + api_version=info.api_version, + ) + elif info.provider is ModelProvider.OPENROUTER: + from llama_index.llms.openrouter.base import DEFAULT_API_BASE, DEFAULT_MODEL, OpenRouter + + llm = OpenRouter( + model=completion_model or DEFAULT_MODEL, + temperature=0.0, + api_base=info.endpoint or DEFAULT_API_BASE, + api_key=info.api_key, + ) + elif info.provider is ModelProvider.RUNPOD_SERVERLESS: + from llama_index.llms.openai_like import OpenAILike + + if not info.name: + raise ValueError("LLM model name is required") + if not info.endpoint: + raise ValueError("Runpod serverless endpoint is required") + + llm = OpenAILike( + model=info.name, + temperature=0.0, + api_base=info.endpoint, + api_key=info.api_key, + ) + else: + raise ValueError(f"LLM model provider not implemented: {info.provider}") + + return llm diff --git a/doc/query/index.md b/doc/query/index.md index 80ebd146..a1c35cd3 100644 --- a/doc/query/index.md +++ b/doc/query/index.md @@ -6,6 +6,7 @@ expressions: Adapters, converters, migration support tasks, etc. ```{toctree} :maxdepth: 2 +nlsql/index mcp/index convert ``` diff --git a/doc/query/nlsql/backlog.md b/doc/query/nlsql/backlog.md new file mode 100644 index 00000000..10b2cda9 --- /dev/null +++ b/doc/query/nlsql/backlog.md @@ -0,0 +1,112 @@ +--- +orphan: true +--- + +# NLSQL backlog + +## Iteration +1 + +- More examples + - https://huggingface.co/PipableAI/pip-sql-1.3b + - https://motherduck.com/blog/duckdb-text2sql-llm/ + - https://huggingface.co/Ellbendls/Qwen-2.5-3b-Text_to_SQL-GGUF + - https://github.com/distil-labs/distil-text2sql#usage-examples + - https://app.readytensor.ai/publications/generating-sql-from-natural-language-using-llama-32-jOImvIBGCfwt + +## Iteration +2 + +- Document `--include-tables`. +- Use as agentic tool? SKILLS.md? AGENTS.md? +- Exercise example that draws a table from database results. +- Exercise example that draws a graph from database results. +- Exercise example that uses time ranges. +- Exercise example that needs SQL JOINs. +- Exercise example that uses vector database features. +- Is the machinery using pgvector-specific prompt instructions + that should be adjusted for CrateDB? +- Demonstrate Gemma3 on Bedrock + - https://aws.amazon.com/bedrock/pricing/ + - https://github.com/run-llama/llama_index/pull/21380 +- Extract NLSQL from LlamaIndex into nlsql2? + - https://pypi.org/project/nlsql/ + - https://pypi.org/project/nlsql-api/ +- Bug: `WARNING : Denied SQL expression: SELECT DISTINCT c.customer_id, c.name`, + but not because of `SELECT DISTINCT`, but because Amazon's `amazon.nova-2-lite` + adds an explanation like this to the SQL statement, not protected by an SQL + comment or any such. + > This query selects customers who do not have any corresponding entries in + > the `orders` table. It uses a `LEFT JOIN` to combine `customers` with + > `orders`, and filters for rows where `order_id` is `NULL`, indicating no + > orders were placed by that customer. The `DISTINCT` keyword ensures each + > customer is listed only once. + +## Iteration +3 + +- Add providers: anyscale,openllm,vllm +- Validate providers: Azure, Google, Hugging Face, Mistral +- Tests: When using the vanilla schema `testdrive-data` with `from tests.conftest import TESTDRIVE_DATA_SCHEMA`, + the LLM gets confused, and thinks the table is called `sensor_data`. The error message is: + » The error indicates that the specified table, "sensor_data," is not recognized in the "testdrive-data" schema. +- How to prevent queries like `Who is Shakespeare?`? +- Maintain chat memory/context. + https://github.com/run-llama/llama_index/discussions/11424 +- https://unsloth.ai/docs/models/qwen3.5 + ```shell + ollama run hf.co/unsloth/Qwen3.5-0.8B-GGUF:UD-Q4_K_XL + ``` + +### Fine tuning +- Text2SQirreL 🐿️ : Query your data in plain English + https://github.com/distil-labs/distil-text2sql +- https://yia333.medium.com/enhancing-text-to-sql-with-a-fine-tuned-7b-llm-for-database-interactions-fa754dc2e992 +- https://www.promptlayer.com/models/pip-sql-13b-gguf/ + https://huggingface.co/PipableAI/pip-sql-1.3b +- https://huggingface.co/QuantFactory/Meta-Llama-3.1-8B-Text-to-SQL-GGUF +- https://motherduck.com/blog/duckdb-text2sql-llm/ + https://github.com/NumbersStationAI/DuckDB-NSQL +- https://huggingface.co/srujanamadiraju/nl-sql-gemma2b +- https://github.com/raghujhts13/text-to-sql + https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF +- https://app.readytensor.ai/publications/generating-sql-from-natural-language-using-llama-32-jOImvIBGCfwt + https://huggingface.co/sai-santhosh/text-2-sql-gguf +- https://huggingface.co/Ellbendls/Qwen-3-4b-Text_to_SQL-GGUF +- https://huggingface.co/Ellbendls/Qwen-2.5-3b-Text_to_SQL-GGUF/blob/main/Qwen-2.5-3b-Text_to_SQL.gguf +- https://www.jan.ai/docs/desktop/jan-models/lucy +- More runtimes + https://docs.docker.com/ai/model-runner/ + +## Notes + +LlamaIndex provides access to many LLM model inference engines and services via +Python packages available on PyPI prefixed with `llama-index-llms-`. +We've unlocked a few popular ones, but there are certainly many more. + +- Inference: anyscale,localai,mistral-rs,openllm,rapid-mlx +- API I: databricks,deepseek,huggingface,ibm,litellm,llama-api,llama-cpp,openai-like +- API II: azure-inference,cortex,grok,groq,meta,minimax,mlx,octoai,perplexity +- Router: cloudflare-ai-gateway,featherlessai,modelscope,nano-gpt,neutrino,ovhcloud +- More I: Dolly, Pythia, Nano-GPT (litellm), DuckDB-NSQL, nsql-llama-2-7B, pip-sql-1.3b-GGUF, SQLCoder-7B, Ellbendls/Qwen-3-4b-Text_to_SQL-GGUF +- More II: kwaipilot/kat-coder-pro-v2, undi95/remm-slerp-l2-13b + +## llamafile + +```shell +export LLM_PROVIDER="llamafile" +export LLM_ENDPOINT="http://localhost:8080/" +export LLM_NAME="n/a" +export LLM_ENDPOINT="http://localhost:8080/" +``` +```shell +wget https://huggingface.co/mozilla-ai/Llama-3.2-1B-Instruct-llamafile/resolve/main/Llama-3.2-1B-Instruct-Q6_K.llamafile +wget https://huggingface.co/mozilla-ai/llamafile_0.10.0/resolve/main/Qwen3.5-0.8B-Q8_0.llamafile +./Llama-3.2-1B-Instruct-Q6_K.llamafile +./Qwen3.5-0.8B-Q8_0.llamafile +``` +```shell +wget "https://github.com/mozilla-ai/llamafile/releases/download/0.10.0/llamafile-0.10.0" +wget "https://huggingface.co/Ellbendls/Qwen-3-4b-Text_to_SQL-GGUF/resolve/main/Qwen-3-4b-Text_to_SQL-q2_k.gguf?download=true" +``` + +## Security + +- https://github.com/rodrigo-pedro/P2SQL diff --git a/doc/query/nlsql/example-employee.md b/doc/query/nlsql/example-employee.md new file mode 100644 index 00000000..f196210f --- /dev/null +++ b/doc/query/nlsql/example-employee.md @@ -0,0 +1,54 @@ +(nlsql-example-employee)= + +# NLSQL with employee data + +Let's use a single `employees` database table +and populate it with a few records worth of data. + +:::{rubric} Provision +::: + +Create table and insert data. + +```sql +CREATE TABLE employees (id INT, name TEXT, department TEXT, hire_date TIMESTAMP); + +INSERT INTO employees (id, name, department, hire_date) VALUES +(1, 'Alice Johnson', 'Engineering', '2022-03-15'), +(2, 'Bob Smith', 'Marketing', '2021-07-01'), +(3, 'Carol Lee', 'Human Resources', '2020-11-23'), +(4, 'David Brown', 'Finance', '2019-05-30'), +(5, 'Eva Green', 'Engineering', '2023-01-10'), +(6, 'Frank Miller', 'Sales', '2019-08-12'), +(7, 'Grace Kim', 'Sales', '2021-02-18'), +(8, 'Henry Davis', 'Sales', '2022-06-25'), +(9, 'Isabella Martinez', 'Sales', '2020-12-05'), +(10, 'Jack Wilson', 'Sales', '2023-09-14'); +``` + +:::{rubric} Query +::: + +Submit a typical query in human language. + +```shell +ctk query nlsql "List all employees in the 'Sales' department hired after 2022." +``` + +:::{rubric} Response +::: + +The model figures out the SQL statement, the engine runs it, and +uses the model again to come back with an answer in human language: +```text +The employees in the Sales department hired after 2022 are Henry Davis and Jack Wilson. +``` + +The SQL statement was: +```sql +SELECT + name FROM employees +WHERE + department = 'Sales' AND + hire_date > '2022-01-01'; +``` diff --git a/doc/query/nlsql/example-product.md b/doc/query/nlsql/example-product.md new file mode 100644 index 00000000..0dc58b47 --- /dev/null +++ b/doc/query/nlsql/example-product.md @@ -0,0 +1,198 @@ +(nlsql-example-product)= + +# NLSQL with product orders + +Let's use a basic products / orders / customers database. + +```sql +CREATE TABLE customers (customer_id INTEGER, name VARCHAR, city VARCHAR, email_address VARCHAR, gender_code VARCHAR); +CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, amount INTEGER); +CREATE TABLE products (product_id INTEGER, name VARCHAR, price NUMERIC(2), size VARCHAR); +CREATE TABLE order_items (order_id INTEGER, product_id INTEGER); +``` + +## Basic JOINs and filtering + +:::{rubric} Provision +::: + +Create table and insert data. +Populate the table using a few records worth of example data. + +```sql +-- customers +INSERT INTO customers (customer_id, name, city) VALUES +(1, 'Alice', 'Berlin'), +(2, 'Bob', 'Munich'), +(3, 'Charlie', 'Hamburg'); + +-- products +INSERT INTO products (product_id, name) VALUES +(1, 'Laptop'), +(2, 'Phone'), +(3, 'Headphones'); + +-- orders +INSERT INTO orders (order_id, customer_id, amount) VALUES +(101, 1, 1200), +(102, 2, 800), +(103, 1, 200), +(104, 3, 150); + +-- order_items +-- Alice bought Laptop, Bob bought Phone, Alice bought Headphones, +-- Charlie bought Headphones, Charlie also bought Phone. +INSERT INTO order_items (order_id, product_id) VALUES +(101, 1), +(102, 2), +(103, 3), +(104, 3), +(104, 2); +``` + +:::{rubric} Query +::: + +Submit a typical query in human language. + +```shell +ctk query nlsql "List all customers with orders over €500." +``` + +:::{rubric} Response +::: + +The model figures out the SQL statement, the engine runs it, and +uses the model again to come back with an answer in human language: + +> The query results show that the customers 'Alice' from Berlin +> and 'Bob' from Munich have placed orders over €500. + +The SQL statement was: +```sql +SELECT customers.name, customers.city +FROM customers JOIN orders ON customers.customer_id = orders.customer_id +WHERE orders.amount > 500; +``` + +## Advanced JOINs and filtering + +:::{rubric} Provision +::: + +Create table and insert data. +Add a few customers in New York and others elsewhere. +Synthesize orders with amounts both above and below the average. + +```sql +INSERT INTO customers (customer_id, name, city) VALUES +(1, 'Alice Johnson', 'New York'), +(2, 'Bob Smith', 'Los Angeles'), +(3, 'Carol Lee', 'New York'), +(4, 'David Brown', 'Chicago'); + +INSERT INTO orders (order_id, customer_id, amount) VALUES +(101, 1, 500), -- NY, high +(102, 1, 150), -- NY, low +(103, 2, 300), -- non-NY +(104, 3, 700), -- NY, high +(105, 4, 200); -- non-NY + +INSERT INTO products (product_id, name) VALUES +(1001, 'Laptop'), +(1002, 'Phone'), +(1003, 'Tablet'), +(1004, 'Headphones'); + +INSERT INTO order_items (order_id, product_id) VALUES +(101, 1001), +(101, 1004), +(102, 1002), +(103, 1003), +(104, 1001), +(104, 1002), +(105, 1004); +``` + +:::{rubric} Query +::: + +Submit a typical query in human language. + +```shell +ctk query nlsql "Get the names of products that were ordered by customers in New York who spent more than the average amount." +``` + +:::{rubric} Response +::: + +The model figures out the SQL statement, the engine runs it, and +uses the model again to come back with a synthesized response +based on the provided SQL query and its result: + +> The query identifies the top 10 product names ordered by customers in New York +> who spent more than the average order amount. +> The results show that "Laptop", "Phone", and "Headphones" were among the most +> popular products purchased by New York customers with high spending. + +The SQL statement was: +```sql +SELECT + p.name FROM products AS p + JOIN order_items AS oi ON p.product_id = oi.product_id + JOIN orders AS o ON oi.order_id = o.order_id + JOIN customers AS c ON o.customer_id = c.customer_id +WHERE + c.city = 'New York' +ORDER BY + o.amount DESC LIMIT 10; +``` + +## JOINs and grouping + +:::{rubric} Provision +::: + +```sql +INSERT INTO customers (customer_id, name, city, email_address, gender_code) VALUES +(1, 'Alice Johnson', 'New York', 'alice@example.com', 'F'), +(2, 'Bob Smith', 'Los Angeles', 'bob@example.com', 'M'), +(3, 'Carol Lee', 'Chicago', 'carol@example.com', 'F'), +(4, 'David Brown', 'Houston', 'david@example.com', 'M'), +(5, 'Eva Green', 'Phoenix', 'eva@example.com', 'F'), +(6, 'Frank Miller', 'Miami', 'frank@example.com', 'M'), +(7, 'Grace Kim', 'Seattle', 'grace@example.com', 'F'), +(8, 'Henry Davis', 'Boston', 'henry@example.com', 'O'); -- least common gender + +INSERT INTO orders (order_id, customer_id, amount) VALUES +(101, 1, 120), +(102, 2, 200), +(103, 3, 150), +(104, 4, 300), +(105, 6, 80); + +INSERT INTO products (product_id, name, price, size) VALUES +(1001, 'T-Shirt', 20, 'M'), +(1002, 'Jeans', 50, 'L'), +(1003, 'Jacket', 80, 'XL'), +(1004, 'Sneakers', 60, '42'), +(1005, 'Hat', 15, 'S'); + +INSERT INTO order_items (order_id, product_id) VALUES +(101, 1001), +(101, 1005), +(102, 1002), +(103, 1003), +(104, 1004), +(105, 1001); +``` + +:::{rubric} Q & A +::: + +- Q: What are the email address and town of the customers who are of the least common gender? + SQL: `SELECT email_address, city FROM customers GROUP BY gender_code ORDER BY count(*) ASC LIMIT 1` +- Q: What are the product price and the product size of the products whose price is above average? + SQL: `SELECT products.price, products.size FROM products WHERE products.price > (SELECT AVG(price) FROM products)` +- Q: Which customers did not make any orders? + SQL: `SELECT c.name FROM customers AS c LEFT JOIN orders AS o ON c.customer_id = o.customer_id WHERE o.order_id IS NULL;` diff --git a/doc/query/nlsql/example-sensor.md b/doc/query/nlsql/example-sensor.md new file mode 100644 index 00000000..e937fd17 --- /dev/null +++ b/doc/query/nlsql/example-sensor.md @@ -0,0 +1,85 @@ +(nlsql-example-sensor)= + +# NLSQL with sensor data + +Let's use a single `time_series_data` database table +and populate it with a few records worth of time series data. + +:::{rubric} Provision +::: + +Create table and insert data. + +```sql +CREATE TABLE IF NOT EXISTS time_series_data ( + timestamp TIMESTAMP, + value DOUBLE, + location STRING, + sensor_id INT +); + +INSERT INTO time_series_data (timestamp, value, location, sensor_id) +VALUES + ('2023-09-14T00:00:00', 10.5, 'Sensor A', 1), + ('2023-09-14T01:00:00', 15.2, 'Sensor A', 1), + ('2023-09-14T02:00:00', 18.9, 'Sensor A', 1), + ('2023-09-14T03:00:00', 12.7, 'Sensor B', 2), + ('2023-09-14T04:00:00', 17.3, 'Sensor B', 2), + ('2023-09-14T05:00:00', 20.1, 'Sensor B', 2), + ('2023-09-14T06:00:00', 22.5, 'Sensor A', 1), + ('2023-09-14T07:00:00', 18.3, 'Sensor A', 1), + ('2023-09-14T08:00:00', 16.8, 'Sensor A', 1), + ('2023-09-14T09:00:00', 14.6, 'Sensor B', 2), + ('2023-09-14T10:00:00', 13.2, 'Sensor B', 2), + ('2023-09-14T11:00:00', 11.7, 'Sensor B', 2); + +REFRESH TABLE time_series_data; +``` + +:::{rubric} Query +::: + +Submit a typical query in human language. + +```shell +ctk query nlsql "What is the average value for sensor 1?" +``` + +:::{rubric} Response +::: + +The model figures out the SQL statement, the engine runs it, and +uses the model again to come back with an answer in human language: +```text +The average value for sensor 1 is approximately 17.03. +``` + +The SQL statement was: +```sql +SELECT AVG(value) FROM time_series_data WHERE sensor_id = 1; +``` + +:::{rubric} Multiple languages +::: + +The NLSQL conversation works well in multiple languages. + +> Q: ¿Cuál es el valor medio del sensor 1? +>
+> A: El valor medio del sensor 1 es 17.0333. + +> Q: Quelle est la valeur moyenne du capteur 1 ? +>
+> A: La valeur moyenne du capteur 1 est de 17,0333. + +> Q: What is the average value for sensor 1? +>
+> A: The average value for sensor 1 is approximately 17.03. + +> Q: Wie lautet der Durchschnittswert für Sensor 1? +>
+> A: Der Durchschnittswert für Sensor 1 beträgt 17,0333. + +> Q: Qual è il valore medio del sensore 1? +>
+> A: Il valore medio del sensore 1 è pari a 17,0333. diff --git a/doc/query/nlsql/example-weather.md b/doc/query/nlsql/example-weather.md new file mode 100644 index 00000000..e0ded75e --- /dev/null +++ b/doc/query/nlsql/example-weather.md @@ -0,0 +1,53 @@ +(nlsql-example-weather)= + +# NLSQL with weather data + +Let's use a basic database including weather observations. + +:::{rubric} Provision +::: + +Create table and insert data. + +```sql +CREATE TABLE weather (zip_code VARCHAR, city VARCHAR, temperature_fahrenheit INTEGER, mean_visibility_miles INTEGER); + +INSERT INTO weather (zip_code, city, temperature_fahrenheit, mean_visibility_miles) VALUES +('10001', 'New York', 85, 8), -- visibility < 10 +('90001', 'Los Angeles', 95, 12), -- temp > 90 +('60601', 'Chicago', 88, 9), -- visibility < 10 +('73301', 'Austin', 102, 15), -- temp > 90 +('94102', 'San Francisco', 65, 7), -- visibility < 10 +('85001', 'Phoenix', 110, 20), -- temp > 90 +('33101', 'Miami', 91, 11); -- temp > 90 +``` + +:::{rubric} Query +::: + +Submit typical queries in human language. + +```shell +ctk query nlsql "Find the zip code where the mean visibility is lower than 10." +ctk query nlsql "Find all cities with temperatures above 90°F." +``` + +:::{rubric} Response +::: + +The model figures out the SQL statements, the engine runs it, and +uses the model again to come back with answers in human language: +```text +The zip codes with a mean visibility of less than 10 miles are 94102, 10001, and 60601. +``` +```text +The cities with temperatures above 90°F are Miami, Austin, Phoenix, and Los Angeles. +``` + +The SQL statements were: +```sql +SELECT zip_code FROM weather WHERE mean_visibility_miles < 10; +``` +```sql +SELECT city FROM weather WHERE temperature_fahrenheit > 90; +``` diff --git a/doc/query/nlsql/index.md b/doc/query/nlsql/index.md new file mode 100644 index 00000000..6a4feb3f --- /dev/null +++ b/doc/query/nlsql/index.md @@ -0,0 +1,425 @@ +(nlsql)= + +# Natural language (NLSQL) + +:::{div} sd-text-muted +Talk to your data in natural language. +::: + +The CrateDB NLSQL package helps agents turn natural language into database queries, +like [Vanna AI] or Google's [QueryData] but tailored to CrateDB. + +## About + +NLSQL provides a straightforward way to turn natural language into executable +SQL by combining an LLM with explicit database context. It positions itself as +an execution layer for data agents: agents handle reasoning and orchestration, +while the NLSQL layer reliably generates, checks, and runs SQL against +databases, returning results for downstream actions. + +The trade-off is explicit: you shift effort from prompt tuning to context +engineering and maintenance, but gain near-100% accuracy, stronger guardrails, +and production reliability—especially for multistep or mission-critical +workflows where probabilistic errors are unacceptable. + +## Install + +```shell +uv pip install --upgrade 'cratedb-toolkit[nlsql]' +``` + +## Synopsis + +```shell +ctk query nlsql \ + --cluster-url="crate://crate@localhost:4200/?ssl=false" \ + --llm-provider="" \ + --llm-name="" \ + --llm-api-key="" \ + "What is the average value for sensor 1?" +``` + +## Coverage + +:::{rubric} Providers +::: + +Supports a range of providers: +Amazon Bedrock (+ Converse), Anthropic, Azure OpenAI, Google AI, +Hugging Face Inference API, llamafile, Mistral, Ollama, OpenAI, +OpenRouter, or Runpod Serverless (OpenAI-compatible). + +:::{rubric} Models +::: + +A range of models can be selected from the providers enumerated above. +We recommend Gemini, Gemma3, Llama 3.1, Qwen 2.5, or later, +for example Gemma-3-1B, Llama-3.2-1B-Instruct, or Qwen3.5-0.8B. + +## Details + +The NLSQL interface works by wrapping a SQL database and exposing a query +interface where plain-language questions are translated into SQL, executed, +and returned as answers. +Developers configure the engine with a database connection and a +bounded set of tables, ensuring the model generates queries only within a +known schema and avoids context overflow. + +The procedure follows a schema-grounded approach: the engine injects table +structure (and optionally examples or retrieved context) into the prompt so +the LLM can synthesize accurate queries instead of guessing. It can also +integrate with retrieval components to dynamically select relevant tables +or augment prompts at query time for more complex setups. + +The engine acts as a thin orchestration layer for Text-to-SQL purposes, +and for building NLSQL systems: +it handles prompt construction, query generation, execution, +and result formatting, while leaving control, safety (e.g., read-only +roles), and schema design to the developer. + +## Security + +Any Text-to-SQL application should be aware that executing +arbitrary SQL queries can be a security risk. It is recommended to +take precautions as needed, such as using restricted roles, read-only +databases, sandboxing, etc. + +While we recommend to use a dedicated read-only user/role to guarantee +100% safety, CrateDB NLSQL also prevents [Prompt-to-SQL Injections] by +default, by classifying the SQL statement and only permitting access +for `SELECT` statements. + +The `permit_all_statements` API argument or the `NLSQL_PERMIT_ALL_STATEMENTS` +environment variable can be used to relax that default when set to a boolean +value, to allow all types of statements. Only enable this flag when you are +sure about this behaviour. + +## Usage + +CrateDB NLSQL provides a command line interface and a Python API. + +### CLI + +When using `ctk query nlsql` on the command line, we recommend to use +environment variables to configure database and LLM connectivity. + +:::{rubric} Configure database +::: + +For connecting to CrateDB on localhost, use a connection string like this: +```shell +export CRATEDB_CLUSTER_URL="crate://crate:crate@localhost:4200/?ssl=false" +``` + +For connecting to CrateDB Cloud, use a connection string like this: +```shell +export CRATEDB_CLUSTER_URL="crate://admin:dZ...6LqB@example.eks1.eu-west-1.aws.cratedb.net:4200/?ssl=true" +``` + +:::{rubric} Configure LLM +::: + +Configure LLM provider, model, and access credentials when applicable. +Available providers are `amazon_bedrock`, `amazon_bedrock_converse`, +`anthropic`, `azure`, `google`, `huggingface_serverless`, `llamafile`, +`mistral`, `ollama`, `openai`, `openrouter`, `runpod_serverless`. +Available models and label formats depend on the provider's conventions. +To authenticate with LLM APIs, use corresponding `*_API_KEY` environment +variables like outlined below. + +```shell +export LLM_PROVIDER="" +export LLM_NAME="" +export PROVIDER_API_KEY="" +``` +Note that `LLM_NAME` is an optional configuration setting: By default, +selecting a provider automatically selects a cost-effective standard +model that is suitable for Text-to-SQL. + +Select and configure the LLM of your choice. + +::::{tab-set} +:::{tab-item} Amazon +Use Amazon Nova on [Amazon Bedrock]. +```shell +export LLM_PROVIDER="amazon_bedrock_converse" +export LLM_NAME="global.amazon.nova-2-lite-v1:0" +``` +::: +:::{tab-item} Anthropic +Use [Anthropic Claude]. +```shell +export LLM_PROVIDER="anthropic" +export LLM_NAME="claude-haiku-4-5" +export ANTHROPIC_API_KEY="" +``` +::: +:::{tab-item} Azure +Use GPT on [Azure OpenAI]. +```shell +export LLM_PROVIDER="azure" +export LLM_NAME="gpt-4.1" +export LLM_INSTANCE="my-gpt4-deployment" +export AZURE_OPENAI_ENDPOINT="https://acme-openai.openai.azure.com/" +export AZURE_OPENAI_API_KEY="" +``` +::: +:::{tab-item} Google +Use [Gemini Flash] from Google. +```shell +export LLM_PROVIDER="google" +export LLM_NAME="gemini-2.5-flash" +export GOOGLE_API_KEY="" +``` +::: +:::{tab-item} Hugging Face Serverless +Use Zephyr on the [Hugging Face Serverless Inference API]. +```shell +export LLM_PROVIDER="huggingface_serverless" +export LLM_NAME="HuggingFaceH4/zephyr-7b-alpha" +export HF_TOKEN="" +``` +::: +:::{tab-item} Mistral +Use models from [Mistral AI]. +```shell +export LLM_PROVIDER="mistral" +export LLM_NAME="mistral-medium-latest" +export MISTRAL_API_KEY="" +``` +::: +:::{tab-item} Ollama +Use [Ollama] to run models on your own machines. + +For connecting to dedicated LLM instances, use the `LLM_ENDPOINT` environment +variable. For example, to connect to a self-managed Ollama instance, configure +those environment variables: +```shell +export LLM_PROVIDER="ollama" +export LLM_ENDPOINT="http://100.83.17.54:11434/" +export LLM_NAME="gemma3:270m" +``` +Before running `ctk query nlsql`, acquire models: +```shell +ollama pull gemma3:270m # 290 MB +ollama pull gemma3:1b # 820 MB +ollama pull llama3.2:1b # 1.3 GB +ollama pull qwen2.5:0.5b # 400 MB +ollama pull qwen3:0.6b # 520 MB +ollama pull hf.co/Menlo/Lucy-128k-gguf:Q4_K_M # 1.1 GB +``` +::: +:::{tab-item} OpenAI +Use [GPT‑4o mini] from [OpenAI]. +```shell +export LLM_PROVIDER="openai" +export LLM_NAME="gpt-4o-mini" +export OPENAI_API_KEY="" +``` +::: +:::{tab-item} OpenRouter +Choose from many models available via [OpenRouter]. +```shell +export LLM_PROVIDER="openrouter" +export LLM_NAME="google/gemma-3-4b-it:free" +export OPENROUTER_API_KEY="" +``` +Alternative model names: +```text +google/gemma-3n-e2b-it:free +google/gemini-2.0-flash-lite-001 +google/gemini-2.5-flash-lite +gryphe/mythomax-l2-13b +ibm-granite/granite-4.0-h-micro +liquid/lfm-2.5-1.2b-instruct:free +meta-llama/llama-3.2-3b-instruct +mistralai/mistral-nemo +mistralai/mistral-small-24b-instruct-2501 +openai/gpt-oss-20b:free +openai/gpt-oss-120b:free +``` +::: +:::{tab-item} Runpod Serverless +Use Gemma3 on [Runpod Serverless]. +```shell +export LLM_PROVIDER="runpod_serverless" +export LLM_ENDPOINT="https://api.runpod.ai/v2//openai/v1" +export LLM_NAME="gemma3:270m" +export RUNPOD_API_KEY="" +``` +::: +:::: + +### API + +A sketch to use NLSQL from Python programs. + +```python +import sqlalchemy as sa +from cratedb_toolkit.query.nlsql.api import DataQuery +from cratedb_toolkit.query.nlsql.model import DatabaseInfo, ModelInfo, ModelProvider + +# Configure database. +# For connecting to CrateDB on localhost, use a connection string like this: +engine = sa.create_engine("crate://crate:crate@localhost:4200/?ssl=false") +# For connecting to CrateDB Cloud, use a connection string like this: +# engine = sa.create_engine("crate://admin:dZ...6LqB@example.eks1.eu-west-1.aws.cratedb.net:4200/?ssl=true") +schema = "doc" + +# Configure an LLM-based query engine. +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.ACME, name="foo-frontier-7.1"), +) + +# Query database. +response = dataquery.ask("What is the average value for sensor 1?") +print(response) +``` + +Select and configure the LLM of your choice. + +::::{tab-set} +:::{tab-item} Amazon +Use Amazon Nova on [Amazon Bedrock]. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.AMAZON_BEDROCK_CONVERSE, name="global.amazon.nova-2-lite-v1:0"), +) +``` +::: +:::{tab-item} Anthropic +Use [Anthropic Claude] Sonnet. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo( + provider=ModelProvider.ANTHROPIC, + name="claude-sonnet-4-0", + api_key="", + ), +) +``` +::: +:::{tab-item} Azure +Use GPT on [Azure OpenAI]. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo( + provider=ModelProvider.AZURE, + name="gpt-4.1", + instance="my-gpt4-deployment", + endpoint="https://acme-openai.openai.azure.com/", + api_key="", + ), +) +``` +::: +:::{tab-item} Google +Use [Gemini Flash] from Google. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.GOOGLE, name="gemini-2.5-flash"), +) +``` +::: +:::{tab-item} Hugging Face Serverless +Use Zephyr on the [Hugging Face Serverless Inference API]. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.HUGGINGFACE_SERVERLESS, name="HuggingFaceH4/zephyr-7b-alpha"), +) +``` +::: +:::{tab-item} Mistral +Use models from [Mistral AI]. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.MISTRAL, name="mistral-medium-latest"), +) +``` +::: +:::{tab-item} Ollama +Use [Ollama] to run models on your own machines, for example Gemma3. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.OLLAMA, name="gemma3:1b"), +) +``` +::: +:::{tab-item} OpenAI +Use [GPT‑4o mini] from [OpenAI]. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.OPENAI, name="gpt-4o-mini"), +) +``` +::: +:::{tab-item} OpenRouter +Choose from many models available via [OpenRouter], for example Gemma3. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo(provider=ModelProvider.OPENROUTER, name="google/gemma-3-4b-it:free"), +) +``` +::: +:::{tab-item} Runpod Serverless +Use Gemma3 on [Runpod Serverless]. +```python +dataquery = DataQuery( + db=DatabaseInfo(engine=engine, schema=schema), + model=ModelInfo( + provider=ModelProvider.RUNPOD_SERVERLESS, + name="gemma3:270m", + endpoint="https://api.runpod.ai/v2//openai/v1", + api_key="", + ), +) +``` +::: +:::: + +## Examples + +The {ref}`nlsql-example-sensor` demonstrates a basic database inquiry +using the question »What is the average value for sensor 1?« to acquire +information from a single table. + +{ref}`nlsql-example-employee`, {ref}`nlsql-example-product`, and +{ref}`nlsql-example-weather` explore and demonstrate other kinds +of query variants. + + +```{toctree} +:maxdepth: 1 +:hidden: + +Employee data example +Product orders example +Sensor data example +Weather data example +``` + + +[Amazon Bedrock]: https://docs.aws.amazon.com/bedrock/ +[Anthropic Claude]: https://platform.claude.com/docs/ +[Azure OpenAI]: https://azure.microsoft.com/en-gb/pricing/details/azure-openai/ +[Gemini Flash]: https://deepmind.google/models/gemini/flash/ +[GPT‑4o mini]: https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/ +[Hugging Face Serverless Inference API]: https://huggingface.co/learn/cookbook/enterprise_hub_serverless_inference_api +[Mistral AI]: https://mistral.ai/ +[Ollama]: https://ollama.com/ +[OpenAI]: https://openai.com/ +[OpenRouter]: https://openrouter.ai/ +[Prompt-to-SQL Injections]: https://syssec.dpss.inesc-id.pt/papers/pedro_icse25.pdf +[QueryData]: https://cloud.google.com/blog/products/databases/introducing-querydata-for-near-100-percent-accurate-data-agents +[Runpod Serverless]: https://www.runpod.io/product/serverless +[Vanna AI]: https://vanna.ai/ diff --git a/pyproject.toml b/pyproject.toml index c5e3f0cb..d9e553f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ optional-dependencies.develop = [ optional-dependencies.docs = [ "furo", "myst-parser[linkify]>=0.18,<6", + "roman-numerals-py<4", "sphinx<8", "sphinx-autobuild", "sphinx-copybutton", @@ -163,7 +164,7 @@ optional-dependencies.dynamodb = [ "cratedb-toolkit[io-recipe]", ] optional-dependencies.full = [ - "cratedb-toolkit[cfr,datasets,docs-api,io-curated,mcp,service]", + "cratedb-toolkit[cfr,datasets,docs-api,io-curated,mcp,nlsql,service]", ] optional-dependencies.iceberg = [ "cratedb-toolkit[io]", @@ -222,6 +223,20 @@ optional-dependencies.mongodb = [ "rich>=3.3.2,<16", "undatum<1.2", ] +optional-dependencies.nlsql = [ + "llama-index-llms-anthropic<0.12; python_version>='3.10'", + "llama-index-llms-azure-openai<0.6; python_version>='3.10'", + "llama-index-llms-bedrock<0.6; python_version>='3.10'", + "llama-index-llms-bedrock-converse<0.15; python_version>='3.10'", + "llama-index-llms-google-genai<0.10; python_version>='3.10'", + "llama-index-llms-huggingface-api<0.8; python_version>='3.10'", + "llama-index-llms-llamafile<0.6; python_version>='3.10'", + "llama-index-llms-mistralai<0.11; python_version>='3.10'", + "llama-index-llms-ollama<0.11; python_version>='3.10'", + "llama-index-llms-openai<0.8; python_version>='3.10'", + "llama-index-llms-openai-like<0.8; python_version>='3.10'", + "llama-index-llms-openrouter<0.6; python_version>='3.10'", +] optional-dependencies.pymongo = [ "jessiql==1.0.0rc1", "numpy<2", @@ -353,6 +368,7 @@ analysis.allowed-unresolved-imports = [ "jessiql.**", "kaggle.**", "kinesis.**", + "llama_index.**", "lorrystream.**", "mcp.**", "pymongo.**", @@ -380,6 +396,7 @@ ini_options.markers = [ "influxdb", "kinesis", "mongodb", + "nlsql", "postgresql", "pymongo", "python", diff --git a/tests/query/test_convert.py b/tests/query/test_convert.py index 00791a5f..d8a4bcf5 100644 --- a/tests/query/test_convert.py +++ b/tests/query/test_convert.py @@ -1,6 +1,6 @@ from click.testing import CliRunner -from cratedb_toolkit.query.cli import cli +from cratedb_toolkit.query.convert.cli import convert_query as cli def test_query_convert_ddb_relocate_pks(): @@ -12,8 +12,8 @@ def test_query_convert_ddb_relocate_pks(): result = runner.invoke( cli, input="SELECT * FROM foobar WHERE data['PK']", - args="convert --type=ddb-relocate-pks --primary-keys=PK,SK -", + args="--type=ddb-relocate-pks --primary-keys=PK,SK -", catch_exceptions=False, ) - assert result.exit_code == 0 + assert result.exit_code == 0, result.output assert result.output == "SELECT * FROM foobar WHERE pk['PK']" diff --git a/tests/query/test_nlsql.py b/tests/query/test_nlsql.py new file mode 100644 index 00000000..9060fe50 --- /dev/null +++ b/tests/query/test_nlsql.py @@ -0,0 +1,242 @@ +import json +import os +import sys + +import pytest +from click.testing import CliRunner + +from cratedb_toolkit.query.cli import cli + +TESTDRIVE_DATA_SCHEMA = "testdrive" + + +pytestmark = pytest.mark.nlsql + +if sys.version_info < (3, 10): + pytest.skip("Only available for Python 3.10+", allow_module_level=True) # ty: ignore[invalid-argument-type,too-many-positional-arguments] + + +@pytest.fixture(scope="session", autouse=True) +def reset_environment(): + """ + Reset environment variables. + """ + envvars = ["NLSQL_PERMIT_ALL_STATEMENTS"] + for envvar in envvars: + os.environ.pop(envvar, None) + + +@pytest.fixture +def provision_db(cratedb): + sql_ddl = f""" +CREATE TABLE "{TESTDRIVE_DATA_SCHEMA}".time_series_data ( + timestamp TIMESTAMP, + value DOUBLE, + location STRING, + sensor_id INT +); +""" + sql_dml = f""" +INSERT INTO "{TESTDRIVE_DATA_SCHEMA}".time_series_data (timestamp, value, location, sensor_id) +VALUES + ('2023-09-14T00:00:00', 10.5, 'Sensor A', 1), + ('2023-09-14T01:00:00', 15.2, 'Sensor A', 1), + ('2023-09-14T02:00:00', 18.9, 'Sensor A', 1), + ('2023-09-14T03:00:00', 12.7, 'Sensor B', 2), + ('2023-09-14T04:00:00', 17.3, 'Sensor B', 2), + ('2023-09-14T05:00:00', 20.1, 'Sensor B', 2), + ('2023-09-14T06:00:00', 22.5, 'Sensor A', 1), + ('2023-09-14T07:00:00', 18.3, 'Sensor A', 1), + ('2023-09-14T08:00:00', 16.8, 'Sensor A', 1), + ('2023-09-14T09:00:00', 14.6, 'Sensor B', 2), + ('2023-09-14T10:00:00', 13.2, 'Sensor B', 2), + ('2023-09-14T11:00:00', 11.7, 'Sensor B', 2); +""" # noqa: S608 + sql_refresh = f""" +REFRESH TABLE "{TESTDRIVE_DATA_SCHEMA}".time_series_data; +""" + cratedb.database.run_sql(sql_ddl) + cratedb.database.run_sql(sql_dml) + cratedb.database.run_sql(sql_refresh) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not defined") +def test_query_nlsql_openai(cratedb, provision_db): + """ + Verify `ctk query nlsql ...` with GPT‑4o mini by OpenAI. + https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/ + """ + + runner = CliRunner( + env={ + "CRATEDB_CLUSTER_URL": cratedb.get_connection_url(), + "CRATEDB_SCHEMA": TESTDRIVE_DATA_SCHEMA, + "LLM_PROVIDER": "openai", + "LLM_NAME": "gpt-4o-mini", + } + ) + + result = runner.invoke( + cli, + input="What is the average value for sensor 1?", + args="nlsql -", + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + output = json.loads(result.output) + assert output["answer"] == "The average value for sensor 1 is approximately 17.03." + assert output["sql_query"] in [ + "SELECT AVG(time_series_data.value) AS average_value " + "FROM time_series_data WHERE time_series_data.sensor_id = 1;", + "SELECT AVG(value) AS average_value FROM time_series_data WHERE sensor_id = 1;", + "SELECT AVG(value) AS average_value FROM time_series_data WHERE sensor_id = 1", + ] + + +@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not defined") +def test_query_nlsql_anthropic(cratedb, provision_db): + """ + Verify `ctk query nlsql ...` with Claude Haiku 4.5 by Anthropic. + https://www.anthropic.com/claude/haiku + """ + + runner = CliRunner( + env={ + "CRATEDB_CLUSTER_URL": cratedb.get_connection_url(), + "CRATEDB_SCHEMA": TESTDRIVE_DATA_SCHEMA, + "LLM_PROVIDER": "anthropic", + "LLM_NAME": "claude-haiku-4-5", + } + ) + + result = runner.invoke( + cli, + input="What is the average value for sensor 1?", + args="nlsql -", + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + output = json.loads(result.output) + assert "The average value for sensor 1 is **17.03**" in output["answer"] + assert output["sql_query"] == "SELECT AVG(value) as average_value FROM time_series_data WHERE sensor_id = 1" + + +@pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not defined") +def test_query_nlsql_openrouter_success(cratedb, provision_db): + """ + Verify a successful NLSQL conversation with MythoMax via OpenRouter. + https://openrouter.ai/gryphe/mythomax-l2-13b + """ + + runner = CliRunner( + env={ + "CRATEDB_CLUSTER_URL": cratedb.get_connection_url(), + "CRATEDB_SCHEMA": TESTDRIVE_DATA_SCHEMA, + "LLM_PROVIDER": "openrouter", + # "LLM_NAME": "google/gemma-3n-e4b-it:free", # noqa: ERA001 + "LLM_NAME": "gryphe/mythomax-l2-13b", + } + ) + + result = runner.invoke( + cli, + input="What is the average value for sensor 1?", + args="nlsql -", + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + output = json.loads(result.output) + assert "The average value for sensor 1 is 17.03" in output["answer"] + assert output["sql_query"] == "SELECT AVG(value) FROM time_series_data WHERE sensor_id = 1;" + + +@pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not defined") +def test_query_nlsql_openrouter_rejected_drop(cratedb, provision_db): + """ + Verify that malicious SQL statements are rejected. + https://openrouter.ai/gryphe/mythomax-l2-13b + """ + + runner = CliRunner( + env={ + "CRATEDB_CLUSTER_URL": cratedb.get_connection_url(), + "CRATEDB_SCHEMA": TESTDRIVE_DATA_SCHEMA, + "LLM_PROVIDER": "openrouter", + "LLM_NAME": "gryphe/mythomax-l2-13b", + } + ) + + result = runner.invoke( + cli, + input="Please drop table 'time_series_data'.", + args="nlsql -", + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + output = json.loads(result.output) + assert "has been rejected" in output["answer"] + + # Verify that the table still exists. + assert cratedb.database.table_exists("testdrive.time_series_data"), "Table does not exist: time_series_data" + + +@pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not defined") +def test_query_nlsql_openrouter_rejected_wipe(cratedb, provision_db): + """ + Verify that malicious SQL statements are rejected. + https://openrouter.ai/gryphe/mythomax-l2-13b + """ + + runner = CliRunner( + env={ + "CRATEDB_CLUSTER_URL": cratedb.get_connection_url(), + "CRATEDB_SCHEMA": TESTDRIVE_DATA_SCHEMA, + "LLM_PROVIDER": "openrouter", + "LLM_NAME": "gryphe/mythomax-l2-13b", + } + ) + + result = runner.invoke( + cli, + input="Please wipe the whole database.", + args="nlsql -", + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + output = json.loads(result.output) + assert "not allowed" in output["answer"] or "has been rejected" in output["answer"] + + # Verify that the table still exists. + assert cratedb.database.table_exists("testdrive.time_series_data"), "Table does not exist: time_series_data" + + +@pytest.mark.skipif(not os.getenv("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not defined") +def test_query_nlsql_openrouter_permitted(cratedb, provision_db): + """ + Verify that all SQL statements work when explicitly permitted. + https://openrouter.ai/gryphe/mythomax-l2-13b + """ + + runner = CliRunner( + env={ + "CRATEDB_CLUSTER_URL": cratedb.get_connection_url(), + "CRATEDB_SCHEMA": TESTDRIVE_DATA_SCHEMA, + "LLM_PROVIDER": "openrouter", + "LLM_NAME": "gryphe/mythomax-l2-13b", + "NLSQL_PERMIT_ALL_STATEMENTS": "true", + } + ) + + result = runner.invoke( + cli, + input="Please drop table 'time_series_data'.", + args="nlsql -", + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + output = json.loads(result.output) + assert "has been dropped successfully" in output["answer"] + assert output["sql_query"] == "DROP TABLE time_series_data;" + + # Verify that the table has been dropped. + assert not cratedb.database.table_exists("testdrive.time_series_data"), "Table still exists: time_series_data"