Skip to content

Commit f8c57cd

Browse files
authored
feat: support simplecache url chaining (#42)
* test for simplecache * implementing _get_file * read in chunks * update test * working _get_file * add additional cache test * fix problem with chunking * repeat on FileNotFoundError to fix flaky glob * use unstrip_protocol to build url * match parent signature for _rm_file * increase chunk size to 256kiB
1 parent 851adaa commit f8c57cd

3 files changed

Lines changed: 75 additions & 4 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
run: python -m pip install .[test]
6565
- name: Test package
6666
run: |
67-
python -m pytest -vv tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|timeout|expired|connection|socket"
67+
python -m pytest -vv tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|FileNotFoundError|timeout|expired|connection|socket"
6868
6969
- name: Run fsspec-xrootd tests from uproot latest release
7070
run: |
@@ -75,7 +75,7 @@ jobs:
7575
python -m pip install ./uproot[test]
7676
# Install xrootd-fsspec again because it may have been overwritten by uproot
7777
python -m pip install .[test]
78-
python -m pytest -vv -k "xrootd" uproot/tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|timeout|expired|connection|socket"
78+
python -m pytest -vv -k "xrootd" uproot/tests --reruns 10 --reruns-delay 30 --only-rerun "(?i)OSError|FileNotFoundError|timeout|expired|connection|socket"
7979
8080
dist:
8181
name: Distribution build

src/fsspec_xrootd/xrootd.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ async def _rmdir(self, path: str) -> None:
259259

260260
rmdir = sync_wrapper(_rmdir)
261261

262-
async def _rm_file(self, path: str) -> None:
262+
async def _rm_file(self, path: str, **kwargs: Any) -> None:
263263
status, n = await _async_wrap(self._myclient.rm, path, self.timeout)
264264
if not status.ok:
265265
raise OSError(f"File not removed properly: {status.message}")
@@ -391,7 +391,7 @@ async def _cat_file(self, path: str, start: int, end: int, **kwargs: Any) -> Any
391391
try:
392392
status, _n = await _async_wrap(
393393
_myFile.open,
394-
self.protocol + "://" + self.storage_options["hostid"] + "/" + path,
394+
self.unstrip_protocol(path),
395395
OpenFlags.READ,
396396
self.timeout,
397397
)
@@ -412,6 +412,45 @@ async def _cat_file(self, path: str, start: int, end: int, **kwargs: Any) -> Any
412412
self.timeout,
413413
)
414414

415+
async def _get_file(
416+
self, rpath: str, lpath: str, chunk_size: int = 262_144, **kwargs: Any
417+
) -> None:
418+
# Open the remote file for reading
419+
remote_file = client.File()
420+
421+
try:
422+
status, _n = await _async_wrap(
423+
remote_file.open,
424+
self.unstrip_protocol(rpath),
425+
OpenFlags.READ,
426+
self.timeout,
427+
)
428+
if not status.ok:
429+
raise OSError(f"Remote file failed to open: {status.message}")
430+
431+
with open(lpath, "wb") as local_file:
432+
start: int = 0
433+
while True:
434+
# Read a chunk of content from the remote file
435+
status, chunk = await _async_wrap(
436+
remote_file.read, start, chunk_size, self.timeout
437+
)
438+
start += chunk_size
439+
440+
if not status.ok:
441+
raise OSError(f"Remote file failed to read: {status.message}")
442+
443+
# Break if there is no more content
444+
if not chunk:
445+
break
446+
447+
# Write the chunk to the local file
448+
local_file.write(chunk)
449+
450+
finally:
451+
# Close the remote file
452+
await _async_wrap(remote_file.close, self.timeout)
453+
415454
async def _get_max_chunk_info(self, file: Any) -> tuple[int, int]:
416455
"""Queries the XRootD server for info required for pyxrootd vector_read() function.
417456
Queries for maximum number of chunks and the maximum chunk size allowed by the server.

tests/test_basicio.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,35 @@ def test_glob_full_names(localserver, clear_server):
412412
for name in full_names:
413413
with fsspec.open(name) as f:
414414
assert f.read() in [bytes(data, "utf-8") for data in [TESTDATA1, TESTDATA2]]
415+
416+
417+
@pytest.mark.parametrize("protocol_prefix", ["", "simplecache::"])
418+
def test_cache(localserver, clear_server, protocol_prefix):
419+
data = TESTDATA1 * int(1e7 / len(TESTDATA1)) # bigger than the chunk size
420+
remoteurl, localpath = localserver
421+
with open(localpath + "/testfile.txt", "w") as fout:
422+
fout.write(data)
423+
424+
with fsspec.open(protocol_prefix + remoteurl + "/testfile.txt", "rb") as f:
425+
contents = f.read()
426+
assert contents == data.encode("utf-8")
427+
428+
429+
def test_cache_directory(localserver, clear_server, tmp_path):
430+
remoteurl, localpath = localserver
431+
with open(localpath + "/testfile.txt", "w") as fout:
432+
fout.write(TESTDATA1)
433+
434+
cache_directory = tmp_path / "cache"
435+
with fsspec.open(
436+
"simplecache::" + remoteurl + "/testfile.txt",
437+
"rb",
438+
simplecache={"cache_storage": str(cache_directory)},
439+
) as f:
440+
contents = f.read()
441+
assert contents == TESTDATA1.encode("utf-8")
442+
443+
assert len(os.listdir(cache_directory)) == 1
444+
with open(cache_directory / os.listdir(cache_directory)[0], "rb") as f:
445+
contents = f.read()
446+
assert contents == TESTDATA1.encode("utf-8")

0 commit comments

Comments
 (0)