Skip to content

Commit 7c3dd30

Browse files
committed
Fix progress bar against non-UTF-8 encodings
For #439. `file_progress()` sets the progress bar length to the file size in bytes (`os.path.getsize(file.name)`), but `UpdateWrapper.__iter__` called `update(len(line))`, which is the decoded character count. With UTF-16-LE input every character is 2 bytes, so the bar capped at 50%; UTF-32 capped at 25%; etc. Simon noted in the issue that the obvious fix (calling `.tell()` on the wrapped text stream) doesn't work because text mode disables it during iteration. The underlying binary buffer doesn't have that restriction though, so this tracks progress against `TextIOWrapper.buffer.tell()` when the wrapped object exposes one. For raw binary streams (no `.buffer` attribute) we keep the old behaviour, which was already byte-accurate. Added six regression tests in tests/test_utils.py covering UTF-8, UTF-16-LE, BOM-prefixed UTF-16, the sniff-style `BufferedReader` chain, a raw binary fallback, and the `.read()` path used by the JSON loader. Each asserts that the sum of update() calls equals the on-disk file size, which is what `click.progressbar` needs to reach 100%.
1 parent 8f0c06e commit 7c3dd30

2 files changed

Lines changed: 132 additions & 2 deletions

File tree

sqlite_utils/utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,52 @@ class UpdateWrapper:
215215
def __init__(self, wrapped: io.IOBase, update: Callable[[int], None]) -> None:
216216
self._wrapped = wrapped
217217
self._update = update
218+
# `file_progress` sets the progress bar length to the file size in
219+
# bytes, but iterating a text-mode stream yields decoded characters,
220+
# so reporting `len(line)` undercounts for any multi-byte encoding
221+
# (UTF-16-LE caps the bar at ~50%, UTF-32 at ~25%, etc.). When the
222+
# wrapped object is a text wrapper, track progress against the
223+
# underlying binary buffer's position instead. See #439.
224+
self._byte_source = getattr(wrapped, "buffer", None)
225+
self._last_byte_pos = 0
226+
if self._byte_source is not None:
227+
try:
228+
self._last_byte_pos = self._byte_source.tell()
229+
except (io.UnsupportedOperation, OSError):
230+
self._byte_source = None
231+
232+
def _advance_to_buffer_pos(self) -> None:
233+
# Bring the progress bar up to the current byte position of the
234+
# underlying binary buffer (which may have read ahead).
235+
assert self._byte_source is not None
236+
try:
237+
pos = self._byte_source.tell()
238+
except OSError:
239+
return
240+
delta = pos - self._last_byte_pos
241+
if delta > 0:
242+
self._update(delta)
243+
self._last_byte_pos = pos
218244

219245
def __iter__(self) -> Iterator[bytes]:
246+
if self._byte_source is None:
247+
for line in self._wrapped:
248+
self._update(len(line))
249+
yield line
250+
return
220251
for line in self._wrapped:
221-
self._update(len(line))
252+
self._advance_to_buffer_pos()
222253
yield line
254+
# The wrapper may have buffered the last chunk without emitting any
255+
# more lines; flush the remaining bytes so the bar reaches 100%.
256+
self._advance_to_buffer_pos()
223257

224258
def read(self, size: int = -1) -> bytes:
225259
data = self._wrapped.read(size)
226-
self._update(len(data))
260+
if self._byte_source is not None:
261+
self._advance_to_buffer_pos()
262+
else:
263+
self._update(len(data))
227264
return data
228265

229266

tests/test_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,96 @@ def test_maximize_csv_field_size_limit():
8383
)
8484
def test_flatten(input, expected):
8585
assert utils.flatten(input) == expected
86+
87+
88+
# Regression tests for #439: progress bar against multi-byte encodings
89+
90+
91+
def _collect_updates(rows):
92+
"""Iterate the wrapper, capturing every update() value."""
93+
return list(rows)
94+
95+
96+
def _make_temp(content_bytes, tmp_path, name):
97+
path = tmp_path / name
98+
path.write_bytes(content_bytes)
99+
return path
100+
101+
102+
def test_updatewrapper_utf8_reports_byte_lengths(tmp_path):
103+
# Sanity: ASCII / UTF-8 still hits 100% (this was already correct,
104+
# but we want a baseline to protect.)
105+
raw = b"a,b\n1,2\n3,4\n"
106+
path = _make_temp(raw, tmp_path, "in.csv")
107+
updates = []
108+
with open(path, "rb") as fp:
109+
wrapper = utils.UpdateWrapper(io.TextIOWrapper(fp, encoding="utf-8"), updates.append)
110+
_collect_updates(wrapper)
111+
assert sum(updates) == len(raw)
112+
113+
114+
def test_updatewrapper_utf16le_reports_byte_lengths(tmp_path):
115+
# Without the fix this test fails: the bar only reaches len(decoded)
116+
# which is half the raw byte length for UTF-16-LE.
117+
raw = "a,b\n1,2\n3,4\n".encode("utf-16-le")
118+
path = _make_temp(raw, tmp_path, "in.csv")
119+
updates = []
120+
with open(path, "rb") as fp:
121+
wrapper = utils.UpdateWrapper(io.TextIOWrapper(fp, encoding="utf-16-le"), updates.append)
122+
_collect_updates(wrapper)
123+
assert sum(updates) == len(raw)
124+
125+
126+
def test_updatewrapper_utf16le_with_bom_reaches_total_bytes(tmp_path):
127+
# BOM-prefixed UTF-16. The BOM byte is consumed by the TextIOWrapper
128+
# before iteration starts; we should still account for the full file
129+
# size so the bar reaches 100%.
130+
raw = "" + "a,b\n1,2\n3,4\n"
131+
raw_bytes = raw.encode("utf-16-le")
132+
path = _make_temp(raw_bytes, tmp_path, "in.csv")
133+
updates = []
134+
with open(path, "rb") as fp:
135+
wrapper = utils.UpdateWrapper(io.TextIOWrapper(fp, encoding="utf-16"), updates.append)
136+
_collect_updates(wrapper)
137+
assert sum(updates) == len(raw_bytes)
138+
139+
140+
def test_updatewrapper_through_buffered_reader(tmp_path):
141+
# The --sniff path wraps the raw file in io.BufferedReader before the
142+
# TextIOWrapper. Progress reporting must still resolve to the binary
143+
# file's byte count.
144+
raw = "a,b\n1,2\n3,4\n".encode("utf-16-le")
145+
path = _make_temp(raw, tmp_path, "in.csv")
146+
updates = []
147+
with open(path, "rb") as fp:
148+
buffered = io.BufferedReader(fp, buffer_size=4096)
149+
wrapper = utils.UpdateWrapper(
150+
io.TextIOWrapper(buffered, encoding="utf-16-le"), updates.append
151+
)
152+
_collect_updates(wrapper)
153+
assert sum(updates) == len(raw)
154+
155+
156+
def test_updatewrapper_binary_file_unchanged(tmp_path):
157+
# If the wrapped object is itself a raw binary file (no .buffer attr),
158+
# we should keep the old behaviour: iterate yields bytes and len() is
159+
# already the byte count.
160+
raw = b"a,b\n1,2\n3,4\n"
161+
path = _make_temp(raw, tmp_path, "in.csv")
162+
updates = []
163+
with open(path, "rb") as fp:
164+
wrapper = utils.UpdateWrapper(fp, updates.append)
165+
_collect_updates(wrapper)
166+
assert sum(updates) == len(raw)
167+
168+
169+
def test_updatewrapper_read_path_utf16le(tmp_path):
170+
# The .read() path is used by the JSON loader (not the CSV iterator),
171+
# but must agree with the iterator path on byte accounting.
172+
raw = '{"a": 1}'.encode("utf-16-le")
173+
path = _make_temp(raw, tmp_path, "in.json")
174+
updates = []
175+
with open(path, "rb") as fp:
176+
wrapper = utils.UpdateWrapper(io.TextIOWrapper(fp, encoding="utf-16-le"), updates.append)
177+
wrapper.read()
178+
assert sum(updates) == len(raw)

0 commit comments

Comments
 (0)