Skip to content

Commit 71f98c5

Browse files
committed
Keep history for unsaved connections in session
1 parent 0975972 commit 71f98c5

File tree

2 files changed

+159
-13
lines changed

2 files changed

+159
-13
lines changed

sqlit/domains/query/ui/mixins/query_execution.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,30 @@ def _get_history_store(self: QueryMixinHost) -> Any:
125125
return store
126126
return self.services.history_store
127127

128+
def _get_unsaved_history_store(self: QueryMixinHost) -> Any:
129+
store = getattr(self, "_unsaved_history_store", None)
130+
if store is None:
131+
from sqlit.domains.query.store.memory import InMemoryHistoryStore
132+
133+
store = InMemoryHistoryStore()
134+
self._unsaved_history_store = store
135+
return store
136+
137+
def _should_save_query_history(self: QueryMixinHost, config: Any) -> bool:
138+
"""Return True if the connection is saved and history should be persisted."""
139+
name = getattr(config, "name", "")
140+
if not name:
141+
return False
142+
connections = getattr(self, "connections", None) or []
143+
return any(getattr(conn, "name", None) == name for conn in connections)
144+
145+
def _save_query_history(self: QueryMixinHost, config: Any, query: str) -> None:
146+
"""Save query history only for saved connections."""
147+
if self._should_save_query_history(config):
148+
self._get_history_store().save_query(config.name, query)
149+
return
150+
self._get_unsaved_history_store().save_query(config.name, query)
151+
128152
def _get_query_service(self: QueryMixinHost, provider: Any) -> QueryService:
129153
if self._query_service is None or (
130154
self._query_service_db_type is not None
@@ -262,8 +286,6 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
262286

263287
# Use TransactionExecutor for transaction-aware query execution
264288
executor = self._get_transaction_executor(config, provider)
265-
service = self._get_query_service(provider)
266-
267289
# Check if this is a multi-statement query
268290
statements = split_statements(query)
269291
is_multi_statement = len(statements) > 1
@@ -305,7 +327,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
305327
return
306328

307329
try:
308-
await asyncio.to_thread(service._save_to_history, config.name, query)
330+
await asyncio.to_thread(self._save_query_history, config, query)
309331
except Exception:
310332
pass
311333
result = outcome.result
@@ -336,7 +358,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
336358
elapsed_ms = (time.perf_counter() - start_time) * 1000
337359

338360
try:
339-
await asyncio.to_thread(service._save_to_history, config.name, query)
361+
await asyncio.to_thread(self._save_query_history, config, query)
340362
except Exception:
341363
pass
342364
self._display_multi_statement_results(multi_result, elapsed_ms)
@@ -350,7 +372,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
350372
elapsed_ms = (time.perf_counter() - start_time) * 1000
351373

352374
try:
353-
await asyncio.to_thread(service._save_to_history, config.name, query)
375+
await asyncio.to_thread(self._save_query_history, config, query)
354376
except Exception:
355377
pass
356378

@@ -401,8 +423,6 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None:
401423

402424
# Create a dedicated executor for atomic execution
403425
executor = TransactionExecutor(config=config, provider=provider)
404-
service = self._get_query_service(provider)
405-
406426
try:
407427
start_time = time.perf_counter()
408428
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS
@@ -414,7 +434,7 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None:
414434
elapsed_ms = (time.perf_counter() - start_time) * 1000
415435

416436
try:
417-
await asyncio.to_thread(service._save_to_history, config.name, query)
437+
await asyncio.to_thread(self._save_query_history, config, query)
418438
except Exception:
419439
pass
420440

@@ -551,9 +571,11 @@ def action_show_history(self: QueryMixinHost) -> None:
551571

552572
from ..screens import QueryHistoryScreen
553573

554-
history_store = self._get_history_store()
555574
starred_store = self.services.starred_store
556-
history = history_store.load_for_connection(self.current_config.name)
575+
if self._should_save_query_history(self.current_config):
576+
history = self._get_history_store().load_for_connection(self.current_config.name)
577+
else:
578+
history = self._get_unsaved_history_store().load_for_connection(self.current_config.name)
557579
starred = starred_store.load_for_connection(self.current_config.name)
558580
self.push_screen(
559581
QueryHistoryScreen(history, self.current_config.name, starred),
@@ -587,16 +609,28 @@ def _show_telescope(self: QueryMixinHost, *, auto_open_filter: bool) -> None:
587609
"""Open telescope with optional filter preset."""
588610
from ..screens import QueryHistoryScreen
589611

612+
connection_map = self._get_telescope_connection_map()
613+
available_connections = set(connection_map.keys())
614+
590615
history_store = self._get_history_store()
591616
if hasattr(history_store, "load_all"):
592617
history = history_store.load_all()
593618
else:
594619
history = []
595-
for config in self._get_telescope_connection_map().values():
620+
for config in connection_map.values():
596621
history.extend(history_store.load_for_connection(config.name))
597622
history.sort(key=lambda entry: entry.timestamp, reverse=True)
598623

599-
connection_map = self._get_telescope_connection_map()
624+
unsaved_store = getattr(self, "_unsaved_history_store", None)
625+
if unsaved_store is not None and hasattr(unsaved_store, "load_all"):
626+
history.extend(unsaved_store.load_all())
627+
628+
if available_connections:
629+
history = [
630+
entry for entry in history
631+
if getattr(entry, "connection_name", None) in available_connections
632+
]
633+
history.sort(key=lambda entry: entry.timestamp, reverse=True)
600634
connection_labels = {
601635
name: self._format_telescope_connection_label(config)
602636
for name, config in connection_map.items()

tests/ui/test_query_history.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,19 @@
44

55
import pytest
66

7+
from textual.widgets import OptionList
8+
9+
from sqlit.domains.query.store.history import QueryHistoryEntry
10+
from sqlit.domains.query.ui.screens.query_history import QueryHistoryScreen
711
from sqlit.domains.shell.app.main import SSMSTUI
812

9-
from .mocks import MockConnectionStore, MockSettingsStore, build_test_services, create_test_connection
13+
from .mocks import (
14+
MockConnectionStore,
15+
MockHistoryStore,
16+
MockSettingsStore,
17+
build_test_services,
18+
create_test_connection,
19+
)
1020

1121

1222
class TestQueryHistoryCursorMemory:
@@ -153,3 +163,105 @@ async def test_cursor_cache_handles_same_query_text(self):
153163

154164
# Cursor should be at the remembered position
155165
assert app.query_input.cursor_location == (0, 5)
166+
167+
168+
class TestQueryHistorySavePolicy:
169+
"""Tests for query history behavior across saved and unsaved connections."""
170+
171+
@pytest.mark.asyncio
172+
async def test_show_history_for_unsaved_connection_uses_session_history(self) -> None:
173+
unsaved_conn = create_test_connection("temp-db", "sqlite")
174+
history_store = MockHistoryStore()
175+
services = build_test_services(
176+
connection_store=MockConnectionStore([]),
177+
settings_store=MockSettingsStore({"theme": "tokyo-night"}),
178+
history_store=history_store,
179+
)
180+
app = SSMSTUI(services=services)
181+
182+
async with app.run_test(size=(100, 35)) as pilot:
183+
app.current_config = unsaved_conn
184+
app._save_query_history(unsaved_conn, "SELECT 1")
185+
186+
app.action_show_history()
187+
await pilot.pause(0.2)
188+
189+
screen = next(
190+
(s for s in app.screen_stack if isinstance(s, QueryHistoryScreen)),
191+
None,
192+
)
193+
assert screen is not None, "History screen should be present"
194+
195+
option_list = screen.query_one("#history-list", OptionList)
196+
assert option_list.option_count == 1
197+
198+
def test_saved_connection_queries_saved(self) -> None:
199+
saved_conn = create_test_connection("saved-db", "sqlite")
200+
history_store = MockHistoryStore()
201+
services = build_test_services(
202+
connection_store=MockConnectionStore([saved_conn]),
203+
settings_store=MockSettingsStore({"theme": "tokyo-night"}),
204+
history_store=history_store,
205+
)
206+
app = SSMSTUI(services=services)
207+
app.connections = [saved_conn]
208+
209+
app._save_query_history(saved_conn, "SELECT 1")
210+
211+
assert history_store.entries["saved-db"][0]["query"] == "SELECT 1"
212+
213+
@pytest.mark.asyncio
214+
async def test_telescope_hides_unavailable_unsaved_history(self) -> None:
215+
saved_conn = create_test_connection("saved-db", "sqlite")
216+
saved_entry = QueryHistoryEntry(
217+
query="select 1",
218+
timestamp="2026-01-01T00:00:00",
219+
connection_name="saved-db",
220+
)
221+
unsaved_entry = QueryHistoryEntry(
222+
query="select 2",
223+
timestamp="2026-01-02T00:00:00",
224+
connection_name="temp-db",
225+
)
226+
227+
class StubHistoryStore:
228+
def __init__(self, entries):
229+
self._entries = entries
230+
231+
def load_all(self):
232+
return list(self._entries)
233+
234+
def load_for_connection(self, connection_name):
235+
return [e for e in self._entries if e.connection_name == connection_name]
236+
237+
def delete_entry(self, connection_name, timestamp):
238+
_ = connection_name
239+
_ = timestamp
240+
return False
241+
242+
def save_query(self, connection_name, query):
243+
_ = connection_name
244+
_ = query
245+
246+
history_store = StubHistoryStore([saved_entry, unsaved_entry])
247+
services = build_test_services(
248+
connection_store=MockConnectionStore([saved_conn]),
249+
settings_store=MockSettingsStore({"theme": "tokyo-night"}),
250+
history_store=history_store,
251+
)
252+
app = SSMSTUI(services=services)
253+
254+
async with app.run_test(size=(100, 35)) as pilot:
255+
app.connections = [saved_conn]
256+
app.action_telescope()
257+
await pilot.pause(0.2)
258+
259+
screen = next(
260+
(s for s in app.screen_stack if isinstance(s, QueryHistoryScreen)),
261+
None,
262+
)
263+
assert screen is not None, "Telescope screen should be present"
264+
265+
option_list = screen.query_one("#history-list", OptionList)
266+
assert option_list.option_count == 1
267+
assert all(entry.connection_name == "saved-db" for entry in screen._merged_entries)

0 commit comments

Comments
 (0)