Skip to content

Commit 5d90ae0

Browse files
authored
Merge pull request #1926 from dbcli/RW/add-test-coverage-special-init-py
Add test coverage for `packages/special/__init__.py`
2 parents 3c42720 + 14f6812 commit 5d90ae0

2 files changed

Lines changed: 103 additions & 0 deletions

File tree

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Internal
66
* Improve test coverage for `completion_refresher.py`.
77
* Add test coverage for `client_query.py`.
88
* Improve test coverage for `output.py`.
9+
* Add test coverage for `special/__init__.py`.
910

1011

1112
1.74.0 (2026/06/06)

test/pytests/test_special_init.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable, Generator
4+
import importlib
5+
import sys
6+
from types import ModuleType
7+
8+
import pytest
9+
10+
import mycli.packages
11+
12+
13+
@pytest.fixture
14+
def load_special(monkeypatch: pytest.MonkeyPatch) -> Generator[Callable[[bool], ModuleType], None, None]:
15+
original_module = sys.modules.get('mycli.packages.special')
16+
parent_had_special = hasattr(mycli.packages, 'special')
17+
original_parent_special = getattr(mycli.packages, 'special', None)
18+
19+
def load(llm_off: bool) -> ModuleType:
20+
if llm_off:
21+
monkeypatch.setenv('MYCLI_LLM_OFF', '1')
22+
else:
23+
monkeypatch.delenv('MYCLI_LLM_OFF', raising=False)
24+
sys.modules.pop('mycli.packages.special', None)
25+
if hasattr(mycli.packages, 'special'):
26+
delattr(mycli.packages, 'special')
27+
return importlib.import_module('mycli.packages.special')
28+
29+
yield load
30+
31+
sys.modules.pop('mycli.packages.special', None)
32+
if original_module is not None:
33+
sys.modules['mycli.packages.special'] = original_module
34+
if parent_had_special:
35+
mycli.packages.special = original_parent_special # type: ignore[attr-defined]
36+
elif hasattr(mycli.packages, 'special'):
37+
delattr(mycli.packages, 'special')
38+
39+
40+
def test_special_init_exports_public_names(load_special: Callable[[bool], ModuleType]) -> None:
41+
special = load_special(False)
42+
43+
for name in special.__all__:
44+
assert hasattr(special, name)
45+
46+
47+
def test_special_init_reexports_special_command_api(load_special: Callable[[bool], ModuleType]) -> None:
48+
special = load_special(False)
49+
special_main = importlib.import_module('mycli.packages.special.main')
50+
51+
assert special.execute is special_main.execute
52+
assert special.special_command is special_main.special_command
53+
assert special.CommandNotFound is special_main.CommandNotFound
54+
55+
56+
def test_special_init_reexports_io_state_api(load_special: Callable[[bool], ModuleType]) -> None:
57+
special = load_special(False)
58+
iocommands = importlib.import_module('mycli.packages.special.iocommands')
59+
60+
assert special.set_pager_enabled is iocommands.set_pager_enabled
61+
assert special.is_pager_enabled is iocommands.is_pager_enabled
62+
assert special.write_tee is iocommands.write_tee
63+
64+
65+
def test_special_init_reexports_dbcommands(load_special: Callable[[bool], ModuleType]) -> None:
66+
special = load_special(False)
67+
dbcommands = importlib.import_module('mycli.packages.special.dbcommands')
68+
69+
assert special.list_databases is dbcommands.list_databases
70+
assert special.list_tables is dbcommands.list_tables
71+
assert special.status is dbcommands.status
72+
73+
74+
def test_special_init_uses_llm_implementation_when_enabled(load_special: Callable[[bool], ModuleType]) -> None:
75+
special = load_special(False)
76+
llm = importlib.import_module('mycli.packages.special.llm')
77+
78+
assert special.FinishIteration is llm.FinishIteration
79+
assert special.is_llm_command is llm.is_llm_command
80+
assert special.handle_llm is llm.handle_llm
81+
assert special.sql_using_llm is llm.sql_using_llm
82+
83+
84+
def test_special_init_uses_llm_stubs_when_disabled(load_special: Callable[[bool], ModuleType]) -> None:
85+
special = load_special(True)
86+
87+
assert special.is_llm_command(r'\llm prompt') is False
88+
with pytest.raises(special.FinishIteration) as handle_exc:
89+
special.handle_llm(cast_args := object())
90+
with pytest.raises(special.FinishIteration) as sql_exc:
91+
special.sql_using_llm(cast_args)
92+
93+
assert handle_exc.value.results is None
94+
assert sql_exc.value.results is None
95+
96+
97+
def test_special_init_stub_finish_iteration_stores_results(load_special: Callable[[bool], ModuleType]) -> None:
98+
special = load_special(True)
99+
100+
error = special.FinishIteration(results=['done'])
101+
102+
assert error.results == ['done']

0 commit comments

Comments
 (0)