Skip to content

Commit b55d091

Browse files
committed
ensure distinct --batch and --checkpoint files
to avoid --checkpoint logging SQL back to the source --batch. The --checkpoint argument is converted to a string, and the opened checkpoint is cached to avoid repeatedly seeking to the end of a large file. The change to a string is visible in Click's generated helpdoc. Incidentally improve the helpdoc for --resume, which looks like it may have been mangled by an agent during refactoring.
1 parent 8b0df0d commit b55d091

9 files changed

Lines changed: 112 additions & 65 deletions

File tree

changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Upcoming (TBD)
22
==============
33

4+
Bug Fixes
5+
---------
6+
* Ensure that `--batch` and `--checkpoint` files are distinct.
7+
8+
49
Internal
510
--------
611
* Improve test coverage for `completion_refresher.py`.

mycli/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
self._keepalive_counter = 0
9191
self.keepalive_ticks: int | None = 0
9292
self.sandbox_mode: bool = False
93+
self.checkpoint: IO | None = None
9394

9495
# self.cnf_files is a class variable that stores the list of mysql
9596
# config files to read in at launch.

mycli/client_query.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
from io import TextIOWrapper
4-
from typing import TYPE_CHECKING, Any
3+
from typing import IO, TYPE_CHECKING, Any
54

65
import click
76
from pymysql.cursors import Cursor
@@ -26,6 +25,7 @@ class ClientQueryMixin:
2625
numeric_alignment: str | None
2726
binary_display: str | None
2827
query_history: list[Any]
28+
checkpoint: IO | None
2929

3030
def log_query(self, query: str) -> None: ...
3131
def log_output(self, output: str) -> None: ...
@@ -74,12 +74,14 @@ def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None:
7474
def run_query(
7575
self,
7676
query: str,
77-
checkpoint: TextIOWrapper | None = None,
77+
checkpoint: str | None = None,
7878
new_line: bool = True,
7979
) -> None:
8080
"""Runs *query*."""
8181
assert self.sqlexecute is not None
8282
self.log_query(query)
83+
if checkpoint and not self.checkpoint:
84+
self.checkpoint = click.open_file(checkpoint, mode='a')
8385
results = self.sqlexecute.run(query)
8486
for result in results:
8587
self.main_formatter.query = query
@@ -111,9 +113,9 @@ def run_query(
111113
)
112114
for line in output:
113115
click.echo(line, nl=new_line)
114-
if checkpoint:
115-
checkpoint.write(query.rstrip('\n') + '\n')
116-
checkpoint.flush()
116+
if self.checkpoint:
117+
self.checkpoint.write(query.rstrip('\n') + '\n')
118+
self.checkpoint.flush()
117119

118120
def get_last_query(self) -> str | None:
119121
"""Get the last query executed or None."""

mycli/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ class CliArgs:
206206
type=click.File(mode='a', encoding='utf-8'),
207207
help='Log every query and its results to a file.',
208208
)
209-
checkpoint: TextIOWrapper | None = clickdc.option(
210-
type=click.File(mode='a', encoding='utf-8'),
211-
help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.',
209+
checkpoint: str | None = clickdc.option(
210+
type=str,
211+
help='In batch or --execute mode, log successful queries to a file, and skip them with --resume.',
212212
)
213213
resume: bool = clickdc.option(
214214
'--resume',
@@ -369,6 +369,17 @@ def preprocess_cli_args(
369369
click.secho('Error: --resume requires a --batch file.', err=True, fg='red')
370370
sys.exit(1)
371371

372+
if (
373+
cli_args.checkpoint
374+
and os.path.exists(cli_args.checkpoint)
375+
and cli_args.batch
376+
and cli_args.batch != '-'
377+
and os.path.exists(cli_args.batch)
378+
):
379+
if os.stat(cli_args.batch) == os.stat(cli_args.checkpoint):
380+
click.secho('Error: --batch and --checkpoint must be different files.', err=True, fg='red')
381+
sys.exit(1)
382+
372383
if cli_args.verbose and cli_args.quiet:
373384
click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red')
374385
sys.exit(1)

mycli/main_modes/batch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from io import TextIOWrapper
43
import os
54
import sys
65
import time
@@ -27,23 +26,24 @@ class CheckpointReplayError(Exception):
2726

2827
def replay_checkpoint_file(
2928
batch_path: str,
30-
checkpoint: TextIOWrapper | None,
29+
checkpoint_path: str | None,
3130
resume: bool,
3231
) -> int:
3332
if not resume:
3433
return 0
3534

36-
if checkpoint is None:
35+
if checkpoint_path is None:
36+
return 0
37+
38+
if not os.path.exists(checkpoint_path):
3739
return 0
3840

3941
if batch_path == '-':
4042
raise CheckpointReplayError('--resume is incompatible with reading from the standard input.')
4143

42-
checkpoint_name = checkpoint.name
43-
checkpoint.flush()
4444
completed_count = 0
4545
try:
46-
with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_name, mode='r', encoding='utf-8') as checkpoint_h:
46+
with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h:
4747
try:
4848
batch_gen = statements_from_filehandle(batch_h)
4949
except ValueError as e:
@@ -59,7 +59,7 @@ def replay_checkpoint_file(
5959
raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.')
6060
completed_count += 1
6161
except ValueError as e:
62-
raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint.name}: {e}') from None
62+
raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None
6363
except FileNotFoundError as e:
6464
raise CheckpointReplayError(f'FileNotFoundError: {e}') from None
6565
except OSError as e:
@@ -133,7 +133,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
133133
click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red')
134134
return 1
135135
except CheckpointReplayError as e:
136-
name = cli_args.checkpoint.name if cli_args.checkpoint else 'None'
136+
name = cli_args.checkpoint if cli_args.checkpoint else 'None'
137137
click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red')
138138
return 1
139139
try:
@@ -175,7 +175,7 @@ def main_batch_without_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
175175
click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red')
176176
return 1
177177
except CheckpointReplayError as e:
178-
name = cli_args.checkpoint.name if cli_args.checkpoint else 'None'
178+
name = cli_args.checkpoint if cli_args.checkpoint else 'None'
179179
click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red')
180180
return 1
181181
try:

test/pytests/test_client_query.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,7 @@ def format_sqlresult(result: SQLResult, **kwargs: Any) -> list[str]:
202202
cli.log_query = lambda query: state['logged_queries'].append(query)
203203
cli.log_output = lambda line: state['logged_output'].append(line)
204204
cli.format_sqlresult = format_sqlresult
205-
checkpoint = state['checkpoint_path'].open('w+', encoding='utf-8')
206-
try:
207-
main.MyCli.run_query(cli, 'select 1;\n', checkpoint=checkpoint, new_line=False)
208-
finally:
209-
checkpoint.close()
205+
main.MyCli.run_query(cli, 'select 1;\n', checkpoint=str(state['checkpoint_path']), new_line=False)
210206
state['cli'] = cli
211207
return state
212208

test/pytests/test_main.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,23 @@ def test_preprocess_cli_args_validates_resume_requirements(
27112711
assert expected in capsys.readouterr().err
27122712

27132713

2714+
def test_preprocess_cli_args_rejects_same_batch_and_checkpoint_file(
2715+
capsys: pytest.CaptureFixture[str],
2716+
tmp_path: Path,
2717+
) -> None:
2718+
batch_path = tmp_path / 'batch.sql'
2719+
batch_path.write_text('select 1;\n', encoding='utf-8')
2720+
cli_args = CliArgs()
2721+
cli_args.batch = str(batch_path)
2722+
cli_args.checkpoint = str(batch_path)
2723+
2724+
with pytest.raises(SystemExit) as excinfo:
2725+
preprocess_cli_args(cli_args, valid_connection_scheme)
2726+
2727+
assert excinfo.value.code == 1
2728+
assert 'Error: --batch and --checkpoint must be different files.' in capsys.readouterr().err
2729+
2730+
27142731
def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None:
27152732
cli_args = CliArgs()
27162733
cli_args.verbose = 1

test/pytests/test_main_modes_batch.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -153,28 +153,37 @@ def write_batch_file(tmp_path: Path, contents: str) -> str:
153153
return str(batch_path)
154154

155155

156-
def open_checkpoint_file(tmp_path: Path, contents: str) -> TextIOWrapper:
156+
def write_checkpoint_file(tmp_path: Path, contents: str) -> str:
157157
checkpoint_path = tmp_path / 'checkpoint.sql'
158158
checkpoint_path.write_text(contents, encoding='utf-8')
159-
return checkpoint_path.open('a', encoding='utf-8')
159+
return str(checkpoint_path)
160160

161161

162162
def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: Path) -> None:
163163
batch_path = write_batch_file(tmp_path, 'select 1;\n')
164164

165165
assert batch_mode.replay_checkpoint_file(batch_path, None, resume=True) == 0
166166

167-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
168-
with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'):
169-
batch_mode.replay_checkpoint_file('-', checkpoint, resume=True)
167+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
168+
169+
with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'):
170+
batch_mode.replay_checkpoint_file('-', checkpoint, resume=True)
171+
172+
173+
def test_replay_checkpoint_file_returns_zero_when_checkpoint_is_missing(tmp_path: Path) -> None:
174+
batch_path = write_batch_file(tmp_path, 'select 1;\n')
175+
checkpoint_path = str(tmp_path / 'missing-checkpoint.sql')
176+
177+
assert batch_mode.replay_checkpoint_file(batch_path, checkpoint_path, resume=True) == 0
170178

171179

172180
def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: Path) -> None:
173181
batch_path = write_batch_file(tmp_path, 'select 1;\n')
174182

175-
with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint:
176-
with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'):
177-
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
183+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')
184+
185+
with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'):
186+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
178187

179188

180189
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
@@ -183,9 +192,10 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path:
183192

184193
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch')))
185194

186-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
187-
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'):
188-
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
195+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
196+
197+
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'):
198+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
189199

190200

191201
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
@@ -203,9 +213,10 @@ def fake_statements_from_filehandle(handle):
203213

204214
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)
205215

206-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
207-
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'):
208-
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
216+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
217+
218+
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'):
219+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
209220

210221

211222
@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
@@ -219,27 +230,30 @@ def fake_statements_from_filehandle(handle):
219230

220231
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)
221232

222-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
223-
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint.name}: bad checkpoint'):
224-
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
233+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
234+
235+
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'):
236+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
225237

226238

227239
def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None:
228240
batch_path = str(tmp_path / 'missing.sql')
229241

230-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
231-
with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'):
232-
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
242+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
243+
244+
with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'):
245+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
233246

234247

235248
def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None:
236249
batch_path = write_batch_file(tmp_path, 'select 1;\n')
237250

238251
monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed')))
239252

240-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
241-
with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'):
242-
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
253+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
254+
255+
with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'):
256+
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
243257

244258

245259
@pytest.mark.parametrize(
@@ -514,10 +528,10 @@ def test_main_batch_without_progress_bar_skips_checkpoint_prefix(monkeypatch, tm
514528
)
515529
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))
516530

517-
with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint:
518-
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
531+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')
519532

520-
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
533+
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
534+
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
521535

522536
assert result == 0
523537
assert dispatch_calls == [('select 3;', 2)]
@@ -534,10 +548,10 @@ def test_main_batch_without_progress_bar_skips_only_matching_duplicate_prefix(mo
534548
)
535549
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))
536550

537-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
538-
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
551+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
539552

540-
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
553+
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
554+
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
541555

542556
assert result == 0
543557
assert dispatch_calls == [('select 1;', 1), ('select 2;', 2)]
@@ -554,10 +568,10 @@ def test_main_batch_without_progress_bar_fails_on_mismatched_checkpoint(monkeypa
554568
)
555569
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))
556570

557-
with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint:
558-
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
571+
checkpoint = write_checkpoint_file(tmp_path, 'select 9;\n')
559572

560-
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
573+
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
574+
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
561575

562576
assert result == 1
563577
assert dispatch_calls == []
@@ -574,10 +588,10 @@ def test_main_batch_without_progress_bar_succeeds_when_checkpoint_skips_all(monk
574588
)
575589
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))
576590

577-
with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint:
578-
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
591+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')
579592

580-
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
593+
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
594+
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
581595

582596
assert result == 0
583597
assert dispatch_calls == []
@@ -597,10 +611,10 @@ def test_main_batch_with_progress_bar_skips_checkpoint_prefix_and_counts_all_sta
597611
)
598612
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))
599613

600-
with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
601-
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
614+
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
602615

603-
result = main_batch_with_progress_bar(DummyMyCli(), cli_args)
616+
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
617+
result = main_batch_with_progress_bar(DummyMyCli(), cli_args)
604618

605619
assert result == 0
606620
assert dispatch_calls == [('select 2;', 1), ('select 3;', 2)]
@@ -614,13 +628,13 @@ def test_main_batch_with_progress_bar_returns_error_when_checkpoint_replay_fails
614628
monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg)))
615629
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))
616630

617-
with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint:
618-
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
631+
checkpoint = write_checkpoint_file(tmp_path, 'select 9;\n')
619632

620-
result = main_batch_with_progress_bar(DummyMyCli(), cli_args)
633+
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
634+
result = main_batch_with_progress_bar(DummyMyCli(), cli_args)
621635

622636
assert result == 1
623-
assert messages == [(f'Error replaying --checkpoint file: {checkpoint.name}: Statement mismatch: select 9;.', True, 'red')]
637+
assert messages == [(f'Error replaying --checkpoint file: {checkpoint}: Statement mismatch: select 9;.', True, 'red')]
624638

625639

626640
def test_main_batch_without_progress_bar_returns_error_when_iteration_fails(monkeypatch) -> None:

0 commit comments

Comments
 (0)