The .batch() method currently assumes the input (batch) is always a dictionary, which causes errors when it isn't. This can happen with formatted datasets, since formats like "pyarrow", "pandas" (only affects IterableDataset), and "polars" return tables/dataframes instead of dictionaries.
For example:
from datasets import IterableDataset, Dataset
list(IterableDataset.from_dict({"a": [1, 2, 3, 4]}).with_format("pyarrow").batch(2))
# AttributeError: 'pyarrow.lib.Table' object has no attribute 'items'
Ideally, the result should be the same whether the format is applied before or after batching, i.e., the following should hold for all the format types:
assert list(IterableDataset.from_dict({"a": [1, 2, 3, 4]}).with_format(format_type).batch(2)) == list(IterableDataset.from_dict({"a": [1, 2, 3, 4]}).batch(2).with_format(format_type))
assert list(Dataset.from_dict({"a": [1, 2, 3, 4]}).with_format(format_type).batch(2)) == list(Dataset.from_dict({"a": [1, 2, 3, 4]}).batch(2).with_format(format_type))
The
.batch()method currently assumes the input (batch) is always a dictionary, which causes errors when it isn't. This can happen with formatted datasets, since formats like"pyarrow","pandas"(only affectsIterableDataset), and"polars"return tables/dataframes instead of dictionaries.For example:
Ideally, the result should be the same whether the format is applied before or after batching, i.e., the following should hold for all the format types: