Skip to content

Commit b7cb2f3

Browse files
committed
Fix PromQL injection, connection leaks, and missing timeouts
Addresses critical findings from quality framework audit: - Add _escape_promql_label() and _promql_filter() to reporter to prevent PromQL injection via cluster/node/database/index names (FM-4) - Apply escaping to H001 base_filter, H002 idx_scan query, and H004 redundant index queries (the most dangerous sites where DB metadata like index_name/table_name is interpolated) - Add escape_promql_label() to Flask backend and apply to filter building - Fix connection leak in CLI: 2 locations in postgres-ai.ts where Client.connect() had no finally block (mon targets add, interactive add) - Add connectionTimeoutMillis: 10000 to all Client() instances in CLI - Add connect_timeout=10 to all psycopg2.connect() calls (reporter + Flask) - Add 11 unit tests for PromQL escaping covering injection attempts, backslash/quote handling, and normal PostgreSQL identifiers https://claude.ai/code/session_01TKKnEc2Yn2zM64bwCJ2UaX
1 parent c00f80e commit b7cb2f3

4 files changed

Lines changed: 167 additions & 35 deletions

File tree

cli/bin/postgres-ai.ts

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,16 +2452,20 @@ mon
24522452

24532453
// Test connection
24542454
console.log("Testing connection to the added instance...");
2455-
try {
2456-
const client = new Client({ connectionString: connStr });
2457-
await client.connect();
2458-
const result = await client.query("select version();");
2459-
console.log("✓ Connection successful");
2460-
console.log(`${result.rows[0].version}\n`);
2461-
await client.end();
2462-
} catch (error) {
2463-
const message = error instanceof Error ? error.message : String(error);
2464-
console.error(`✗ Connection failed: ${message}\n`);
2455+
{
2456+
let testClient: InstanceType<typeof Client> | null = null;
2457+
try {
2458+
testClient = new Client({ connectionString: connStr, connectionTimeoutMillis: 10000 });
2459+
await testClient.connect();
2460+
const result = await testClient.query("select version();");
2461+
console.log("✓ Connection successful");
2462+
console.log(`${result.rows[0].version}\n`);
2463+
} catch (error) {
2464+
const message = error instanceof Error ? error.message : String(error);
2465+
console.error(`✗ Connection failed: ${message}\n`);
2466+
} finally {
2467+
if (testClient) await testClient.end();
2468+
}
24652469
}
24662470
} else if (opts.yes) {
24672471
// Auto-yes mode without database URL - skip database setup
@@ -2496,16 +2500,20 @@ mon
24962500

24972501
// Test connection
24982502
console.log("Testing connection to the added instance...");
2499-
try {
2500-
const client = new Client({ connectionString: connStr });
2501-
await client.connect();
2502-
const result = await client.query("select version();");
2503-
console.log("✓ Connection successful");
2504-
console.log(`${result.rows[0].version}\n`);
2505-
await client.end();
2506-
} catch (error) {
2507-
const message = error instanceof Error ? error.message : String(error);
2508-
console.error(`✗ Connection failed: ${message}\n`);
2503+
{
2504+
let testClient: InstanceType<typeof Client> | null = null;
2505+
try {
2506+
testClient = new Client({ connectionString: connStr, connectionTimeoutMillis: 10000 });
2507+
await testClient.connect();
2508+
const result = await testClient.query("select version();");
2509+
console.log("✓ Connection successful");
2510+
console.log(`${result.rows[0].version}\n`);
2511+
} catch (error) {
2512+
const message = error instanceof Error ? error.message : String(error);
2513+
console.error(`✗ Connection failed: ${message}\n`);
2514+
} finally {
2515+
if (testClient) await testClient.end();
2516+
}
25092517
}
25102518
}
25112519
} else {
@@ -3292,7 +3300,7 @@ targets
32923300
console.log(`Testing connection to monitoring target '${name}'...`);
32933301

32943302
// Use native pg client instead of requiring psql to be installed
3295-
const client = new Client({ connectionString: instance.conn_str });
3303+
const client = new Client({ connectionString: instance.conn_str, connectionTimeoutMillis: 10000 });
32963304

32973305
try {
32983306
await client.connect();

monitoring_flask_backend/app.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
logger = logging.getLogger(__name__)
1717

1818

19+
def escape_promql_label(value: str) -> str:
20+
"""Escape a value for safe use inside PromQL label matchers (double-quoted strings)."""
21+
return value.replace("\\", "\\\\").replace('"', '\\"')
22+
23+
1924
def smart_truncate_query(query: str, max_length: int = 40) -> str:
2025
"""
2126
Smart SQL query truncation for display names.
@@ -250,7 +255,7 @@ def get_query_texts_from_sink(db_name: str = None, truncation_mode: str = 'smart
250255

251256
conn = None
252257
try:
253-
conn = psycopg2.connect(POSTGRES_SINK_URL)
258+
conn = psycopg2.connect(POSTGRES_SINK_URL, connect_timeout=10)
254259
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
255260
# Skip db_name filter if it's empty, "All", or contains special chars
256261
use_db_filter = db_name and db_name.lower() not in ('all', '') and not db_name.startswith('$')
@@ -384,14 +389,14 @@ def get_pgss_metrics_csv():
384389
# Build the base query for pg_stat_statements metrics
385390
base_query = 'pgwatch_pg_stat_statements_calls'
386391

387-
# Add filters if provided
392+
# Add filters if provided (escape values to prevent PromQL injection)
388393
filters = []
389394
if cluster_name:
390-
filters.append(f'cluster="{cluster_name}"')
395+
filters.append(f'cluster="{escape_promql_label(cluster_name)}"')
391396
if node_name:
392-
filters.append(f'instance=~".*{node_name}.*"')
397+
filters.append(f'instance=~".*{escape_promql_label(node_name)}.*"')
393398
if db_name:
394-
filters.append(f'datname="{db_name}"')
399+
filters.append(f'datname="{escape_promql_label(db_name)}"')
395400

396401
if filters:
397402
base_query += '{' + ','.join(filters) + '}'
@@ -1176,7 +1181,7 @@ def get_query_texts():
11761181

11771182
conn = None
11781183
try:
1179-
conn = psycopg2.connect(POSTGRES_SINK_URL)
1184+
conn = psycopg2.connect(POSTGRES_SINK_URL, connect_timeout=10)
11801185
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
11811186
# Skip db_name filter if it's empty, "All", or contains special chars
11821187
use_db_filter = db_name and db_name.lower() not in ('all', '') and not db_name.startswith('$')
@@ -1287,7 +1292,7 @@ def get_query_info_metrics():
12871292

12881293
conn = None
12891294
try:
1290-
conn = psycopg2.connect(POSTGRES_SINK_URL)
1295+
conn = psycopg2.connect(POSTGRES_SINK_URL, connect_timeout=10)
12911296
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
12921297
# Skip db_name filter if it's empty, "All", or contains special chars
12931298
use_db_filter = db_name and db_name.lower() not in ('all', '') and not db_name.startswith('$')

reporter/postgres_reports.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,29 @@ def _load_build_metadata(self) -> Dict[str, Optional[str]]:
190190
"build_ts": self._read_text_file(build_ts_path),
191191
}
192192

193+
@staticmethod
194+
def _escape_promql_label(value: str) -> str:
195+
"""Escape a value for use inside PromQL label matchers.
196+
197+
PromQL label values are enclosed in double quotes and only need
198+
backslash and double-quote characters escaped (like Go strings).
199+
This prevents PromQL injection when cluster names, database names,
200+
or other user-controlled strings contain special characters.
201+
"""
202+
return value.replace("\\", "\\\\").replace('"', '\\"')
203+
204+
def _promql_filter(self, **labels: str) -> str:
205+
"""Build a PromQL label filter string from keyword arguments.
206+
207+
Example:
208+
self._promql_filter(cluster="my-cluster", node_name="node-01")
209+
# returns: '{cluster="my-cluster", node_name="node-01"}'
210+
"""
211+
parts = []
212+
for key, val in labels.items():
213+
parts.append(f'{key}="{self._escape_promql_label(val)}"')
214+
return "{" + ", ".join(parts) + "}"
215+
193216
def test_connection(self) -> bool:
194217
"""Test connection to Prometheus."""
195218
try:
@@ -205,9 +228,9 @@ def connect_postgres_sink(self) -> bool:
205228
return False
206229
if psycopg2 is None:
207230
raise RuntimeError("psycopg2 is required for postgres sink access but is not installed")
208-
231+
209232
try:
210-
self.pg_conn = psycopg2.connect(self.postgres_sink_url)
233+
self.pg_conn = psycopg2.connect(self.postgres_sink_url, connect_timeout=10)
211234
return True
212235
except Exception as e:
213236
logger.error(f"Postgres sink connection failed: {e}")
@@ -670,7 +693,8 @@ def generate_h001_invalid_indexes_report(self, cluster: str = "local", node_name
670693

671694
# Query all invalid indexes metrics and merge by index key
672695
# Each field is a separate metric in pgwatch prometheus export
673-
base_filter = f'cluster="{cluster}", node_name="{node_name}", datname="{db_name}"'
696+
_esc = self._escape_promql_label
697+
base_filter = f'cluster="{_esc(cluster)}", node_name="{_esc(node_name)}", datname="{_esc(db_name)}"'
674698

675699
# Query primary metric (index_size_bytes) - this determines which indexes exist
676700
size_query = f'last_over_time(pgwatch_pg_invalid_indexes_index_size_bytes{{{base_filter}}}[3h])'
@@ -831,7 +855,8 @@ def generate_h002_unused_indexes_report(self, cluster: str = "local", node_name:
831855
index_size_bytes = float(item['value'][1]) if item.get('value') else 0
832856

833857
# Query other related metrics for this index
834-
idx_scan_query = f'last_over_time(pgwatch_unused_indexes_idx_scan{{cluster="{cluster}", node_name="{node_name}", datname="{db_name}", schema_name="{schema_name}", table_name="{table_name}", index_name="{index_name}"}}[3h])'
858+
_e = self._escape_promql_label
859+
idx_scan_query = f'last_over_time(pgwatch_unused_indexes_idx_scan{{cluster="{_e(cluster)}", node_name="{_e(node_name)}", datname="{_e(db_name)}", schema_name="{_e(schema_name)}", table_name="{_e(table_name)}", index_name="{_e(index_name)}"}}[3h])'
835860
idx_scan_result = self.query_instant(idx_scan_query)
836861
idx_scan = float(idx_scan_result['data']['result'][0]['value'][1]) if idx_scan_result.get('data',
837862
{}).get(
@@ -941,18 +966,20 @@ def generate_h004_redundant_indexes_report(self, cluster: str = "local", node_na
941966
index_size_bytes = float(item['value'][1]) if item.get('value') else 0
942967

943968
# Query other related metrics for this index
944-
table_size_query = f'last_over_time(pgwatch_redundant_indexes_table_size_bytes{{cluster="{cluster}", node_name="{node_name}", dbname="{db_name}", schema_name="{schema_name}", table_name="{table_name}", index_name="{index_name}"}}[3h])'
969+
_e = self._escape_promql_label
970+
_idx_filter = f'cluster="{_e(cluster)}", node_name="{_e(node_name)}", dbname="{_e(db_name)}", schema_name="{_e(schema_name)}", table_name="{_e(table_name)}", index_name="{_e(index_name)}"'
971+
table_size_query = f'last_over_time(pgwatch_redundant_indexes_table_size_bytes{{{_idx_filter}}}[3h])'
945972
table_size_result = self.query_instant(table_size_query)
946973
table_size_bytes = float(
947974
table_size_result['data']['result'][0]['value'][1]) if table_size_result.get('data', {}).get(
948975
'result') else 0
949976

950-
index_usage_query = f'last_over_time(pgwatch_redundant_indexes_index_usage{{cluster="{cluster}", node_name="{node_name}", dbname="{db_name}", schema_name="{schema_name}", table_name="{table_name}", index_name="{index_name}"}}[3h])'
977+
index_usage_query = f'last_over_time(pgwatch_redundant_indexes_index_usage{{{_idx_filter}}}[3h])'
951978
index_usage_result = self.query_instant(index_usage_query)
952979
index_usage = float(index_usage_result['data']['result'][0]['value'][1]) if index_usage_result.get(
953980
'data', {}).get('result') else 0
954981

955-
supports_fk_query = f'last_over_time(pgwatch_redundant_indexes_supports_fk{{cluster="{cluster}", node_name="{node_name}", dbname="{db_name}", schema_name="{schema_name}", table_name="{table_name}", index_name="{index_name}"}}[3h])'
982+
supports_fk_query = f'last_over_time(pgwatch_redundant_indexes_supports_fk{{{_idx_filter}}}[3h])'
956983
supports_fk_result = self.query_instant(supports_fk_query)
957984
supports_fk = bool(
958985
int(supports_fk_result['data']['result'][0]['value'][1])) if supports_fk_result.get('data',
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Tests for PromQL label escaping and query filter building.
2+
3+
These tests verify that user-controlled values (cluster names, database names,
4+
index names, etc.) are properly escaped when interpolated into PromQL queries,
5+
preventing PromQL injection attacks.
6+
"""
7+
import pytest
8+
9+
from reporter.postgres_reports import PostgresReportGenerator
10+
11+
12+
@pytest.fixture
13+
def generator():
14+
"""Create a generator instance for testing."""
15+
return PostgresReportGenerator(
16+
prometheus_url="http://prom.test",
17+
postgres_sink_url="",
18+
)
19+
20+
21+
class TestEscapePromqlLabel:
22+
"""Tests for _escape_promql_label static method."""
23+
24+
@pytest.mark.unit
25+
def test_plain_string_unchanged(self):
26+
assert PostgresReportGenerator._escape_promql_label("my-cluster") == "my-cluster"
27+
28+
@pytest.mark.unit
29+
def test_escapes_double_quotes(self):
30+
assert PostgresReportGenerator._escape_promql_label('db"name') == 'db\\"name'
31+
32+
@pytest.mark.unit
33+
def test_escapes_backslashes(self):
34+
assert PostgresReportGenerator._escape_promql_label("path\\to") == "path\\\\to"
35+
36+
@pytest.mark.unit
37+
def test_escapes_backslash_before_quote(self):
38+
"""Backslash must be escaped first, then quote."""
39+
result = PostgresReportGenerator._escape_promql_label('a\\"b')
40+
assert result == 'a\\\\\\"b'
41+
42+
@pytest.mark.unit
43+
def test_empty_string(self):
44+
assert PostgresReportGenerator._escape_promql_label("") == ""
45+
46+
@pytest.mark.unit
47+
def test_injection_attempt_closing_brace(self):
48+
"""A value like: db"}} OR vector(1) should not break out of the label."""
49+
result = PostgresReportGenerator._escape_promql_label('db"}} OR vector(1)')
50+
assert '"' not in result or result.count('\\"') == result.count('"')
51+
assert result == 'db\\"}} OR vector(1)'
52+
53+
@pytest.mark.unit
54+
def test_normal_postgres_identifiers(self):
55+
"""Common PostgreSQL identifiers should pass through unchanged."""
56+
identifiers = [
57+
"public",
58+
"my_table",
59+
"idx_users_email",
60+
"pg_stat_statements",
61+
"node-01",
62+
"cluster.local",
63+
]
64+
for ident in identifiers:
65+
assert PostgresReportGenerator._escape_promql_label(ident) == ident
66+
67+
68+
class TestPromqlFilter:
69+
"""Tests for _promql_filter method."""
70+
71+
@pytest.mark.unit
72+
def test_single_label(self, generator):
73+
result = generator._promql_filter(cluster="local")
74+
assert result == '{cluster="local"}'
75+
76+
@pytest.mark.unit
77+
def test_multiple_labels(self, generator):
78+
result = generator._promql_filter(cluster="local", node_name="node-01")
79+
assert 'cluster="local"' in result
80+
assert 'node_name="node-01"' in result
81+
assert result.startswith("{")
82+
assert result.endswith("}")
83+
84+
@pytest.mark.unit
85+
def test_escapes_values(self, generator):
86+
result = generator._promql_filter(cluster='my"cluster')
87+
assert result == '{cluster="my\\"cluster"}'
88+
89+
@pytest.mark.unit
90+
def test_empty_labels(self, generator):
91+
result = generator._promql_filter()
92+
assert result == "{}"

0 commit comments

Comments
 (0)