|
5 | 5 | import csv |
6 | 6 | import json |
7 | 7 | import sys |
| 8 | +from typing import TYPE_CHECKING, Callable |
8 | 9 |
|
9 | | -from .adapters import create_ssh_tunnel, get_adapter |
10 | 10 | from .config import ( |
11 | 11 | AUTH_TYPE_LABELS, |
12 | 12 | AuthType, |
|
16 | 16 | load_connections, |
17 | 17 | save_connections, |
18 | 18 | ) |
| 19 | +from .services import ConnectionSession, QueryResult, QueryService |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + from .services import HistoryStoreProtocol |
19 | 23 |
|
20 | 24 |
|
21 | 25 | def cmd_connection_list(args) -> int: |
@@ -219,8 +223,98 @@ def cmd_connection_delete(args) -> int: |
219 | 223 | return 0 |
220 | 224 |
|
221 | 225 |
|
222 | | -def cmd_query(args) -> int: |
223 | | - """Execute a SQL query against a connection.""" |
| 226 | +def _stream_csv_output(cursor, columns: list[str]) -> int: |
| 227 | + """Stream CSV output from cursor using fetchmany.""" |
| 228 | + writer = csv.writer(sys.stdout) |
| 229 | + writer.writerow(columns) |
| 230 | + row_count = 0 |
| 231 | + batch_size = 1000 |
| 232 | + while True: |
| 233 | + rows = cursor.fetchmany(batch_size) |
| 234 | + if not rows: |
| 235 | + break |
| 236 | + for row in rows: |
| 237 | + writer.writerow(str(val) if val is not None else "" for val in row) |
| 238 | + row_count += 1 |
| 239 | + return row_count |
| 240 | + |
| 241 | + |
| 242 | +def _stream_json_output(cursor, columns: list[str]) -> int: |
| 243 | + """Stream JSON output from cursor using fetchmany (JSON array format).""" |
| 244 | + print("[") |
| 245 | + first = True |
| 246 | + row_count = 0 |
| 247 | + batch_size = 1000 |
| 248 | + while True: |
| 249 | + rows = cursor.fetchmany(batch_size) |
| 250 | + if not rows: |
| 251 | + break |
| 252 | + for row in rows: |
| 253 | + if not first: |
| 254 | + print(",") |
| 255 | + first = False |
| 256 | + obj = dict(zip(columns, [val if val is not None else None for val in row])) |
| 257 | + print(json.dumps(obj, default=str), end="") |
| 258 | + row_count += 1 |
| 259 | + print("\n]") |
| 260 | + return row_count |
| 261 | + |
| 262 | + |
| 263 | +def _output_table(columns: list[str], rows: list[tuple], truncated: bool) -> None: |
| 264 | + """Output query results in table format with optimized width calculation.""" |
| 265 | + MAX_COL_WIDTH = 50 # Cap column width to avoid excessive line length |
| 266 | + |
| 267 | + # Calculate column widths (only scan first 100 rows for performance) |
| 268 | + col_widths = [min(len(col), MAX_COL_WIDTH) for col in columns] |
| 269 | + for row in rows[:100]: |
| 270 | + for i, val in enumerate(row): |
| 271 | + val_str = str(val) if val is not None else "NULL" |
| 272 | + col_widths[i] = min(MAX_COL_WIDTH, max(col_widths[i], len(val_str))) |
| 273 | + |
| 274 | + # Print header |
| 275 | + header_parts = [] |
| 276 | + for i, col in enumerate(columns): |
| 277 | + col_display = col[:col_widths[i]] if len(col) > col_widths[i] else col |
| 278 | + header_parts.append(col_display.ljust(col_widths[i])) |
| 279 | + header = " | ".join(header_parts) |
| 280 | + print(header) |
| 281 | + print("-" * len(header)) |
| 282 | + |
| 283 | + # Print rows |
| 284 | + for row in rows: |
| 285 | + row_parts = [] |
| 286 | + for i, val in enumerate(row): |
| 287 | + val_str = str(val) if val is not None else "NULL" |
| 288 | + if len(val_str) > col_widths[i]: |
| 289 | + val_str = val_str[: col_widths[i] - 2] + ".." |
| 290 | + row_parts.append(val_str.ljust(col_widths[i])) |
| 291 | + print(" | ".join(row_parts)) |
| 292 | + |
| 293 | + # Print count with truncation notice |
| 294 | + if truncated: |
| 295 | + print(f"\n({len(rows)} rows shown, results truncated)") |
| 296 | + else: |
| 297 | + print(f"\n({len(rows)} row(s) returned)") |
| 298 | + |
| 299 | + |
| 300 | +def cmd_query( |
| 301 | + args, |
| 302 | + *, |
| 303 | + session_factory: Callable[[ConnectionConfig], ConnectionSession] | None = None, |
| 304 | + query_service: QueryService | None = None, |
| 305 | +) -> int: |
| 306 | + """Execute a SQL query against a connection. |
| 307 | +
|
| 308 | + Args: |
| 309 | + args: Parsed command-line arguments. |
| 310 | + session_factory: Optional factory for creating ConnectionSession. |
| 311 | + Defaults to ConnectionSession.create. Useful for testing. |
| 312 | + query_service: Optional QueryService instance. |
| 313 | + Defaults to a new QueryService(). Useful for testing. |
| 314 | +
|
| 315 | + Returns: |
| 316 | + Exit code (0 for success, 1 for error). |
| 317 | + """ |
224 | 318 | connections = load_connections() |
225 | 319 |
|
226 | 320 | config = None |
@@ -253,86 +347,84 @@ def cmd_query(args) -> int: |
253 | 347 | print("Error: Either --query or --file must be provided.") |
254 | 348 | return 1 |
255 | 349 |
|
256 | | - tunnel = None |
| 350 | + # Determine row limit (0 means unlimited) |
| 351 | + max_rows = args.limit if args.limit > 0 else None |
| 352 | + |
| 353 | + # Use injected or default factories |
| 354 | + create_session = session_factory or ConnectionSession.create |
| 355 | + service = query_service or QueryService() |
| 356 | + |
257 | 357 | try: |
258 | | - from dataclasses import replace |
259 | | - |
260 | | - # Create SSH tunnel if enabled |
261 | | - tunnel, host, port = create_ssh_tunnel(config) |
262 | | - if tunnel: |
263 | | - connect_config = replace(config, server=host, port=str(port)) |
264 | | - else: |
265 | | - connect_config = config |
266 | | - |
267 | | - adapter = get_adapter(config.db_type) |
268 | | - db_conn = adapter.connect(connect_config) |
269 | | - |
270 | | - # Detect query type to avoid executing non-SELECT statements twice |
271 | | - query_type = query.strip().upper().split()[0] if query.strip() else "" |
272 | | - is_select_query = query_type in ("SELECT", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "PRAGMA") |
273 | | - |
274 | | - if is_select_query: |
275 | | - columns, rows, _truncated = adapter.execute_query(db_conn, query) |
276 | | - else: |
277 | | - columns, rows = [], [] |
278 | | - |
279 | | - if columns: |
280 | | - if args.format == "csv": |
281 | | - # Use proper CSV writer for correct quoting/escaping |
282 | | - writer = csv.writer(sys.stdout) |
283 | | - writer.writerow(columns) |
284 | | - for row in rows: |
285 | | - writer.writerow(str(val) if val is not None else "" for val in row) |
286 | | - elif args.format == "json": |
287 | | - result = [] |
288 | | - for row in rows: |
289 | | - result.append( |
290 | | - dict( |
291 | | - zip( |
292 | | - columns, |
293 | | - [val if val is not None else None for val in row], |
294 | | - ) |
295 | | - ) |
296 | | - ) |
297 | | - print(json.dumps(result, indent=2, default=str)) |
| 358 | + # Use ConnectionSession for automatic resource cleanup |
| 359 | + with create_session(config) as session: |
| 360 | + # For unlimited streaming output (CSV/JSON only), use direct cursor access |
| 361 | + from .services.query import is_select_query |
| 362 | + |
| 363 | + if max_rows is None and args.format in ("csv", "json") and is_select_query(query): |
| 364 | + # Stream directly from cursor for unlimited CSV/JSON |
| 365 | + cursor = session.connection.cursor() |
| 366 | + cursor.execute(query) |
| 367 | + |
| 368 | + if not cursor.description: |
| 369 | + print("Query executed successfully (no results)") |
| 370 | + return 0 |
| 371 | + |
| 372 | + columns = [col[0] for col in cursor.description] |
| 373 | + |
| 374 | + if args.format == "csv": |
| 375 | + row_count = _stream_csv_output(cursor, columns) |
| 376 | + else: |
| 377 | + row_count = _stream_json_output(cursor, columns) |
| 378 | + |
| 379 | + # Save to history |
| 380 | + service._save_to_history(config.name, query) |
| 381 | + print(f"\n({row_count} row(s) returned)", file=sys.stderr) |
| 382 | + return 0 |
| 383 | + |
| 384 | + # Standard execution with QueryService (with row limit) |
| 385 | + result = service.execute( |
| 386 | + connection=session.connection, |
| 387 | + adapter=session.adapter, |
| 388 | + query=query, |
| 389 | + config=config, |
| 390 | + max_rows=max_rows, |
| 391 | + save_to_history=True, |
| 392 | + ) |
| 393 | + |
| 394 | + if isinstance(result, QueryResult): |
| 395 | + columns = result.columns |
| 396 | + rows = result.rows |
| 397 | + |
| 398 | + if args.format == "csv": |
| 399 | + writer = csv.writer(sys.stdout) |
| 400 | + writer.writerow(columns) |
| 401 | + for row in rows: |
| 402 | + writer.writerow(str(val) if val is not None else "" for val in row) |
| 403 | + if result.truncated: |
| 404 | + print(f"\n({len(rows)} rows shown, results truncated)", file=sys.stderr) |
| 405 | + else: |
| 406 | + print(f"\n({len(rows)} row(s) returned)", file=sys.stderr) |
| 407 | + elif args.format == "json": |
| 408 | + json_result = [ |
| 409 | + dict(zip(columns, [val if val is not None else None for val in row])) |
| 410 | + for row in rows |
| 411 | + ] |
| 412 | + print(json.dumps(json_result, indent=2, default=str)) |
| 413 | + if result.truncated: |
| 414 | + print(f"\n({len(rows)} rows shown, results truncated)", file=sys.stderr) |
| 415 | + else: |
| 416 | + print(f"\n({len(rows)} row(s) returned)", file=sys.stderr) |
| 417 | + else: |
| 418 | + _output_table(columns, rows, result.truncated) |
298 | 419 | else: |
299 | | - col_widths = [len(col) for col in columns] |
300 | | - for row in rows: |
301 | | - for i, val in enumerate(row): |
302 | | - col_widths[i] = max( |
303 | | - col_widths[i], len(str(val) if val is not None else "NULL") |
304 | | - ) |
305 | | - |
306 | | - header = " | ".join( |
307 | | - col.ljust(col_widths[i]) for i, col in enumerate(columns) |
308 | | - ) |
309 | | - print(header) |
310 | | - print("-" * len(header)) |
311 | | - |
312 | | - for row in rows: |
313 | | - row_str = " | ".join( |
314 | | - (str(val) if val is not None else "NULL").ljust(col_widths[i]) |
315 | | - for i, val in enumerate(row) |
316 | | - ) |
317 | | - print(row_str) |
318 | | - |
319 | | - print(f"\n({len(rows)} row(s) returned)") |
320 | | - else: |
321 | | - affected = adapter.execute_non_query(db_conn, query) |
322 | | - print(f"Query executed successfully. Rows affected: {affected}") |
323 | | - |
324 | | - db_conn.close() |
325 | | - if tunnel: |
326 | | - tunnel.stop() |
327 | | - return 0 |
| 420 | + # NonQueryResult |
| 421 | + print(f"Query executed successfully. Rows affected: {result.rows_affected}") |
| 422 | + |
| 423 | + return 0 |
328 | 424 |
|
329 | 425 | except ImportError as e: |
330 | 426 | print(f"Error: Required module not installed: {e}") |
331 | | - if tunnel: |
332 | | - tunnel.stop() |
333 | 427 | return 1 |
334 | 428 | except Exception as e: |
335 | 429 | print(f"Error: {e}") |
336 | | - if tunnel: |
337 | | - tunnel.stop() |
338 | 430 | return 1 |
0 commit comments