Skip to content

Commit 1525c7f

Browse files
authored
Merge pull request #1943 from dbcli/RW/more-batch-steps-before-progress-bar
Spinners for setup steps with `--progress` `--batch`
2 parents fddea11 + 7336214 commit 1525c7f

4 files changed

Lines changed: 174 additions & 6 deletions

File tree

changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ Upcoming (TBD)
22
==============
33

44
Features
5-
---------
5+
--------
66
* Silently accept forward slash to introduce special commands.
7+
* `--progress` spinners for setup steps in `--batch` mode.
78

89

910
Internal

mycli/main_modes/batch.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from prompt_toolkit.shortcuts import ProgressBar
1111
from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters
1212
import pymysql
13+
from yaspin import yaspin
1314

1415
from mycli.packages.batch_utils import statements_from_filehandle
1516
from mycli.packages.interactive_utils import confirm_destructive_query
@@ -28,6 +29,7 @@ def replay_checkpoint_file(
2829
batch_path: str,
2930
checkpoint_path: str | None,
3031
resume: bool,
32+
progress: bool = False,
3133
) -> int:
3234
if not resume:
3335
return 0
@@ -42,27 +44,46 @@ def replay_checkpoint_file(
4244
raise CheckpointReplayError('--resume is incompatible with reading from the standard input.')
4345

4446
completed_count = 0
47+
if progress:
48+
spinner = yaspin(text='replaying checkpoint', side='right', stream=sys.stderr)
49+
spinner.start()
4550
try:
4651
with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h:
4752
try:
4853
batch_gen = statements_from_filehandle(batch_h)
4954
except ValueError as e:
55+
if progress:
56+
spinner.fail('✘')
5057
raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None
5158
for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h):
5259
try:
5360
batch_statement, _batch_counter = next(batch_gen)
5461
except StopIteration:
62+
if progress:
63+
spinner.fail('✘')
5564
raise CheckpointReplayError('Checkpoint script longer than batch script.') from None
5665
except ValueError as e:
66+
if progress:
67+
spinner.fail('✘')
5768
raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None
5869
if checkpoint_statement != batch_statement:
70+
if progress:
71+
spinner.fail('✘')
5972
raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.')
6073
completed_count += 1
74+
if progress:
75+
spinner.ok('✔')
6176
except ValueError as e:
77+
if progress:
78+
spinner.fail('✘')
6279
raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None
6380
except FileNotFoundError as e:
81+
if progress:
82+
spinner.fail('✘')
6483
raise CheckpointReplayError(f'FileNotFoundError: {e}') from None
6584
except OSError as e:
85+
if progress:
86+
spinner.fail('✘')
6687
raise CheckpointReplayError(f'OSError: {e}') from None
6788

6889
return completed_count
@@ -119,10 +140,12 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
119140
click.secho('--progress is only compatible with a plain file.', err=True, fg='red')
120141
return 1
121142
try:
122-
completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume)
143+
completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume, progress=True)
123144
batch_count_h = click.open_file(cli_args.batch)
124-
for _statement, _counter in statements_from_filehandle(batch_count_h):
125-
goal_statements += 1
145+
with yaspin(text='validating batch ', side='right', stream=sys.stderr) as spinner:
146+
for _statement, _counter in statements_from_filehandle(batch_count_h):
147+
goal_statements += 1
148+
spinner.ok('✔')
126149
batch_count_h.close()
127150
batch_h = click.open_file(cli_args.batch)
128151
batch_gen = statements_from_filehandle(batch_h)
@@ -140,7 +163,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
140163
if goal_statements:
141164
pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'})
142165
custom_formatters = [
143-
progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '),
166+
progress_bar_formatters.Bar(start='running queries [', end=']', sym_a=' ', sym_b=' ', sym_c=' '),
144167
progress_bar_formatters.Text(' '),
145168
progress_bar_formatters.Progress(),
146169
progress_bar_formatters.Text(' '),

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"pyfzf ~= 0.3.1",
2626
"rapidfuzz ~= 3.14.3",
2727
"keyring ~= 25.7.0",
28+
"yaspin ~= 3.4.0",
2829
]
2930

3031
[project.urls]

test/pytests/test_main_modes_batch.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ def close(self) -> None:
6868
self.closed = True
6969

7070

71+
class DummyStream:
72+
def __init__(self, tty: bool = False) -> None:
73+
self.closed = False
74+
self.tty = tty
75+
self.writes: list[str] = []
76+
77+
def isatty(self) -> bool:
78+
return self.tty
79+
80+
def write(self, value: str) -> int:
81+
self.writes.append(value)
82+
return len(value)
83+
84+
def flush(self) -> None:
85+
return None
86+
87+
7188
class DummyProgressBar:
7289
calls: list[list[int]] = []
7390

@@ -86,6 +103,32 @@ def __call__(self, iterable) -> list[int]:
86103
return values
87104

88105

106+
class DummySpinner:
107+
instances: list['DummySpinner'] = []
108+
109+
def __init__(self, *args, **kwargs) -> None:
110+
self.fail_calls: list[str] = []
111+
self.ok_calls: list[str] = []
112+
self.started = False
113+
DummySpinner.instances.append(self)
114+
115+
def __enter__(self) -> 'DummySpinner':
116+
self.start()
117+
return self
118+
119+
def __exit__(self, exc_type, exc, tb) -> Literal[False]:
120+
return False
121+
122+
def start(self) -> None:
123+
self.started = True
124+
125+
def fail(self, text: str) -> None:
126+
self.fail_calls.append(text)
127+
128+
def ok(self, text: str) -> None:
129+
self.ok_calls.append(text)
130+
131+
89132
def dispatch_batch_statements(
90133
mycli: DummyMyCli,
91134
cli_args: DummyCliArgs,
@@ -108,7 +151,7 @@ def main_batch_from_stdin(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int:
108151

109152

110153
def make_fake_sys(stdin_tty: bool, stderr_tty: bool | None = None) -> SimpleNamespace:
111-
stderr = SimpleNamespace(isatty=lambda: stderr_tty) if stderr_tty is not None else object()
154+
stderr = DummyStream(bool(stderr_tty))
112155
return SimpleNamespace(
113156
stdin=SimpleNamespace(isatty=lambda: stdin_tty),
114157
stderr=stderr,
@@ -186,6 +229,21 @@ def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: P
186229
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
187230

188231

232+
def test_replay_checkpoint_file_marks_progress_failed_when_checkpoint_is_longer(
233+
monkeypatch,
234+
tmp_path: Path,
235+
) -> None:
236+
batch_path = write_batch_file(tmp_path, 'select 1;\n')
237+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')
238+
DummySpinner.instances.clear()
239+
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)
240+
241+
with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'):
242+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)
243+
244+
assert DummySpinner.instances[0].fail_calls == ['✘']
245+
246+
189247
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
190248
def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None:
191249
batch_path = write_batch_file(tmp_path, 'select 1;\n')
@@ -198,6 +256,20 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path:
198256
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
199257

200258

259+
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
260+
def test_replay_checkpoint_file_marks_progress_failed_for_batch_read_error(monkeypatch, tmp_path: Path) -> None:
261+
batch_path = write_batch_file(tmp_path, 'select 1;\n')
262+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
263+
DummySpinner.instances.clear()
264+
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch')))
265+
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)
266+
267+
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'):
268+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)
269+
270+
assert DummySpinner.instances[0].fail_calls == ['✘']
271+
272+
201273
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
202274
def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None:
203275
batch_path = write_batch_file(tmp_path, 'select 1;\n')
@@ -219,6 +291,31 @@ def fake_statements_from_filehandle(handle):
219291
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
220292

221293

294+
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
295+
def test_replay_checkpoint_file_marks_progress_failed_for_batch_iteration_error(monkeypatch, tmp_path: Path) -> None:
296+
batch_path = write_batch_file(tmp_path, 'select 1;\n')
297+
298+
def raise_on_next():
299+
raise ValueError('bad batch iterator')
300+
yield
301+
302+
def fake_statements_from_filehandle(handle):
303+
if handle.name == batch_path:
304+
return raise_on_next()
305+
return iter([('select 1;', 0)])
306+
307+
DummySpinner.instances.clear()
308+
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)
309+
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)
310+
311+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
312+
313+
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'):
314+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)
315+
316+
assert DummySpinner.instances[0].fail_calls == ['✘']
317+
318+
222319
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
223320
def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None:
224321
batch_path = write_batch_file(tmp_path, 'select 1;\n')
@@ -236,6 +333,27 @@ def fake_statements_from_filehandle(handle):
236333
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
237334

238335

336+
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
337+
def test_replay_checkpoint_file_marks_progress_failed_for_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None:
338+
batch_path = write_batch_file(tmp_path, 'select 1;\n')
339+
340+
def fake_statements_from_filehandle(handle):
341+
if handle.name == batch_path:
342+
return iter([('select 1;', 0)])
343+
return (_ for _ in ()).throw(ValueError('bad checkpoint'))
344+
345+
DummySpinner.instances.clear()
346+
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)
347+
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)
348+
349+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
350+
351+
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'):
352+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)
353+
354+
assert DummySpinner.instances[0].fail_calls == ['✘']
355+
356+
239357
def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None:
240358
batch_path = str(tmp_path / 'missing.sql')
241359

@@ -245,6 +363,18 @@ def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None:
245363
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
246364

247365

366+
def test_replay_checkpoint_file_marks_progress_failed_for_missing_files(monkeypatch, tmp_path: Path) -> None:
367+
batch_path = str(tmp_path / 'missing.sql')
368+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
369+
DummySpinner.instances.clear()
370+
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)
371+
372+
with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'):
373+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)
374+
375+
assert DummySpinner.instances[0].fail_calls == ['✘']
376+
377+
248378
def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None:
249379
batch_path = write_batch_file(tmp_path, 'select 1;\n')
250380

@@ -256,6 +386,19 @@ def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path)
256386
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
257387

258388

389+
def test_replay_checkpoint_file_marks_progress_failed_for_open_errors(monkeypatch, tmp_path: Path) -> None:
390+
batch_path = write_batch_file(tmp_path, 'select 1;\n')
391+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
392+
DummySpinner.instances.clear()
393+
monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed')))
394+
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)
395+
396+
with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'):
397+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)
398+
399+
assert DummySpinner.instances[0].fail_calls == ['✘']
400+
401+
259402
@pytest.mark.parametrize(
260403
('format_name', 'batch_counter', 'expected'),
261404
(

0 commit comments

Comments
 (0)