Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/skhep_testdata/local_files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import shutil
import tempfile
import zipfile
from importlib import resources
Expand Down Expand Up @@ -60,5 +61,8 @@ def download_all(cache_dir: str | None = None) -> None:

with zipfile.ZipFile(f) as z:
for n in z.namelist():
if "src/skhep_testdata/data/" in n and not n.endswith(".py"):
z.extract(n, str(local_dir / str(n.split("/")[-1])))
if n.endswith(("/", ".py")) or "src/skhep_testdata/data/" not in n:
continue
target = local_dir / n.split("/")[-1]
with z.open(n) as src, target.open("wb") as dst:
shutil.copyfileobj(src, dst)
9 changes: 8 additions & 1 deletion src/skhep_testdata/remote_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import sys
import tarfile
from importlib import resources
from pathlib import Path
Expand Down Expand Up @@ -76,7 +77,13 @@ def fetch_remote_dataset(
logging.warning("Extracting %s", writefile)
with tarfile.open(str(writefile)) as tar:
members = [tar.getmember(f) for f in files.values()]
tar.extractall(str(dataset_dir), members)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't supported on all patch releases before 3.12.0.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ€– AI text below πŸ€–

Good catch. The data filter only exists on patch releases that got the backport (e.g. 3.10.12, not 3.10.0). Pushed df8b23d: the filtered call is now gated to sys.version_info >= (3, 12) or a hasattr(tarfile, "data_filter") feature check, falling back to the unfiltered extractall otherwise.

# The "data" filter is always present on 3.12+, but before that it
# only exists on patch releases that received the backport (e.g.
# 3.10.12, not 3.10.0).
if sys.version_info >= (3, 12) or hasattr(tarfile, "data_filter"):
tar.extractall(str(dataset_dir), members, filter="data")
else: # pragma: no cover
tar.extractall(str(dataset_dir), members)

for outfile, infile in files.items():
full_in = dataset_dir / infile
Expand Down
38 changes: 38 additions & 0 deletions tests/test_local_files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import io
import zipfile
from pathlib import Path

import pytest
Expand Down Expand Up @@ -42,3 +44,39 @@ def dummy_remote_file(filename, cache_dir=None, raise_missing=False):

path = skhtd.data_path(str(Path("dataset") / "a_remote_file.root"))
assert path == str(tmpdir / "dataset" / "a_remote_file.root")


def test_download_all(monkeypatch, tmp_path):
file_bytes = b"\x90root-file-contents\x00"

buffer = io.BytesIO()
with zipfile.ZipFile(buffer, "w") as z:
z.writestr("repo-abc123/", b"") # directory entry, must be ignored
z.writestr("repo-abc123/src/skhep_testdata/data/somefile.root", file_bytes)
z.writestr("repo-abc123/src/skhep_testdata/data/__init__.py", b"# ignore me")
zip_bytes = buffer.getvalue()

class DummyResponse:
def __enter__(self):
return self

def __exit__(self, *args):
return False

def raise_for_status(self):
pass

def iter_content(self, chunk_size=8192):
for i in range(0, len(zip_bytes), chunk_size):
yield zip_bytes[i : i + chunk_size]

monkeypatch.setattr(
skhtd.local_files.requests, "get", lambda *a, **k: DummyResponse()
)

skhtd.download_all(cache_dir=str(tmp_path))

extracted = tmp_path / "somefile.root"
assert extracted.is_file()
assert extracted.read_bytes() == file_bytes
assert not (tmp_path / "__init__.py").exists()
Loading