From 33957b8cb0cc3f0466188745d174bbffa62b5b84 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 16 Apr 2026 13:01:35 -0800 Subject: [PATCH 01/12] Infra: us-west-2 region, reserved Lambda concurrency, cp311 packaging pin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the deployment to us-west-2, add reserved Lambda concurrency as the primary brake on fan-out into the upstream CKAN portal, and pin Lambda packaging to cp311/manylinux wheels so the ZIP works regardless of build host Python version. - terraform/aws: add lambda_reserved_concurrency (default 10) wired to aws_lambda_function.reserved_concurrent_executions. Extract the S3 backend config out of main.tf; real backend.tf is gitignored because the bucket name embeds the deployer's AWS account ID. Ship backend.tf.example as the template. - prod/staging tfvars: aws_region=us-west-2, api_quota_limit=3000 (was 1000), lambda_reserved_concurrency=10. Prod custom domain is boston-data.codeforanchorage.org; staging has no custom domain. - scripts/deploy.sh + .github/workflows/release.yml: force cp311 manylinux wheel resolution on every pip/uv install (without this, a Python 3.14 build host produces a ZIP that 502s at Lambda cold start). Detect python3/python cross-platform. Build the ZIP with stdlib zipfile instead of the `zip` binary so the packaging step works on CI images and Windows. - scripts/setup-backend.sh: fix malformed bucket name (boston-opencontext-opendataterraform-state-... → boston-opencontext- tfstate-...). - config.yaml: replace symlink-to-example with a concrete Boston CKAN config targeting data.boston.gov. ArcGIS kept disabled for reference. - local_server.py: accept POSTs on both / and /mcp so the same local server works with Claude Desktop stdio bridges and MCP Inspector. Co-Authored-By: Claude Opus 4.7 --- .github/workflows/release.yml | 26 ++++++++++++- .gitignore | 4 ++ config.yaml | 29 ++++++++++++++- local_server.py | 1 + scripts/deploy.sh | 63 ++++++++++++++++++++++---------- scripts/setup-backend.sh | 2 +- terraform/aws/backend.tf.example | 18 +++++++++ terraform/aws/main.tf | 9 +---- terraform/aws/prod.tfvars | 7 ++-- terraform/aws/staging.tfvars | 9 +++-- terraform/aws/variables.tf | 6 +++ 11 files changed, 137 insertions(+), 37 deletions(-) create mode 100644 terraform/aws/backend.tf.example diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aa9d4f9..ca6ef03 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -112,7 +112,16 @@ jobs: echo "Version: $VERSION" - name: Install Python dependencies into package directory - run: uv pip install --target ./package -r requirements.txt + # Pin --python-version and --python-platform so wheel resolution is + # independent of the runner's ambient interpreter and OS. Matches + # the uv path in scripts/deploy.sh and the python3.11 runtime pinned + # in terraform/aws/main.tf. + run: | + uv pip install -r requirements.txt \ + --target ./package \ + --python-platform x86_64-manylinux2014 \ + --python-version 3.11 \ + --no-compile - name: Copy application code into package run: | @@ -122,9 +131,22 @@ jobs: cp examples/boston-opendata/config.yaml package/config.yaml - name: Create Lambda ZIP + # Use Python stdlib zipfile to match scripts/deploy.sh and + # .github/workflows/infra.yml — avoids depending on the `zip` binary. run: | cd package - zip -r ../opencontext-lambda-${{ steps.get_version.outputs.version }}.zip . + python - "../opencontext-lambda-${{ steps.get_version.outputs.version }}.zip" <<'PY' + import os + import sys + import zipfile + + zip_path = sys.argv[1] + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: + for root, _, files in os.walk("."): + for name in files: + path = os.path.join(root, name) + z.write(path, os.path.relpath(path, ".")) + PY - name: Upload Lambda ZIP artifact uses: actions/upload-artifact@v4 diff --git a/.gitignore b/.gitignore index ae0d18c..e65198f 100644 --- a/.gitignore +++ b/.gitignore @@ -217,6 +217,10 @@ terraform/*/lambda-deployment.zip **/*.tfstate **/*.tfstate.* +# Terraform backend config — contains the deployer's AWS account ID in the +# S3 bucket name. Each fork ships its own. See terraform/aws/backend.tf.example. +terraform/aws/backend.tf + # OpenContext client binaries opencontext-client opencontext-client-* diff --git a/config.yaml b/config.yaml index 407b7fd..c644758 120000 --- a/config.yaml +++ b/config.yaml @@ -1 +1,28 @@ -examples/dc-arcgis/config.yaml \ No newline at end of file +--- +server_name: "Boston OpenData MCP" +description: "City of Boston open data MCP server - Safe, conversational access to Boston's open data" +organization: "City of Boston" + +plugins: + ckan: + enabled: true + base_url: "https://data.boston.gov/" + portal_url: "https://data.boston.gov/" + city_name: "Boston" + timeout: 120 + + arcgis: + enabled: false + portal_url: "https://data-boston.hub.arcgis.com" + city_name: "Boston" + timeout: 120 + +aws: + region: "us-west-2" + lambda_name: "boston-ckan-mcp-staging" + lambda_memory: 512 + lambda_timeout: 120 + +logging: + level: "INFO" + format: "json" diff --git a/local_server.py b/local_server.py index 0ed0810..b0e3afe 100644 --- a/local_server.py +++ b/local_server.py @@ -73,6 +73,7 @@ async def start_server(): app = web.Application() app.router.add_post("/", handle_mcp_request) + app.router.add_post("/mcp", handle_mcp_request) runner = web.AppRunner(app) await runner.setup() diff --git a/scripts/deploy.sh b/scripts/deploy.sh index 79ef6db..9bcf65e 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -86,9 +86,13 @@ if [ ! -f "config.yaml" ]; then exit 1 fi -# Check if Python is available -if ! command -v python3 &> /dev/null; then - echo -e "${RED}❌ Error: python3 not found${NC}" +# Check if Python is available (python3 on Linux/macOS, python on Windows) +if command -v python3 &> /dev/null; then + PYTHON=python3 +elif command -v python &> /dev/null; then + PYTHON=python +else + echo -e "${RED}❌ Error: python not found${NC}" echo "Please install Python 3.11 or later." exit 1 fi @@ -103,7 +107,7 @@ fi echo -e "${YELLOW}📋 Step 1: Validating configuration...${NC}" # Count enabled plugins using Python for reliable YAML parsing -ENABLED_COUNT=$(python3 << 'EOF' +ENABLED_COUNT=$($PYTHON << 'EOF' import yaml import sys @@ -153,7 +157,7 @@ if [ $EXIT_CODE -eq 1 ]; then echo "See docs/GETTING_STARTED.md for setup instructions." exit 1 elif [ $EXIT_CODE -eq 2 ]; then - ENABLED_PLUGINS=$(python3 << 'EOF' + ENABLED_PLUGINS=$($PYTHON << 'EOF' import yaml with open('config.yaml', 'r') as f: config = yaml.safe_load(f) @@ -197,7 +201,7 @@ elif [ $EXIT_CODE -ne 0 ]; then exit 1 fi -ENABLED_PLUGIN=$(python3 << 'EOF' +ENABLED_PLUGIN=$($PYTHON << 'EOF' import yaml with open('config.yaml', 'r') as f: config = yaml.safe_load(f) @@ -212,7 +216,7 @@ echo -e "${GREEN}✓ Configuration valid: ${ENABLED_PLUGIN} plugin enabled${NC}" echo "" # Extract server name and AWS settings -SERVER_NAME=$(python3 << 'EOF' +SERVER_NAME=$($PYTHON << 'EOF' import yaml with open('config.yaml', 'r') as f: config = yaml.safe_load(f) @@ -220,7 +224,7 @@ print(config.get('server_name', 'my-mcp-server')) EOF ) -AWS_REGION=$(python3 << 'EOF' +AWS_REGION=$($PYTHON << 'EOF' import yaml with open('config.yaml', 'r') as f: config = yaml.safe_load(f) @@ -261,21 +265,42 @@ if command -v uv &> /dev/null; then fi else echo "uv not found, falling back to pip..." - if ! pip install -r requirements.txt -t "$PACKAGE_DIR/" --platform manylinux2014_x86_64 --only-binary :all: --no-compile --no-deps 2>/dev/null; then - echo "Platform-specific install failed, trying generic install..." - if ! pip install -r requirements.txt -t "$PACKAGE_DIR/" --no-compile 2>/dev/null; then - echo -e "${RED}❌ Error: Failed to install dependencies${NC}" - echo "Please ensure pip is available and requirements.txt is valid." - exit 1 - fi + # Lambda runtime is pinned to python3.11 in terraform/aws/main.tf, so we + # must force pip to resolve cp311 wheels regardless of the host Python + # version. Without these flags pip picks wheels for the ambient interpreter + # (e.g. cp314 on a Python 3.14 build host), which then fail to import at + # Lambda cold start with a 502 InternalServerErrorException. + if ! pip install -r requirements.txt \ + -t "$PACKAGE_DIR/" \ + --platform manylinux2014_x86_64 \ + --python-version 3.11 \ + --implementation cp \ + --abi cp311 \ + --only-binary :all: \ + --no-compile; then + echo -e "${RED}❌ Error: Failed to install dependencies for python3.11${NC}" + echo "Ensure pip >= 22 is available and every requirement has a cp311 manylinux wheel." + exit 1 fi fi -# Create zip file +# Create zip file using Python's stdlib zipfile — matches the convention in +# .github/workflows/infra.yml and avoids depending on the `zip` binary, which +# isn't present in every build environment (notably the staging CI image). ZIP_FILE="lambda-deployment.zip" -cd "$PACKAGE_DIR" -zip -r "../$ZIP_FILE" . > /dev/null -cd .. +(cd "$PACKAGE_DIR" && "$PYTHON" - "../$ZIP_FILE" <<'PY' +import os +import sys +import zipfile + +zip_path = sys.argv[1] +with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: + for root, _, files in os.walk("."): + for name in files: + path = os.path.join(root, name) + z.write(path, os.path.relpath(path, ".")) +PY +) echo -e "${GREEN}✓ Lambda package created: $ZIP_FILE${NC}" echo "" diff --git a/scripts/setup-backend.sh b/scripts/setup-backend.sh index d1015a4..0eb2c49 100755 --- a/scripts/setup-backend.sh +++ b/scripts/setup-backend.sh @@ -14,7 +14,7 @@ if [ -z "$AWS_ACCOUNT_ID" ]; then fi # Generate bucket name -BUCKET_NAME="boston-opencontext-opendataterraform-state-${AWS_ACCOUNT_ID}-${AWS_REGION}" +BUCKET_NAME="boston-opencontext-tfstate-${AWS_ACCOUNT_ID}-${AWS_REGION}" TABLE_NAME="terraform-state-lock" echo "AWS Account ID: $AWS_ACCOUNT_ID" diff --git a/terraform/aws/backend.tf.example b/terraform/aws/backend.tf.example new file mode 100644 index 0000000..78a7fe9 --- /dev/null +++ b/terraform/aws/backend.tf.example @@ -0,0 +1,18 @@ +# Copy this file to `backend.tf` and replace with your AWS +# account ID. Terraform does not allow variable interpolation inside the +# backend block (it's evaluated before variables are resolved), so the value +# has to be hardcoded. `backend.tf` itself is gitignored — each fork ships +# its own. +# +# Run `scripts/setup-backend.sh` first to create the S3 state bucket and +# DynamoDB lock table, then `terraform init` against `terraform/aws/`. + +terraform { + backend "s3" { + bucket = "boston-opencontext-tfstate--us-west-2" + key = "terraform.tfstate" + region = "us-west-2" + dynamodb_table = "terraform-state-lock" + encrypt = true + } +} diff --git a/terraform/aws/main.tf b/terraform/aws/main.tf index 0fa555e..2093d6a 100644 --- a/terraform/aws/main.tf +++ b/terraform/aws/main.tf @@ -1,13 +1,6 @@ terraform { required_version = ">= 1.0" - backend "s3" { - bucket = "opencontext-terraform-state" - key = "opencontext/terraform.tfstate" - region = "us-east-1" - encrypt = true - } - required_providers { aws = { source = "hashicorp/aws" @@ -81,6 +74,8 @@ resource "aws_lambda_function" "mcp_server" { memory_size = local.lambda_memory timeout = local.lambda_timeout + reserved_concurrent_executions = var.lambda_reserved_concurrency + environment { variables = { OPENCONTEXT_CONFIG = local.config_json diff --git a/terraform/aws/prod.tfvars b/terraform/aws/prod.tfvars index 5a7d340..3178b40 100644 --- a/terraform/aws/prod.tfvars +++ b/terraform/aws/prod.tfvars @@ -1,10 +1,11 @@ lambda_name = "boston-opencontext-mcp-prod" stage_name = "prod" -aws_region = "us-east-1" +aws_region = "us-west-2" config_file = "config.yaml" lambda_memory = 512 lambda_timeout = 120 -api_quota_limit = 1000 +lambda_reserved_concurrency = 10 +api_quota_limit = 3000 api_rate_limit = 5 api_burst_limit = 10 -custom_domain = "data-mcp.boston.gov" +custom_domain = "boston-data.codeforanchorage.org" diff --git a/terraform/aws/staging.tfvars b/terraform/aws/staging.tfvars index 6802a87..b104e82 100644 --- a/terraform/aws/staging.tfvars +++ b/terraform/aws/staging.tfvars @@ -1,10 +1,11 @@ -lambda_name = "boston-opencontext-mcp-staging" +lambda_name = "boston-ckan-mcp-staging" stage_name = "staging" -aws_region = "us-east-1" +aws_region = "us-west-2" config_file = "config.yaml" lambda_memory = 512 lambda_timeout = 120 -api_quota_limit = 1000 +lambda_reserved_concurrency = 10 +api_quota_limit = 3000 api_rate_limit = 5 api_burst_limit = 10 -custom_domain = "data-mcp-staging.boston.gov" +custom_domain = "" diff --git a/terraform/aws/variables.tf b/terraform/aws/variables.tf index 49cc07a..66befc4 100644 --- a/terraform/aws/variables.tf +++ b/terraform/aws/variables.tf @@ -28,6 +28,12 @@ variable "lambda_timeout" { default = 120 } +variable "lambda_reserved_concurrency" { + description = "Maximum concurrent Lambda invocations. Caps how hard a single abusive client can fan out into the upstream open-data portal. Set to -1 to disable the limit (use AWS account-wide concurrency)." + type = number + default = 10 +} + variable "api_quota_limit" { description = "API Gateway daily request quota" type = number From d0a18cda67df0b93c5a50d36e48d61da41564d25 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 16 Apr 2026 13:03:35 -0800 Subject: [PATCH 02/12] Security: harden SQL validator, add SafeSQLBuilder, cap request body size MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten the two attack surfaces that directly forward user-controlled input into upstream CKAN: the execute_sql path and the aggregate_data path. Add a request-body-size cap on the HTTP handler to bound the work a single JSON-RPC call can cost. See docs/SECURITY.md for the full threat model. - plugins/ckan/sql_validator.py * SQLValidator: shrink MAX_SQL_LENGTH 50000 → 8192; strip SQL comments before keyword/function scans so /* ... */ and -- ... obfuscation can't smuggle forbidden tokens past the checks; expand FORBIDDEN_KEYWORDS with PREPARE/COPY/LISTEN/NOTIFY/VACUUM/ ANALYZE/CLUSTER/REINDEX/LOAD/DO; add FORBIDDEN_FUNCTIONS (xp_cmdshell, pg_sleep, pg_read_file, pg_ls_dir, pg_stat_file, lo_import, lo_export, current_setting, set_config, dblink); walk the sqlparse AST to require every FROM/JOIN target to be a UUID-quoted resource or a CTE alias (rejects schema-qualified targets like pg_catalog.pg_class); match INTO OUTFILE/DUMPFILE. * New enforce_row_limit: appends LIMIT 10000 to any validated SQL that lacks a top-level LIMIT so a caller can't trigger an unbounded scan on a multi-million-row CKAN DataStore table. * New SafeSQLBuilder: typed, allowlist-only builder for the aggregate_data path. Identifiers must match [A-Za-z_]\w*, metric expressions must be count(*) or {count|sum|avg|min|max|stddev} ([DISTINCT] ), filter values coerced per type with ' escaping, order_by parsed and quoted, limit clamped to 10000, HAVING values must be numeric. - plugins/ckan/plugin.py: route aggregate_data through SafeSQLBuilder (was string concatenation); call SQLValidator.enforce_row_limit after validate_query. - server/http_handler.py: reject JSON-RPC bodies > 65 KB with HTTP 413 before parsing. The MCP surface fits in a few KB; a megabyte payload is either a bug or abuse. - tests: cover body-size cap at and over the boundary, each new forbidden keyword/function, comment obfuscation, schema-qualified FROM rejection, enforce_row_limit behavior, and every SafeSQLBuilder method. Co-Authored-By: Claude Opus 4.7 --- plugins/ckan/plugin.py | 121 +++++----- plugins/ckan/sql_validator.py | 397 +++++++++++++++++++++++++++++---- server/http_handler.py | 39 ++++ tests/test_ckan_plugin.py | 21 +- tests/test_http_handler.py | 67 +++++- tests/test_sql_validator.py | 409 ++++++++++++++++++++++++++++++---- 6 files changed, 903 insertions(+), 151 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index 8870961..906546e 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -16,7 +16,7 @@ from core.interfaces import DataPlugin, PluginType, ToolDefinition, ToolResult from plugins.ckan.config_schema import CKANPluginConfig -from plugins.ckan.sql_validator import SQLValidator +from plugins.ckan.sql_validator import SafeSQLBuilder, SQLValidator logger = logging.getLogger(__name__) @@ -491,12 +491,9 @@ async def query_data( Returns: List of data records """ - params = {"resource_id": resource_id, "limit": limit} - - # Convert filters to CKAN filter format + params: Dict[str, Any] = {"resource_id": resource_id, "limit": limit} if filters: - for field, value in filters.items(): - params[f"filters[{field}]"] = value + params["filters"] = filters response = await self._call_ckan_api("datastore_search", params) return response.get("result", {}).get("records", []) @@ -530,6 +527,9 @@ async def execute_sql(self, sql: str) -> Dict[str, Any]: if not is_valid: return {"error": True, "message": error} + # Bound upstream scan cost: append LIMIT if the caller didn't set one. + sql = SQLValidator.enforce_row_limit(sql) + # Log SQL execution (truncated for security) logger.info("Executing SQL", extra={"sql": sql[:500]}) @@ -562,56 +562,69 @@ async def aggregate_data( ) -> Dict[str, Any]: """Aggregate data with GROUP BY. - Args: - resource_id: Resource ID (must be valid UUID) - group_by: List of fields to group by - metrics: Dictionary of metric_name: sql_expression (e.g., {"count": "count(*)"}) - filters: Optional WHERE clause filters (field: value pairs) - having: Optional HAVING clause filters (expression: value pairs) - order_by: Optional field to order by - limit: Maximum number of results - - Returns: - Dictionary with success flag, records, fields, or error message + Every identifier, metric expression, filter value, and LIMIT is + validated against a strict allowlist via ``SafeSQLBuilder`` before + the SQL is assembled, so caller-supplied strings cannot escape into + the generated query. """ - # SELECT - select_fields = ", ".join(group_by) if group_by else "" - select_metrics = ", ".join( - [f"{expr} as {name}" for name, expr in metrics.items()] - ) - select_clause = ( - f"{select_fields}, {select_metrics}" if select_fields else select_metrics - ) + try: + resource_id = SafeSQLBuilder.validate_resource_id(resource_id) + if not metrics: + raise ValueError("metrics must be non-empty") + + group_by_quoted = [ + SafeSQLBuilder.quote_identifier(f) for f in (group_by or []) + ] + + metric_parts: List[str] = [] + for alias, expr in metrics.items(): + alias_quoted = SafeSQLBuilder.quote_identifier(alias) + expr_quoted = SafeSQLBuilder.validate_metric_expr(expr) + metric_parts.append(f"{expr_quoted} AS {alias_quoted}") + + select_clause = ", ".join(group_by_quoted + metric_parts) + + where_clause = "" + if filters: + conditions = [ + SafeSQLBuilder.build_filter_condition(f, v) + for f, v in filters.items() + ] + where_clause = " WHERE " + " AND ".join(conditions) + + group_clause = "" + if group_by_quoted: + group_clause = " GROUP BY " + ", ".join(group_by_quoted) + + having_clause = "" + if having: + having_parts: List[str] = [] + for expr, value in having.items(): + expr_quoted = SafeSQLBuilder.validate_metric_expr(expr) + if isinstance(value, bool) or not isinstance( + value, (int, float) + ): + raise ValueError( + f"HAVING value must be numeric: {value!r}" + ) + having_parts.append(f"{expr_quoted} > {value}") + having_clause = " HAVING " + " AND ".join(having_parts) + + order_clause = "" + if order_by: + order_clause = " ORDER BY " + SafeSQLBuilder.validate_order_by( + order_by + ) - # WHERE - where_clause = "" - if filters: - conditions = [] - for field, value in filters.items(): - if isinstance(value, str): - # Escape single quotes in SQL strings - escaped_value = value.replace("'", "''") - conditions.append(f"{field} = '{escaped_value}'") - elif value is None: - conditions.append(f"{field} IS NULL") - else: - conditions.append(f"{field} = {value}") - where_clause = "WHERE " + " AND ".join(conditions) - - # GROUP BY - group_clause = f"GROUP BY {', '.join(group_by)}" if group_by else "" - - # HAVING - having_clause = "" - if having: - conditions = [f"{expr} > {value}" for expr, value in having.items()] - having_clause = "HAVING " + " AND ".join(conditions) - - # ORDER BY - order_clause = f"ORDER BY {order_by}" if order_by else "" - - # Build SQL - sql = f'SELECT {select_clause} FROM "{resource_id}" {where_clause} {group_clause} {having_clause} {order_clause} LIMIT {limit}'.strip() + limit_int = SafeSQLBuilder.clamp_limit(limit) + except ValueError as e: + return {"error": True, "message": str(e)} + + sql = ( + f'SELECT {select_clause} FROM "{resource_id}"' + f"{where_clause}{group_clause}{having_clause}{order_clause}" + f" LIMIT {limit_int}" + ) return await self.execute_sql(sql) diff --git a/plugins/ckan/sql_validator.py b/plugins/ckan/sql_validator.py index a3b3400..ae73524 100644 --- a/plugins/ckan/sql_validator.py +++ b/plugins/ckan/sql_validator.py @@ -1,19 +1,51 @@ -"""SQL validator for CKAN plugin. +"""SQL validator and safe SQL builder for CKAN plugin. -Provides security validation for SQL queries to prevent SQL injection -and destructive operations. +Two concerns live here: + +- ``SQLValidator`` hardens the ``execute_sql`` path. It rejects anything that + isn't a single SELECT against a UUID-quoted CKAN resource (or a CTE alias + thereof). Comments are stripped before keyword/function scanning so that + block-comment obfuscation cannot slip forbidden tokens past the check. + +- ``SafeSQLBuilder`` powers ``aggregate_data``. It validates every identifier, + metric expression, filter value, and LIMIT against an allowlist so that + caller-supplied strings can never reach the generated SQL unescaped. """ import re -from typing import Tuple, Optional +from typing import Any, List, Optional, Set, Tuple import sqlparse +from sqlparse.sql import Identifier, IdentifierList, Parenthesis, TokenList +from sqlparse.tokens import Keyword + + +_UUID_RE = re.compile( + r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$", + re.IGNORECASE, +) +_SIMPLE_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +_METRIC_RE = re.compile( + r"^\s*(count|sum|avg|min|max|stddev)\s*" + r"\(\s*(?:(distinct)\s+)?([A-Za-z_][A-Za-z0-9_]*)\s*\)\s*$", + re.IGNORECASE, +) +_COUNT_STAR_RE = re.compile(r"^\s*count\s*\(\s*\*\s*\)\s*$", re.IGNORECASE) +_BLOCK_COMMENT_RE = re.compile(r"/\*.*?\*/", re.DOTALL) +_LINE_COMMENT_RE = re.compile(r"--[^\n]*") + + +def _strip_comments(sql: str) -> str: + """Remove SQL comments so content-based scans can't be hidden behind them.""" + without_block = _BLOCK_COMMENT_RE.sub(" ", sql) + return _LINE_COMMENT_RE.sub(" ", without_block) class SQLValidator: """Validates SQL queries for security before execution.""" - MAX_SQL_LENGTH = 50000 + MAX_SQL_LENGTH = 8192 + DEFAULT_ROW_LIMIT = 10000 FORBIDDEN_KEYWORDS = [ "INSERT", "UPDATE", @@ -29,70 +61,337 @@ class SQLValidator: "CALL", "DECLARE", "SET", + "PREPARE", + "COPY", + "LISTEN", + "NOTIFY", + "VACUUM", + "ANALYZE", + "CLUSTER", + "REINDEX", + "LOAD", + "DO", + ] + FORBIDDEN_FUNCTIONS = [ + "xp_cmdshell", + "pg_sleep", + "pg_read_file", + "pg_read_binary_file", + "pg_ls_dir", + "pg_stat_file", + "lo_import", + "lo_export", + "current_setting", + "set_config", + "dblink", ] @staticmethod - def validate_query(sql: str) -> Tuple[bool, Optional[str]]: - """Validate SQL security. Returns (is_valid, error_message). - - Args: - sql: SQL query string to validate - - Returns: - Tuple of (is_valid: bool, error_message: Optional[str]) - If is_valid is True, error_message is None. - If is_valid is False, error_message contains the reason. - """ - # 1. Basic checks + def validate_query(sql: Any) -> Tuple[bool, Optional[str]]: + """Validate SQL security. Returns (is_valid, error_message).""" if not sql or not isinstance(sql, str): return False, "SQL must be non-empty string" sql = sql.strip() + if not sql: + return False, "SQL must be non-empty string" if len(sql) > SQLValidator.MAX_SQL_LENGTH: - return ( - False, - f"SQL too long (max {SQLValidator.MAX_SQL_LENGTH})", - ) + return False, f"SQL too long (max {SQLValidator.MAX_SQL_LENGTH})" + + # Strip comments so keyword/function scans can't be bypassed by hiding + # payloads inside /* ... */ or -- comments. + sql_scan = _strip_comments(sql) - # 2. Block forbidden keywords (check before SELECT check to get specific error messages) for keyword in SQLValidator.FORBIDDEN_KEYWORDS: - if re.search(rf"\b{keyword}\b", sql, re.IGNORECASE): + if re.search(rf"\b{keyword}\b", sql_scan, re.IGNORECASE): return False, f"Forbidden keyword: {keyword}" - # 3. Must start with SELECT or WITH (for CTEs) - sql_upper = sql.upper().strip() + for fn in SQLValidator.FORBIDDEN_FUNCTIONS: + if re.search(rf"\b{re.escape(fn)}\b", sql_scan, re.IGNORECASE): + return False, f"Forbidden function: {fn}" + + sql_upper = sql_scan.lstrip().upper() if not (sql_upper.startswith("SELECT") or sql_upper.startswith("WITH")): return False, "Only SELECT queries allowed" - # 4. Block dangerous patterns - patterns = [ - (r";.*(?:DROP|DELETE|INSERT)", "Multiple statements detected"), - (r"--.*(?:DROP|DELETE)", "Dangerous comment detected"), - (r"xp_cmdshell", "Command execution detected"), + for pattern, msg in [ (r"into\s+outfile", "File write detected"), - (r"pg_sleep", "Sleep function detected"), - ] - for pattern, msg in patterns: - if re.search(pattern, sql, re.IGNORECASE): + (r"into\s+dumpfile", "File write detected"), + ]: + if re.search(pattern, sql_scan, re.IGNORECASE): return False, msg - # 5. Validate with sqlparse try: parsed = sqlparse.parse(sql) - if len(parsed) != 1: - return False, "Multiple statements not allowed" - statement_type = parsed[0].get_type() - # sqlparse returns "SELECT" for SELECT statements and CTEs (WITH ... SELECT) - # If type is None, it might be a CTE - we already validated it starts with WITH or SELECT above - if statement_type is not None and statement_type != "SELECT": - return False, "Only SELECT statements allowed" except Exception as e: - return False, f"SQL parsing error: {str(e)}" + return False, f"SQL parsing error: {e}" + if len(parsed) != 1: + return False, "Multiple statements not allowed" + statement = parsed[0] + statement_type = statement.get_type() + if statement_type is not None and statement_type not in ("SELECT", "UNKNOWN"): + return False, "Only SELECT statements allowed" - # 6. Validate resource IDs are UUIDs - resource_ids = re.findall(r'"([a-f0-9-]{36})"', sql, re.IGNORECASE) - uuid_pattern = r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$" - for rid in resource_ids: - if not re.match(uuid_pattern, rid, re.IGNORECASE): - return False, f"Invalid UUID format: {rid}" + try: + cte_aliases = _extract_cte_aliases(statement) + targets = _extract_from_join_targets(statement) + except Exception as e: + return False, f"Could not parse FROM/JOIN clause: {e}" + + if not targets: + return False, ( + "No FROM/JOIN target found (query must reference a resource)" + ) + + for name, parent in targets: + if parent is not None: + return False, ( + f"Schema-qualified FROM/JOIN target not allowed: " + f"{parent}.{name}" + ) + if _UUID_RE.match(name): + continue + if name in cte_aliases: + continue + return False, ( + f"FROM/JOIN target must be a UUID-quoted resource or CTE alias: " + f"{name}" + ) return True, None + + @classmethod + def enforce_row_limit(cls, sql: str) -> str: + """Append ``LIMIT`` to an already-validated query if it lacks one. + + Bounds the upstream scan cost on CKAN: every ``execute_sql`` path + resolves to a query capped at ``DEFAULT_ROW_LIMIT`` rows even if the + caller forgot to set one. A user-supplied top-level ``LIMIT`` is + preserved as-is; a ``LIMIT`` buried inside a subquery or CTE does not + count — we only treat the outermost statement. + """ + try: + parsed = sqlparse.parse(sql) + except Exception: + return sql + if not parsed: + return sql + statement = parsed[0] + for tok in statement.tokens: + if isinstance(tok, Parenthesis): + continue + if tok.ttype in Keyword and tok.normalized.upper() == "LIMIT": + return sql + stripped = sql.rstrip().rstrip(";").rstrip() + return f"{stripped} LIMIT {cls.DEFAULT_ROW_LIMIT}" + + +def _extract_cte_aliases(statement: TokenList) -> Set[str]: + """If statement is a CTE (``WITH ...``), collect the alias names.""" + aliases: Set[str] = set() + for tok in statement.tokens: + if tok.is_whitespace: + continue + if tok.ttype in Keyword and tok.normalized.upper() == "WITH": + # Next non-whitespace token should be the alias declarations. + idx = statement.tokens.index(tok) + for nxt in statement.tokens[idx + 1 :]: + if nxt.is_whitespace: + continue + if isinstance(nxt, IdentifierList): + for ident in nxt.get_identifiers(): + if isinstance(ident, Identifier): + name = ident.get_real_name() + if name: + aliases.add(name) + elif isinstance(nxt, Identifier): + name = nxt.get_real_name() + if name: + aliases.add(name) + break + break + # Not a CTE. + return aliases + return aliases + + +def _extract_from_join_targets( + statement: TokenList, +) -> List[Tuple[str, Optional[str]]]: + """Walk sqlparse tokens to extract every FROM/JOIN table reference. + + Returns a list of ``(name, parent)`` tuples. ``parent`` is the schema + qualifier (e.g. ``pg_catalog``) if present, otherwise ``None``. + Subqueries are recursed into rather than recorded. Aliases attached to + CTEs and subqueries are skipped because the subquery's inner FROM is + what we care about. + """ + results: List[Tuple[str, Optional[str]]] = [] + + def record(ident: Identifier) -> None: + name = ident.get_real_name() + parent = ident.get_parent_name() + if name is None: + # Couldn't parse — be conservative and reject. + results.append((str(ident).strip(), "?")) + return + results.append((name, parent)) + + def walk(token_list: TokenList) -> None: + expecting = False + for tok in token_list.tokens: + if tok.is_whitespace: + continue + + if tok.ttype in Keyword: + upper = tok.normalized.upper() + if upper == "FROM" or "JOIN" in upper: + expecting = True + else: + expecting = False + continue + + if isinstance(tok, Parenthesis): + walk(tok) + expecting = False + continue + + if expecting: + if isinstance(tok, IdentifierList): + for ident in tok.get_identifiers(): + if isinstance(ident, Identifier): + first = ident.token_first(skip_ws=True, skip_cm=True) + if isinstance(first, Parenthesis): + walk(first) + else: + record(ident) + elif isinstance(ident, Parenthesis): + walk(ident) + expecting = False + continue + if isinstance(tok, Identifier): + first = tok.token_first(skip_ws=True, skip_cm=True) + if isinstance(first, Parenthesis): + walk(first) + else: + record(tok) + expecting = False + continue + if isinstance(tok, TokenList): + walk(tok) + expecting = False + continue + + if isinstance(tok, TokenList): + walk(tok) + + walk(statement) + return results + + +class SafeSQLBuilder: + """Build safe SQL fragments for ``aggregate_data``. + + Every method either returns a validated, quoted SQL fragment or raises + ``ValueError``. Callers should surface ``ValueError`` as a user-visible + error without executing anything against CKAN. + """ + + MAX_LIMIT = 10000 + ALLOWED_AGG_FUNCTIONS = {"count", "sum", "avg", "min", "max", "stddev"} + + @staticmethod + def validate_resource_id(resource_id: Any) -> str: + if not isinstance(resource_id, str) or not _UUID_RE.match(resource_id): + raise ValueError( + f"resource_id must be a valid UUID (got: {resource_id!r})" + ) + return resource_id + + @staticmethod + def quote_identifier(name: Any) -> str: + """Validate a column/alias name and return its double-quoted form.""" + if not isinstance(name, str) or not _SIMPLE_IDENT_RE.match(name): + raise ValueError( + f"Invalid identifier (must match [A-Za-z_][A-Za-z0-9_]*): " + f"{name!r}" + ) + return f'"{name}"' + + @staticmethod + def validate_metric_expr(expr: Any) -> str: + """Validate an aggregate expression against an allowlist. + + Accepted forms: + - ``count(*)`` + - ``{count|sum|avg|min|max|stddev}()`` + - ``{count|sum|avg|min|max|stddev}(DISTINCT )`` + + Returns the canonicalized form with identifiers double-quoted. + """ + if not isinstance(expr, str): + raise ValueError(f"metric expression must be a string: {expr!r}") + if _COUNT_STAR_RE.match(expr): + return "count(*)" + m = _METRIC_RE.match(expr) + if not m: + raise ValueError( + "Invalid metric expression (allowed: count(*), " + "{count|sum|avg|min|max|stddev}([DISTINCT] )): " + f"{expr!r}" + ) + func = m.group(1).lower() + distinct = "DISTINCT " if m.group(2) else "" + ident = m.group(3) + return f'{func}({distinct}"{ident}")' + + @staticmethod + def build_filter_condition(field: Any, value: Any) -> str: + """Build a safe WHERE condition. + + Field names are validated as identifiers. Values are coerced: + ``None`` → ``IS NULL``, booleans → ``TRUE``/``FALSE``, numbers are + formatted, and strings are single-quoted with embedded quotes + escaped. + """ + quoted = SafeSQLBuilder.quote_identifier(field) + if value is None: + return f"{quoted} IS NULL" + if isinstance(value, bool): + return f"{quoted} = {'TRUE' if value else 'FALSE'}" + if isinstance(value, (int, float)): + return f"{quoted} = {value}" + if isinstance(value, str): + escaped = value.replace("'", "''") + return f"{quoted} = '{escaped}'" + raise ValueError( + f"Unsupported filter value type for {field!r}: " + f"{type(value).__name__}" + ) + + @staticmethod + def validate_order_by(order_by: Any) -> str: + """Validate an ``ORDER BY`` clause: `` [ASC|DESC]``.""" + if not isinstance(order_by, str): + raise ValueError(f"order_by must be a string: {order_by!r}") + m = re.match( + r"^\s*([A-Za-z_][A-Za-z0-9_]*)\s*(ASC|DESC)?\s*$", + order_by, + re.IGNORECASE, + ) + if not m: + raise ValueError( + f"Invalid order_by (expected identifier [ASC|DESC]): " + f"{order_by!r}" + ) + ident = f'"{m.group(1)}"' + direction = f" {m.group(2).upper()}" if m.group(2) else "" + return f"{ident}{direction}" + + @staticmethod + def clamp_limit(limit: Any) -> int: + """Accept a positive int and clamp to ``MAX_LIMIT``.""" + if isinstance(limit, bool) or not isinstance(limit, int): + raise ValueError(f"limit must be an integer: {limit!r}") + if limit < 1: + raise ValueError(f"limit must be >= 1: {limit}") + return min(limit, SafeSQLBuilder.MAX_LIMIT) diff --git a/server/http_handler.py b/server/http_handler.py index e73a740..c075aeb 100644 --- a/server/http_handler.py +++ b/server/http_handler.py @@ -50,6 +50,12 @@ _mcp_server: Optional[MCPServer] = None _config: Optional[Dict[str, Any]] = None +# Reject JSON-RPC request bodies larger than this before parsing. The MCP +# surface is small — every legitimate tool call fits well under a few KB — +# so this is a cheap DoS guard against attackers flooding the Lambda with +# megabyte-sized payloads. +MAX_BODY_SIZE = 65536 + def _load_config() -> Dict[str, Any]: """Load configuration from environment or embedded config. @@ -226,6 +232,39 @@ async def handle_request( error_body, ) + # Body size cap — reject oversized payloads before JSON parse. + body_bytes = ( + len(body.encode("utf-8")) if isinstance(body, str) else len(body or b"") + ) + if body_bytes > MAX_BODY_SIZE: + duration_ms = (time.perf_counter() - start_time) * 1000 + error_body = json.dumps( + { + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32600, + "message": "Payload too large", + "data": ( + f"Request body {body_bytes} bytes exceeds " + f"{MAX_BODY_SIZE} byte limit" + ), + }, + } + ) + logger.warning( + f"413 error: body {body_bytes} bytes exceeds {MAX_BODY_SIZE}", + extra={ + "request_id": request_id, + "request_path": path, + "http_method": method, + "duration_ms": duration_ms, + }, + ) + error_headers = {"Content-Type": "application/json"} + error_headers.update(self._get_cors_headers()) + return (413, error_headers, error_body) + # Parse JSON to check if this is an initialize request # NOTE: This is intentionally parsing the JSON body separately from the # later parsing in _mcp_server.handle_http_request(). This early parsing diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index 20a7fc7..01fb657 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -403,8 +403,7 @@ async def test_query_data_passes_filters(self, ckan_config): params = call_args[1]["json"] assert params["resource_id"] == "resource-123" assert params["limit"] == 50 - assert params["filters[status]"] == "Open" - assert params["filters[category]"] == "311" + assert params["filters"] == {"status": "Open", "category": "311"} class TestExecuteTool: @@ -496,7 +495,7 @@ async def test_execute_tool_execute_sql_succeeds(self, ckan_config): result = await plugin.execute_tool( "execute_sql", { - "sql": 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" LIMIT 1' + "sql": 'SELECT * FROM "11111111-2222-3333-4444-555555555555" LIMIT 1' }, ) @@ -599,7 +598,12 @@ async def test_execute_sql_returns_error_when_ckan_body_has_success_false( mock_response_sql = Mock() mock_response_sql.json.return_value = { "success": False, - "error": {"message": 'relation "fake-uuid" does not exist'}, + "error": { + "message": ( + 'relation "11111111-2222-3333-4444-555555555555" ' + "does not exist" + ) + }, } mock_response_sql.raise_for_status = Mock() mock_client.post = AsyncMock( @@ -610,7 +614,12 @@ async def test_execute_sql_returns_error_when_ckan_body_has_success_false( await plugin.initialize() result = await plugin.execute_tool( "execute_sql", - {"sql": 'SELECT * FROM "fake-uuid" LIMIT 1'}, + { + "sql": ( + 'SELECT * FROM ' + '"11111111-2222-3333-4444-555555555555" LIMIT 1' + ) + }, ) assert result.success is False @@ -647,7 +656,7 @@ async def test_aggregate_data_returns_error_when_ckan_body_has_success_false( result = await plugin.execute_tool( "aggregate_data", { - "resource_id": "bad-resource-id", + "resource_id": "11111111-2222-3333-4444-555555555555", "metrics": {"count": "count(*)"}, }, ) diff --git a/tests/test_http_handler.py b/tests/test_http_handler.py index 0b617f6..52f3b1a 100644 --- a/tests/test_http_handler.py +++ b/tests/test_http_handler.py @@ -9,7 +9,12 @@ import os from unittest.mock import AsyncMock, MagicMock, patch -from server.http_handler import UniversalHTTPHandler, _initialize_server, _load_config +from server.http_handler import ( + MAX_BODY_SIZE, + UniversalHTTPHandler, + _initialize_server, + _load_config, +) from core.validators import ConfigurationError @@ -555,3 +560,63 @@ async def test_initialize_server_raises_on_configuration_error(self): await _initialize_server() assert "Configuration error" in str(exc_info.value) + + +class TestBodySizeCap: + """The handler rejects request bodies larger than ``MAX_BODY_SIZE``.""" + + @pytest.mark.asyncio + async def test_oversized_body_returns_413_without_parsing(self): + """A body one byte over the cap is rejected with 413 and never reaches the MCP server.""" + handler = UniversalHTTPHandler() + oversized = "x" * (MAX_BODY_SIZE + 1) + + with patch("server.http_handler._mcp_server") as mock_mcp_server: + mock_mcp_server.handle_http_request = AsyncMock() + + status, headers, body = await handler.handle_request( + method="POST", + path="/mcp", + body=oversized, + headers={}, + ) + + assert status == 413 + mock_mcp_server.handle_http_request.assert_not_called() + payload = json.loads(body) + assert payload["error"]["message"] == "Payload too large" + assert str(MAX_BODY_SIZE) in payload["error"]["data"] + assert headers["Content-Type"] == "application/json" + + @pytest.mark.asyncio + async def test_body_exactly_at_cap_is_accepted(self): + """A body exactly at the cap is accepted (only strictly over should 413).""" + handler = UniversalHTTPHandler() + # Build a request that serializes to exactly MAX_BODY_SIZE bytes. + base = {"jsonrpc": "2.0", "id": 1, "method": "ping", "params": {"pad": ""}} + overhead = len(json.dumps(base)) + base["params"]["pad"] = "x" * (MAX_BODY_SIZE - overhead) + body_str = json.dumps(base) + assert len(body_str.encode("utf-8")) == MAX_BODY_SIZE + + with ( + patch("server.http_handler._initialize_server"), + patch("server.http_handler._mcp_server") as mock_mcp_server, + ): + mock_mcp_server.handle_http_request = AsyncMock( + return_value={ + "statusCode": 200, + "headers": {}, + "body": json.dumps({"result": "ok"}), + } + ) + + status, _, _ = await handler.handle_request( + method="POST", + path="/mcp", + body=body_str, + headers={}, + ) + + assert status == 200 + mock_mcp_server.handle_http_request.assert_called_once() diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index 6131cb2..4dfb980 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -4,7 +4,9 @@ and destructive operations while allowing valid SELECT queries. """ -from plugins.ckan.sql_validator import SQLValidator +import pytest + +from plugins.ckan.sql_validator import SafeSQLBuilder, SQLValidator class TestValidSelectQueries: @@ -12,14 +14,14 @@ class TestValidSelectQueries: def test_simple_select_passes(self): """Test that simple SELECT query passes.""" - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_where_clause_passes(self): """Test that SELECT with WHERE clause passes.""" - sql = "SELECT * FROM \"abc-123-def-456-ghi-789-012-345-678-901\" WHERE status = 'Open'" + sql = "SELECT * FROM \"11111111-2222-3333-4444-555555555555\" WHERE status = 'Open'" is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None @@ -27,7 +29,7 @@ def test_select_with_where_clause_passes(self): def test_select_with_order_by_passes(self): """Test that SELECT with ORDER BY passes.""" sql = ( - 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" ORDER BY date DESC' + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" ORDER BY date DESC' ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True @@ -35,35 +37,35 @@ def test_select_with_order_by_passes(self): def test_select_with_limit_passes(self): """Test that SELECT with LIMIT passes.""" - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" LIMIT 10' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555" LIMIT 10' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_specific_columns_passes(self): """Test that SELECT with specific columns passes.""" - sql = 'SELECT id, name, status FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT id, name, status FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_count_passes(self): """Test that SELECT with COUNT passes.""" - sql = 'SELECT COUNT(*) FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT COUNT(*) FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_group_by_passes(self): """Test that SELECT with GROUP BY passes.""" - sql = 'SELECT status, COUNT(*) FROM "abc-123-def-456-ghi-789-012-345-678-901" GROUP BY status' + sql = 'SELECT status, COUNT(*) FROM "11111111-2222-3333-4444-555555555555" GROUP BY status' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_join_passes(self): """Test that SELECT with JOIN passes.""" - sql = 'SELECT a.* FROM "abc-123-def-456-ghi-789-012-345-678-901" a JOIN "def-456-ghi-789-012-345-678-901-234" b ON a.id = b.id' + sql = 'SELECT a.* FROM "11111111-2222-3333-4444-555555555555" a JOIN "66666666-7777-8888-9999-aaaaaaaaaaaa" b ON a.id = b.id' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None @@ -72,7 +74,7 @@ def test_select_with_cte_passes(self): """Test that SELECT with CTE passes.""" sql = """ WITH subquery AS ( - SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" + SELECT * FROM "11111111-2222-3333-4444-555555555555" ) SELECT * FROM subquery """ @@ -82,7 +84,7 @@ def test_select_with_cte_passes(self): def test_select_with_window_function_passes(self): """Test that SELECT with window functions passes.""" - sql = 'SELECT *, RANK() OVER (PARTITION BY status ORDER BY date) FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT *, RANK() OVER (PARTITION BY status ORDER BY date) FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None @@ -96,7 +98,7 @@ def test_select_with_valid_uuid_format_passes(self): def test_select_case_insensitive_passes(self): """Test that SELECT in lowercase passes.""" - sql = 'select * from "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'select * from "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None @@ -108,7 +110,7 @@ class TestRejectDestructiveOperations: def test_insert_statement_rejected(self): """Test that INSERT statements are rejected.""" sql = ( - "INSERT INTO \"abc-123-def-456-ghi-789-012-345-678-901\" VALUES (1, 'test')" + "INSERT INTO \"11111111-2222-3333-4444-555555555555\" VALUES (1, 'test')" ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False @@ -117,7 +119,7 @@ def test_insert_statement_rejected(self): def test_update_statement_rejected(self): """Test that UPDATE statements are rejected.""" - sql = "UPDATE \"abc-123-def-456-ghi-789-012-345-678-901\" SET status = 'Closed'" + sql = "UPDATE \"11111111-2222-3333-4444-555555555555\" SET status = 'Closed'" is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -125,7 +127,7 @@ def test_update_statement_rejected(self): def test_delete_statement_rejected(self): """Test that DELETE statements are rejected.""" - sql = 'DELETE FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE id = 1' + sql = 'DELETE FROM "11111111-2222-3333-4444-555555555555" WHERE id = 1' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -133,7 +135,7 @@ def test_delete_statement_rejected(self): def test_drop_statement_rejected(self): """Test that DROP statements are rejected.""" - sql = 'DROP TABLE "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'DROP TABLE "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -150,7 +152,7 @@ def test_create_statement_rejected(self): def test_alter_statement_rejected(self): """Test that ALTER statements are rejected.""" sql = ( - 'ALTER TABLE "abc-123-def-456-ghi-789-012-345-678-901" ADD COLUMN test INT' + 'ALTER TABLE "11111111-2222-3333-4444-555555555555" ADD COLUMN test INT' ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False @@ -159,7 +161,7 @@ def test_alter_statement_rejected(self): def test_truncate_statement_rejected(self): """Test that TRUNCATE statements are rejected.""" - sql = 'TRUNCATE TABLE "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'TRUNCATE TABLE "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -167,7 +169,7 @@ def test_truncate_statement_rejected(self): def test_grant_statement_rejected(self): """Test that GRANT statements are rejected.""" - sql = 'GRANT SELECT ON "abc-123-def-456-ghi-789-012-345-678-901" TO user' + sql = 'GRANT SELECT ON "11111111-2222-3333-4444-555555555555" TO user' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -175,7 +177,7 @@ def test_grant_statement_rejected(self): def test_revoke_statement_rejected(self): """Test that REVOKE statements are rejected.""" - sql = 'REVOKE SELECT ON "abc-123-def-456-ghi-789-012-345-678-901" FROM user' + sql = 'REVOKE SELECT ON "11111111-2222-3333-4444-555555555555" FROM user' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -188,7 +190,7 @@ class TestRejectSQLInjection: def test_multiple_statements_rejected(self): """Test that multiple statements are rejected.""" sql = ( - 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901"; DROP TABLE users;' + 'SELECT * FROM "11111111-2222-3333-4444-555555555555"; DROP TABLE users;' ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False @@ -197,7 +199,7 @@ def test_multiple_statements_rejected(self): def test_multiple_select_statements_rejected(self): """Test that multiple SELECT statements are rejected.""" - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901"; SELECT * FROM "def-456-ghi-789-012-345-678-901-234"' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555"; SELECT * FROM "66666666-7777-8888-9999-aaaaaaaaaaaa"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -205,7 +207,7 @@ def test_multiple_select_statements_rejected(self): def test_dangerous_comment_rejected(self): """Test that dangerous comments are rejected.""" - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" -- DROP TABLE users' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555" -- DROP TABLE users' is_valid, error = SQLValidator.validate_query(sql) # This might pass if comment handling is lenient, but should ideally fail # The validator should catch DROP in comments @@ -214,7 +216,7 @@ def test_dangerous_comment_rejected(self): def test_command_execution_pattern_rejected(self): """Test that command execution patterns are rejected.""" - sql = "SELECT * FROM \"abc-123-def-456-ghi-789-012-345-678-901\" WHERE xp_cmdshell('dir')" + sql = "SELECT * FROM \"11111111-2222-3333-4444-555555555555\" WHERE xp_cmdshell('dir')" is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -222,7 +224,7 @@ def test_command_execution_pattern_rejected(self): def test_file_write_pattern_rejected(self): """Test that file write patterns are rejected.""" - sql = 'SELECT * INTO OUTFILE "/tmp/test" FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT * INTO OUTFILE "/tmp/test" FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -231,7 +233,7 @@ def test_file_write_pattern_rejected(self): def test_sleep_function_rejected(self): """Test that sleep functions are rejected.""" sql = ( - 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE pg_sleep(10)' + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE pg_sleep(10)' ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False @@ -241,7 +243,7 @@ def test_sleep_function_rejected(self): def test_union_based_injection_detected(self): """Test that UNION-based injection attempts are detected.""" # This should fail because it's not a valid SELECT structure - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" UNION SELECT * FROM users' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555" UNION SELECT * FROM users' is_valid, error = SQLValidator.validate_query(sql) # UNION might be valid in some contexts, but should be checked # The validator should parse and validate the structure @@ -278,7 +280,7 @@ def test_whitespace_only_rejected(self): def test_too_long_query_rejected(self): """Test that queries exceeding max length are rejected.""" # Create a query that exceeds MAX_SQL_LENGTH - base_query = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE ' + base_query = 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE ' padding = "x" * (SQLValidator.MAX_SQL_LENGTH + 100) sql = base_query + padding is_valid, error = SQLValidator.validate_query(sql) @@ -317,7 +319,7 @@ def test_malformed_uuid_rejected(self): def test_uuid_without_quotes_passes_if_no_uuid_check(self): """Test that UUID without quotes might pass (depends on validator).""" - sql = "SELECT * FROM abc-123-def-456-ghi-789-012-345-678-901" + sql = "SELECT * FROM 11111111-2222-3333-4444-555555555555" is_valid, error = SQLValidator.validate_query(sql) # Without quotes, UUID pattern won't match, so won't be validated # But should still pass SELECT validation @@ -330,7 +332,7 @@ class TestRejectForbiddenKeywords: def test_execute_keyword_rejected(self): """Test that EXECUTE keyword is rejected.""" sql = ( - 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE EXECUTE test' + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE EXECUTE test' ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False @@ -339,7 +341,7 @@ def test_execute_keyword_rejected(self): def test_exec_keyword_rejected(self): """Test that EXEC keyword is rejected.""" - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE EXEC test' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE EXEC test' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -347,7 +349,7 @@ def test_exec_keyword_rejected(self): def test_call_keyword_rejected(self): """Test that CALL keyword is rejected.""" - sql = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE CALL test' + sql = 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE CALL test' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False assert error is not None @@ -356,7 +358,7 @@ def test_call_keyword_rejected(self): def test_declare_keyword_rejected(self): """Test that DECLARE keyword is rejected.""" sql = ( - 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE DECLARE @var' + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE DECLARE @var' ) is_valid, error = SQLValidator.validate_query(sql) assert is_valid is False @@ -366,7 +368,7 @@ def test_declare_keyword_rejected(self): def test_set_keyword_in_where_might_pass(self): """Test that SET keyword in WHERE clause might pass (context-dependent).""" sql = ( - 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901" WHERE status = SET' + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" WHERE status = SET' ) is_valid, error = SQLValidator.validate_query(sql) # SET as a value might pass, but SET as keyword should be caught @@ -379,28 +381,28 @@ class TestEdgeCases: def test_select_with_nested_subquery_passes(self): """Test that SELECT with nested subquery passes.""" - sql = 'SELECT * FROM (SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901") sub' + sql = 'SELECT * FROM (SELECT * FROM "11111111-2222-3333-4444-555555555555") sub' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_having_clause_passes(self): """Test that SELECT with HAVING clause passes.""" - sql = 'SELECT status, COUNT(*) FROM "abc-123-def-456-ghi-789-012-345-678-901" GROUP BY status HAVING COUNT(*) > 10' + sql = 'SELECT status, COUNT(*) FROM "11111111-2222-3333-4444-555555555555" GROUP BY status HAVING COUNT(*) > 10' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_distinct_passes(self): """Test that SELECT DISTINCT passes.""" - sql = 'SELECT DISTINCT status FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT DISTINCT status FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_aggregate_functions_passes(self): """Test that SELECT with aggregate functions passes.""" - sql = 'SELECT AVG(value), MAX(value), MIN(value) FROM "abc-123-def-456-ghi-789-012-345-678-901"' + sql = 'SELECT AVG(value), MAX(value), MIN(value) FROM "11111111-2222-3333-4444-555555555555"' is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None @@ -408,7 +410,7 @@ def test_select_with_aggregate_functions_passes(self): def test_select_exactly_max_length_passes(self): """Test that query exactly at max length passes.""" # Create query exactly at MAX_SQL_LENGTH - base_query = 'SELECT * FROM "abc-123-def-456-ghi-789-012-345-678-901"' + base_query = 'SELECT * FROM "11111111-2222-3333-4444-555555555555"' padding_length = SQLValidator.MAX_SQL_LENGTH - len(base_query) if padding_length > 0: sql = base_query + " " + "x" * (padding_length - 1) @@ -418,14 +420,339 @@ def test_select_exactly_max_length_passes(self): def test_select_with_special_characters_passes(self): """Test that SELECT with special characters passes.""" - sql = "SELECT * FROM \"abc-123-def-456-ghi-789-012-345-678-901\" WHERE name = 'O'Brien'" + sql = "SELECT * FROM \"11111111-2222-3333-4444-555555555555\" WHERE name = 'O'Brien'" is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None def test_select_with_regex_patterns_passes(self): """Test that SELECT with regex patterns passes.""" - sql = "SELECT * FROM \"abc-123-def-456-ghi-789-012-345-678-901\" WHERE name ~ '^[A-Z]'" + sql = "SELECT * FROM \"11111111-2222-3333-4444-555555555555\" WHERE name ~ '^[A-Z]'" is_valid, error = SQLValidator.validate_query(sql) assert is_valid is True assert error is None + + +class TestFromJoinTargetEnforcement: + """FROM/JOIN targets must be UUID-quoted resources or CTE aliases.""" + + def test_schema_qualified_target_rejected(self): + sql = "SELECT * FROM pg_catalog.pg_user" + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "Schema-qualified" in error + + def test_information_schema_rejected(self): + sql = "SELECT * FROM information_schema.columns" + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "Schema-qualified" in error + + def test_union_to_unknown_table_rejected(self): + sql = ( + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" ' + "UNION SELECT * FROM users" + ) + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "users" in error + + def test_quoted_non_uuid_rejected(self): + sql = 'SELECT * FROM "not-a-uuid-at-all-really"' + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "UUID" in error or "Invalid" in error + + def test_cte_alias_accepted(self): + sql = ( + 'WITH sub AS (SELECT * FROM "11111111-2222-3333-4444-555555555555") ' + "SELECT * FROM sub" + ) + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is True, error + + def test_subquery_with_alias_accepted(self): + sql = ( + 'SELECT * FROM (SELECT * FROM "11111111-2222-3333-4444-555555555555")' + " sub" + ) + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is True, error + + def test_join_with_one_unknown_target_rejected(self): + sql = ( + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" a ' + 'JOIN "pg_user" b ON a.id = b.id' + ) + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + + def test_bare_select_without_from_rejected(self): + sql = "SELECT 1" + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "FROM" in error + + +class TestCommentStrippingBeforeKeywordScan: + """Forbidden keywords hidden in comments must not slip past the scanner.""" + + def test_block_comment_hiding_select_prefix_rejected(self): + sql = '/*SELECT*/ DELETE FROM "11111111-2222-3333-4444-555555555555"' + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "DELETE" in error + + def test_line_comment_hiding_delete_rejected(self): + sql = ( + '-- comment\nDELETE FROM "11111111-2222-3333-4444-555555555555"' + ) + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert "DELETE" in error + + def test_benign_block_comment_accepted(self): + sql = 'SELECT * /* hello */ FROM "11111111-2222-3333-4444-555555555555"' + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is True, error + + +class TestForbiddenFunctions: + """Postgres functions useful for data exfiltration are blocked by name.""" + + @pytest.mark.parametrize( + "fn", + [ + "pg_read_file", + "pg_ls_dir", + "pg_stat_file", + "lo_import", + "lo_export", + "current_setting", + "set_config", + "dblink", + ], + ) + def test_forbidden_function_rejected(self, fn): + sql = ( + f"SELECT {fn}('x') FROM " + '"11111111-2222-3333-4444-555555555555"' + ) + is_valid, error = SQLValidator.validate_query(sql) + assert is_valid is False + assert fn in error + + +class TestSafeSQLBuilderIdentifier: + def test_valid_identifier_quoted(self): + assert SafeSQLBuilder.quote_identifier("neighborhood") == '"neighborhood"' + + def test_underscore_and_digits(self): + assert SafeSQLBuilder.quote_identifier("col_1") == '"col_1"' + + @pytest.mark.parametrize( + "bad", + [ + "col; DROP TABLE x", + "col)", + "1col", + "col space", + "col.other", + "col--", + "", + None, + 42, + ], + ) + def test_bad_identifier_rejected(self, bad): + with pytest.raises(ValueError): + SafeSQLBuilder.quote_identifier(bad) + + +class TestSafeSQLBuilderMetric: + def test_count_star(self): + assert SafeSQLBuilder.validate_metric_expr("count(*)") == "count(*)" + + def test_count_star_whitespace(self): + assert ( + SafeSQLBuilder.validate_metric_expr(" COUNT ( * ) ") + == "count(*)" + ) + + @pytest.mark.parametrize( + "expr,expected", + [ + ("sum(amount)", 'sum("amount")'), + ("avg(value)", 'avg("value")'), + ("min(x)", 'min("x")'), + ("max(x)", 'max("x")'), + ("stddev(y)", 'stddev("y")'), + ("count(distinct user_id)", 'count(DISTINCT "user_id")'), + ], + ) + def test_aggregate_quotes_identifier(self, expr, expected): + assert SafeSQLBuilder.validate_metric_expr(expr) == expected + + @pytest.mark.parametrize( + "bad", + [ + "pg_sleep(10)", + "count(*)); DROP TABLE x--", + "sum(x + y)", + "sum(x); select 1", + "count(*) + 1", + "concat(a, b)", + "sum(x.y)", + "", + None, + ], + ) + def test_bad_metric_rejected(self, bad): + with pytest.raises(ValueError): + SafeSQLBuilder.validate_metric_expr(bad) + + +class TestSafeSQLBuilderFilter: + def test_integer_value(self): + assert SafeSQLBuilder.build_filter_condition("id", 42) == '"id" = 42' + + def test_float_value(self): + assert ( + SafeSQLBuilder.build_filter_condition("lat", 42.5) + == '"lat" = 42.5' + ) + + def test_none_value(self): + assert ( + SafeSQLBuilder.build_filter_condition("status", None) + == '"status" IS NULL' + ) + + def test_bool_true(self): + assert ( + SafeSQLBuilder.build_filter_condition("active", True) + == '"active" = TRUE' + ) + + def test_string_value_escaped(self): + assert ( + SafeSQLBuilder.build_filter_condition("name", "O'Brien") + == "\"name\" = 'O''Brien'" + ) + + def test_string_injection_escaped_not_executed(self): + got = SafeSQLBuilder.build_filter_condition("name", "x' OR 1=1--") + assert got == "\"name\" = 'x'' OR 1=1--'" + + def test_bad_field_rejected(self): + with pytest.raises(ValueError): + SafeSQLBuilder.build_filter_condition("name; DROP TABLE x", "ok") + + def test_unsupported_value_type_rejected(self): + with pytest.raises(ValueError): + SafeSQLBuilder.build_filter_condition("name", ["list"]) + + +class TestSafeSQLBuilderOrderAndLimit: + def test_order_by_plain(self): + assert SafeSQLBuilder.validate_order_by("date") == '"date"' + + def test_order_by_desc(self): + assert SafeSQLBuilder.validate_order_by("date DESC") == '"date" DESC' + + def test_order_by_asc_lower(self): + assert SafeSQLBuilder.validate_order_by("date asc") == '"date" ASC' + + @pytest.mark.parametrize( + "bad", ["date; DROP", "date, other", "1", "a.b", "", None] + ) + def test_bad_order_by_rejected(self, bad): + with pytest.raises(ValueError): + SafeSQLBuilder.validate_order_by(bad) + + def test_limit_clamped_to_max(self): + assert SafeSQLBuilder.clamp_limit(10**9) == SafeSQLBuilder.MAX_LIMIT + + def test_limit_passthrough(self): + assert SafeSQLBuilder.clamp_limit(50) == 50 + + @pytest.mark.parametrize("bad", [0, -1, "10", None, True, 1.5]) + def test_bad_limit_rejected(self, bad): + with pytest.raises(ValueError): + SafeSQLBuilder.clamp_limit(bad) + + +class TestSafeSQLBuilderResourceId: + def test_valid_uuid(self): + uuid = "11111111-2222-3333-4444-555555555555" + assert SafeSQLBuilder.validate_resource_id(uuid) == uuid + + @pytest.mark.parametrize( + "bad", + [ + "pg_catalog.pg_user", + "not-a-uuid", + "11111111-2222-3333-4444-55555555555", # too short + "", + None, + 123, + ], + ) + def test_bad_resource_id_rejected(self, bad): + with pytest.raises(ValueError): + SafeSQLBuilder.validate_resource_id(bad) + + +class TestEnforceRowLimit: + """``SQLValidator.enforce_row_limit`` appends a LIMIT if absent.""" + + UUID = "11111111-2222-3333-4444-555555555555" + + def test_appends_limit_when_missing(self): + sql = f'SELECT * FROM "{self.UUID}"' + out = SQLValidator.enforce_row_limit(sql) + assert out.endswith(f"LIMIT {SQLValidator.DEFAULT_ROW_LIMIT}") + + def test_preserves_existing_top_level_limit(self): + sql = f'SELECT * FROM "{self.UUID}" LIMIT 5' + out = SQLValidator.enforce_row_limit(sql) + assert out == sql + + def test_preserves_limit_case_insensitive(self): + sql = f'SELECT * FROM "{self.UUID}" limit 5' + out = SQLValidator.enforce_row_limit(sql) + assert out == sql + + def test_subquery_limit_does_not_count_as_top_level(self): + sql = ( + f'SELECT * FROM (SELECT * FROM "{self.UUID}" LIMIT 5) sub' + ) + out = SQLValidator.enforce_row_limit(sql) + assert out.endswith(f"LIMIT {SQLValidator.DEFAULT_ROW_LIMIT}") + + def test_cte_without_top_level_limit_gets_limit_appended(self): + sql = ( + f'WITH t AS (SELECT neighborhood FROM "{self.UUID}") ' + "SELECT * FROM t" + ) + out = SQLValidator.enforce_row_limit(sql) + assert out.endswith(f"LIMIT {SQLValidator.DEFAULT_ROW_LIMIT}") + + def test_cte_with_top_level_limit_preserved(self): + sql = ( + f'WITH t AS (SELECT neighborhood FROM "{self.UUID}") ' + "SELECT * FROM t LIMIT 3" + ) + out = SQLValidator.enforce_row_limit(sql) + assert out == sql + + def test_trailing_semicolon_stripped_before_append(self): + sql = f'SELECT * FROM "{self.UUID}";' + out = SQLValidator.enforce_row_limit(sql) + assert ";" not in out.split("LIMIT")[0] + assert out.endswith(f"LIMIT {SQLValidator.DEFAULT_ROW_LIMIT}") + + def test_trailing_whitespace_handled(self): + sql = f'SELECT * FROM "{self.UUID}" \n ' + out = SQLValidator.enforce_row_limit(sql) + assert out.endswith(f"LIMIT {SQLValidator.DEFAULT_ROW_LIMIT}") From 40702642ea4bd27f15b6383fccbb2ae83f8763a9 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 16 Apr 2026 13:03:49 -0800 Subject: [PATCH 03/12] Docs: AWS deployment + security guides; add Python stdio bridge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Document the Boston fork's AWS hosting and security posture, with portal protection as the top design constraint. Add stdio_bridge.py as a Python alternative to the Go stdio client for Claude Desktop. - docs/AWS_DEPLOYMENT.md: how this fork is hosted (us-west-2, custom domain, reserved concurrency, cp311 packaging), what changed vs. upstream's single-region default, and how to operate the stack. Leads with the portal-protection design constraint. - docs/SECURITY.md: the full rationale behind the hardening changes, organized around who is being protected — upstream portal first, deployment second, end users third. Covers the SQL validator and SafeSQLBuilder surface, rate limits and body-size cap, privacy posture (stateless, no PII, 14-day log retention, SQL truncated to 500 chars), and known gaps. - README.md: link both new docs from the documentation table. - stdio_bridge.py: minimal Python stdio-to-HTTP bridge. Reads JSON-RPC messages from stdin, POSTs them to the local/remote MCP server, writes responses to stdout. Useful where the Go client is impractical (Windows, no Go toolchain). - CLAUDE.md: repo guidance for Claude Code sessions — commands, request flow, architecture notes, single-plugin rule. Co-Authored-By: Claude Opus 4.7 --- CLAUDE.md | 72 ++++++++++++++++ README.md | 2 + docs/AWS_DEPLOYMENT.md | 185 +++++++++++++++++++++++++++++++++++++++++ docs/SECURITY.md | 143 +++++++++++++++++++++++++++++++ stdio_bridge.py | 87 +++++++++++++++++++ 5 files changed, 489 insertions(+) create mode 100644 CLAUDE.md create mode 100644 docs/AWS_DEPLOYMENT.md create mode 100644 docs/SECURITY.md create mode 100644 stdio_bridge.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..5b88205 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,72 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build & Development Commands + +```bash +# Install dependencies (uv preferred, pip fallback) +uv sync # or: pip install -r requirements.txt + +# Run local MCP server (no Lambda needed) +python3 scripts/local_server.py # Serves on http://localhost:8000/mcp +# Or: python3 local_server.py # Alternate entry point, serves on / and /mcp + +# Validate config +python3 -c "from core.validators import load_and_validate_config; load_and_validate_config('config.yaml')" + +# Tests +uv run pytest tests/ -n auto # All tests, parallel +uv run pytest tests/test_ckan_plugin.py -v # Single file +uv run pytest tests/test_ckan_plugin.py::TestClass::test_name -v # Single test +uv run pytest tests/ --cov=core --cov=plugins --cov-report=term-missing # With coverage (80% minimum) + +# Linting (ruff) +uv run ruff check core/ plugins/ server/ tests/ # Check +uv run ruff check core/ plugins/ server/ tests/ --fix # Auto-fix +uv run ruff format core/ plugins/ server/ tests/ # Format + +# Pre-commit hooks +pre-commit run --all-files + +# Go client (requires Go 1.21+) +cd client && make build + +# Deploy to AWS +./scripts/deploy.sh --environment staging +``` + +## Architecture + +**Core rule: One Fork = One MCP Server.** Each deployment runs exactly ONE plugin. This is enforced at config validation time (`core/validators.py`) and at runtime (`PluginManager.load_plugins()`). To deploy multiple MCP servers, fork the repo per plugin. + +**Request flow:** +``` +Claude (stdio) → Go client (client/) or stdio_bridge.py → HTTP POST /mcp + → Lambda (server/adapters/aws_lambda.py) or local_server.py + → server/http_handler.py → core/mcp_server.py (JSON-RPC 2.0) + → core/plugin_manager.py → Plugin → External API +``` + +**Key modules:** +- `core/interfaces.py` — Abstract bases: `MCPPlugin`, `DataPlugin`, plus `ToolDefinition`, `ToolResult`, `PluginType` enum +- `core/plugin_manager.py` — Discovers plugins by scanning `plugins/` and `custom_plugins/` for `plugin.py` files. Registers tools with `pluginname__toolname` prefix. Routes `tools/call` to the correct plugin. +- `core/mcp_server.py` — Handles MCP JSON-RPC methods: `initialize`, `tools/list`, `tools/call`, `ping` +- `core/validators.py` — Loads config from `config.yaml` (local) or `OPENCONTEXT_CONFIG` env var (Lambda). Enforces single-plugin rule. +- `server/adapters/aws_lambda.py` — AWS Lambda entry point (handler: `server.adapters.aws_lambda.lambda_handler`). Also `server/lambda_handler.py` as legacy entry point. +- `server/http_handler.py` — Cloud-agnostic HTTP handler shared by Lambda and local server +- `stdio_bridge.py` — Python stdio-to-HTTP bridge for connecting Claude Desktop/Code to the local server (alternative to Go client) + +**Built-in plugins** (`plugins/`): `ckan`, `arcgis`, `socrata` — each implements `DataPlugin` with `search_datasets`, `get_dataset`, `query_data`. Custom plugins go in `custom_plugins/` and are auto-discovered. + +## Plugin Development + +New plugins must implement `MCPPlugin` (or `DataPlugin` for data sources). Place in `custom_plugins//plugin.py`. The class must define `plugin_name`, `plugin_type`, `plugin_version` and implement `initialize()`, `shutdown()`, `get_tools()`, `execute_tool()`, `health_check()`. Tool names are auto-prefixed — return bare names from `get_tools()`. + +## Configuration + +Copy `config-example.yaml` to `config.yaml`. Enable exactly one plugin. Config supports `${ENV_VAR}` substitution. For Lambda, config is serialized to the `OPENCONTEXT_CONFIG` env var by Terraform. + +## CI + +GitHub Actions (`.github/workflows/ci.yml`) runs ruff lint/format, pip-audit, pytest with coverage, and Go tests on push to main/develop and on PRs. diff --git a/README.md b/README.md index 7b7e121..de84a69 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,8 @@ See [Getting Started](docs/GETTING_STARTED.md) for full setup. | [Getting Started](docs/GETTING_STARTED.md) | Setup and usage | | [Architecture](docs/ARCHITECTURE.md) | System design and plugins | | [Deployment](docs/DEPLOYMENT.md) | AWS, Terraform, monitoring | +| [AWS Deployment (Boston)](docs/AWS_DEPLOYMENT.md) | Boston fork: region, concurrency, domain, packaging | +| [Security](docs/SECURITY.md) | SQL hardening, rate limits, upstream-portal protection | | [Testing](docs/TESTING.md) | Local testing (Terminal, Claude, MCP Inspector) | diff --git a/docs/AWS_DEPLOYMENT.md b/docs/AWS_DEPLOYMENT.md new file mode 100644 index 0000000..2c66f4e --- /dev/null +++ b/docs/AWS_DEPLOYMENT.md @@ -0,0 +1,185 @@ +# AWS Deployment (Boston fork) + +This document describes how the Boston-specific deployment of OpenContext is hosted on AWS, what changed relative to the upstream defaults, and how to operate the stack. It complements [DEPLOYMENT.md](DEPLOYMENT.md), which covers the upstream single-Lambda/API-Gateway architecture. + +- **Public endpoint (prod):** `https://boston-data.codeforanchorage.org` +- **Upstream data source:** Boston CKAN portal at `https://data.boston.gov/` +- **Runtime:** AWS Lambda (Python 3.11) behind API Gateway, us-west-2 + +> **Design constraint:** this fork's top operational priority is **not overwhelming `data.boston.gov`**. It is a shared civic resource, not our infrastructure. Every defensive control below — reserved Lambda concurrency, API Gateway rate limits and daily quota, enforced `LIMIT` on SQL, clamped aggregation limits, body-size caps — exists to keep this MCP server from becoming the noisiest client on that portal. See [SECURITY.md §1](SECURITY.md#1-protecting-the-upstream-data-portal) for the full rationale. + +--- + +## 1. What changed in this fork + +The upstream deployment assumes a single-region (us-east-1) Lambda with a standard rate-limited API Gateway in front of it. This fork makes the following operational changes: + +### 1.1 Region moved to us-west-2 + +Terraform variables and the deploy script default to `us-west-2`: + +- `terraform/aws/prod.tfvars`, `terraform/aws/staging.tfvars`: `aws_region = "us-west-2"` +- `config.yaml`: `aws.region: "us-west-2"` + +The move is for co-location with other Code for Anchorage infrastructure and has no functional effect on the Lambda. Cost numbers in [DEPLOYMENT.md](DEPLOYMENT.md#cost-us-east-1) still apply; us-west-2 pricing is effectively identical for Lambda and API Gateway. + +### 1.2 Terraform backend extracted and renamed + +The upstream `main.tf` hard-coded an `opencontext-terraform-state` bucket in us-east-1. This fork moves the backend into its own file so the bootstrap account+region+bucket are explicit, and renames the bucket to the convention used by `scripts/setup-backend.sh`: + +`terraform/aws/backend.tf` (new file): + +```hcl +terraform { + backend "s3" { + bucket = "boston-opencontext-tfstate--us-west-2" + key = "terraform.tfstate" + region = "us-west-2" + dynamodb_table = "terraform-state-lock" + encrypt = true + } +} +``` + +The actual `backend.tf` in this repo hardcodes the Code for Anchorage AWS account ID — Terraform cannot interpolate variables into a backend block, so the literal value has to live in the file. A DynamoDB table (`terraform-state-lock`) is used for state locking — forked deployments should run `scripts/setup-backend.sh` to create both the bucket and the lock table, update the account ID in `backend.tf`, then `terraform init` against `terraform/aws/`. + +### 1.3 Reserved Lambda concurrency + +A new `lambda_reserved_concurrency` variable caps the number of concurrent Lambda invocations. Default is **10**, set in both staging and prod `.tfvars`. + +```hcl +# terraform/aws/variables.tf +variable "lambda_reserved_concurrency" { + default = 10 +} +``` + +This serves two purposes. The first is cost containment: a surprise traffic spike can't run the bill away. The second, more important one, is **protecting the upstream open-data portal**. Boston's CKAN portal is a shared civic resource; if a misbehaving client fans out into thousands of parallel SQL queries, reserved concurrency bounds how much of that load we can relay. See [SECURITY.md](SECURITY.md#3-upstream-portal-protection) for the full threat model. + +Set to `-1` to disable the cap (fall back to the account-wide concurrency limit). Don't do this in prod without a reason. + +### 1.4 API Gateway quota raised, rate limits unchanged + +``` +api_quota_limit = 3000 # was 1000 upstream +api_rate_limit = 5 # unchanged — sustained req/s +api_burst_limit = 10 # unchanged — burst req/s +``` + +The daily quota was raised to 3000 after staging traffic showed legitimate per-connector usage (tool discovery + a handful of queries per conversation) could brush against 1000/day for a single user. The per-second rate is kept low deliberately — see [SECURITY.md §2](SECURITY.md#2-rate-limiting-and-body-size). + +### 1.5 Custom domain + +Prod now fronts the API Gateway with an ACM cert and the custom domain `boston-data.codeforanchorage.org`. Staging has no custom domain (`custom_domain = ""`) — use the raw API Gateway URL from `terraform output`. + +### 1.6 Cross-platform, 3.11-pinned packaging + +Both `scripts/deploy.sh` and `.github/workflows/release.yml` were updated so the Lambda ZIP matches the runtime regardless of the build host. + +- Detects `python3` or falls back to `python` (Windows build hosts). +- Forces cp311 manylinux wheels on every dependency install: + ```bash + pip install -r requirements.txt -t ./package \ + --platform manylinux2014_x86_64 \ + --python-version 3.11 \ + --implementation cp \ + --abi cp311 \ + --only-binary :all: \ + --no-compile + ``` + Without the pin, a build host running Python 3.14 will pull cp314 wheels that fail to import at Lambda cold start with a 502 `InternalServerErrorException`. +- Builds the ZIP with Python's stdlib `zipfile` module instead of the `zip` binary, which isn't present on every runner (notably the staging CI image and Windows). + +### 1.7 `local_server.py` serves both `/` and `/mcp` + +The Claude Desktop stdio bridge posts to `/mcp`; some earlier testing tools post to `/`. The local dev server now accepts both so you can point Claude Desktop and MCP Inspector at the same endpoint without editing routes. + +### 1.8 Concrete Boston CKAN `config.yaml` + +Upstream `config.yaml` is a symlink to the DC ArcGIS example. This fork replaces it with a concrete CKAN config targeting `data.boston.gov`. ArcGIS is kept `enabled: false` in the file for reference (Boston's ArcGIS hub at `data-boston.hub.arcgis.com` returns 401 without auth; CKAN is the public entry point). + +```yaml +plugins: + ckan: + enabled: true + base_url: "https://data.boston.gov/" + portal_url: "https://data.boston.gov/" + city_name: "Boston" + timeout: 120 + arcgis: + enabled: false +``` + +--- + +## 2. Operator reference + +### 2.1 First-time bootstrap + +```bash +# 1. Create the state bucket + lock table (once per account/region) +export AWS_REGION=us-west-2 +./scripts/setup-backend.sh + +# 2. Initialize Terraform against the S3 backend +cd terraform/aws +terraform init +``` + +### 2.2 Deploying changes + +The deploy script validates `config.yaml`, builds a cp311/manylinux Lambda ZIP, and runs `terraform apply`: + +```bash +# Staging +./scripts/deploy.sh --environment staging + +# Prod +./scripts/deploy.sh --environment prod +``` + +Under the hood: + +1. Counts enabled plugins (must be exactly one — enforced by `core/validators.py`). +2. Builds `lambda-deployment.zip` with dependencies forced to cp311 manylinux wheels. +3. `terraform apply -var-file=.tfvars` against `terraform/aws/`. + +### 2.3 Environment configuration + +| Variable | Staging | Prod | +| ------------------------------- | ---------------------------- | ------------------------------------------ | +| `lambda_name` | `boston-ckan-mcp-staging` | `boston-opencontext-mcp-prod` | +| `aws_region` | `us-west-2` | `us-west-2` | +| `lambda_memory` | 512 MB | 512 MB | +| `lambda_timeout` | 120 s | 120 s | +| `lambda_reserved_concurrency` | 10 | 10 | +| `api_quota_limit` | 3000 / day | 3000 / day | +| `api_rate_limit` / `burst` | 5 / 10 req/s | 5 / 10 req/s | +| `custom_domain` | *(none)* | `boston-data.codeforanchorage.org` | + +### 2.4 Getting the endpoint URL + +```bash +cd terraform/aws +terraform output -raw api_gateway_url # Custom domain on prod, exec-api URL on staging +``` + +### 2.5 Monitoring + +CloudWatch log group `/aws/lambda/`, 14-day retention. Logs are JSON-structured (`logging.format: json` in `config.yaml`) and include a `request_id` field you can join against API Gateway access logs. + +```bash +aws logs tail /aws/lambda/boston-opencontext-mcp-prod --follow --region us-west-2 +``` + +### 2.6 Cost + +Expected steady-state cost at current quota is well under \$5/month: at 3000 requests/day × 30 days × 512 MB × ~1 s, Lambda runs roughly \$1–2/month. API Gateway REST API adds ~\$3.50 per million requests; at 100k/month that is ~\$0.35. Route 53 hosted zone + ACM cert are the fixed floor (~\$0.50/month). + +--- + +## 3. Known limitations + +- **Single-region, single-AZ.** No failover. Fine for a civic-data read proxy; not for critical services. +- **Reserved concurrency is a fuse, not a queue.** Beyond 10 in-flight requests, API Gateway returns 429. Clients must retry with backoff. +- **ArcGIS plugin is disabled.** Enabling it requires an authenticated portal; Boston's hub returns 401 without auth. diff --git a/docs/SECURITY.md b/docs/SECURITY.md new file mode 100644 index 0000000..87ae626 --- /dev/null +++ b/docs/SECURITY.md @@ -0,0 +1,143 @@ +# Security, privacy, and portal protection + +This document covers the controls this fork adds to protect three distinct parties: + +1. **The upstream open-data portal** (`data.boston.gov`) — the most important, because it is a shared civic resource, not our infrastructure. +2. **This MCP deployment** — AWS account cost and availability. +3. **End users** — privacy of who asks what. + +All three are addressed by the same change set. The sections below describe what was added, why, and what to look at if you are forking this for a different portal. + +> **Upstream-portal-first ethos.** An MCP server is a traffic amplifier: a single LLM conversation can fan out into many upstream queries. Being a good citizen of someone else's public API is the top design constraint in this fork. + +--- + +## 1. Protecting the upstream data portal + +Boston's CKAN portal at `data.boston.gov` is a public, unauthenticated civic resource. It is shared by journalists, researchers, city staff, and anyone else building on the open-data ecosystem. An MCP server in front of it can easily become the noisiest client on the portal — one Claude conversation can translate into dozens of SQL queries, each of which hits CKAN's DataStore. + +Four layers of defense keep this fork from becoming that client: + +### 1.1 Reserved Lambda concurrency (hard cap) + +`terraform/aws/variables.tf` defines `lambda_reserved_concurrency` (default 10). Only 10 Lambda invocations run concurrently; additional requests get throttled by AWS before they reach the portal. Even if a client bypasses the API Gateway rate limit (e.g. via the Lambda Function URL), they cannot drive more than 10 parallel upstream SQL queries through this deployment. + +### 1.2 API Gateway rate limit and daily quota + +- `api_rate_limit = 5` sustained req/s, `api_burst_limit = 10` — per-client-key. +- `api_quota_limit = 3000` requests/day — per-client-key. + +These are conservative on purpose. The MCP surface is small (tool discovery + a handful of queries per conversation), so 5 rps is well above legitimate use. + +### 1.3 Enforced `LIMIT` on every `execute_sql` query + +CKAN's DataStore will happily execute an unbounded `SELECT *` against a multi-million-row table. This fork rejects that implicitly: `SQLValidator.enforce_row_limit` appends `LIMIT 10000` (the `DEFAULT_ROW_LIMIT`) to any validated SQL that doesn't already declare one at the top level. A user-supplied top-level `LIMIT` is kept as-is; a `LIMIT` buried inside a subquery or CTE does not count. + +See `plugins/ckan/sql_validator.py:SQLValidator.enforce_row_limit`. + +### 1.4 `aggregate_data` `LIMIT` is clamped + +`SafeSQLBuilder.clamp_limit` (in the same file) enforces `MAX_LIMIT = 10000` on the aggregation path. A caller who asks for `limit: 999999999` gets 10000. + +### 1.5 When forking for another portal + +If you reuse this fork for another city's CKAN/ArcGIS portal, revisit: + +- **`lambda_reserved_concurrency`** — a smaller portal might need lower. +- **`DEFAULT_ROW_LIMIT` and `MAX_LIMIT`** (`plugins/ckan/sql_validator.py`) — both set to 10000; lower them if the target portal is slower or more sensitive. +- **Respect `Retry-After` and portal-published rate limits.** This fork does not yet do adaptive backoff; see [§5 Known gaps](#5-known-gaps). + +--- + +## 2. SQL injection and query validation + +The `execute_sql` and `aggregate_data` tools both forward user-controlled input into SQL sent to CKAN DataStore. Upstream had a regex-based SQL validator; this fork replaces it with a defense-in-depth validator plus a typed, allowlist-only builder for aggregation. + +### 2.1 `SQLValidator` (execute_sql path) + +`plugins/ckan/sql_validator.py:SQLValidator.validate_query` enforces: + +- **Length cap reduced from 50,000 → 8,192 bytes.** No legitimate MCP-generated query is that long. +- **Comment stripping before scanning.** `/* ... */` and `-- ...` are removed before keyword and function scans run, so obfuscated payloads like `SEL/**/ECT ... UNI/**/ON` or `DROP /* hidden */ TABLE` cannot slip past. +- **Expanded forbidden keyword list.** Upstream blocked the obvious DDL verbs; this fork adds `PREPARE`, `COPY`, `LISTEN`, `NOTIFY`, `VACUUM`, `ANALYZE`, `CLUSTER`, `REINDEX`, `LOAD`, `DO`. +- **Forbidden function list.** `xp_cmdshell`, `pg_sleep`, `pg_read_file`, `pg_read_binary_file`, `pg_ls_dir`, `pg_stat_file`, `lo_import`, `lo_export`, `current_setting`, `set_config`, `dblink`. +- **File-write pattern match.** `INTO OUTFILE` and `INTO DUMPFILE`. +- **AST-validated `FROM`/`JOIN` targets.** Uses `sqlparse` to walk the statement; every table reference must be either a UUID-quoted CKAN resource ID (e.g. `"11111111-2222-3333-4444-555555555555"`) or a CTE alias declared in the same statement. Schema-qualified targets like `pg_catalog.pg_class` are rejected. +- **Single-statement only.** Multiple statements separated by `;` are rejected. +- **SELECT/WITH only.** The statement type must be `SELECT` (including `WITH ... SELECT`). + +### 2.2 `SafeSQLBuilder` (aggregate_data path) + +The upstream `aggregate_data` implementation built SQL by string concatenation — including for user-supplied `group_by`, `metrics`, `filters`, `having`, and `order_by`. This fork rewrites the path to use `SafeSQLBuilder`, which treats every caller-supplied value as untrusted input: + +| Input | Validation | +| -------------- | ------------------------------------------------------------------------------------- | +| `resource_id` | Must match UUID regex. | +| Column name | Must match `[A-Za-z_][A-Za-z0-9_]*` — then double-quoted. | +| Metric expr | Allowlist: `count(*)`, `{count\|sum\|avg\|min\|max\|stddev}([DISTINCT] )`. | +| Filter value | Coerced by type: `None → IS NULL`, bool → `TRUE`/`FALSE`, int/float formatted, string single-quoted with `'` escaped to `''`. | +| `order_by` | Must match ` [ASC\|DESC]`. | +| `having` value | Must be numeric. | +| `limit` | Must be a positive `int`; clamped to `MAX_LIMIT = 10000`. | + +Anything not on the allowlist raises `ValueError` and is surfaced to the caller as an error; nothing is executed against CKAN. + +### 2.3 Tests + +`tests/test_sql_validator.py` and `tests/test_ckan_plugin.py` cover valid queries, each forbidden keyword and function, comment-based obfuscation, schema-qualified FROM targets, UUID validation, and every `SafeSQLBuilder` method. + +--- + +## 3. Rate limiting and body size (this deployment) + +### 3.1 Request body size cap + +`server/http_handler.py` rejects any JSON-RPC body larger than **65,536 bytes (64 KB)** with HTTP 413 before the JSON parser runs. The MCP surface is small — every legitimate tool call fits well under a few KB — so a megabyte-sized payload is either a bug or abuse. Tests: `tests/test_http_handler.py:TestBodySizeCap`. + +### 3.2 API Gateway + +See [§1.2](#12-api-gateway-rate-limit-and-daily-quota). + +### 3.3 SQL length cap + +Upstream allowed 50 KB SQL strings; this fork drops the cap to 8 KB (`SQLValidator.MAX_SQL_LENGTH`). Combined with the body-size cap, an attacker cannot inflate the work we relay to CKAN via a single huge query. + +--- + +## 4. Privacy + +### 4.1 What this server stores + +**Nothing user-identifying, by design.** This deployment is stateless: + +- No database. No user accounts. No cookies. No session tokens. +- CloudWatch logs capture per-request: `request_id`, HTTP method/path, duration, status, and (truncated) tool name and SQL. Logs retention is 14 days. +- SQL log entries are truncated to 500 characters (`plugins/ckan/plugin.py: logger.info("Executing SQL", extra={"sql": sql[:500]})`). +- API Gateway access logs may record caller IPs per AWS defaults — treat these as the only identifying data we retain. + +### 4.2 What the upstream portal sees + +From CKAN's perspective, this server is a single client. End-user identity is not forwarded: every upstream request is made by the Lambda using its own outbound IP pool. This is a privacy win (your CKAN query isn't tied to your IP) but means rate-limit abuse by one user affects everyone sharing the deployment — which is exactly why [§1](#1-protecting-the-upstream-data-portal) exists. + +### 4.3 What users should know + +Connectors built on top of this MCP pass prompts through Claude. This deployment only sees the tool calls that Claude generates — not the user's raw prompt — but those tool calls (especially `execute_sql`) may contain content the user typed. The 14-day log retention and truncation are there to minimize this, but anyone deploying this should treat CloudWatch as "may contain incidental user content" for compliance purposes. + +### 4.4 Data is public + +All data this server returns comes from `data.boston.gov`, which is public open data. There is no private, PII-bearing, or licensed content behind this API. If you fork for a portal with non-public or licensed data, that changes the threat model substantially — add authentication in front of API Gateway. + +--- + +## 5. Known gaps + +- **No adaptive backoff on upstream errors.** If CKAN starts returning 429 or 5xx, this server does not currently slow down — it just relays the error. A future change should honor `Retry-After` and apply exponential backoff. +- **No per-tool rate limiting.** The API Gateway limit is per-client-key across all tools; a caller could spend their entire 5 rps on `execute_sql`. This is fine for now (reserved concurrency is the backstop) but worth revisiting if usage patterns change. +- **Lambda Function URL is public.** The Terraform stack still creates one for debugging; it bypasses the API Gateway quota. Disable it (`create_lambda_url = false` if the variable is added, or remove the resource) before handing out the URL publicly. +- **No authentication.** API Gateway uses usage plans + API keys for rate limiting, but there is no per-user auth. Appropriate for a public open-data proxy; not appropriate for anything else. + +--- + +## 6. Reporting a vulnerability + +Please open a private security advisory on the GitHub repo rather than filing a public issue. Include a proof-of-concept request and the expected vs. actual behavior. diff --git a/stdio_bridge.py b/stdio_bridge.py new file mode 100644 index 0000000..de6cb74 --- /dev/null +++ b/stdio_bridge.py @@ -0,0 +1,87 @@ +"""Stdio-to-HTTP bridge for connecting Claude to OpenContext MCP server. + +Reads JSON-RPC messages from stdin, forwards them to the local HTTP server, +and writes responses to stdout. This bridges Claude's stdio MCP transport +to the OpenContext HTTP-based MCP server. +""" + +import json +import sys +import urllib.request +import urllib.error + +SERVER_URL = "http://localhost:8000/mcp" + + +def main(): + url = sys.argv[1] if len(sys.argv) > 1 else SERVER_URL + if not url.endswith("/mcp"): + url = url.rstrip("/") + "/mcp" + + for line in sys.stdin: + line = line.strip() + if not line: + continue + + try: + request = json.loads(line) + except json.JSONDecodeError: + error_resp = { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": "Parse error"}, + } + print(json.dumps(error_resp), flush=True) + continue + + is_notification = request.get("id") is None + + try: + req = urllib.request.Request( + url, + data=json.dumps(request).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=120) as resp: + body = resp.read().decode("utf-8") + + if not body: + if is_notification: + continue + error_resp = { + "jsonrpc": "2.0", + "id": request.get("id"), + "error": {"code": -32603, "message": "Empty response"}, + } + print(json.dumps(error_resp), flush=True) + continue + + response = json.loads(body) + if is_notification: + continue + print(json.dumps(response), flush=True) + + except urllib.error.HTTPError as e: + if is_notification: + continue + error_resp = { + "jsonrpc": "2.0", + "id": request.get("id"), + "error": {"code": -32603, "message": f"HTTP {e.code}"}, + } + print(json.dumps(error_resp), flush=True) + + except Exception as e: + if is_notification: + continue + error_resp = { + "jsonrpc": "2.0", + "id": request.get("id"), + "error": {"code": -32603, "message": str(e)}, + } + print(json.dumps(error_resp), flush=True) + + +if __name__ == "__main__": + main() From f69bbc864e32420413d64cdf4b53967ba5686605 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Tue, 21 Apr 2026 18:43:52 -0800 Subject: [PATCH 04/12] Docs: document stdio_bridge.py for Claude Desktop / Claude Code Adds Method 4 to docs/TESTING.md covering how to wire the local HTTP server to Claude Desktop (claude_desktop_config.json) and Claude Code (.mcp.json) via stdio_bridge.py. The bridge was previously only mentioned in CLAUDE.md. Co-Authored-By: Claude Opus 4.7 --- docs/TESTING.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/docs/TESTING.md b/docs/TESTING.md index 61fdeb7..b298821 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -74,6 +74,50 @@ MCP Inspector is a web-based tool for testing MCP servers. --- +## Method 4: Claude Desktop / Claude Code via `stdio_bridge.py` + +Claude Desktop and Claude Code speak MCP over stdio, not HTTP. `stdio_bridge.py` is a small Python adapter that reads JSON-RPC from stdin, forwards it to the local HTTP server, and writes responses back to stdout — a dependency-free replacement for the Go client in `client/` for cases where you just want a local stdio connection. + +Start the local server first (`python local_server.py` from the repo root — this entry point accepts both `/` and `/mcp`), then register the bridge as an MCP server. + +**Claude Desktop** — edit `claude_desktop_config.json`: + +```json +{ + "mcpServers": { + "boston-opendata": { + "command": "python", + "args": [ + "C:/projects/boston/OpenContext/stdio_bridge.py", + "http://localhost:8000/mcp" + ] + } + } +} +``` + +**Claude Code** — add an `.mcp.json` at the project root (or under `.claude/`): + +```json +{ + "mcpServers": { + "boston-opendata": { + "command": "python", + "args": [ + "C:/projects/boston/OpenContext/stdio_bridge.py", + "http://localhost:8000/mcp" + ] + } + } +} +``` + +The URL argument is optional — it defaults to `http://localhost:8000/mcp`. Pass a different URL to point at a deployed endpoint instead of the local server. + +Adjust the interpreter path (e.g. `C:/projects/boston/OpenContext/venv/Scripts/python.exe`) if you want to pin to a specific virtualenv. `stdio_bridge.py` uses only Python stdlib, so no extra installs are required. + +--- + ## Quick Checks Optional checks before starting the server. From 2f3c09fc9b0ed64d48e8878ca6501f0a2f4d99e6 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 11:00:16 -0800 Subject: [PATCH 05/12] CKAN: improve GPT-4o tool chaining (next-step hints, suggested_resource_id, search_and_query) GPT-4o struggles to extract a resource UUID from one tool result and pass it to the next call. This patch makes the CKAN plugin friendlier to weaker multi-step tool callers without changing existing tool semantics: - Tool descriptions now end with explicit "Next step" guidance naming the follow-up tool (ckan__query_data, ckan__get_dataset, etc.). - Per-parameter descriptions spell out provenance ("the `id` field inside the `resources` array returned by ckan__search_datasets"). - search_datasets / get_dataset formatted output leads with a NEXT STEP block exposing suggested_resource_id and a literal suggested_call line. - New composite tool ckan__search_and_query that chains search + query server-side, eliminating the cross-call ID-passing problem entirely. Tests: tool count 6 -> 7, three new cases for search_and_query happy path, no-match, and dataset-with-no-resources. Full CKAN suite passes (32/32). --- plugins/ckan/plugin.py | 504 +++++++++++++++++++++++++++++++++----- tests/test_ckan_plugin.py | 128 +++++++++- 2 files changed, 576 insertions(+), 56 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index 906546e..43a9cbf 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -155,20 +155,39 @@ def get_tools(self) -> List[ToolDefinition]: Returns: List of tool definitions """ + city = self.plugin_config.city_name return [ ToolDefinition( name="search_datasets", - description=f"Search for datasets in {self.plugin_config.city_name}'s open data portal", + description=( + f"Search for datasets in {city}'s open data portal by keyword.\n\n" + "Returns a list of CKAN datasets. Each dataset contains a " + "`resources` array; each resource has its own `id` (a UUID) " + "that identifies a queryable table.\n\n" + "Next step:\n" + " - EASIEST: if you just want data rows, call " + "`ckan__search_and_query` with the same query — it combines " + "search + query in one call.\n" + " - Otherwise: pick a resource from `resources[].id` in the " + "response and call `ckan__query_data` with that value as " + "`resource_id`.\n" + " - To inspect a dataset's resources first, call " + "`ckan__get_dataset` with `dataset_id` set to the dataset's " + "`id` or `name`.\n\n" + "The formatted response surfaces a `suggested_resource_id` " + "and `suggested_next_tool` line at the top — read those to " + "pick the next call." + ), input_schema={ "type": "object", "properties": { "query": { "type": "string", - "description": "Search query string", + "description": "Free-text search keywords (e.g. '311 service requests', 'building permits').", }, "limit": { "type": "integer", - "description": "Maximum number of results (default: 20)", + "description": "Maximum number of datasets to return (default: 20).", "default": 20, }, }, @@ -177,13 +196,27 @@ def get_tools(self) -> List[ToolDefinition]: ), ToolDefinition( name="get_dataset", - description=f"Get detailed information about a specific dataset from {self.plugin_config.city_name}'s open data portal", + description=( + f"Get full metadata for one dataset in {city}'s open data " + "portal, including its `resources` array.\n\n" + "Use this to find the resource UUIDs needed by " + "`ckan__query_data`, `ckan__get_schema`, " + "`ckan__aggregate_data`, or `ckan__execute_sql`. The " + "response lists each resource with its `Resource ID` " + "(a UUID).\n\n" + "Next step: call `ckan__query_data` with `resource_id` set " + "to one of the `Resource ID` values from this response." + ), input_schema={ "type": "object", "properties": { "dataset_id": { "type": "string", - "description": "Dataset ID or name", + "description": ( + "Dataset ID or slug. Provenance: the `id` " + "(or `name`) field of a dataset returned by " + "`ckan__search_datasets`. NOT a resource UUID." + ), }, }, "required": ["dataset_id"], @@ -191,21 +224,38 @@ def get_tools(self) -> List[ToolDefinition]: ), ToolDefinition( name="query_data", - description=f"Query data from a specific resource in {self.plugin_config.city_name}'s open data portal", + description=( + f"Query rows from a specific resource in {city}'s open " + "data portal.\n\n" + "The `resource_id` parameter is a CKAN resource UUID — " + "NOT a dataset ID. Get one by first calling " + "`ckan__search_datasets` or `ckan__get_dataset` and " + "reading the `id` inside the `resources` array.\n\n" + "Tip: if you only have a keyword and no resource_id yet, " + "use `ckan__search_and_query` instead — it does the " + "lookup and the data fetch in a single call." + ), input_schema={ "type": "object", "properties": { "resource_id": { "type": "string", - "description": "Resource ID to query", + "description": ( + "CKAN resource UUID (36-char, e.g. " + "'11111111-2222-3333-4444-555555555555'). " + "Provenance: the `id` field inside the " + "`resources` array returned by " + "`ckan__search_datasets` or " + "`ckan__get_dataset`. This is NOT a dataset ID." + ), }, "filters": { "type": "object", - "description": "Optional filters (field: value pairs)", + "description": "Optional filters as field:value pairs (e.g. {\"status\": \"Open\"}).", }, "limit": { "type": "integer", - "description": "Maximum number of records (default: 100)", + "description": "Maximum number of records (default: 100).", "default": 100, }, }, @@ -214,13 +264,28 @@ def get_tools(self) -> List[ToolDefinition]: ), ToolDefinition( name="get_schema", - description=f"Get schema information for a resource in {self.plugin_config.city_name}'s open data portal", + description=( + f"Get the schema (field names and types) for a resource " + f"in {city}'s open data portal.\n\n" + "Call this BEFORE `ckan__aggregate_data` or " + "`ckan__execute_sql` so you know the exact field names " + "to reference in `group_by`, `metrics`, SELECT, or WHERE " + "clauses.\n\n" + "Next step: pass the field names you discover to " + "`ckan__aggregate_data` (in `group_by` / `metrics`) or " + "to `ckan__execute_sql`." + ), input_schema={ "type": "object", "properties": { "resource_id": { "type": "string", - "description": "Resource ID", + "description": ( + "CKAN resource UUID. Provenance: the `id` " + "inside the `resources` array returned by " + "`ckan__search_datasets` or " + "`ckan__get_dataset`." + ), }, }, "required": ["resource_id"], @@ -228,25 +293,29 @@ def get_tools(self) -> List[ToolDefinition]: ), ToolDefinition( name="execute_sql", - description="""Execute raw PostgreSQL SELECT query. - -⚠️ Advanced users only. For complex queries requiring full SQL. - -Security: Only SELECT allowed. INSERT/UPDATE/DELETE blocked. - -Examples: -- Window functions: RANK() OVER (...) -- CTEs: WITH subquery AS (...) -- Complex aggregations: PERCENTILE_CONT(0.5) WITHIN GROUP - -Resource IDs must be double-quoted: FROM "uuid-here" -""", + description=( + f"Execute a raw PostgreSQL SELECT query against " + f"{city}'s CKAN datastore.\n\n" + "⚠️ Advanced users only. For complex queries requiring " + "full SQL.\n\n" + "Security: Only SELECT allowed. INSERT/UPDATE/DELETE " + "blocked.\n\n" + "Examples:\n" + "- Window functions: RANK() OVER (...)\n" + "- CTEs: WITH subquery AS (...)\n" + "- Complex aggregations: PERCENTILE_CONT(0.5) WITHIN GROUP\n\n" + "Resource IDs must be double-quoted: FROM \"uuid-here\"\n\n" + "Prerequisites:\n" + " - resource UUID for the FROM clause: get from " + "`ckan__search_datasets` or `ckan__get_dataset`.\n" + " - field names: get from `ckan__get_schema`." + ), input_schema={ "type": "object", "properties": { "sql": { "type": "string", - "description": "PostgreSQL SELECT statement", + "description": "PostgreSQL SELECT statement. Resource UUIDs in FROM must be double-quoted.", }, }, "required": ["sql"], @@ -254,26 +323,44 @@ def get_tools(self) -> List[ToolDefinition]: ), ToolDefinition( name="aggregate_data", - description=f"""Aggregate data with GROUP BY from {self.plugin_config.city_name}'s open data portal. - -Prerequisites: get_schema for field names - -Examples: -- Count by field: group_by=["neighborhood"], metrics={{count: "count(*)"}} -- Multiple metrics: metrics={{total: "count(*)", avg: "avg(field)"}} -- With filters: filters={{"status": "Open"}} - -Supports: count(*), sum(), avg(), min(), max(), stddev() -""", + description=( + f"Aggregate data with GROUP BY from {city}'s open data " + "portal.\n\n" + "Prerequisites:\n" + " - `resource_id`: get from `ckan__search_datasets` / " + "`ckan__get_dataset` (the `id` inside the `resources` " + "array).\n" + " - field names for `group_by` / `metrics`: get from " + "`ckan__get_schema`.\n\n" + "Examples:\n" + '- Count by field: group_by=["neighborhood"], ' + 'metrics={"count": "count(*)"}\n' + '- Multiple metrics: metrics={"total": "count(*)", ' + '"avg": "avg(field)"}\n' + '- With filters: filters={"status": "Open"}\n\n' + "Supports: count(*), sum(), avg(), min(), max(), stddev()." + ), input_schema={ "type": "object", "properties": { - "resource_id": {"type": "string"}, + "resource_id": { + "type": "string", + "description": ( + "CKAN resource UUID. Provenance: the `id` " + "inside the `resources` array returned by " + "`ckan__search_datasets` or " + "`ckan__get_dataset`." + ), + }, "group_by": { "type": "array", "items": {"type": "string"}, + "description": "Field names to group by. Get exact names from `ckan__get_schema`.", + }, + "metrics": { + "type": "object", + "description": "Map of alias -> aggregate expression, e.g. {\"count\": \"count(*)\"}.", }, - "metrics": {"type": "object"}, "filters": {"type": "object"}, "having": {"type": "object"}, "order_by": {"type": "string"}, @@ -282,6 +369,54 @@ def get_tools(self) -> List[ToolDefinition]: "required": ["resource_id", "metrics"], }, ), + ToolDefinition( + name="search_and_query", + description=( + f"ONE-CALL keyword-to-data for {city}'s open data " + "portal: searches for the best-matching dataset and " + "immediately returns rows from its first resource — no " + "tool chaining required.\n\n" + "Use this when you have a keyword (e.g. " + "'311 service requests', 'building permits') and want " + "actual data rows. It combines " + "`ckan__search_datasets` + `ckan__query_data` into a " + "single server-side step, so you do NOT need to extract " + "a resource_id from a previous response.\n\n" + "Returns: data rows from the chosen dataset's chosen " + "resource, plus a header showing which dataset and " + "resource were used so you can drill deeper with " + "`ckan__query_data` or `ckan__get_dataset` if needed." + ), + input_schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Free-text search keywords (e.g. '311 service requests').", + }, + "limit": { + "type": "integer", + "description": "Maximum number of data rows to return (default: 100).", + "default": 100, + }, + "filters": { + "type": "object", + "description": "Optional row-level filters as field:value pairs, applied to the matched resource.", + }, + "dataset_index": { + "type": "integer", + "description": "Which search result to use (0 = best match, default 0).", + "default": 0, + }, + "resource_index": { + "type": "integer", + "description": "Which resource within the chosen dataset to query (0 = first, default 0).", + "default": 0, + }, + }, + "required": ["query"], + }, + ), ] async def execute_tool( @@ -394,6 +529,43 @@ async def execute_tool( success=True, ) + elif tool_name == "search_and_query": + query = arguments.get("query") + if not query: + return ToolResult( + content=[], + success=False, + error_message="query is required", + ) + limit = arguments.get("limit", 100) + filters = arguments.get("filters") or {} + dataset_index = arguments.get("dataset_index", 0) + resource_index = arguments.get("resource_index", 0) + composite = await self.search_and_query( + query=query, + limit=limit, + filters=filters, + dataset_index=dataset_index, + resource_index=resource_index, + ) + if composite.get("error"): + return ToolResult( + content=[], + success=False, + error_message=composite.get( + "message", "search_and_query failed" + ), + ) + return ToolResult( + content=[ + { + "type": "text", + "text": self._format_search_and_query(composite, limit), + } + ], + success=True, + ) + elif tool_name == "aggregate_data": resource_id = arguments.get("resource_id") if not resource_id: @@ -628,6 +800,100 @@ async def aggregate_data( return await self.execute_sql(sql) + async def search_and_query( + self, + query: str, + limit: int = 100, + filters: Optional[Dict[str, Any]] = None, + dataset_index: int = 0, + resource_index: int = 0, + ) -> Dict[str, Any]: + """Search for a dataset and immediately query its first resource. + + Combines search_datasets + query_data into one server-side step so + callers don't have to extract a resource_id from a previous response. + + Returns: + Dict with either {"error": True, "message": ...} or + {"dataset": {...}, "resource": {...}, "records": [...]}. + """ + # Cap how many search results we fetch so dataset_index can pick a + # non-best match without an unbounded scan. + search_rows = max(dataset_index + 1, 5) + datasets = await self.search_datasets(query, limit=search_rows) + if not datasets: + return { + "error": True, + "message": ( + f"No datasets found for query {query!r} in " + f"{self.plugin_config.city_name}'s open data portal." + ), + } + + if dataset_index < 0 or dataset_index >= len(datasets): + return { + "error": True, + "message": ( + f"dataset_index {dataset_index} is out of range " + f"(found {len(datasets)} dataset(s))." + ), + } + + chosen_dataset = datasets[dataset_index] + resources = chosen_dataset.get("resources") or [] + if not resources: + return { + "error": True, + "message": ( + f"Dataset {chosen_dataset.get('id')!r} has no resources. " + f"Try a different dataset_index or call ckan__get_dataset " + f"to inspect available resources." + ), + } + + if resource_index < 0 or resource_index >= len(resources): + return { + "error": True, + "message": ( + f"resource_index {resource_index} is out of range for " + f"dataset {chosen_dataset.get('id')!r} " + f"(has {len(resources)} resource(s))." + ), + } + + chosen_resource = resources[resource_index] + resource_id = chosen_resource.get("id") + if not resource_id: + return { + "error": True, + "message": ( + f"Resource at index {resource_index} of dataset " + f"{chosen_dataset.get('id')!r} has no id." + ), + } + + try: + records = await self.query_data( + resource_id=resource_id, + filters=filters or None, + limit=limit, + ) + except Exception as e: + return { + "error": True, + "message": ( + f"Found dataset {chosen_dataset.get('id')!r} resource " + f"{resource_id!r} but query_data failed: {e}" + ), + } + + return { + "dataset": chosen_dataset, + "resource": chosen_resource, + "records": records, + "alternate_datasets": datasets, + } + async def health_check(self) -> bool: """Check if CKAN API is accessible. @@ -646,9 +912,43 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: if not datasets: return f"No datasets found in {self.plugin_config.city_name}'s open data portal." - lines = [ + suggested_resource_id: Optional[str] = None + suggested_dataset_id: Optional[str] = None + for ds in datasets: + resources = ds.get("resources") or [] + if resources: + suggested_resource_id = resources[0].get("id") + suggested_dataset_id = ds.get("id") + if suggested_resource_id: + break + + lines: List[str] = [] + if suggested_resource_id: + lines.extend( + [ + "=== NEXT STEP (read this first) ===", + f"suggested_resource_id: {suggested_resource_id}", + "suggested_next_tool: ckan__query_data", + f"suggested_call: ckan__query_data(resource_id=\"{suggested_resource_id}\")", + "(or use ckan__search_and_query for a one-call keyword-to-data flow)", + "===================================", + "", + ] + ) + else: + lines.extend( + [ + "=== NEXT STEP ===", + "No resource UUIDs were attached to these datasets. Call " + "ckan__get_dataset with a dataset_id below to look them up.", + "=================", + "", + ] + ) + + lines.append( f"Found {len(datasets)} dataset(s) in {self.plugin_config.city_name}'s open data portal:\n" - ] + ) for i, dataset in enumerate(datasets, 1): title = dataset.get("title", "Untitled") @@ -658,9 +958,15 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: if dataset.get("notes") else "No description" ) + resources = dataset.get("resources") or [] + first_resource_id = resources[0].get("id") if resources else None lines.append(f"{i}. {title}") - lines.append(f" ID: {dataset_id}") + lines.append(f" dataset_id: {dataset_id}") + if first_resource_id: + lines.append( + f" resource_id (use this with ckan__query_data): {first_resource_id}" + ) lines.append(f" Description: {notes}") lines.append( f" Portal: {self.plugin_config.portal_url}/dataset/{dataset_id}" @@ -669,8 +975,12 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: lines.append( f"View all datasets at: {self.plugin_config.portal_url}\n" - f"Use get_dataset tool with a dataset ID to get more details." + f"Use ckan__get_dataset with a dataset_id (above) for full resource details, " + f"or ckan__query_data with a resource_id to fetch rows." ) + if suggested_dataset_id: + # Hint for narrative chaining: makes the dataset_id discoverable too. + lines.append(f"suggested_dataset_id: {suggested_dataset_id}") return "\n".join(lines) @@ -680,17 +990,36 @@ def _format_dataset(self, dataset: Dict[str, Any]) -> str: dataset_id = dataset.get("id", "unknown") notes = dataset.get("notes", "No description") organization = dataset.get("organization", {}).get("title", "Unknown") - resources = dataset.get("resources", []) + resources = dataset.get("resources", []) or [] - lines = [ - f"Dataset: {title}", - f"ID: {dataset_id}", - f"Organization: {organization}", - f"Description: {notes}", - "", - f"Portal URL: {self.plugin_config.portal_url}/dataset/{dataset_id}", - "", - ] + suggested_resource_id = ( + resources[0].get("id") if resources else None + ) + + lines: List[str] = [] + if suggested_resource_id: + lines.extend( + [ + "=== NEXT STEP (read this first) ===", + f"suggested_resource_id: {suggested_resource_id}", + "suggested_next_tool: ckan__query_data", + f"suggested_call: ckan__query_data(resource_id=\"{suggested_resource_id}\")", + "===================================", + "", + ] + ) + + lines.extend( + [ + f"Dataset: {title}", + f"dataset_id: {dataset_id}", + f"Organization: {organization}", + f"Description: {notes}", + "", + f"Portal URL: {self.plugin_config.portal_url}/dataset/{dataset_id}", + "", + ] + ) if resources: lines.append(f"Resources ({len(resources)}):") @@ -699,12 +1028,15 @@ def _format_dataset(self, dataset: Dict[str, Any]) -> str: res_id = resource.get("id", "unknown") res_format = resource.get("format", "unknown") lines.append(f" {i}. {res_name} ({res_format})") - lines.append(f" Resource ID: {res_id}") + lines.append(f" resource_id: {res_id}") lines.append( - f" Use query_data tool with resource_id='{res_id}' to query this data" + f" Use ckan__query_data with resource_id=\"{res_id}\" to fetch rows." ) else: - lines.append("No resources available for this dataset.") + lines.append( + "No resources available for this dataset. Try a different " + "dataset_id or use ckan__search_datasets again." + ) return "\n".join(lines) @@ -746,6 +1078,68 @@ def _format_schema(self, fields: List[Dict[str, Any]]) -> str: return "\n".join(lines) + def _format_search_and_query( + self, composite: Dict[str, Any], limit: int + ) -> str: + """Format a search_and_query composite result for user display.""" + dataset = composite.get("dataset", {}) or {} + resource = composite.get("resource", {}) or {} + records = composite.get("records", []) or [] + alternates = composite.get("alternate_datasets", []) or [] + + dataset_id = dataset.get("id", "unknown") + dataset_title = dataset.get("title", "Untitled") + resource_id = resource.get("id", "unknown") + resource_name = resource.get("name", "Unnamed") + + lines: List[str] = [ + "=== search_and_query result ===", + f"matched_dataset: {dataset_title}", + f"dataset_id: {dataset_id}", + f"resource_id (use with ckan__query_data): {resource_id}", + f"resource_name: {resource_name}", + f"row_count: {len(records)} (limit={limit})", + "================================", + "", + ] + + if not records: + lines.append( + "No rows returned. Try broadening filters or pick a different " + "dataset/resource (see alternates below)." + ) + else: + lines.append(f"Showing up to 5 of {len(records)} record(s):") + for i, record in enumerate(records[:5], 1): + lines.append(f"Record {i}:") + for key, value in record.items(): + if key != "_id": + lines.append(f" {key}: {value}") + lines.append("") + if len(records) > 5: + lines.append(f"... and {len(records) - 5} more record(s)") + + if len(alternates) > 1: + lines.append("") + lines.append( + "Other matching datasets (pass dataset_index=N to switch):" + ) + for i, alt in enumerate(alternates): + if i == 0: + continue # the chosen one + alt_title = alt.get("title", "Untitled") + alt_id = alt.get("id", "unknown") + alt_resources = alt.get("resources") or [] + alt_resource_id = ( + alt_resources[0].get("id") if alt_resources else None + ) + lines.append(f" [dataset_index={i}] {alt_title}") + lines.append(f" dataset_id: {alt_id}") + if alt_resource_id: + lines.append(f" resource_id: {alt_resource_id}") + + return "\n".join(lines) + def _format_sql_results( self, records: List[Dict[str, Any]], fields: List[Dict[str, Any]] ) -> str: diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index 01fb657..7614c8a 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -136,7 +136,7 @@ def test_get_tools_returns_all_tools(self, ckan_config): plugin = CKANPlugin(ckan_config) tools = plugin.get_tools() - assert len(tools) == 6 + assert len(tools) == 7 tool_names = [t.name for t in tools] assert "search_datasets" in tool_names assert "get_dataset" in tool_names @@ -144,6 +144,7 @@ def test_get_tools_returns_all_tools(self, ckan_config): assert "get_schema" in tool_names assert "execute_sql" in tool_names assert "aggregate_data" in tool_names + assert "search_and_query" in tool_names def test_get_tools_includes_city_name_in_descriptions(self, ckan_config): """Test that tool descriptions include city name.""" @@ -543,6 +544,131 @@ async def test_execute_tool_execute_sql_missing_param(self, ckan_config): assert result.success is False assert "required" in result.error_message.lower() + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_succeeds(self, ckan_config): + """search_and_query returns rows from the first resource of the first match.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + # 1) search_datasets + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + { + "id": "dataset-1", + "title": "311 Service Requests", + "resources": [ + { + "id": "11111111-2222-3333-4444-555555555555", + "name": "311 CSV", + "format": "CSV", + } + ], + } + ] + } + } + mock_response_search.raise_for_status = Mock() + # 2) datastore_search (query_data) + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": { + "records": [ + {"_id": 1, "type": "Pothole"}, + {"_id": 2, "type": "Streetlight"}, + ] + } + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_search, + mock_response_query, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", {"query": "311", "limit": 10} + ) + + assert result.success is True + assert len(result.content) == 1 + text = result.content[0]["text"] + # Header surfaces the chosen IDs + assert "11111111-2222-3333-4444-555555555555" in text + assert "dataset-1" in text + # Rows from the second mocked call show up + assert "Pothole" in text or "Streetlight" in text + + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_no_matches(self, ckan_config): + """search_and_query returns an error when no datasets match.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = {"result": {"results": []}} + mock_response_search.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_search] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", {"query": "nonexistent-keyword-xyz"} + ) + + assert result.success is False + assert result.error_message is not None + assert "No datasets found" in result.error_message + + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_dataset_has_no_resources( + self, ckan_config + ): + """search_and_query reports an error when the matched dataset has no resources.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + {"id": "dataset-empty", "title": "Empty", "resources": []} + ] + } + } + mock_response_search.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_search] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", {"query": "anything"} + ) + + assert result.success is False + assert "no resources" in (result.error_message or "").lower() + @pytest.mark.asyncio async def test_execute_tool_unknown_tool(self, ckan_config): """Test executing unknown tool.""" From ddcee2e5fb7c2fbb5ea8f054ca6ad6171effcc23 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 11:20:14 -0800 Subject: [PATCH 06/12] CKAN: filter for datastore_active resources (fixes parks 404) Boston attaches each dataset as 5-7 resources (GeoJSON, KML, SHP, PDF, ArcGIS REST, CSV) but only the CSV is loaded into the queryable Postgres datastore. The previous code blindly handed the model `resources[0].id`, which is the GeoJSON one for most parks/permits/etc. datasets, and the caller then got a 404 from datastore_search. Fixes: - New _is_queryable / _first_queryable_resource helpers. - search_datasets formatter: suggested_resource_id and per-dataset resource_id lines now point at the datastore_active resource only. When no resource is queryable, say so explicitly instead of suggesting a doomed UUID. - get_dataset formatter: NEXT STEP block points at the queryable resource; resource list labels each as [QUERYABLE] or [DOWNLOAD-ONLY] and surfaces the download URL for non-queryable ones. - search_and_query: walks search results and resources to find one with datastore_active=true (configurable via dataset_index/resource_index). Falls through to the next dataset rather than 404-ing on a download- only first match. - query_data: 404s now append a hint about the datastore_active gotcha pointing at get_dataset / search_and_query so the model doesn't keep retrying the same UUID. Tool descriptions updated to call out the queryable-vs-download distinction. Tests: 35 passing (was 32). Three new regressions: - parks-style multi-resource shape skips GeoJSON and queries CSV - walks to the next dataset when first has no queryable resource - query_data 404 includes the datastore_active hint Live smoke test against data.boston.gov confirms ckan__search_and_query "parks" now returns real Park_Features rows (Iacono/Readville Playground, East Boston Memorial Park, etc.) instead of 404. --- plugins/ckan/plugin.py | 294 +++++++++++++++++++++++++++++--------- tests/test_ckan_plugin.py | 187 +++++++++++++++++++++++- 2 files changed, 416 insertions(+), 65 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index 43a9cbf..fc57777 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -84,6 +84,25 @@ async def shutdown(self) -> None: self._initialized = False logger.info("CKAN plugin shut down") + @staticmethod + def _is_queryable(resource: Dict[str, Any]) -> bool: + """A CKAN resource is queryable via datastore_search only if it has + been loaded into CKAN's Postgres datastore. Boston attaches each + dataset as 5–7 download-only resources (GeoJSON, KML, SHP, ...) plus + a single CSV that's actually loaded; only that one returns rows.""" + return bool(resource.get("datastore_active")) + + @classmethod + def _first_queryable_resource( + cls, dataset: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """Return the first resource of a dataset that is loaded into the + datastore (i.e. answers datastore_search), or None if none are.""" + for res in dataset.get("resources") or []: + if cls._is_queryable(res): + return res + return None + def _parse_ckan_error( self, response_body: Dict[str, Any], context: str = "" ) -> str: @@ -231,9 +250,19 @@ def get_tools(self) -> List[ToolDefinition]: "NOT a dataset ID. Get one by first calling " "`ckan__search_datasets` or `ckan__get_dataset` and " "reading the `id` inside the `resources` array.\n\n" - "Tip: if you only have a keyword and no resource_id yet, " - "use `ckan__search_and_query` instead — it does the " - "lookup and the data fetch in a single call." + "IMPORTANT: only resources with `datastore_active=true` " + "are queryable here. Boston datasets typically attach " + "5–7 resources (GeoJSON, KML, SHP, PDF, ArcGIS REST, " + "CSV) but only the CSV one is loaded into the " + "datastore. If you call this tool with a download-only " + "resource UUID it will return 404. The output of " + "`ckan__search_datasets` and `ckan__get_dataset` " + "labels resources as QUERYABLE or DOWNLOAD-ONLY — pick " + "the QUERYABLE one.\n\n" + "Tip: if you only have a keyword and no resource_id " + "yet, use `ckan__search_and_query` instead — it does " + "the lookup and the data fetch in a single call and " + "auto-picks the datastore-loaded resource." ), input_schema={ "type": "object", @@ -374,14 +403,21 @@ def get_tools(self) -> List[ToolDefinition]: description=( f"ONE-CALL keyword-to-data for {city}'s open data " "portal: searches for the best-matching dataset and " - "immediately returns rows from its first resource — no " - "tool chaining required.\n\n" + "immediately returns rows from its first datastore-" + "loaded resource — no tool chaining required.\n\n" "Use this when you have a keyword (e.g. " - "'311 service requests', 'building permits') and want " - "actual data rows. It combines " + "'311 service requests', 'parks', 'building permits') " + "and want actual data rows. It combines " "`ckan__search_datasets` + `ckan__query_data` into a " - "single server-side step, so you do NOT need to extract " - "a resource_id from a previous response.\n\n" + "single server-side step, so you do NOT need to " + "extract a resource_id from a previous response.\n\n" + "Auto-picks the right resource: Boston datasets " + "typically attach 5–7 resources (GeoJSON, KML, SHP, " + "PDF, ArcGIS REST, CSV) but only the CSV is loaded " + "into the queryable datastore. This tool walks the " + "search results and skips datasets / resources that " + "aren't datastore-active, so you don't get a 404 " + "from a download-only resource.\n\n" "Returns: data rows from the chosen dataset's chosen " "resource, plus a header showing which dataset and " "resource were used so you can drill deeper with " @@ -405,13 +441,22 @@ def get_tools(self) -> List[ToolDefinition]: }, "dataset_index": { "type": "integer", - "description": "Which search result to use (0 = best match, default 0).", - "default": 0, + "description": ( + "Which search result to use (0 = best " + "match). If omitted, walks the search " + "results until one with a queryable " + "(datastore_active) resource is found." + ), }, "resource_index": { "type": "integer", - "description": "Which resource within the chosen dataset to query (0 = first, default 0).", - "default": 0, + "description": ( + "Which resource within the chosen dataset " + "to query. If omitted, auto-picks the " + "first datastore_active resource (Boston " + "datasets typically attach 5–7 resources " + "but only the CSV is queryable)." + ), }, }, "required": ["query"], @@ -539,8 +584,8 @@ async def execute_tool( ) limit = arguments.get("limit", 100) filters = arguments.get("filters") or {} - dataset_index = arguments.get("dataset_index", 0) - resource_index = arguments.get("resource_index", 0) + dataset_index = arguments.get("dataset_index") + resource_index = arguments.get("resource_index") composite = await self.search_and_query( query=query, limit=limit, @@ -667,7 +712,26 @@ async def query_data( if filters: params["filters"] = filters - response = await self._call_ckan_api("datastore_search", params) + try: + response = await self._call_ckan_api("datastore_search", params) + except RuntimeError as e: + msg = str(e) + # 404 from datastore_search almost always means the resource UUID + # is real but isn't loaded into the queryable Postgres datastore + # (datastore_active=false). Surface that explicitly so callers + # don't keep retrying with the same UUID. + if "404" in msg or "not found" in msg.lower(): + raise RuntimeError( + f"{msg}\n" + "Hint: this resource may exist as a file download " + "(GeoJSON/KML/SHP/PDF) but not be loaded into the " + "datastore (datastore_active=false). Call " + "ckan__get_dataset on the parent dataset to find a " + "QUERYABLE resource (typically the CSV one), or use " + "ckan__search_and_query, which auto-picks the " + "datastore-loaded resource." + ) from e + raise return response.get("result", {}).get("records", []) async def get_schema(self, resource_id: str) -> Dict[str, Any]: @@ -805,21 +869,29 @@ async def search_and_query( query: str, limit: int = 100, filters: Optional[Dict[str, Any]] = None, - dataset_index: int = 0, - resource_index: int = 0, + dataset_index: Optional[int] = None, + resource_index: Optional[int] = None, ) -> Dict[str, Any]: - """Search for a dataset and immediately query its first resource. + """Search for a dataset and immediately query its first queryable resource. Combines search_datasets + query_data into one server-side step so callers don't have to extract a resource_id from a previous response. + Picks the first resource where datastore_active=true (skipping + download-only resources like GeoJSON/KML/SHP/PDF). When the caller + leaves dataset_index/resource_index unset, walks the search results + until one with a queryable resource is found. + Returns: Dict with either {"error": True, "message": ...} or {"dataset": {...}, "resource": {...}, "records": [...]}. """ + explicit_dataset = dataset_index is not None + explicit_resource = resource_index is not None + ds_idx = dataset_index or 0 # Cap how many search results we fetch so dataset_index can pick a # non-best match without an unbounded scan. - search_rows = max(dataset_index + 1, 5) + search_rows = max(ds_idx + 1, 10) datasets = await self.search_datasets(query, limit=search_rows) if not datasets: return { @@ -830,45 +902,99 @@ async def search_and_query( ), } - if dataset_index < 0 or dataset_index >= len(datasets): - return { - "error": True, - "message": ( - f"dataset_index {dataset_index} is out of range " - f"(found {len(datasets)} dataset(s))." - ), - } + if explicit_dataset: + if ds_idx < 0 or ds_idx >= len(datasets): + return { + "error": True, + "message": ( + f"dataset_index {ds_idx} is out of range " + f"(found {len(datasets)} dataset(s))." + ), + } + candidate_indices = [ds_idx] + else: + # Auto-walk: try the best match first, fall through to the next + # datasets if it has no queryable resource. + candidate_indices = list(range(len(datasets))) - chosen_dataset = datasets[dataset_index] - resources = chosen_dataset.get("resources") or [] - if not resources: - return { - "error": True, - "message": ( - f"Dataset {chosen_dataset.get('id')!r} has no resources. " - f"Try a different dataset_index or call ckan__get_dataset " - f"to inspect available resources." - ), - } + chosen_dataset: Optional[Dict[str, Any]] = None + chosen_resource: Optional[Dict[str, Any]] = None + skipped_summary: List[str] = [] - if resource_index < 0 or resource_index >= len(resources): + for idx in candidate_indices: + ds = datasets[idx] + resources = ds.get("resources") or [] + if not resources: + skipped_summary.append( + f" [{idx}] {ds.get('title') or ds.get('id')}: " + "no resources" + ) + continue + + if explicit_resource: + if resource_index < 0 or resource_index >= len(resources): + return { + "error": True, + "message": ( + f"resource_index {resource_index} is out of " + f"range for dataset {ds.get('id')!r} " + f"(has {len(resources)} resource(s))." + ), + } + resource = resources[resource_index] + if not self._is_queryable(resource): + return { + "error": True, + "message": ( + f"resource_index {resource_index} of dataset " + f"{ds.get('id')!r} has datastore_active=false " + "(download-only). Pick a different " + "resource_index, or omit it to auto-pick the " + "queryable one." + ), + } + chosen_dataset, chosen_resource = ds, resource + break + + queryable = self._first_queryable_resource(ds) + if queryable: + chosen_dataset, chosen_resource = ds, queryable + break + + formats = sorted( + { + (r.get("format") or "?").upper() + for r in resources + } + ) + skipped_summary.append( + f" [{idx}] {ds.get('title') or ds.get('id')}: " + f"no datastore-loaded resource (formats: {', '.join(formats)})" + ) + + if chosen_dataset is None or chosen_resource is None: + details = ( + "\n".join(skipped_summary) + if skipped_summary + else " (no datasets inspected)" + ) return { "error": True, "message": ( - f"resource_index {resource_index} is out of range for " - f"dataset {chosen_dataset.get('id')!r} " - f"(has {len(resources)} resource(s))." + f"No queryable (datastore_active) resource found among " + f"{len(datasets)} matching dataset(s) for query " + f"{query!r}.\nSkipped:\n{details}\nTry a different " + "keyword or call ckan__get_dataset to inspect resources." ), } - chosen_resource = resources[resource_index] resource_id = chosen_resource.get("id") if not resource_id: return { "error": True, "message": ( - f"Resource at index {resource_index} of dataset " - f"{chosen_dataset.get('id')!r} has no id." + f"Chosen resource of dataset {chosen_dataset.get('id')!r}" + " has no id." ), } @@ -915,12 +1041,11 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: suggested_resource_id: Optional[str] = None suggested_dataset_id: Optional[str] = None for ds in datasets: - resources = ds.get("resources") or [] - if resources: - suggested_resource_id = resources[0].get("id") + queryable = self._first_queryable_resource(ds) + if queryable and queryable.get("id"): + suggested_resource_id = queryable.get("id") suggested_dataset_id = ds.get("id") - if suggested_resource_id: - break + break lines: List[str] = [] if suggested_resource_id: @@ -930,7 +1055,10 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: f"suggested_resource_id: {suggested_resource_id}", "suggested_next_tool: ckan__query_data", f"suggested_call: ckan__query_data(resource_id=\"{suggested_resource_id}\")", - "(or use ckan__search_and_query for a one-call keyword-to-data flow)", + "(this is the datastore-loaded resource — only such " + "resources can be queried; others are file downloads.)", + "(or use ckan__search_and_query for a one-call " + "keyword-to-data flow.)", "===================================", "", ] @@ -939,8 +1067,11 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: lines.extend( [ "=== NEXT STEP ===", - "No resource UUIDs were attached to these datasets. Call " - "ckan__get_dataset with a dataset_id below to look them up.", + "None of the matched datasets have a queryable resource " + "(datastore_active=true). The attached resources are " + "file downloads only. Try a different search keyword, " + "or call ckan__get_dataset to inspect non-datastore " + "resources.", "=================", "", ] @@ -959,13 +1090,24 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: else "No description" ) resources = dataset.get("resources") or [] - first_resource_id = resources[0].get("id") if resources else None + queryable = self._first_queryable_resource(dataset) + queryable_id = queryable.get("id") if queryable else None + queryable_format = queryable.get("format") if queryable else None lines.append(f"{i}. {title}") lines.append(f" dataset_id: {dataset_id}") - if first_resource_id: + if queryable_id: + fmt = f" [{queryable_format}]" if queryable_format else "" lines.append( - f" resource_id (use this with ckan__query_data): {first_resource_id}" + f" resource_id (use this with ckan__query_data){fmt}: " + f"{queryable_id}" + ) + elif resources: + lines.append( + f" resource_id: NONE QUERYABLE — this dataset has " + f"{len(resources)} resource(s) but none are loaded into " + "the datastore (datastore_active=false). Use " + "ckan__get_dataset for download URLs." ) lines.append(f" Description: {notes}") lines.append( @@ -992,9 +1134,8 @@ def _format_dataset(self, dataset: Dict[str, Any]) -> str: organization = dataset.get("organization", {}).get("title", "Unknown") resources = dataset.get("resources", []) or [] - suggested_resource_id = ( - resources[0].get("id") if resources else None - ) + queryable = self._first_queryable_resource(dataset) + suggested_resource_id = queryable.get("id") if queryable else None lines: List[str] = [] if suggested_resource_id: @@ -1004,10 +1145,24 @@ def _format_dataset(self, dataset: Dict[str, Any]) -> str: f"suggested_resource_id: {suggested_resource_id}", "suggested_next_tool: ckan__query_data", f"suggested_call: ckan__query_data(resource_id=\"{suggested_resource_id}\")", + "(this is the datastore-loaded resource; the others are " + "file downloads only.)", "===================================", "", ] ) + elif resources: + lines.extend( + [ + "=== NEXT STEP ===", + f"This dataset has {len(resources)} resource(s) but none " + "are loaded into the datastore (datastore_active=false), " + "so ckan__query_data will not work on them. They are " + "file downloads — see URLs below.", + "=================", + "", + ] + ) lines.extend( [ @@ -1027,11 +1182,22 @@ def _format_dataset(self, dataset: Dict[str, Any]) -> str: res_name = resource.get("name", "Unnamed") res_id = resource.get("id", "unknown") res_format = resource.get("format", "unknown") - lines.append(f" {i}. {res_name} ({res_format})") + res_url = resource.get("url", "") + queryable_flag = self._is_queryable(resource) + marker = "QUERYABLE" if queryable_flag else "DOWNLOAD-ONLY" + lines.append(f" {i}. [{marker}] {res_name} ({res_format})") lines.append(f" resource_id: {res_id}") - lines.append( - f" Use ckan__query_data with resource_id=\"{res_id}\" to fetch rows." - ) + if queryable_flag: + lines.append( + f" Use ckan__query_data with resource_id=\"{res_id}\" to fetch rows." + ) + else: + if res_url: + lines.append(f" download_url: {res_url}") + lines.append( + " (not loaded into datastore — ckan__query_data " + "will return 404 for this resource_id)" + ) else: lines.append( "No resources available for this dataset. Try a different " diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index 7614c8a..fd27df1 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -567,6 +567,7 @@ async def test_execute_tool_search_and_query_succeeds(self, ckan_config): "id": "11111111-2222-3333-4444-555555555555", "name": "311 CSV", "format": "CSV", + "datastore_active": True, } ], } @@ -667,7 +668,191 @@ async def test_execute_tool_search_and_query_dataset_has_no_resources( ) assert result.success is False - assert "no resources" in (result.error_message or "").lower() + err = (result.error_message or "").lower() + assert "no queryable" in err or "no resources" in err + + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_skips_download_only_resources( + self, ckan_config + ): + """Parks-style regression: a dataset with GeoJSON/KML/SHP first and a + single datastore_active CSV — the composite tool must skip past the + download-only resources and query the CSV one.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + { + "id": "dataset-parks", + "title": "Park_Features", + "resources": [ + { + "id": "0826fc19-4ff8-44a5-b9c4-916960d8cfb3", + "format": "GeoJSON", + "datastore_active": False, + }, + { + "id": "4d28fc98-c503-4065-987f-9fbc41947fc4", + "format": "CSV", + "datastore_active": True, + }, + { + "id": "5f130274-b67e-44e6-9c72-4175a2dca339", + "format": "SHP", + "datastore_active": False, + }, + ], + } + ] + } + } + mock_response_search.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": { + "records": [ + {"_id": 1, "park_name": "Boston Common"}, + {"_id": 2, "park_name": "Franklin Park"}, + ] + } + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_search, + mock_response_query, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", {"query": "parks"} + ) + + assert result.success is True + text = result.content[0]["text"] + # Picked the CSV resource, not the GeoJSON one + assert "4d28fc98-c503-4065-987f-9fbc41947fc4" in text + assert "0826fc19-4ff8-44a5-b9c4-916960d8cfb3" not in text + assert "Boston Common" in text or "Franklin Park" in text + # And the third call's body asked for the CSV resource + third_call = mock_client.post.call_args_list[2] + assert third_call[1]["json"]["resource_id"] == ( + "4d28fc98-c503-4065-987f-9fbc41947fc4" + ) + + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_walks_to_next_dataset( + self, ckan_config + ): + """If the best-match dataset has no datastore_active resource, the + composite tool falls through to the next dataset.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + { + "id": "dataset-no-datastore", + "title": "PDFs only", + "resources": [ + { + "id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "format": "PDF", + "datastore_active": False, + } + ], + }, + { + "id": "dataset-with-csv", + "title": "Has CSV", + "resources": [ + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "format": "CSV", + "datastore_active": True, + } + ], + }, + ] + } + } + mock_response_search.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": {"records": [{"_id": 1, "x": 1}]} + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_search, + mock_response_query, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", {"query": "anything"} + ) + + assert result.success is True + text = result.content[0]["text"] + assert "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" in text + assert "Has CSV" in text + + @pytest.mark.asyncio + async def test_query_data_404_includes_datastore_active_hint(self, ckan_config): + """A 404 from query_data should append the datastore_active hint.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_404 = Mock() + mock_response_404.status_code = 404 + mock_response_404.json.return_value = { + "success": False, + "error": {"message": "Resource not found"}, + } + mock_response_404.raise_for_status = Mock( + side_effect=httpx.HTTPStatusError( + "Not Found", + request=Mock(), + response=mock_response_404, + ) + ) + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_404] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + {"resource_id": "0826fc19-4ff8-44a5-b9c4-916960d8cfb3"}, + ) + + assert result.success is False + assert "datastore_active" in (result.error_message or "") @pytest.mark.asyncio async def test_execute_tool_unknown_tool(self, ckan_config): From 668e24f25cc18f67f2119e6943d7160c32920a16 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 11:22:54 -0800 Subject: [PATCH 07/12] CKAN: alternates list in search_and_query also surfaces queryable resource_id --- plugins/ckan/plugin.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index fc57777..659b7a3 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -1290,19 +1290,25 @@ def _format_search_and_query( lines.append( "Other matching datasets (pass dataset_index=N to switch):" ) + chosen_dataset_id = dataset.get("id") for i, alt in enumerate(alternates): - if i == 0: - continue # the chosen one + if alt.get("id") == chosen_dataset_id: + continue # skip the dataset we already returned rows for alt_title = alt.get("title", "Untitled") alt_id = alt.get("id", "unknown") - alt_resources = alt.get("resources") or [] - alt_resource_id = ( - alt_resources[0].get("id") if alt_resources else None - ) + alt_queryable = self._first_queryable_resource(alt) lines.append(f" [dataset_index={i}] {alt_title}") lines.append(f" dataset_id: {alt_id}") - if alt_resource_id: - lines.append(f" resource_id: {alt_resource_id}") + if alt_queryable: + lines.append( + f" resource_id (queryable): " + f"{alt_queryable.get('id')}" + ) + else: + lines.append( + " (no datastore-loaded resource — " + "download-only)" + ) return "\n".join(lines) From 5825317e335db449b1d392652b0e100fe64bff32 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 11:50:17 -0800 Subject: [PATCH 08/12] CKAN: add structured `where`, schema footer, sharper descriptions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Real failure observed: model asked for "311 service requests closed on 4/29", called ckan__search_and_query without filters, and got 10 unfiltered rows (almost none closed on 4/29; 85 actually were). Two design problems: 1) `query` matches dataset metadata, not row content — but the model read it as a row filter. 2) `filters` is equality-only against datastore_search, so it can't express "close_date in [4/29, 4/30)" even if used. Fixes (A + B + C from the proposal): A) Tool descriptions explicitly distinguish dataset-metadata search from row filtering and steer to the right knob (filters / where / execute_sql) with concrete examples. B) Every successful query_data / search_and_query response now ends with a "Filterable columns" footer listing field names + types, pulled from the datastore response. The next pivot to where / execute_sql becomes a one-shot. C) New structured `where` argument on query_data and search_and_query. Routes through datastore_search_sql via SafeSQLBuilder; supports eq, ne, gt, gte, lt, lte, in, not_in, like, ilike, is_null. Strings are length-capped (256) and quote-escaped; IN-lists capped (100); identifiers validated as before. Live verification: search_and_query("311", where={close_date: {gte: "2026-04-29", lt: "2026-04-30"}, case_status: "Closed"}) now returns exactly 85 rows — matching ground truth from a manual SQL count. Tests: +26 new (170 -> 196). Coverage: - where-clause builder for every operator + edge case (injection escaping, bad ops/types, oversized lists/strings, unknown operators) - query_data routes through datastore_search_sql when `where` set - query_data validation errors short-circuit before any API call - schema footer surfaces in both SQL and non-SQL paths --- plugins/ckan/plugin.py | 305 +++++++++++++++++++++++++++++----- plugins/ckan/sql_validator.py | 137 +++++++++++++++ tests/test_ckan_plugin.py | 137 +++++++++++++++ tests/test_sql_validator.py | 122 ++++++++++++++ 4 files changed, 662 insertions(+), 39 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index 659b7a3..721824c 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -4,7 +4,7 @@ """ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import httpx from tenacity import ( @@ -202,7 +202,13 @@ def get_tools(self) -> List[ToolDefinition]: "properties": { "query": { "type": "string", - "description": "Free-text search keywords (e.g. '311 service requests', 'building permits').", + "description": ( + "Free-text keywords matched against " + "DATASET METADATA (title/tags/desc), NOT " + "row content. Use the row-returning " + "tools' `where` argument to filter ROWS. " + "Examples: '311', 'parks'." + ), }, "limit": { "type": "integer", @@ -259,6 +265,24 @@ def get_tools(self) -> List[ToolDefinition]: "`ckan__search_datasets` and `ckan__get_dataset` " "labels resources as QUERYABLE or DOWNLOAD-ONLY — pick " "the QUERYABLE one.\n\n" + "FILTERING — pick the right knob:\n" + " - `filters` is EQUALITY ONLY (case_status='Closed'). " + "Cannot do dates, ranges, BETWEEN, IN, LIKE, or any " + "comparison. A timestamp column will NEVER match an " + "equality filter on a date string like '2026-04-29'.\n" + " - `where` is structured comparison. Use this for " + "date ranges, numeric bounds, IN-lists, LIKE/ILIKE, " + "or NULL checks. Example for 'closed on 2026-04-29':\n" + " where = {\"close_date\": {\"gte\": " + "\"2026-04-29\", \"lt\": \"2026-04-30\"}, " + "\"case_status\": \"Closed\"}\n" + " - For window functions, CTEs, joins, or anything " + "the structured `where` can't express, use " + "`ckan__execute_sql` instead.\n\n" + "Note: `query` arguments on search tools match dataset " + "metadata (titles/tags), NOT row content. To filter " + "ROWS by date/status/etc., use `where` here or in " + "`ckan__search_and_query`.\n\n" "Tip: if you only have a keyword and no resource_id " "yet, use `ckan__search_and_query` instead — it does " "the lookup and the data fetch in a single call and " @@ -280,7 +304,30 @@ def get_tools(self) -> List[ToolDefinition]: }, "filters": { "type": "object", - "description": "Optional filters as field:value pairs (e.g. {\"status\": \"Open\"}).", + "description": ( + "EQUALITY-ONLY filters as field:value " + "pairs (e.g. {\"status\": \"Open\"}). For " + "ranges/dates/IN/LIKE, use `where` " + "instead — `filters` cannot express " + "anything other than exact equality." + ), + }, + "where": { + "type": "object", + "description": ( + "Structured WHERE clause supporting " + "comparison operators. Each entry is " + "either {field: scalar} (equality) or " + "{field: {op: value, ...}} where op is " + "one of: eq, ne, gt, gte, lt, lte, in, " + "not_in, like, ilike, is_null. Example " + "for 'closed on 2026-04-29': " + "{\"close_date\": {\"gte\": " + "\"2026-04-29\", \"lt\": \"2026-04-30\"}, " + "\"case_status\": \"Closed\"}. The " + "schema footer in any prior query result " + "lists available column names and types." + ), }, "limit": { "type": "integer", @@ -325,19 +372,33 @@ def get_tools(self) -> List[ToolDefinition]: description=( f"Execute a raw PostgreSQL SELECT query against " f"{city}'s CKAN datastore.\n\n" - "⚠️ Advanced users only. For complex queries requiring " - "full SQL.\n\n" + "⚠️ Use this only when the structured `where` " + "argument on `ckan__query_data` / " + "`ckan__search_and_query` cannot express your filter " + "(e.g. window functions, CTEs, joins, aggregations " + "beyond ckan__aggregate_data).\n\n" "Security: Only SELECT allowed. INSERT/UPDATE/DELETE " "blocked.\n\n" - "Examples:\n" + "Concrete examples:\n" + "- Closed on a specific date:\n" + " SELECT * FROM \"\" WHERE " + "close_date >= '2026-04-29' AND close_date < " + "'2026-04-30' AND case_status = 'Closed' LIMIT 100\n" + "- Counts by day:\n" + " SELECT date_trunc('day', close_date) AS d, " + "COUNT(*) FROM \"\" GROUP BY d " + "ORDER BY d DESC LIMIT 30\n" "- Window functions: RANK() OVER (...)\n" - "- CTEs: WITH subquery AS (...)\n" - "- Complex aggregations: PERCENTILE_CONT(0.5) WITHIN GROUP\n\n" - "Resource IDs must be double-quoted: FROM \"uuid-here\"\n\n" + "- CTEs: WITH subquery AS (...)\n\n" + "Resource IDs must be double-quoted: " + "FROM \"uuid-here\"\n\n" "Prerequisites:\n" " - resource UUID for the FROM clause: get from " "`ckan__search_datasets` or `ckan__get_dataset`.\n" - " - field names: get from `ckan__get_schema`." + " - field names: get from `ckan__get_schema`, or " + "the 'Filterable columns' footer of any prior " + "successful `ckan__query_data` / " + "`ckan__search_and_query` call." ), input_schema={ "type": "object", @@ -418,17 +479,34 @@ def get_tools(self) -> List[ToolDefinition]: "search results and skips datasets / resources that " "aren't datastore-active, so you don't get a 404 " "from a download-only resource.\n\n" + "WHAT `query` MEANS: `query` matches dataset metadata " + "(title, tags, description) — it does NOT filter ROWS. " + "If the user asks for '311 requests closed on 4/29', " + "the `query` finds the 311 dataset and `where` does " + "the row filtering:\n" + " query=\"311\", where={\"close_date\": {\"gte\": " + "\"2026-04-29\", \"lt\": \"2026-04-30\"}, " + "\"case_status\": \"Closed\"}\n\n" "Returns: data rows from the chosen dataset's chosen " - "resource, plus a header showing which dataset and " - "resource were used so you can drill deeper with " - "`ckan__query_data` or `ckan__get_dataset` if needed." + "resource, plus a 'Filterable columns' footer listing " + "the schema (so you can refine with `where` or pivot " + "to `ckan__execute_sql` for joins/CTEs/window funcs), " + "plus a header showing which dataset and resource " + "were used." ), input_schema={ "type": "object", "properties": { "query": { "type": "string", - "description": "Free-text search keywords (e.g. '311 service requests').", + "description": ( + "Free-text keywords matched against " + "DATASET METADATA (title/tags/desc), NOT " + "row content. Use `where` (or " + "`ckan__execute_sql`) to filter ROWS by " + "date/status/etc. Examples: '311', " + "'parks', 'building permits'." + ), }, "limit": { "type": "integer", @@ -437,7 +515,26 @@ def get_tools(self) -> List[ToolDefinition]: }, "filters": { "type": "object", - "description": "Optional row-level filters as field:value pairs, applied to the matched resource.", + "description": ( + "EQUALITY-ONLY row filters (e.g. " + "{\"case_status\": \"Closed\"}). For " + "ranges/dates/IN/LIKE, use `where` " + "instead." + ), + }, + "where": { + "type": "object", + "description": ( + "Structured WHERE clause for ranges, " + "dates, IN, LIKE, NULL checks. Each " + "entry is {field: scalar} (equality) or " + "{field: {op: value, ...}} where op is " + "one of: eq, ne, gt, gte, lt, lte, in, " + "not_in, like, ilike, is_null. The " + "right knob for 'closed on 2026-04-29': " + "{\"close_date\": {\"gte\": " + "\"2026-04-29\", \"lt\": \"2026-04-30\"}}." + ), }, "dataset_index": { "type": "integer", @@ -518,14 +615,28 @@ async def execute_tool( success=False, error_message="resource_id is required", ) - filters = arguments.get("filters", {}) + filters = arguments.get("filters") or {} + where = arguments.get("where") or None limit = arguments.get("limit", 100) - data = await self.query_data(resource_id, filters, limit) + records, fields, error = await self._query_with_schema( + resource_id=resource_id, + filters=filters, + limit=limit, + where=where, + ) + if error: + return ToolResult( + content=[], + success=False, + error_message=error, + ) return ToolResult( content=[ { "type": "text", - "text": self._format_query_results(data, limit), + "text": self._format_query_results( + records, fields, limit + ), } ], success=True, @@ -584,12 +695,14 @@ async def execute_tool( ) limit = arguments.get("limit", 100) filters = arguments.get("filters") or {} + where = arguments.get("where") or None dataset_index = arguments.get("dataset_index") resource_index = arguments.get("resource_index") composite = await self.search_and_query( query=query, limit=limit, filters=filters, + where=where, dataset_index=dataset_index, resource_index=resource_index, ) @@ -697,17 +810,81 @@ async def query_data( resource_id: str, filters: Optional[Dict[str, Any]] = None, limit: int = 100, + where: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Query data from a specific resource. Args: resource_id: Resource ID - filters: Optional filters (field: value pairs) + filters: Equality-only filters (field: value) passed to + CKAN's datastore_search limit: Maximum number of records + where: Structured WHERE spec supporting comparison operators + (gt/gte/lt/lte/in/not_in/like/ilike/is_null). When set, + routes through datastore_search_sql for a real WHERE clause. Returns: - List of data records + List of data records (the schema-aware variant is + ``_query_with_schema``). + """ + records, _fields, error = await self._query_with_schema( + resource_id=resource_id, + filters=filters, + limit=limit, + where=where, + ) + if error: + raise RuntimeError(error) + return records + + async def _query_with_schema( + self, + resource_id: str, + filters: Optional[Dict[str, Any]] = None, + limit: int = 100, + where: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Optional[str]]: + """Query datastore and return (records, fields, error_message). + + Routes through datastore_search_sql when ``where`` is set so the + caller can express ranges/IN/LIKE; otherwise falls back to the + cheaper datastore_search equality path. """ + if where: + try: + validated_id = SafeSQLBuilder.validate_resource_id(resource_id) + where_sql = SafeSQLBuilder.build_where_clause(where) + limit_int = SafeSQLBuilder.clamp_limit(limit) + except ValueError as e: + return [], [], str(e) + + sql_parts = [f'SELECT * FROM "{validated_id}"'] + if where_sql: + sql_parts.append(f" WHERE {where_sql}") + if filters: + # Equality filters can ride alongside `where` clauses. + try: + eq_conds = [ + SafeSQLBuilder.build_filter_condition(f, v) + for f, v in filters.items() + ] + except ValueError as e: + return [], [], str(e) + joiner = " AND " if where_sql else " WHERE " + sql_parts.append(joiner + " AND ".join(eq_conds)) + sql_parts.append(f" LIMIT {limit_int}") + sql = "".join(sql_parts) + + result = await self.execute_sql(sql) + if result.get("error"): + return [], [], result.get("message", "SQL execution failed") + return ( + result.get("records", []), + result.get("fields", []), + None, + ) + + # No `where` → cheap datastore_search path. params: Dict[str, Any] = {"resource_id": resource_id, "limit": limit} if filters: params["filters"] = filters @@ -716,12 +893,10 @@ async def query_data( response = await self._call_ckan_api("datastore_search", params) except RuntimeError as e: msg = str(e) - # 404 from datastore_search almost always means the resource UUID - # is real but isn't loaded into the queryable Postgres datastore - # (datastore_active=false). Surface that explicitly so callers - # don't keep retrying with the same UUID. if "404" in msg or "not found" in msg.lower(): - raise RuntimeError( + return ( + [], + [], f"{msg}\n" "Hint: this resource may exist as a file download " "(GeoJSON/KML/SHP/PDF) but not be loaded into the " @@ -729,10 +904,16 @@ async def query_data( "ckan__get_dataset on the parent dataset to find a " "QUERYABLE resource (typically the CSV one), or use " "ckan__search_and_query, which auto-picks the " - "datastore-loaded resource." - ) from e - raise - return response.get("result", {}).get("records", []) + "datastore-loaded resource.", + ) + return [], [], msg + + result = response.get("result", {}) + return ( + result.get("records", []), + result.get("fields", []), + None, + ) async def get_schema(self, resource_id: str) -> Dict[str, Any]: """Get schema information for a resource. @@ -869,6 +1050,7 @@ async def search_and_query( query: str, limit: int = 100, filters: Optional[Dict[str, Any]] = None, + where: Optional[Dict[str, Any]] = None, dataset_index: Optional[int] = None, resource_index: Optional[int] = None, ) -> Dict[str, Any]: @@ -998,18 +1180,18 @@ async def search_and_query( ), } - try: - records = await self.query_data( - resource_id=resource_id, - filters=filters or None, - limit=limit, - ) - except Exception as e: + records, fields, error = await self._query_with_schema( + resource_id=resource_id, + filters=filters or None, + where=where, + limit=limit, + ) + if error: return { "error": True, "message": ( f"Found dataset {chosen_dataset.get('id')!r} resource " - f"{resource_id!r} but query_data failed: {e}" + f"{resource_id!r} but query_data failed: {error}" ), } @@ -1017,6 +1199,7 @@ async def search_and_query( "dataset": chosen_dataset, "resource": chosen_resource, "records": records, + "fields": fields, "alternate_datasets": datasets, } @@ -1206,10 +1389,17 @@ def _format_dataset(self, dataset: Dict[str, Any]) -> str: return "\n".join(lines) - def _format_query_results(self, records: List[Dict[str, Any]], limit: int) -> str: + def _format_query_results( + self, + records: List[Dict[str, Any]], + fields: Optional[List[Dict[str, Any]]] = None, + limit: int = 100, + ) -> str: """Format query results for user display.""" if not records: - return "No records found matching the query." + text = "No records found matching the query." + schema_footer = self._format_schema_footer(fields) + return f"{text}\n\n{schema_footer}" if schema_footer else text lines = [f"Found {len(records)} record(s) (showing up to {limit}):\n"] @@ -1224,6 +1414,37 @@ def _format_query_results(self, records: List[Dict[str, Any]], limit: int) -> st if len(records) > 5: lines.append(f"... and {len(records) - 5} more record(s)") + schema_footer = self._format_schema_footer(fields) + if schema_footer: + lines.append("") + lines.append(schema_footer) + + return "\n".join(lines) + + def _format_schema_footer( + self, fields: Optional[List[Dict[str, Any]]] + ) -> str: + """Render a per-call 'Filterable columns' block listing every field + the model can pass to ``where``, ``filters``, or reference in + ``execute_sql``. + + We surface this on every successful row-returning call so the next + pivot (e.g. 'now filter by close_date') is a one-shot.""" + if not fields: + return "" + usable = [ + f for f in fields if f.get("id") and f.get("id") != "_id" + ] + if not usable: + return "" + lines = [ + "Filterable columns (use these names in `where`, `filters`, " + "or `execute_sql`):" + ] + for f in usable: + fid = f.get("id") + ftype = f.get("type", "?") + lines.append(f" - {fid} ({ftype})") return "\n".join(lines) def _format_schema(self, fields: List[Dict[str, Any]]) -> str: @@ -1251,6 +1472,7 @@ def _format_search_and_query( dataset = composite.get("dataset", {}) or {} resource = composite.get("resource", {}) or {} records = composite.get("records", []) or [] + fields = composite.get("fields", []) or [] alternates = composite.get("alternate_datasets", []) or [] dataset_id = dataset.get("id", "unknown") @@ -1285,6 +1507,11 @@ def _format_search_and_query( if len(records) > 5: lines.append(f"... and {len(records) - 5} more record(s)") + schema_footer = self._format_schema_footer(fields) + if schema_footer: + lines.append("") + lines.append(schema_footer) + if len(alternates) > 1: lines.append("") lines.append( diff --git a/plugins/ckan/sql_validator.py b/plugins/ckan/sql_validator.py index ae73524..99a8114 100644 --- a/plugins/ckan/sql_validator.py +++ b/plugins/ckan/sql_validator.py @@ -368,6 +368,143 @@ def build_filter_condition(field: Any, value: Any) -> str: f"{type(value).__name__}" ) + # Operator → SQL fragment for build_where_clause. Restricted to a known- + # safe set; arbitrary operators are rejected. + _COMPARISON_OPS = { + "eq": "=", + "ne": "!=", + "gt": ">", + "gte": ">=", + "lt": "<", + "lte": "<=", + } + _MAX_STRING_VALUE_LEN = 256 + _MAX_IN_LIST_LEN = 100 + ALLOWED_WHERE_OPS = ( + "eq", + "ne", + "gt", + "gte", + "lt", + "lte", + "in", + "not_in", + "like", + "ilike", + "is_null", + ) + + @classmethod + def _format_scalar(cls, value: Any, op_label: str) -> str: + """Format a scalar as a SQL literal. Strings are single-quoted with + embedded quotes escaped; numbers and bools are inlined.""" + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + if len(value) > cls._MAX_STRING_VALUE_LEN: + raise ValueError( + f"value too long for {op_label!r}: " + f"{len(value)} > {cls._MAX_STRING_VALUE_LEN}" + ) + return "'" + value.replace("'", "''") + "'" + raise ValueError( + f"unsupported value type for {op_label!r}: " + f"{type(value).__name__}" + ) + + @classmethod + def build_where_clause(cls, where: Any) -> str: + """Build a parameter-validated SQL WHERE fragment from a structured + spec. + + Accepted shapes per field: + - ``"col": `` — equality (or ``IS NULL`` if ``None``). + - ``"col": {"op": value, ...}`` — one or more comparison clauses + ANDed together. Allowed ops: + + ``eq``, ``ne``, ``gt``, ``gte``, ``lt``, ``lte`` — scalar value. + ``in``, ``not_in`` — list of scalars. + ``like``, ``ilike`` — string pattern. + ``is_null`` — bool. + + Returns the WHERE fragment WITHOUT the leading ``WHERE`` (or ``""`` + if ``where`` is empty/None). Raises ``ValueError`` on any disallowed + operator, identifier, or value. + """ + if where in (None, {}): + return "" + if not isinstance(where, dict): + raise ValueError( + f"where must be a dict, got: {type(where).__name__}" + ) + + parts: List[str] = [] + for field, spec in where.items(): + quoted = SafeSQLBuilder.quote_identifier(field) + if not isinstance(spec, dict): + # Scalar shorthand → equality / IS NULL. + parts.append(SafeSQLBuilder.build_filter_condition(field, spec)) + continue + if not spec: + raise ValueError( + f"empty operator dict for field {field!r}" + ) + for op, val in spec.items(): + if not isinstance(op, str): + raise ValueError( + f"operator must be a string for {field!r}: {op!r}" + ) + op_lower = op.lower() + if op_lower in cls._COMPARISON_OPS: + sql_op = cls._COMPARISON_OPS[op_lower] + parts.append( + f"{quoted} {sql_op} {cls._format_scalar(val, op_lower)}" + ) + elif op_lower in ("in", "not_in"): + if not isinstance(val, list) or not val: + raise ValueError( + f"{op_lower!r} requires a non-empty list for " + f"{field!r}, got: {val!r}" + ) + if len(val) > cls._MAX_IN_LIST_LEN: + raise ValueError( + f"{op_lower!r} list too long for {field!r}: " + f"{len(val)} > {cls._MAX_IN_LIST_LEN}" + ) + items = ", ".join( + cls._format_scalar(v, op_lower) for v in val + ) + sql_kw = "IN" if op_lower == "in" else "NOT IN" + parts.append(f"{quoted} {sql_kw} ({items})") + elif op_lower in ("like", "ilike"): + if not isinstance(val, str): + raise ValueError( + f"{op_lower!r} requires a string pattern for " + f"{field!r}, got: {val!r}" + ) + parts.append( + f"{quoted} {op_lower.upper()} " + f"{cls._format_scalar(val, op_lower)}" + ) + elif op_lower == "is_null": + if not isinstance(val, bool): + raise ValueError( + f"'is_null' requires a bool for {field!r}, " + f"got: {val!r}" + ) + parts.append( + f"{quoted} {'IS NULL' if val else 'IS NOT NULL'}" + ) + else: + raise ValueError( + f"Unknown operator {op!r} for field {field!r}. " + f"Allowed: {', '.join(cls.ALLOWED_WHERE_OPS)}" + ) + + return " AND ".join(parts) + @staticmethod def validate_order_by(order_by: Any) -> str: """Validate an ``ORDER BY`` clause: `` [ASC|DESC]``.""" diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index fd27df1..2e1e773 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -817,6 +817,143 @@ async def test_execute_tool_search_and_query_walks_to_next_dataset( assert "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" in text assert "Has CSV" in text + @pytest.mark.asyncio + async def test_execute_tool_query_data_with_where_uses_sql_endpoint( + self, ckan_config + ): + """When `where` is supplied, query_data must route through + datastore_search_sql with a built WHERE clause — not through + datastore_search (equality-only).""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_sql = Mock() + mock_response_sql.json.return_value = { + "result": { + "records": [ + {"_id": 1, "case_id": "BCS-1", "case_status": "Closed"} + ], + "fields": [ + {"id": "case_id", "type": "text"}, + {"id": "close_date", "type": "timestamp"}, + {"id": "case_status", "type": "text"}, + ], + } + } + mock_response_sql.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_sql] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "where": { + "close_date": { + "gte": "2026-04-29", + "lt": "2026-04-30", + }, + "case_status": "Closed", + }, + "limit": 5, + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "BCS-1" in text + # Schema footer surfaces filterable columns + assert "Filterable columns" in text + assert "close_date" in text + # Verify the second POST hit datastore_search_sql with a SQL + # body containing the expected WHERE clause. + second_call = mock_client.post.call_args_list[1] + assert second_call[0][0] == "/api/3/action/datastore_search_sql" + sql = second_call[1]["json"]["sql"] + assert ( + 'FROM "11111111-2222-3333-4444-555555555555"' in sql + ) + assert '"close_date" >= \'2026-04-29\'' in sql + assert '"close_date" < \'2026-04-30\'' in sql + assert '"case_status" = \'Closed\'' in sql + assert "LIMIT 5" in sql + + @pytest.mark.asyncio + async def test_execute_tool_query_data_where_validation_error_surfaces( + self, ckan_config + ): + """A bad `where` operator returns a clean error — no API call.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_client.post = AsyncMock(return_value=mock_response_init) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "where": {"col": {"regex": "."}}, + }, + ) + + assert result.success is False + assert "Unknown operator" in (result.error_message or "") + # Only the init POST should have happened — no SQL call. + assert mock_client.post.call_count == 1 + + @pytest.mark.asyncio + async def test_execute_tool_query_data_schema_footer_in_normal_path( + self, ckan_config + ): + """The non-SQL (no `where`) path also returns the schema footer.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": { + "records": [{"_id": 1, "x": "y"}], + "fields": [ + {"id": "x", "type": "text"}, + {"id": "z", "type": "int"}, + ], + } + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_query] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + {"resource_id": "11111111-2222-3333-4444-555555555555"}, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "Filterable columns" in text + assert "x (text)" in text + assert "z (int)" in text + @pytest.mark.asyncio async def test_query_data_404_includes_datastore_active_hint(self, ckan_config): """A 404 from query_data should append the datastore_active hint.""" diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index 4dfb980..42eedf9 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -653,6 +653,128 @@ def test_unsupported_value_type_rejected(self): SafeSQLBuilder.build_filter_condition("name", ["list"]) +class TestSafeSQLBuilderWhereClause: + def test_empty_returns_empty_string(self): + assert SafeSQLBuilder.build_where_clause(None) == "" + assert SafeSQLBuilder.build_where_clause({}) == "" + + def test_scalar_shorthand_equality(self): + got = SafeSQLBuilder.build_where_clause({"status": "Closed"}) + assert got == "\"status\" = 'Closed'" + + def test_scalar_none_is_null(self): + got = SafeSQLBuilder.build_where_clause({"close_date": None}) + assert got == '"close_date" IS NULL' + + def test_date_range(self): + got = SafeSQLBuilder.build_where_clause( + {"close_date": {"gte": "2026-04-29", "lt": "2026-04-30"}} + ) + assert got == ( + "\"close_date\" >= '2026-04-29' AND " + "\"close_date\" < '2026-04-30'" + ) + + def test_mixed_scalar_and_range(self): + got = SafeSQLBuilder.build_where_clause( + { + "close_date": {"gte": "2026-04-29", "lt": "2026-04-30"}, + "case_status": "Closed", + } + ) + assert "\"case_status\" = 'Closed'" in got + assert "\"close_date\" >= '2026-04-29'" in got + assert "\"close_date\" < '2026-04-30'" in got + assert " AND " in got + + def test_numeric_comparisons(self): + got = SafeSQLBuilder.build_where_clause({"count": {"gt": 5, "lte": 10}}) + assert got == '"count" > 5 AND "count" <= 10' + + def test_in_list_strings(self): + got = SafeSQLBuilder.build_where_clause( + {"neighborhood": {"in": ["Roxbury", "Dorchester"]}} + ) + assert got == "\"neighborhood\" IN ('Roxbury', 'Dorchester')" + + def test_not_in_list(self): + got = SafeSQLBuilder.build_where_clause( + {"status": {"not_in": ["Open", "Pending"]}} + ) + assert got == "\"status\" NOT IN ('Open', 'Pending')" + + def test_like_escaped(self): + got = SafeSQLBuilder.build_where_clause( + {"address": {"like": "%Beacon%"}} + ) + assert got == "\"address\" LIKE '%Beacon%'" + + def test_ilike(self): + got = SafeSQLBuilder.build_where_clause( + {"name": {"ilike": "boston%"}} + ) + assert got == "\"name\" ILIKE 'boston%'" + + def test_is_null_true(self): + got = SafeSQLBuilder.build_where_clause({"close_date": {"is_null": True}}) + assert got == '"close_date" IS NULL' + + def test_is_null_false(self): + got = SafeSQLBuilder.build_where_clause( + {"close_date": {"is_null": False}} + ) + assert got == '"close_date" IS NOT NULL' + + def test_quote_injection_in_string_value_escaped(self): + got = SafeSQLBuilder.build_where_clause( + {"name": {"eq": "x' OR 1=1--"}} + ) + assert got == "\"name\" = 'x'' OR 1=1--'" + + def test_quote_injection_in_in_list_escaped(self): + got = SafeSQLBuilder.build_where_clause( + {"name": {"in": ["a", "b' OR 1=1--"]}} + ) + assert got == "\"name\" IN ('a', 'b'' OR 1=1--')" + + def test_unknown_operator_rejected(self): + with pytest.raises(ValueError, match="Unknown operator"): + SafeSQLBuilder.build_where_clause({"x": {"regex": "."}}) + + def test_bad_field_rejected(self): + with pytest.raises(ValueError): + SafeSQLBuilder.build_where_clause({"x; DROP TABLE": 1}) + + def test_in_requires_list(self): + with pytest.raises(ValueError, match="non-empty list"): + SafeSQLBuilder.build_where_clause({"x": {"in": "single"}}) + + def test_in_rejects_empty_list(self): + with pytest.raises(ValueError, match="non-empty list"): + SafeSQLBuilder.build_where_clause({"x": {"in": []}}) + + def test_in_rejects_oversized_list(self): + with pytest.raises(ValueError, match="too long"): + SafeSQLBuilder.build_where_clause({"x": {"in": list(range(101))}}) + + def test_like_rejects_non_string(self): + with pytest.raises(ValueError, match="string pattern"): + SafeSQLBuilder.build_where_clause({"x": {"like": 5}}) + + def test_is_null_rejects_non_bool(self): + with pytest.raises(ValueError, match="bool"): + SafeSQLBuilder.build_where_clause({"x": {"is_null": "yes"}}) + + def test_overlong_string_rejected(self): + big = "a" * 1000 + with pytest.raises(ValueError, match="too long"): + SafeSQLBuilder.build_where_clause({"x": {"eq": big}}) + + def test_non_dict_top_level_rejected(self): + with pytest.raises(ValueError, match="must be a dict"): + SafeSQLBuilder.build_where_clause(["x"]) + + class TestSafeSQLBuilderOrderAndLimit: def test_order_by_plain(self): assert SafeSQLBuilder.validate_order_by("date") == '"date"' From c1608e211c891305cedcbed665ab366f456bbf58 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 12:24:03 -0800 Subject: [PATCH 09/12] CKAN: surface sibling resources + add resource_name selector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Boston's 311 dataset attaches 22 queryable CSV resources (a rolling 'NEW SYSTEM' view plus per-year archives 2011–2026). The model only ever saw whichever the auto-pick chose first — typically the rolling view — which silently dropped older data on the floor for historical questions. Two improvements: (1) `_format_search_and_query` now emits an "Other queryable resources in this dataset" block listing every datastore_active resource other than the chosen one (name, format, resource_id), so the model can see that '311 - 2020', '311 - 2019', etc. exist. (2) New `resource_name` argument on search_and_query: case-insensitive substring match against resource `name`. Takes precedence over `resource_index`. With dataset_index pinned, a no-match returns a clean error listing the available names. Selection precedence (high → low): resource_name > resource_index > first datastore_active resource Tool description updated with concrete example (resource_name='2020' picks the 2020 archive). The 'Other queryable resources' block tells the model these names exist; the description tells it how to use them. Live verification: search_and_query("311", resource_name="2020") now queries 311 SERVICE REQUESTS - 2020 (resource 6ff6a6fd-...) and returns 2020 records. Default no-resource_name query still picks the rolling NEW SYSTEM and surfaces all 21 archive siblings in the response. Tests: +3 (38 -> 41) covering resource_name match, no-match error, and the siblings block (queryable-only, excludes download-only resources). --- plugins/ckan/plugin.py | 151 ++++++++++++++++++++++++-- tests/test_ckan_plugin.py | 216 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+), 8 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index 721824c..0d36206 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -491,8 +491,20 @@ def get_tools(self) -> List[ToolDefinition]: "resource, plus a 'Filterable columns' footer listing " "the schema (so you can refine with `where` or pivot " "to `ckan__execute_sql` for joins/CTEs/window funcs), " - "plus a header showing which dataset and resource " - "were used." + "an 'Other queryable resources in this dataset' " + "block listing siblings (e.g. per-year archives), " + "and a header showing which dataset and resource " + "were used.\n\n" + "MULTI-RESOURCE DATASETS: a single dataset can hold " + "many queryable resources. Boston's 311 dataset has " + "22 (a rolling 'NEW SYSTEM' view plus per-year " + "archives 2011–2026). Use `resource_name` to pick a " + "specific one — e.g. resource_name=\"2020\" picks " + "'311 Service Requests - 2020'. If you don't pass " + "`resource_name`, the first datastore-loaded " + "resource is used (which is typically the rolling " + "current view, NOT historical archives — so older " + "questions need `resource_name`)." ), input_schema={ "type": "object", @@ -552,7 +564,28 @@ def get_tools(self) -> List[ToolDefinition]: "to query. If omitted, auto-picks the " "first datastore_active resource (Boston " "datasets typically attach 5–7 resources " - "but only the CSV is queryable)." + "but only the CSV is queryable). " + "`resource_name` takes precedence." + ), + }, + "resource_name": { + "type": "string", + "description": ( + "Case-insensitive substring match on a " + "resource's `name`. Use this to pick a " + "specific archive when a dataset has " + "multiple queryable resources (e.g. " + "Boston's 311 dataset has per-year " + "archives '311 Service Requests - 2020', " + "'... - 2021', etc., plus a rolling " + "'NEW SYSTEM'). Examples: " + "resource_name=\"2020\" picks the 2020 " + "archive; resource_name=\"NEW SYSTEM\" " + "picks the rolling current view. The " + "alternates list in any prior " + "search_and_query response shows " + "available names. Takes precedence over " + "`resource_index`." ), }, }, @@ -698,6 +731,7 @@ async def execute_tool( where = arguments.get("where") or None dataset_index = arguments.get("dataset_index") resource_index = arguments.get("resource_index") + resource_name = arguments.get("resource_name") composite = await self.search_and_query( query=query, limit=limit, @@ -705,6 +739,7 @@ async def execute_tool( where=where, dataset_index=dataset_index, resource_index=resource_index, + resource_name=resource_name, ) if composite.get("error"): return ToolResult( @@ -1045,6 +1080,34 @@ async def aggregate_data( return await self.execute_sql(sql) + @staticmethod + def _queryable_resources(dataset: Dict[str, Any]) -> List[Dict[str, Any]]: + """All datastore_active resources in a dataset, in package_show order.""" + return [r for r in (dataset.get("resources") or []) if r.get("datastore_active")] + + @classmethod + def _resource_by_name( + cls, + dataset: Dict[str, Any], + name_query: str, + queryable_only: bool = True, + ) -> Optional[Dict[str, Any]]: + """Pick the first resource whose `name` contains `name_query` + (case-insensitive substring).""" + if not name_query: + return None + needle = name_query.casefold() + candidates = ( + cls._queryable_resources(dataset) + if queryable_only + else (dataset.get("resources") or []) + ) + for r in candidates: + res_name = (r.get("name") or "").casefold() + if needle in res_name: + return r + return None + async def search_and_query( self, query: str, @@ -1053,16 +1116,18 @@ async def search_and_query( where: Optional[Dict[str, Any]] = None, dataset_index: Optional[int] = None, resource_index: Optional[int] = None, + resource_name: Optional[str] = None, ) -> Dict[str, Any]: - """Search for a dataset and immediately query its first queryable resource. + """Search for a dataset and immediately query a queryable resource. Combines search_datasets + query_data into one server-side step so callers don't have to extract a resource_id from a previous response. - Picks the first resource where datastore_active=true (skipping - download-only resources like GeoJSON/KML/SHP/PDF). When the caller - leaves dataset_index/resource_index unset, walks the search results - until one with a queryable resource is found. + Resource selection precedence (highest to lowest): + 1. ``resource_name`` (substring match on resource ``name``). + 2. ``resource_index`` (explicit position in the dataset's + resources array). + 3. First ``datastore_active`` resource in the dataset. Returns: Dict with either {"error": True, "message": ...} or @@ -1113,6 +1178,36 @@ async def search_and_query( ) continue + # 1) name match wins if provided + if resource_name: + matched = self._resource_by_name(ds, resource_name) + if matched is not None: + chosen_dataset, chosen_resource = ds, matched + break + # Only error out if the user fixed the dataset too. + if explicit_dataset: + queryable_names = [ + r.get("name") or "(unnamed)" + for r in self._queryable_resources(ds) + ] + return { + "error": True, + "message": ( + f"No queryable resource in dataset " + f"{ds.get('id')!r} has a name matching " + f"{resource_name!r}. Available queryable " + f"resource names: " + f"{queryable_names or '(none)'}." + ), + } + # Otherwise fall through and try the next dataset. + skipped_summary.append( + f" [{idx}] {ds.get('title') or ds.get('id')}: " + f"no resource name matching {resource_name!r}" + ) + continue + + # 2) explicit positional pick if explicit_resource: if resource_index < 0 or resource_index >= len(resources): return { @@ -1138,6 +1233,7 @@ async def search_and_query( chosen_dataset, chosen_resource = ds, resource break + # 3) auto-pick the first queryable resource queryable = self._first_queryable_resource(ds) if queryable: chosen_dataset, chosen_resource = ds, queryable @@ -1154,6 +1250,24 @@ async def search_and_query( f"no datastore-loaded resource (formats: {', '.join(formats)})" ) + # If we walked all datasets and resource_name was set but never + # matched, give a name-specific error rather than the generic + # "no queryable resource" one. + if ( + chosen_dataset is None + and resource_name + and not explicit_dataset + ): + return { + "error": True, + "message": ( + f"No dataset in the {len(datasets)} matches for query " + f"{query!r} has a queryable resource whose name " + f"matches {resource_name!r}.\nSkipped:\n" + + ("\n".join(skipped_summary) or " (no datasets inspected)") + ), + } + if chosen_dataset is None or chosen_resource is None: details = ( "\n".join(skipped_summary) @@ -1512,6 +1626,27 @@ def _format_search_and_query( lines.append("") lines.append(schema_footer) + # Sibling queryable resources within the SAME dataset. Boston's 311 + # dataset has 22 (a rolling view + per-year archives back to 2011) + # — without this block the model can't see them. + sibling_queryables = self._queryable_resources(dataset) + chosen_resource_id = resource.get("id") + siblings = [ + r for r in sibling_queryables if r.get("id") != chosen_resource_id + ] + if siblings: + lines.append("") + lines.append( + f"Other queryable resources in this dataset " + f"(pass resource_name=... to pick one):" + ) + for r in siblings: + r_name = r.get("name") or "(unnamed)" + r_fmt = r.get("format") or "?" + r_id = r.get("id") or "?" + lines.append(f" - {r_name} [{r_fmt}]") + lines.append(f" resource_id: {r_id}") + if len(alternates) > 1: lines.append("") lines.append( diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index 2e1e773..e58c32a 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -750,6 +750,222 @@ async def test_execute_tool_search_and_query_skips_download_only_resources( "4d28fc98-c503-4065-987f-9fbc41947fc4" ) + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_resource_name_picks_archive( + self, ckan_config + ): + """Boston-style regression: a 311 dataset with a rolling NEW SYSTEM + plus per-year archives. resource_name='2020' must pick the 2020 + archive, not the first datastore_active resource.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + { + "id": "dataset-311", + "title": "311 Service Requests", + "resources": [ + { + "id": "new-uuid", + "name": "311 Service Requests - NEW SYSTEM", + "format": "CSV", + "datastore_active": True, + }, + { + "id": "2020-uuid", + "name": "311 SERVICE REQUESTS - 2020", + "format": "CSV", + "datastore_active": True, + }, + { + "id": "2021-uuid", + "name": "311 SERVICE REQUESTS - 2021", + "format": "CSV", + "datastore_active": True, + }, + ], + } + ] + } + } + mock_response_search.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": { + "records": [{"_id": 1, "case_id": "X-2020"}], + "fields": [{"id": "case_id", "type": "text"}], + } + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_search, + mock_response_query, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", + {"query": "311", "resource_name": "2020"}, + ) + + assert result.success is True + text = result.content[0]["text"] + # Picked the 2020 archive (case-insensitive substring on name) + assert "2020-uuid" in text + assert "X-2020" in text + # Sibling block surfaces the other queryable resources + assert "Other queryable resources in this dataset" in text + assert "311 Service Requests - NEW SYSTEM" in text + assert "311 SERVICE REQUESTS - 2021" in text + # And the third call's body actually queried 2020-uuid + third_call = mock_client.post.call_args_list[2] + assert third_call[1]["json"]["resource_id"] == "2020-uuid" + + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_resource_name_no_match_errors( + self, ckan_config + ): + """resource_name with no match returns a clean error listing names.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + { + "id": "dataset-311", + "title": "311", + "resources": [ + { + "id": "new-uuid", + "name": "NEW SYSTEM", + "format": "CSV", + "datastore_active": True, + } + ], + } + ] + } + } + mock_response_search.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_search] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", + { + "query": "311", + "resource_name": "1999", + "dataset_index": 0, + }, + ) + + assert result.success is False + err = result.error_message or "" + assert "1999" in err + assert "NEW SYSTEM" in err + + @pytest.mark.asyncio + async def test_execute_tool_search_and_query_siblings_block_lists_archives( + self, ckan_config + ): + """Siblings block lists every queryable resource of the chosen + dataset other than the chosen one — even when resource_name is + not used.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "results": [ + { + "id": "ds-311", + "title": "311", + "resources": [ + { + "id": "new", + "name": "NEW SYSTEM", + "format": "CSV", + "datastore_active": True, + }, + { + "id": "y2025", + "name": "311 - 2025", + "format": "CSV", + "datastore_active": True, + }, + { + "id": "y2024", + "name": "311 - 2024", + "format": "CSV", + "datastore_active": True, + }, + { + "id": "geojson", + "name": "GeoJSON", + "format": "GeoJSON", + "datastore_active": False, + }, + ], + } + ] + } + } + mock_response_search.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": {"records": [{"_id": 1}], "fields": []} + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_search, + mock_response_query, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_and_query", {"query": "311"} + ) + + assert result.success is True + text = result.content[0]["text"] + assert "Other queryable resources in this dataset" in text + assert "311 - 2025" in text + assert "311 - 2024" in text + # Only QUERYABLE siblings — the GeoJSON should not appear + # in the siblings block + assert "GeoJSON" not in text.split( + "Other queryable resources in this dataset" + )[1] + @pytest.mark.asyncio async def test_execute_tool_search_and_query_walks_to_next_dataset( self, ckan_config From d9b6fca36406b1fe73f17e9c849fe5c59fd9bc62 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 12:47:37 -0800 Subject: [PATCH 10/12] CKAN: surface true total + warn on limit-truncation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Real-world failure: bot answered "How many 311 requests closed on 4/29/2016?" with 100 — but 100 was the LIMIT; the true count is 531. The model was reading "Found 100 record(s) (showing up to 100)" as the answer to a counting question. Fixes: 1. Capture CKAN's `total` from datastore_search responses (it's already returned by CKAN; we were discarding it). Plumbed through _query_with_schema → ToolResult → formatter. 2. SQL (`where`) path: when SELECT * hits the limit exactly we don't know the true total, so issue a cheap SELECT COUNT(*) follow-up with the same WHERE. This means we always tell the model a real number when it asks a counting question via search_and_query or query_data, regardless of which path was taken. 3. Formatter rewrite: - Header line is now `total_matching_rows: N (returned K, limit=L)` when total is known and exceeds returned. Removed the misleading "Found 100 record(s)" wording that conflated rows-returned with true count. - When truncation is detected, prepend a `=== TRUNCATED ===` block stating "the answer is N, NOT K" and pointing at ckan__aggregate_data for cheap counts and ckan__execute_sql for custom LIMIT/ORDER BY. - When total can't be determined (count follow-up failed) and rows == limit, prepend `=== MAY BE TRUNCATED ===`. Live verification (boston prod CKAN): - search_and_query "311 closed on 2016-04-29" with default limit=100 → returns the 100 sample rows AND header "total_matching_rows: 531" AND the TRUNCATED warning telling the model the answer is 531. - search_and_query "311 closed on 2026-04-29" with limit=100 → returns 85 rows, no TRUNCATED warning (85 < 100), header reads "total_matching_rows: 85". Tests: +5 (41 -> 46) covering total surfaced from datastore_search, no-warning path when under limit, COUNT(*) follow-up fires only when truncated on the SQL path, and graceful fallback when count fails. --- plugins/ckan/plugin.py | 247 ++++++++++++++++++++++++++++++++----- tests/test_ckan_plugin.py | 248 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 463 insertions(+), 32 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index 0d36206..a2bfa97 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -651,7 +651,7 @@ async def execute_tool( filters = arguments.get("filters") or {} where = arguments.get("where") or None limit = arguments.get("limit", 100) - records, fields, error = await self._query_with_schema( + records, fields, total, error = await self._query_with_schema( resource_id=resource_id, filters=filters, limit=limit, @@ -668,7 +668,7 @@ async def execute_tool( { "type": "text", "text": self._format_query_results( - records, fields, limit + records, fields, limit, total ), } ], @@ -862,7 +862,7 @@ async def query_data( List of data records (the schema-aware variant is ``_query_with_schema``). """ - records, _fields, error = await self._query_with_schema( + records, _fields, _total, error = await self._query_with_schema( resource_id=resource_id, filters=filters, limit=limit, @@ -872,14 +872,66 @@ async def query_data( raise RuntimeError(error) return records + async def _count_via_sql( + self, + resource_id: str, + where_sql: str, + filters: Optional[Dict[str, Any]] = None, + ) -> Optional[int]: + """Run a SELECT COUNT(*) with the same filters to discover the true + row total when a SELECT * hit the limit. + + Returns ``None`` if the count call itself fails (we'd rather show + a 'TRUNCATED' warning than block the data response on a failed + count). Returns the integer total on success. + """ + try: + sql_parts = [f'SELECT COUNT(*) AS n FROM "{resource_id}"'] + if where_sql: + sql_parts.append(f" WHERE {where_sql}") + if filters: + eq_conds = [ + SafeSQLBuilder.build_filter_condition(f, v) + for f, v in filters.items() + ] + joiner = " AND " if where_sql else " WHERE " + sql_parts.append(joiner + " AND ".join(eq_conds)) + sql = "".join(sql_parts) + result = await self.execute_sql(sql) + if result.get("error"): + return None + recs = result.get("records") or [] + if not recs: + return None + n = recs[0].get("n") or recs[0].get("count") + if n is None: + return None + try: + return int(n) + except (TypeError, ValueError): + return None + except Exception: + return None + async def _query_with_schema( self, resource_id: str, filters: Optional[Dict[str, Any]] = None, limit: int = 100, where: Optional[Dict[str, Any]] = None, - ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Optional[str]]: - """Query datastore and return (records, fields, error_message). + ) -> Tuple[ + List[Dict[str, Any]], + List[Dict[str, Any]], + Optional[int], + Optional[str], + ]: + """Query datastore and return (records, fields, total, error). + + ``total`` is the true number of rows matching the filter, regardless + of LIMIT. CKAN's datastore_search returns this for free; for the + SQL (``where``) path we issue a follow-up COUNT(*) only when the + result hit the limit (otherwise len(records) IS the total). + Returns ``None`` when we couldn't determine a total. Routes through datastore_search_sql when ``where`` is set so the caller can express ranges/IN/LIKE; otherwise falls back to the @@ -891,7 +943,7 @@ async def _query_with_schema( where_sql = SafeSQLBuilder.build_where_clause(where) limit_int = SafeSQLBuilder.clamp_limit(limit) except ValueError as e: - return [], [], str(e) + return [], [], None, str(e) sql_parts = [f'SELECT * FROM "{validated_id}"'] if where_sql: @@ -904,7 +956,7 @@ async def _query_with_schema( for f, v in filters.items() ] except ValueError as e: - return [], [], str(e) + return [], [], None, str(e) joiner = " AND " if where_sql else " WHERE " sql_parts.append(joiner + " AND ".join(eq_conds)) sql_parts.append(f" LIMIT {limit_int}") @@ -912,12 +964,24 @@ async def _query_with_schema( result = await self.execute_sql(sql) if result.get("error"): - return [], [], result.get("message", "SQL execution failed") - return ( - result.get("records", []), - result.get("fields", []), - None, - ) + return [], [], None, result.get( + "message", "SQL execution failed" + ) + records = result.get("records", []) + fields = result.get("fields", []) + + # If we hit the LIMIT exactly, we don't actually know the total — + # do a cheap COUNT(*) follow-up so the model gets a real number + # instead of mistaking the limit for the count. + total: Optional[int] + if len(records) >= limit_int: + total = await self._count_via_sql( + validated_id, where_sql, filters + ) + else: + total = len(records) + + return records, fields, total, None # No `where` → cheap datastore_search path. params: Dict[str, Any] = {"resource_id": resource_id, "limit": limit} @@ -932,6 +996,7 @@ async def _query_with_schema( return ( [], [], + None, f"{msg}\n" "Hint: this resource may exist as a file download " "(GeoJSON/KML/SHP/PDF) but not be loaded into the " @@ -941,12 +1006,20 @@ async def _query_with_schema( "ckan__search_and_query, which auto-picks the " "datastore-loaded resource.", ) - return [], [], msg + return [], [], None, msg result = response.get("result", {}) + # CKAN returns `total` for free here — a true count of rows + # matching the filter, not capped by limit. + total_val = result.get("total") + try: + total = int(total_val) if total_val is not None else None + except (TypeError, ValueError): + total = None return ( result.get("records", []), result.get("fields", []), + total, None, ) @@ -1294,7 +1367,7 @@ async def search_and_query( ), } - records, fields, error = await self._query_with_schema( + records, fields, total, error = await self._query_with_schema( resource_id=resource_id, filters=filters or None, where=where, @@ -1314,6 +1387,7 @@ async def search_and_query( "resource": chosen_resource, "records": records, "fields": fields, + "total": total, "alternate_datasets": datasets, } @@ -1508,14 +1582,44 @@ def _format_query_results( records: List[Dict[str, Any]], fields: Optional[List[Dict[str, Any]]] = None, limit: int = 100, + total: Optional[int] = None, ) -> str: """Format query results for user display.""" + n_returned = len(records) + truncated_warning = self._format_truncation_block( + n_returned, limit, total + ) + if not records: text = "No records found matching the query." + parts = [truncated_warning, text] if truncated_warning else [text] schema_footer = self._format_schema_footer(fields) - return f"{text}\n\n{schema_footer}" if schema_footer else text + if schema_footer: + parts.append("") + parts.append(schema_footer) + return "\n".join(parts) - lines = [f"Found {len(records)} record(s) (showing up to {limit}):\n"] + lines: List[str] = [] + if truncated_warning: + lines.append(truncated_warning) + lines.append("") + + # Header line: prefer "true total" wording when known, since the + # model has consistently been mistaking len(records) for the total. + if total is not None and total != n_returned: + lines.append( + f"total_matching_rows: {total} (returned {n_returned}, " + f"limit={limit})\n" + ) + elif total is not None: + lines.append( + f"total_matching_rows: {total} (limit={limit})\n" + ) + else: + lines.append( + f"returned_rows: {n_returned} (limit={limit}, " + "total unknown — see warning above if any)\n" + ) # Show first few records as examples for i, record in enumerate(records[:5], 1): @@ -1525,8 +1629,8 @@ def _format_query_results( lines.append(f" {key}: {value}") lines.append("") - if len(records) > 5: - lines.append(f"... and {len(records) - 5} more record(s)") + if n_returned > 5: + lines.append(f"... and {n_returned - 5} more record(s) returned") schema_footer = self._format_schema_footer(fields) if schema_footer: @@ -1535,6 +1639,53 @@ def _format_query_results( return "\n".join(lines) + def _format_truncation_block( + self, + n_returned: int, + limit: int, + total: Optional[int], + ) -> str: + """Emit a clear warning when the result set is — or might be — + truncated by LIMIT. + + Returns ``""`` when no warning is needed (result fits within limit + and total is known/equal to returned).""" + # Total known and matches returned → fits within limit, no warning. + if total is not None and total <= n_returned: + return "" + + # Total known but exceeds returned → exact truncation, exact total. + if total is not None and total > n_returned: + return ( + "=== TRUNCATED ===\n" + f"This query has {total} matching rows, but only " + f"{n_returned} were returned (limit={limit}). For " + "counting questions, the answer is " + f"{total}, NOT {n_returned}. To return more rows, raise " + "`limit` (max 10000) or use `ckan__execute_sql` with " + "your own LIMIT/ORDER BY. For just the count, use " + "ckan__aggregate_data with metrics=" + '{"count": "count(*)"} and a matching filter — ' + "it's cheaper than fetching rows.\n" + "=================" + ) + + # Total unknown and we hit the limit exactly → likely truncated. + if total is None and n_returned >= limit: + return ( + "=== MAY BE TRUNCATED ===\n" + f"Result returned exactly the requested limit " + f"({limit}) and the true total could not be determined. " + "Treat this as a possibly-incomplete sample. For " + "counting questions, do NOT report " + f"{n_returned} as the answer — use ckan__aggregate_data " + 'with metrics={"count": "count(*)"} and the same ' + "filter, or re-run with a higher limit.\n" + "========================" + ) + + return "" + def _format_schema_footer( self, fields: Optional[List[Dict[str, Any]]] ) -> str: @@ -1587,23 +1738,50 @@ def _format_search_and_query( resource = composite.get("resource", {}) or {} records = composite.get("records", []) or [] fields = composite.get("fields", []) or [] + total = composite.get("total") alternates = composite.get("alternate_datasets", []) or [] dataset_id = dataset.get("id", "unknown") dataset_title = dataset.get("title", "Untitled") resource_id = resource.get("id", "unknown") resource_name = resource.get("name", "Unnamed") + n_returned = len(records) + + # Total line: prefer "true total" so the model can't read the + # returned-rows count as the answer to a counting question. + if total is not None and total != n_returned: + count_line = ( + f"total_matching_rows: {total} " + f"(returned {n_returned}, limit={limit})" + ) + elif total is not None: + count_line = f"total_matching_rows: {total} (limit={limit})" + else: + count_line = ( + f"returned_rows: {n_returned} " + f"(limit={limit}, total unknown)" + ) - lines: List[str] = [ - "=== search_and_query result ===", - f"matched_dataset: {dataset_title}", - f"dataset_id: {dataset_id}", - f"resource_id (use with ckan__query_data): {resource_id}", - f"resource_name: {resource_name}", - f"row_count: {len(records)} (limit={limit})", - "================================", - "", - ] + lines: List[str] = [] + truncated_warning = self._format_truncation_block( + n_returned, limit, total + ) + if truncated_warning: + lines.append(truncated_warning) + lines.append("") + + lines.extend( + [ + "=== search_and_query result ===", + f"matched_dataset: {dataset_title}", + f"dataset_id: {dataset_id}", + f"resource_id (use with ckan__query_data): {resource_id}", + f"resource_name: {resource_name}", + count_line, + "================================", + "", + ] + ) if not records: lines.append( @@ -1611,15 +1789,20 @@ def _format_search_and_query( "dataset/resource (see alternates below)." ) else: - lines.append(f"Showing up to 5 of {len(records)} record(s):") + preview_caption = ( + f"Showing first 5 of {n_returned} returned" + + (f" (true total: {total})" if total is not None else "") + + ":" + ) + lines.append(preview_caption) for i, record in enumerate(records[:5], 1): lines.append(f"Record {i}:") for key, value in record.items(): if key != "_id": lines.append(f" {key}: {value}") lines.append("") - if len(records) > 5: - lines.append(f"... and {len(records) - 5} more record(s)") + if n_returned > 5: + lines.append(f"... and {n_returned - 5} more record(s) returned") schema_footer = self._format_schema_footer(fields) if schema_footer: diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index e58c32a..a3309c8 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -1170,6 +1170,254 @@ async def test_execute_tool_query_data_schema_footer_in_normal_path( assert "x (text)" in text assert "z (int)" in text + @pytest.mark.asyncio + async def test_query_data_surfaces_total_from_datastore_search( + self, ckan_config + ): + """When CKAN returns `total`, format prefers total_matching_rows + over returned_rows.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": { + "records": [{"_id": i} for i in range(100)], + "fields": [{"id": "_id", "type": "int"}], + "total": 531, + } + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_query] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "limit": 100, + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "total_matching_rows: 531" in text + assert "TRUNCATED" in text + assert "the answer is 531, NOT 100" in text + assert "ckan__aggregate_data" in text + + @pytest.mark.asyncio + async def test_query_data_no_truncation_warning_when_under_limit( + self, ckan_config + ): + """When records returned < limit, no truncation warning shown.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_query = Mock() + mock_response_query.json.return_value = { + "result": { + "records": [{"_id": i} for i in range(85)], + "fields": [{"id": "_id", "type": "int"}], + "total": 85, + } + } + mock_response_query.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_query] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "limit": 100, + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "total_matching_rows: 85" in text + assert "TRUNCATED" not in text + + @pytest.mark.asyncio + async def test_query_data_where_path_does_count_followup_when_truncated( + self, ckan_config + ): + """SQL (`where`) path: when SELECT * hits the limit, the plugin + must do a COUNT(*) follow-up so the model gets a real total + rather than mistaking the limit for the count.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + # First SQL call: SELECT * returns exactly limit rows + mock_response_select = Mock() + mock_response_select.json.return_value = { + "result": { + "records": [ + {"_id": i, "case_id": f"c{i}"} for i in range(100) + ], + "fields": [ + {"id": "case_id", "type": "text"}, + {"id": "closed_dt", "type": "timestamp"}, + ], + } + } + mock_response_select.raise_for_status = Mock() + # Follow-up SQL call: SELECT COUNT(*) returns the true total + mock_response_count = Mock() + mock_response_count.json.return_value = { + "result": {"records": [{"n": 531}]} + } + mock_response_count.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_select, + mock_response_count, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "where": { + "closed_dt": { + "gte": "2016-04-29", + "lt": "2016-04-30", + } + }, + "limit": 100, + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "total_matching_rows: 531" in text + assert "TRUNCATED" in text + assert "the answer is 531, NOT 100" in text + # Follow-up COUNT(*) actually issued + assert mock_client.post.call_count == 3 + count_call = mock_client.post.call_args_list[2] + count_sql = count_call[1]["json"]["sql"] + assert "COUNT(*)" in count_sql + assert '"closed_dt" >= \'2016-04-29\'' in count_sql + + @pytest.mark.asyncio + async def test_query_data_where_path_no_count_when_under_limit( + self, ckan_config + ): + """SQL path: if records returned < limit we already know the total + — no extra COUNT(*) call should fire.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_select = Mock() + mock_response_select.json.return_value = { + "result": { + "records": [{"_id": i} for i in range(85)], + "fields": [], + } + } + mock_response_select.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_select] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "where": {"x": {"gt": 1}}, + "limit": 100, + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "total_matching_rows: 85" in text + assert "TRUNCATED" not in text + # init + SELECT only, no COUNT(*) + assert mock_client.post.call_count == 2 + + @pytest.mark.asyncio + async def test_query_data_where_path_count_failure_falls_back_to_warning( + self, ckan_config + ): + """If the COUNT(*) follow-up fails, we still return the data with + a 'MAY BE TRUNCATED' warning rather than failing the whole call.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_select = Mock() + mock_response_select.json.return_value = { + "result": { + "records": [{"_id": i} for i in range(100)], + "fields": [], + } + } + mock_response_select.raise_for_status = Mock() + # COUNT(*) call fails server-side + mock_response_count_fail = Mock() + mock_response_count_fail.json.return_value = { + "success": False, + "error": {"message": "COUNT failed"}, + } + mock_response_count_fail.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[ + mock_response_init, + mock_response_select, + mock_response_count_fail, + ] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "query_data", + { + "resource_id": "11111111-2222-3333-4444-555555555555", + "where": {"x": {"gt": 1}}, + "limit": 100, + }, + ) + + # Whole call still succeeds — count-failure must not block data + assert result.success is True + text = result.content[0]["text"] + assert "MAY BE TRUNCATED" in text + @pytest.mark.asyncio async def test_query_data_404_includes_datastore_active_hint(self, ckan_config): """A 404 from query_data should append the datastore_active hint.""" From 47fbdca3cdc4e262d73349b1092e52ad2ae92175 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 13:13:33 -0800 Subject: [PATCH 11/12] CKAN: 'X of Y' phrasing + truncation guards on every row-returning path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User flagged: limit-as-count confusion isn't unique to query_data. Same trap exists in execute_sql ("returned 100 because the SQL said LIMIT 100") and in search_datasets ("found 20 datasets" while CKAN knows there are actually 47 matches, since package_search returns `count`). Switched to the user-suggested 'X of Y' phrasing throughout. Changes: 1. Shared _format_count_header helper renders three cases: - "100 of 531 rows returned (limit=100; raise limit or use ckan__aggregate_data for full count)." - "85 rows returned (full result, limit=100)." - "100 rows returned (limit=100, true total unknown — see TRUNCATED warning above if any)." Used in query_data, search_and_query, and search_datasets (with unit="matching dataset(s) shown"). 2. search_datasets now reads CKAN's `count` from package_search and surfaces it: "5 of 21 matching dataset(s) shown" instead of "Found 5 dataset(s)". 3. execute_sql now extracts the effective LIMIT from the validated SQL via SQLValidator.extract_top_level_limit() and emits a "MAY BE TRUNCATED" block when len(records) >= effective_limit (datastore_search_sql doesn't return a `total`, so this heuristic is the best we can do). aggregate_data inherits the same guard since it formats results via _format_sql_results. 4. Existing query_data/search_and_query header lines switched to the X-of-Y phrasing for consistency. The TRUNCATED warning text now matches the rest of the corpus. Live verification (boston prod CKAN): - search_datasets("parks", limit=5) → "5 of 21 matching dataset(s) shown (limit=5; raise limit to see more)". - execute_sql with LIMIT 100 returning 100 rows → MAY BE TRUNCATED block at top, "100 rows returned (limit=100, true total unknown...)". Tests: +8 (204 -> 212). New: search_datasets count rendering, two execute_sql truncation cases, four extract_top_level_limit unit tests (simple, semicolon, missing, subquery-ignored), enforce-then-extract round trip. --- plugins/ckan/plugin.py | 193 +++++++++++++++++++++++++--------- plugins/ckan/sql_validator.py | 32 ++++++ tests/test_ckan_plugin.py | 125 +++++++++++++++++++++- tests/test_sql_validator.py | 43 ++++++++ 4 files changed, 338 insertions(+), 55 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index a2bfa97..ce0bbef 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -610,12 +610,16 @@ async def execute_tool( if tool_name == "search_datasets": query = arguments.get("query", "") limit = arguments.get("limit", 20) - datasets = await self.search_datasets(query, limit) + datasets, total = await self._search_datasets_with_count( + query, limit + ) return ToolResult( content=[ { "type": "text", - "text": self._format_search_results(datasets), + "text": self._format_search_results( + datasets, total=total, limit=limit + ), } ], success=True, @@ -712,7 +716,10 @@ async def execute_tool( # Format SQL results records = result.get("records", []) fields = result.get("fields", []) - formatted_text = self._format_sql_results(records, fields) + effective_limit = result.get("effective_limit") + formatted_text = self._format_sql_results( + records, fields, effective_limit=effective_limit + ) return ToolResult( content=[{"type": "text", "text": formatted_text}], success=True, @@ -790,7 +797,9 @@ async def execute_tool( error_message=result.get("message", "Aggregation failed"), ) formatted = self._format_sql_results( - result.get("records", []), result.get("fields", []) + result.get("records", []), + result.get("fields", []), + effective_limit=result.get("effective_limit"), ) return ToolResult( content=[{"type": "text", "text": formatted}], success=True @@ -821,12 +830,29 @@ async def search_datasets( limit: Maximum number of results Returns: - List of dataset metadata dictionaries + List of dataset metadata dictionaries (count-aware variant is + ``_search_datasets_with_count``). """ + datasets, _count = await self._search_datasets_with_count(query, limit) + return datasets + + async def _search_datasets_with_count( + self, query: str, limit: int = 20 + ) -> Tuple[List[Dict[str, Any]], Optional[int]]: + """Same as search_datasets but also returns CKAN's `count` — the + true number of datasets matching the query, regardless of the row + cap. Lets the formatter say "20 of 47 matching datasets returned" + instead of just "Found 20".""" response = await self._call_ckan_api( "package_search", {"q": query, "rows": limit} ) - return response.get("result", {}).get("results", []) + result = response.get("result", {}) + count_val = result.get("count") + try: + count = int(count_val) if count_val is not None else None + except (TypeError, ValueError): + count = None + return result.get("results", []), count async def get_dataset(self, dataset_id: str) -> Dict[str, Any]: """Get detailed metadata for a specific dataset. @@ -1045,7 +1071,8 @@ async def execute_sql(self, sql: str) -> Dict[str, Any]: sql: PostgreSQL SELECT statement Returns: - Dictionary with success flag, records, fields, or error message + Dictionary with success flag, records, fields, effective_limit, + or error message """ # Validate SQL is_valid, error = SQLValidator.validate_query(sql) @@ -1054,6 +1081,7 @@ async def execute_sql(self, sql: str) -> Dict[str, Any]: # Bound upstream scan cost: append LIMIT if the caller didn't set one. sql = SQLValidator.enforce_row_limit(sql) + effective_limit = SQLValidator.extract_top_level_limit(sql) # Log SQL execution (truncated for security) logger.info("Executing SQL", extra={"sql": sql[:500]}) @@ -1070,6 +1098,7 @@ async def execute_sql(self, sql: str) -> Dict[str, Any]: "success": True, "records": result.get("result", {}).get("records", []), "fields": result.get("result", {}).get("fields", []), + "effective_limit": effective_limit, } except Exception as e: logger.error(f"SQL execution failed: {e}", exc_info=True) @@ -1404,7 +1433,12 @@ async def health_check(self) -> bool: logger.error(f"Health check failed: {e}") return False - def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: + def _format_search_results( + self, + datasets: List[Dict[str, Any]], + total: Optional[int] = None, + limit: int = 20, + ) -> str: """Format search results for user display.""" if not datasets: return f"No datasets found in {self.plugin_config.city_name}'s open data portal." @@ -1448,9 +1482,25 @@ def _format_search_results(self, datasets: List[Dict[str, Any]]) -> str: ] ) - lines.append( - f"Found {len(datasets)} dataset(s) in {self.plugin_config.city_name}'s open data portal:\n" - ) + # Lead with the X-of-Y framing so the model can't mistake the + # results-shown count for the count of matching datasets. + n_returned = len(datasets) + if total is not None and total > n_returned: + lines.append( + f"{n_returned} of {total} matching dataset(s) shown " + f"(limit={limit}; raise limit to see more) in " + f"{self.plugin_config.city_name}'s open data portal:\n" + ) + elif total is not None: + lines.append( + f"{total} matching dataset(s) (full result, limit={limit}) " + f"in {self.plugin_config.city_name}'s open data portal:\n" + ) + else: + lines.append( + f"{n_returned} dataset(s) in " + f"{self.plugin_config.city_name}'s open data portal:\n" + ) for i, dataset in enumerate(datasets, 1): title = dataset.get("title", "Untitled") @@ -1604,22 +1654,11 @@ def _format_query_results( lines.append(truncated_warning) lines.append("") - # Header line: prefer "true total" wording when known, since the - # model has consistently been mistaking len(records) for the total. - if total is not None and total != n_returned: - lines.append( - f"total_matching_rows: {total} (returned {n_returned}, " - f"limit={limit})\n" - ) - elif total is not None: - lines.append( - f"total_matching_rows: {total} (limit={limit})\n" - ) - else: - lines.append( - f"returned_rows: {n_returned} (limit={limit}, " - "total unknown — see warning above if any)\n" - ) + # Header line: lead with the X-of-Y framing so the model can't + # mistake the rows-returned count for the answer to a counting + # question. + lines.append(self._format_count_header(n_returned, limit, total)) + lines.append("") # Show first few records as examples for i, record in enumerate(records[:5], 1): @@ -1639,6 +1678,31 @@ def _format_query_results( return "\n".join(lines) + @staticmethod + def _format_count_header( + n_returned: int, + limit: int, + total: Optional[int], + unit: str = "rows", + ) -> str: + """One-line "X of Y" summary used at the top of every row-returning + response. The model has been mistaking returned-rows for true count; + this phrasing makes the partial/total distinction unambiguous.""" + if total is not None and total > n_returned: + return ( + f"{n_returned} of {total} {unit} returned " + f"(limit={limit}; raise limit or use ckan__aggregate_data " + "for full count)." + ) + if total is not None: + # total == n_returned, all rows returned + return f"{total} {unit} returned (full result, limit={limit})." + # total unknown + return ( + f"{n_returned} {unit} returned (limit={limit}, " + "true total unknown — see TRUNCATED warning above if any)." + ) + def _format_truncation_block( self, n_returned: int, @@ -1747,20 +1811,7 @@ def _format_search_and_query( resource_name = resource.get("name", "Unnamed") n_returned = len(records) - # Total line: prefer "true total" so the model can't read the - # returned-rows count as the answer to a counting question. - if total is not None and total != n_returned: - count_line = ( - f"total_matching_rows: {total} " - f"(returned {n_returned}, limit={limit})" - ) - elif total is not None: - count_line = f"total_matching_rows: {total} (limit={limit})" - else: - count_line = ( - f"returned_rows: {n_returned} " - f"(limit={limit}, total unknown)" - ) + count_line = self._format_count_header(n_returned, limit, total) lines: List[str] = [] truncated_warning = self._format_truncation_block( @@ -1789,11 +1840,13 @@ def _format_search_and_query( "dataset/resource (see alternates below)." ) else: - preview_caption = ( - f"Showing first 5 of {n_returned} returned" - + (f" (true total: {total})" if total is not None else "") - + ":" - ) + if total is not None and total > n_returned: + preview_caption = ( + f"Showing first 5 of {n_returned} returned " + f"(true total: {total}):" + ) + else: + preview_caption = f"Showing first 5 of {n_returned} returned:" lines.append(preview_caption) for i, record in enumerate(records[:5], 1): lines.append(f"Record {i}:") @@ -1858,21 +1911,59 @@ def _format_search_and_query( return "\n".join(lines) def _format_sql_results( - self, records: List[Dict[str, Any]], fields: List[Dict[str, Any]] + self, + records: List[Dict[str, Any]], + fields: List[Dict[str, Any]], + effective_limit: Optional[int] = None, ) -> str: """Format SQL query results for user display. Args: records: List of record dictionaries fields: List of field metadata dictionaries + effective_limit: The LIMIT clause that was actually executed — + either user-supplied or the enforced default. Used to + detect truncation: if len(records) >= effective_limit, the + result was almost certainly capped. Returns: Formatted string representation of results """ + n_returned = len(records) + + # Heuristic truncation detection — datastore_search_sql doesn't + # return a "total"; the only signal is "did we hit our LIMIT?" + truncation_block = "" + if effective_limit is not None and n_returned >= effective_limit: + truncation_block = ( + "=== MAY BE TRUNCATED ===\n" + f"This SQL returned exactly the LIMIT ({effective_limit}) " + "rows. The true total could not be determined from " + "datastore_search_sql alone. For counting questions, do " + f"NOT report {n_returned} as the answer — instead run a " + "separate SELECT COUNT(*) with the same WHERE clause, or " + "use ckan__aggregate_data with metrics=" + '{"count": "count(*)"}.\n' + "========================" + ) + if not records: - return "No records found matching the SQL query." + text = "No records found matching the SQL query." + return f"{truncation_block}\n\n{text}" if truncation_block else text + + lines: List[str] = [] + if truncation_block: + lines.append(truncation_block) + lines.append("") - lines = [f"SQL Query Results: {len(records)} record(s)\n"] + # Header — total is unknown for raw SQL, so show "X rows returned". + if effective_limit is not None: + lines.append( + f"{n_returned} rows returned (limit={effective_limit}, " + "true total unknown — see warning above if any).\n" + ) + else: + lines.append(f"{n_returned} rows returned.\n") # Show field names if available if fields: @@ -1887,7 +1978,7 @@ def _format_sql_results( lines.append(f" {key}: {value}") lines.append("") - if len(records) > 10: - lines.append(f"... and {len(records) - 10} more record(s)") + if n_returned > 10: + lines.append(f"... and {n_returned - 10} more record(s) returned") return "\n".join(lines) diff --git a/plugins/ckan/sql_validator.py b/plugins/ckan/sql_validator.py index 99a8114..d5ce19b 100644 --- a/plugins/ckan/sql_validator.py +++ b/plugins/ckan/sql_validator.py @@ -159,6 +159,38 @@ def validate_query(sql: Any) -> Tuple[bool, Optional[str]]: return True, None + @classmethod + def extract_top_level_limit(cls, sql: str) -> Optional[int]: + """Return the integer LIMIT of the outermost statement, or None + if there isn't one (or it can't be parsed). Subquery / CTE limits + are ignored — same scoping rules as ``enforce_row_limit``. + + Used by callers to detect truncation: if a SELECT returned the + same number of rows as the effective top-level LIMIT, the result + is almost certainly capped.""" + try: + parsed = sqlparse.parse(sql) + except Exception: + return None + if not parsed: + return None + statement = parsed[0] + # Walk top-level tokens for a LIMIT keyword followed by an integer. + toks = [t for t in statement.tokens if not isinstance(t, Parenthesis)] + for i, tok in enumerate(toks): + if tok.ttype in Keyword and tok.normalized.upper() == "LIMIT": + # next non-whitespace token should be the integer literal + for nxt in toks[i + 1 :]: + if nxt.is_whitespace: + continue + text = nxt.value.strip().rstrip(";") + try: + return int(text) + except (TypeError, ValueError): + return None + return None + return None + @classmethod def enforce_row_limit(cls, sql: str) -> str: """Append ``LIMIT`` to an already-validated query if it lacks one. diff --git a/tests/test_ckan_plugin.py b/tests/test_ckan_plugin.py index a3309c8..0e22285 100644 --- a/tests/test_ckan_plugin.py +++ b/tests/test_ckan_plugin.py @@ -544,6 +544,123 @@ async def test_execute_tool_execute_sql_missing_param(self, ckan_config): assert result.success is False assert "required" in result.error_message.lower() + @pytest.mark.asyncio + async def test_execute_tool_search_datasets_surfaces_total_count( + self, ckan_config + ): + """search_datasets reads CKAN's `count` and renders X-of-Y.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_search = Mock() + mock_response_search.json.return_value = { + "result": { + "count": 47, + "results": [ + {"id": f"d{i}", "title": f"Dataset {i}", "resources": []} + for i in range(20) + ], + } + } + mock_response_search.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_search] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "search_datasets", {"query": "parks", "limit": 20} + ) + + assert result.success is True + text = result.content[0]["text"] + assert "20 of 47 matching dataset(s) shown" in text + + @pytest.mark.asyncio + async def test_execute_tool_execute_sql_truncated_warning(self, ckan_config): + """execute_sql warns when len(records) hits the LIMIT clause.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_sql = Mock() + mock_response_sql.json.return_value = { + "result": { + "records": [{"_id": i, "x": i} for i in range(100)], + "fields": [{"id": "x", "type": "int"}], + } + } + mock_response_sql.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_sql] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "execute_sql", + { + "sql": ( + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" ' + "LIMIT 100" + ) + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "MAY BE TRUNCATED" in text + assert "ckan__aggregate_data" in text or "COUNT(*)" in text + + @pytest.mark.asyncio + async def test_execute_tool_execute_sql_no_warning_under_limit( + self, ckan_config + ): + """execute_sql does not warn when fewer rows returned than LIMIT.""" + plugin = CKANPlugin(ckan_config) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_response_init = Mock() + mock_response_init.json.return_value = {"success": True} + mock_response_init.raise_for_status = Mock() + mock_response_sql = Mock() + mock_response_sql.json.return_value = { + "result": { + "records": [{"_id": i, "x": i} for i in range(7)], + "fields": [{"id": "x", "type": "int"}], + } + } + mock_response_sql.raise_for_status = Mock() + mock_client.post = AsyncMock( + side_effect=[mock_response_init, mock_response_sql] + ) + mock_client_class.return_value = mock_client + + await plugin.initialize() + result = await plugin.execute_tool( + "execute_sql", + { + "sql": ( + 'SELECT * FROM "11111111-2222-3333-4444-555555555555" ' + "LIMIT 100" + ) + }, + ) + + assert result.success is True + text = result.content[0]["text"] + assert "TRUNCATED" not in text + assert "7 rows returned" in text + @pytest.mark.asyncio async def test_execute_tool_search_and_query_succeeds(self, ckan_config): """search_and_query returns rows from the first resource of the first match.""" @@ -1208,7 +1325,7 @@ async def test_query_data_surfaces_total_from_datastore_search( assert result.success is True text = result.content[0]["text"] - assert "total_matching_rows: 531" in text + assert "100 of 531" in text assert "TRUNCATED" in text assert "the answer is 531, NOT 100" in text assert "ckan__aggregate_data" in text @@ -1250,7 +1367,7 @@ async def test_query_data_no_truncation_warning_when_under_limit( assert result.success is True text = result.content[0]["text"] - assert "total_matching_rows: 85" in text + assert "85 rows returned" in text assert "TRUNCATED" not in text @pytest.mark.asyncio @@ -1313,7 +1430,7 @@ async def test_query_data_where_path_does_count_followup_when_truncated( assert result.success is True text = result.content[0]["text"] - assert "total_matching_rows: 531" in text + assert "100 of 531" in text assert "TRUNCATED" in text assert "the answer is 531, NOT 100" in text # Follow-up COUNT(*) actually issued @@ -1361,7 +1478,7 @@ async def test_query_data_where_path_no_count_when_under_limit( assert result.success is True text = result.content[0]["text"] - assert "total_matching_rows: 85" in text + assert "85 rows returned" in text assert "TRUNCATED" not in text # init + SELECT only, no COUNT(*) assert mock_client.post.call_count == 2 diff --git a/tests/test_sql_validator.py b/tests/test_sql_validator.py index 42eedf9..8d541e3 100644 --- a/tests/test_sql_validator.py +++ b/tests/test_sql_validator.py @@ -804,6 +804,49 @@ def test_bad_limit_rejected(self, bad): SafeSQLBuilder.clamp_limit(bad) +class TestExtractTopLevelLimit: + UUID = "11111111-2222-3333-4444-555555555555" + + def test_simple_limit(self): + assert ( + SQLValidator.extract_top_level_limit( + f'SELECT * FROM "{self.UUID}" LIMIT 50' + ) + == 50 + ) + + def test_limit_with_trailing_semicolon(self): + assert ( + SQLValidator.extract_top_level_limit( + f'SELECT * FROM "{self.UUID}" LIMIT 100;' + ) + == 100 + ) + + def test_no_limit_returns_none(self): + assert ( + SQLValidator.extract_top_level_limit(f'SELECT * FROM "{self.UUID}"') + is None + ) + + def test_subquery_limit_ignored(self): + # Top-level statement has no LIMIT; the subquery's LIMIT does + # not count. + sql = ( + f'SELECT * FROM (SELECT * FROM "{self.UUID}" LIMIT 5) sub ' + ) + assert SQLValidator.extract_top_level_limit(sql) is None + + def test_after_enforce_row_limit(self): + sql = SQLValidator.enforce_row_limit( + f'SELECT * FROM "{self.UUID}"' + ) + assert ( + SQLValidator.extract_top_level_limit(sql) + == SQLValidator.DEFAULT_ROW_LIMIT + ) + + class TestSafeSQLBuilderResourceId: def test_valid_uuid(self): uuid = "11111111-2222-3333-4444-555555555555" From a8e9feab6a7a58217afa111c514848b8999ada08 Mon Sep 17 00:00:00 2001 From: brendanbabb Date: Thu, 30 Apr 2026 15:36:07 -0800 Subject: [PATCH 12/12] Lint: drop extraneous f-string prefixes (ruff F541) --- plugins/ckan/plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/ckan/plugin.py b/plugins/ckan/plugin.py index ce0bbef..626413f 100644 --- a/plugins/ckan/plugin.py +++ b/plugins/ckan/plugin.py @@ -1873,8 +1873,8 @@ def _format_search_and_query( if siblings: lines.append("") lines.append( - f"Other queryable resources in this dataset " - f"(pass resource_name=... to pick one):" + "Other queryable resources in this dataset " + "(pass resource_name=... to pick one):" ) for r in siblings: r_name = r.get("name") or "(unnamed)"