Skip to content

.batch() error on formatted datasets #8075

@mariosasko

Description

@mariosasko

The .batch() method currently assumes the input (batch) is always a dictionary, which causes errors when it isn't. This can happen with formatted datasets, since formats like "pyarrow", "pandas" (only affects IterableDataset), and "polars" return tables/dataframes instead of dictionaries.

For example:

from datasets import IterableDataset, Dataset
list(IterableDataset.from_dict({"a": [1, 2, 3, 4]}).with_format("pyarrow").batch(2))
# AttributeError: 'pyarrow.lib.Table' object has no attribute 'items'

Ideally, the result should be the same whether the format is applied before or after batching, i.e., the following should hold for all the format types:

assert list(IterableDataset.from_dict({"a": [1, 2, 3, 4]}).with_format(format_type).batch(2)) == list(IterableDataset.from_dict({"a": [1, 2, 3, 4]}).batch(2).with_format(format_type))
assert list(Dataset.from_dict({"a": [1, 2, 3, 4]}).with_format(format_type).batch(2)) == list(Dataset.from_dict({"a": [1, 2, 3, 4]}).batch(2).with_format(format_type))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions