Skip to content

Commit 1214b33

Browse files
gkorlandsirudogCopilot
authored
feat: add support for postgres schema selection (#475)
* feat: add support for postgres schema selection Add support for selecting a PostgreSQL schema instead of always using 'public'. The schema is extracted from the connection URL's options parameter (search_path), following PostgreSQL's native libpq format. Changes: - Add _parse_schema_from_url() to extract schema from connection URL - Thread schema parameter through all extraction methods with 'public' default - Add pg_namespace JOINs for correct cross-schema disambiguation - Add schema input field in DatabaseModal (PostgreSQL only) - Add comprehensive unit tests for URL schema parsing - Update documentation with custom schema configuration guide Based on PR #373 by sirudog with the following fixes: - Fix pg_namespace JOIN order in extract_columns_info to prevent duplicate rows when same-named tables exist across schemas - Fix regex to require '=' separator (prevents mis-capture edge cases) - Improve $user handling to loop through all schemas instead of only checking first two positions - Fix pylint line-too-long in test file Co-authored-by: sirudog <1550561+sirudog@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: make parse_schema_from_url public to fix CI pylint Rename _parse_schema_from_url to parse_schema_from_url since the method is already documented for external use and tested directly. This eliminates W0212 (protected-access) warnings that cause CI pylint to fail with exit code 4. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address review comments on PR #475 - Add constraint_schema qualifier to key_column_usage JOINs in extract_columns_info to prevent cross-schema constraint name collisions - Sanitize schema input in DatabaseModal to strip non-identifier characters before building the URL options - Add edge case tests: empty tokens, blank quoted tokens, repeated $user entries Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * chore: remove accidentally committed build artifacts Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address copilot reviewer comments on PR #475 - Fix regex to capture search_path values with spaces after commas (e.g. $user, public) by matching up to next -c option or EOL - Set session search_path explicitly after connecting so sample queries resolve to the correct schema - Use versionless PostgreSQL docs link (/docs/current/) - Clarify case-sensitivity note for schema names in troubleshooting Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * chore: gitignore build artifacts Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: replace ReDoS-vulnerable regex in parse_schema_from_url Replace (.+?)(?=\s+-c|\s*$) with [^\s,]+(?:\s*,\s*[^\s,]+)* to eliminate polynomial backtracking flagged by CodeQL. The new pattern uses unambiguous character classes with no overlapping quantifiers. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: validate schema input instead of silent sanitization, fix doc URL encoding - DatabaseModal: Show validation error for invalid schema characters instead of silently stripping them. Throw error on submit if invalid chars present. - docs: URL-encode the example URL to prevent copy/paste connection failures. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: revert doc URL to readable form to fix spellcheck The URL-encoded form (-csearch_path%3Dmy_schema) inside the Liquid capture block triggers spellcheck failures ('csearch', 'Dmy'). Reverted to readable form since Python's urlparse handles both formats fine. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: add missing tech terms to spellcheck wordlist Add terms from AGENTS.md/CLAUDE.md (added in staging merge) to the spellcheck wordlist: config, docstring, dotenv, ESLint, HSTS, init, Middleware, monorepo, PRs, pylint, pytest, Radix, Zod, and error class names. Also fix DockerHub capitalization. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: ensure DB connection cleanup on error and add cursor type hints - Wrap psycopg2 connection/cursor in try/finally so they are always closed, even when extract_tables_info or extract_relationships raises - Set conn/cursor to None after explicit close to avoid double-close in the finally block - Add Any type hints to cursor parameters on extract_tables_info, extract_columns_info, extract_foreign_keys, extract_relationships, and _execute_sample_query Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: increase timeout for multi-step E2E chat tests Mark three tests that perform multiple LLM round-trips with test.slow() to triple their timeout (60s → 180s), preventing spurious CI failures when LLM responses are slow: - multiple sequential queries maintain conversation history - switching databases clears chat history - duplicate record shows user-friendly error message Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: sirudog <1550561+sirudog@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 51ddf58 commit 1214b33

9 files changed

Lines changed: 365 additions & 94 deletions

File tree

.github/wordlist.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ schemas
88
psycopg
99
html
1010
PostgreSQLLoader
11+
PostgresLoader
1112
api
1213
postgres
1314
postgresql
@@ -77,6 +78,7 @@ LLM
7778
Ollama
7879
OpenAI
7980
OpenAI's
81+
DockerHub
8082
Dockerhub
8183
FDE
8284
github
@@ -98,4 +100,21 @@ Sanitization
98100
JOINs
99101
subqueries
100102
subquery
101-
TTL
103+
TTL
104+
105+
config
106+
docstring
107+
dotenv
108+
ESLint
109+
GraphNotFoundError
110+
HSTS
111+
init
112+
InternalError
113+
InvalidArgumentError
114+
Middleware
115+
monorepo
116+
PRs
117+
pylint
118+
pytest
119+
Radix
120+
Zod

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,7 @@ demo_tokens.py
3131
/blob-report/
3232
/playwright/.cache/
3333
/playwright/.auth/
34-
e2e/.auth/
34+
e2e/.auth/
35+
# Build artifacts
36+
clients/python/queryweaver_client.egg-info/
37+
clients/ts/dist/

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ with requests.post(url, headers=headers, json={"chat": ["Count orders last week"
231231
continue
232232
obj = json.loads(part)
233233
print('STREAM:', obj)
234+
```
234235

235236
Notes & tips
236237
- Graph IDs are namespaced per-user. When calling the API directly use the plain graph id (the server will namespace by the authenticated user). For uploaded files the `database` field determines the saved graph id.

api/loaders/postgres_loader.py

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import decimal
66
import logging
77
from typing import AsyncGenerator, Dict, Any, List, Tuple
8+
from urllib.parse import urlparse, parse_qs, unquote
89

910
import psycopg2
1011
from psycopg2 import sql
@@ -52,7 +53,7 @@ class PostgresLoader(BaseLoader):
5253

5354
@staticmethod
5455
def _execute_sample_query(
55-
cursor, table_name: str, col_name: str, sample_size: int = 3
56+
cursor: Any, table_name: str, col_name: str, sample_size: int = 3
5657
) -> List[Any]:
5758
"""
5859
Execute query to get random sample values for a column.
@@ -96,39 +97,96 @@ def _serialize_value(value):
9697
return None
9798
return value
9899

100+
@staticmethod
101+
def parse_schema_from_url(connection_url: str) -> str:
102+
"""
103+
Parse the search_path from the connection URL's options parameter.
104+
105+
The options parameter follows PostgreSQL's libpq format:
106+
postgresql://user:pass@host:port/db?options=-csearch_path%3Dschema_name
107+
108+
Args:
109+
connection_url: PostgreSQL connection URL
110+
111+
Returns:
112+
The first schema from search_path, or 'public' if not specified
113+
"""
114+
try:
115+
parsed = urlparse(connection_url)
116+
query_params = parse_qs(parsed.query)
117+
118+
options = query_params.get('options', [])
119+
if not options:
120+
return 'public'
121+
122+
options_str = unquote(options[0])
123+
124+
# Parse -c search_path=value from options
125+
# Format can be: -csearch_path=schema or -c search_path=schema
126+
# Match comma-separated schema tokens (supports spaces after commas).
127+
match = re.search(r'-c\s*search_path\s*=\s*([^\s,]+(?:\s*,\s*[^\s,]+)*)', options_str, re.IGNORECASE)
128+
if match:
129+
search_path = match.group(1)
130+
schemas = search_path.split(',')
131+
for s in schemas:
132+
s = s.strip().strip('"\'')
133+
if s and s != '$user':
134+
return s
135+
return 'public'
136+
137+
return 'public'
138+
139+
except Exception: # pylint: disable=broad-exception-caught
140+
return 'public'
141+
99142
@staticmethod
100143
async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]:
101144
"""
102145
Load the graph data from a PostgreSQL database into the graph database.
103146
104147
Args:
105148
connection_url: PostgreSQL connection URL in format:
106-
postgresql://username:password@host:port/database
149+
postgresql://username:password@host:port/database
150+
Optionally with schema via options parameter:
151+
postgresql://...?options=-csearch_path%3Dschema_name
107152
108153
Returns:
109154
Tuple[bool, str]: Success status and message
110155
"""
156+
conn = None
157+
cursor = None
111158
try:
159+
# Parse schema from connection URL (defaults to 'public')
160+
schema = PostgresLoader.parse_schema_from_url(connection_url)
161+
112162
# Connect to PostgreSQL database
113163
conn = psycopg2.connect(connection_url)
114164
cursor = conn.cursor()
115165

166+
# Set the session search_path to the parsed schema so unqualified
167+
# table references (e.g. in sample queries) resolve correctly.
168+
cursor.execute(
169+
sql.SQL("SET search_path TO {}").format(sql.Identifier(schema))
170+
)
171+
116172
# Extract database name from connection URL
117173
db_name = connection_url.split('/')[-1]
118174
if '?' in db_name:
119175
db_name = db_name.split('?')[0]
120176

121177
# Get all table information
122178
yield True, "Extracting table information..."
123-
entities = PostgresLoader.extract_tables_info(cursor)
179+
entities = PostgresLoader.extract_tables_info(cursor, schema)
124180

125181
yield True, "Extracting relationship information..."
126182
# Get all relationship information
127-
relationships = PostgresLoader.extract_relationships(cursor)
183+
relationships = PostgresLoader.extract_relationships(cursor, schema)
128184

129-
# Close database connection
185+
# Close database connection before graph loading
130186
cursor.close()
187+
cursor = None
131188
conn.close()
189+
conn = None
132190

133191
yield True, "Loading data into graph..."
134192
# Load data into graph
@@ -144,46 +202,53 @@ async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, s
144202
except Exception as e: # pylint: disable=broad-exception-caught
145203
logging.error("Error loading PostgreSQL schema: %s", e)
146204
yield False, "Failed to load PostgreSQL database schema"
205+
finally:
206+
if cursor is not None:
207+
cursor.close()
208+
if conn is not None:
209+
conn.close()
147210

148211
@staticmethod
149-
def extract_tables_info(cursor) -> Dict[str, Any]:
212+
def extract_tables_info(cursor: Any, schema: str = 'public') -> Dict[str, Any]:
150213
"""
151214
Extract table and column information from PostgreSQL database.
152215
153216
Args:
154217
cursor: Database cursor
218+
schema: Database schema to extract tables from (default: 'public')
155219
156220
Returns:
157221
Dict containing table information
158222
"""
159223
entities = {}
160224

161-
# Get all tables in public schema
225+
# Get all tables in the specified schema
162226
cursor.execute("""
163227
SELECT table_name, table_comment
164228
FROM information_schema.tables t
165229
LEFT JOIN (
166230
SELECT schemaname, tablename, description as table_comment
167231
FROM pg_tables pt
168232
JOIN pg_class pc ON pc.relname = pt.tablename
233+
JOIN pg_namespace pn ON pn.oid = pc.relnamespace AND pn.nspname = pt.schemaname
169234
JOIN pg_description pd ON pd.objoid = pc.oid AND pd.objsubid = 0
170-
WHERE pt.schemaname = 'public'
235+
WHERE pt.schemaname = %s
171236
) tc ON tc.tablename = t.table_name
172-
WHERE t.table_schema = 'public'
237+
WHERE t.table_schema = %s
173238
AND t.table_type = 'BASE TABLE'
174239
ORDER BY t.table_name;
175-
""")
240+
""", (schema, schema))
176241

177242
tables = cursor.fetchall()
178243

179244
for table_name, table_comment in tqdm.tqdm(tables, desc="Extracting table information"):
180245
table_name = table_name.strip()
181246

182247
# Get column information for this table
183-
columns_info = PostgresLoader.extract_columns_info(cursor, table_name)
248+
columns_info = PostgresLoader.extract_columns_info(cursor, table_name, schema)
184249

185250
# Get foreign keys for this table
186-
foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name)
251+
foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name, schema)
187252

188253
# Generate table description
189254
table_description = table_comment if table_comment else f"Table: {table_name}"
@@ -201,13 +266,14 @@ def extract_tables_info(cursor) -> Dict[str, Any]:
201266
return entities
202267

203268
@staticmethod
204-
def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
269+
def extract_columns_info(cursor: Any, table_name: str, schema: str = 'public') -> Dict[str, Any]:
205270
"""
206271
Extract column information for a specific table.
207272
208273
Args:
209274
cursor: Database cursor
210275
table_name: Name of the table
276+
schema: Database schema (default: 'public')
211277
212278
Returns:
213279
Dict containing column information
@@ -230,24 +296,29 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
230296
FROM information_schema.table_constraints tc
231297
JOIN information_schema.key_column_usage ku
232298
ON tc.constraint_name = ku.constraint_name
299+
AND tc.constraint_schema = ku.constraint_schema
233300
WHERE tc.table_name = %s
301+
AND tc.table_schema = %s
234302
AND tc.constraint_type = 'PRIMARY KEY'
235303
) pk ON pk.column_name = c.column_name
236304
LEFT JOIN (
237305
SELECT ku.column_name
238306
FROM information_schema.table_constraints tc
239307
JOIN information_schema.key_column_usage ku
240308
ON tc.constraint_name = ku.constraint_name
309+
AND tc.constraint_schema = ku.constraint_schema
241310
WHERE tc.table_name = %s
311+
AND tc.table_schema = %s
242312
AND tc.constraint_type = 'FOREIGN KEY'
243313
) fk ON fk.column_name = c.column_name
244-
LEFT JOIN pg_class pc ON pc.relname = c.table_name
314+
LEFT JOIN pg_namespace pn ON pn.nspname = c.table_schema
315+
LEFT JOIN pg_class pc ON pc.relname = c.table_name AND pc.relnamespace = pn.oid
245316
LEFT JOIN pg_attribute pa ON pa.attrelid = pc.oid AND pa.attname = c.column_name
246317
LEFT JOIN pg_description pgd ON pgd.objoid = pc.oid AND pgd.objsubid = pa.attnum
247318
WHERE c.table_name = %s
248-
AND c.table_schema = 'public'
319+
AND c.table_schema = %s
249320
ORDER BY c.ordinal_position;
250-
""", (table_name, table_name, table_name))
321+
""", (table_name, schema, table_name, schema, table_name, schema))
251322

252323
columns = cursor.fetchall()
253324
columns_info = {}
@@ -289,13 +360,14 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
289360
return columns_info
290361

291362
@staticmethod
292-
def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]:
363+
def extract_foreign_keys(cursor: Any, table_name: str, schema: str = 'public') -> List[Dict[str, str]]:
293364
"""
294365
Extract foreign key information for a specific table.
295366
296367
Args:
297368
cursor: Database cursor
298369
table_name: Name of the table
370+
schema: Database schema (default: 'public')
299371
300372
Returns:
301373
List of foreign key dictionaries
@@ -315,8 +387,8 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]:
315387
AND ccu.table_schema = tc.table_schema
316388
WHERE tc.constraint_type = 'FOREIGN KEY'
317389
AND tc.table_name = %s
318-
AND tc.table_schema = 'public';
319-
""", (table_name,))
390+
AND tc.table_schema = %s;
391+
""", (table_name, schema))
320392

321393
foreign_keys = []
322394
for constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall():
@@ -330,12 +402,13 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]:
330402
return foreign_keys
331403

332404
@staticmethod
333-
def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]:
405+
def extract_relationships(cursor: Any, schema: str = 'public') -> Dict[str, List[Dict[str, str]]]:
334406
"""
335407
Extract all relationship information from the database.
336408
337409
Args:
338410
cursor: Database cursor
411+
schema: Database schema (default: 'public')
339412
340413
Returns:
341414
Dict containing relationship information
@@ -355,9 +428,9 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]:
355428
ON ccu.constraint_name = tc.constraint_name
356429
AND ccu.table_schema = tc.table_schema
357430
WHERE tc.constraint_type = 'FOREIGN KEY'
358-
AND tc.table_schema = 'public'
431+
AND tc.table_schema = %s
359432
ORDER BY tc.table_name, tc.constraint_name;
360-
""")
433+
""", (schema,))
361434

362435
relationships = {}
363436
for (table_name, constraint_name, column_name,

0 commit comments

Comments
 (0)