Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
generate_from_arrow_type,
pandas_types_mapper,
require_decoding,
require_storage_embed,
)
from .filesystems import is_remote_filesystem
from .fingerprint import (
Expand Down Expand Up @@ -715,6 +716,45 @@ def __eq__(self, value):
return value == list(self)


def _get_token_per_repo_id_for_embed(
table: pa.Table,
token: Optional[str],
resolved_output_path: Optional[HfFileSystemResolvedPath] = None,
dataset_name: Optional[str] = None,
) -> dict[str, Union[str, bool, None]]:
"""Build a repo_id -> token mapping for embedding remote files during push_to_hub."""
token_per_repo_id: dict[str, Union[str, bool, None]] = {}
if token is None:
return token_per_repo_id

repo_ids: set[str] = set()
if dataset_name:
repo_ids.add(dataset_name)
if isinstance(resolved_output_path, HfFileSystemResolvedRepositoryPath) and resolved_output_path.repo_id:
repo_ids.add(resolved_output_path.repo_id)

def collect_repo_ids(array, feature):
if not require_storage_embed(feature) or not pa.types.is_struct(array.type):
return
if array.type.get_field_index("path") < 0:
return
for path in array.field("path").to_pylist():
if path is None:
continue
source_url = path.split("::")[-1]
if not source_url.startswith(("hf://", config.HF_ENDPOINT)):
continue
pattern = (
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
)
source_url_fields = string_to_dict(source_url, pattern)
if source_url_fields is not None and "repo_id" in source_url_fields:
repo_ids.add(source_url_fields["repo_id"])

table_visitor(table, collect_repo_ids)
return {repo_id: token for repo_id in repo_ids}


class Dataset(DatasetInfoMixin, IndexableMixin, TensorflowDatasetMixin):
"""A Dataset backed by an Arrow table."""

Expand Down Expand Up @@ -5891,8 +5931,14 @@ def _push_parquet_shards_to_hub_single(
if embed_external_files:
format = shard.format
shard = shard.with_format("arrow")
token_per_repo_id = _get_token_per_repo_id_for_embed(
shard.data,
token=token,
resolved_output_path=resolved_output_path,
dataset_name=shard._info.dataset_name,
)
shard = shard.map(
embed_table_storage,
partial(embed_table_storage, token_per_repo_id=token_per_repo_id),
batched=True,
batch_size=writer_batch_size,
keep_in_memory=True,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5120,3 +5120,22 @@ def test_process_large_few_examples(tmp_path):
# make sure this is split into 2 shards
ds.save_to_disk(dataset_path, max_shard_size="1KB")
assert (dataset_path / "data-00000-of-00001.arrow").exists()


def test_get_token_per_repo_id_for_embed():
from datasets.arrow_dataset import _get_token_per_repo_id_for_embed
from datasets.features import Features, Image

table = pa.table(
{
"image": [
{
"bytes": None,
"path": "hf://datasets/hf-internal-testing/fixtures_image_utils@main/image.jpg",
}
]
}
)
table = table.cast(Features({"image": Image()}).arrow_schema)
token_per_repo_id = _get_token_per_repo_id_for_embed(table, token="hf_test_token")
assert token_per_repo_id == {"hf-internal-testing/fixtures_image_utils": "hf_test_token"}