Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3733,7 +3733,7 @@ def iter_outputs(shard_iterable):
_time = time.time()
for i, example in iter_outputs(shard_iterable):
if update_data:
if i == 0:
if writer is None:
buf_writer, writer, tmp_file = init_buffer_and_writer()
stack.enter_context(writer)
if isinstance(example, pa.Table):
Expand All @@ -3759,7 +3759,7 @@ def iter_outputs(shard_iterable):
for i, batch in iter_outputs(shard_iterable):
num_examples_in_batch = len(i)
if update_data:
if i and i[0] == 0:
if writer is None:
buf_writer, writer, tmp_file = init_buffer_and_writer()
stack.enter_context(writer)
if isinstance(batch, pa.Table):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,30 @@ def __call__(self, example):
dset.map(ex_cnt)
self.assertEqual(ex_cnt.cnt, len(dset))

def test_map_writer_initialized_when_first_examples_return_none(self, in_memory):
"""Dataset.map must not crash when early examples return None.

The writer was previously only initialized when i == 0. If the map
function returns None for the first N examples, update_data stays False
and the writer is never created. When a later example returns a dict,
update_data flips to True but writer is still None, causing:
AttributeError: 'NoneType' object has no attribute 'write'
(issue #7990)
"""
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
def fn(example, idx):
if idx < 2:
return None
return {"filename": example["filename"] + "_transformed"}

result = dset.map(fn, with_indices=True)
# First two rows are skipped (return None → no update)
# Remaining rows are transformed
self.assertEqual(len(result), len(dset))
for i in range(2, len(dset)):
self.assertTrue(result[i]["filename"].endswith("_transformed"))

@require_not_windows
def test_map_crash_subprocess(self, in_memory):
# be sure that a crash in one of the subprocess will not
Expand Down