diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 59451a640e6..a4cf6318303 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -89,6 +89,7 @@ generate_from_arrow_type, pandas_types_mapper, require_decoding, + require_storage_embed, ) from .filesystems import is_remote_filesystem from .fingerprint import ( @@ -715,6 +716,45 @@ def __eq__(self, value): return value == list(self) +def _get_token_per_repo_id_for_embed( + table: pa.Table, + token: Optional[str], + resolved_output_path: Optional[HfFileSystemResolvedPath] = None, + dataset_name: Optional[str] = None, +) -> dict[str, Union[str, bool, None]]: + """Build a repo_id -> token mapping for embedding remote files during push_to_hub.""" + token_per_repo_id: dict[str, Union[str, bool, None]] = {} + if token is None: + return token_per_repo_id + + repo_ids: set[str] = set() + if dataset_name: + repo_ids.add(dataset_name) + if isinstance(resolved_output_path, HfFileSystemResolvedRepositoryPath) and resolved_output_path.repo_id: + repo_ids.add(resolved_output_path.repo_id) + + def collect_repo_ids(array, feature): + if not require_storage_embed(feature) or not pa.types.is_struct(array.type): + return + if array.type.get_field_index("path") < 0: + return + for path in array.field("path").to_pylist(): + if path is None: + continue + source_url = path.split("::")[-1] + if not source_url.startswith(("hf://", config.HF_ENDPOINT)): + continue + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + if source_url_fields is not None and "repo_id" in source_url_fields: + repo_ids.add(source_url_fields["repo_id"]) + + table_visitor(table, collect_repo_ids) + return {repo_id: token for repo_id in repo_ids} + + class Dataset(DatasetInfoMixin, IndexableMixin, TensorflowDatasetMixin): """A Dataset backed by an Arrow table.""" @@ -5891,8 +5931,14 @@ def _push_parquet_shards_to_hub_single( if embed_external_files: format = shard.format shard = shard.with_format("arrow") + token_per_repo_id = _get_token_per_repo_id_for_embed( + shard.data, + token=token, + resolved_output_path=resolved_output_path, + dataset_name=shard._info.dataset_name, + ) shard = shard.map( - embed_table_storage, + partial(embed_table_storage, token_per_repo_id=token_per_repo_id), batched=True, batch_size=writer_batch_size, keep_in_memory=True, diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 18a8a6038fe..fc28dffeab7 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -5120,3 +5120,22 @@ def test_process_large_few_examples(tmp_path): # make sure this is split into 2 shards ds.save_to_disk(dataset_path, max_shard_size="1KB") assert (dataset_path / "data-00000-of-00001.arrow").exists() + + +def test_get_token_per_repo_id_for_embed(): + from datasets.arrow_dataset import _get_token_per_repo_id_for_embed + from datasets.features import Features, Image + + table = pa.table( + { + "image": [ + { + "bytes": None, + "path": "hf://datasets/hf-internal-testing/fixtures_image_utils@main/image.jpg", + } + ] + } + ) + table = table.cast(Features({"image": Image()}).arrow_schema) + token_per_repo_id = _get_token_per_repo_id_for_embed(table, token="hf_test_token") + assert token_per_repo_id == {"hf-internal-testing/fixtures_image_utils": "hf_test_token"}