Skip to content

Commit 1d2af98

Browse files
authored
Fix vector_read data server size query (#56)
1 parent b12503e commit 1d2af98

2 files changed

Lines changed: 36 additions & 28 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ strict = true
2525
show_error_codes = true
2626
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
2727
warn_unreachable = true
28+
ignore_missing_imports = true
2829

2930

3031
[tool.check-manifest]

src/fsspec_xrootd/xrootd.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,17 @@
99
from functools import partial
1010
from typing import Any, Callable, Iterable
1111

12-
from fsspec.asyn import ( # type: ignore[import-not-found]
13-
AsyncFileSystem,
14-
_run_coros_in_chunks,
15-
sync_wrapper,
16-
)
17-
from fsspec.spec import AbstractBufferedFile # type: ignore[import-not-found]
18-
from XRootD import client # type: ignore[import-not-found]
19-
from XRootD.client.flags import ( # type: ignore[import-not-found]
12+
from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks, sync_wrapper
13+
from fsspec.spec import AbstractBufferedFile
14+
from XRootD import client
15+
from XRootD.client.flags import (
2016
DirListFlags,
2117
MkDirFlags,
2218
OpenFlags,
2319
QueryCode,
2420
StatInfoFlags,
2521
)
26-
from XRootD.client.responses import ( # type: ignore[import-not-found]
27-
HostList,
28-
XRootDStatus,
29-
)
22+
from XRootD.client.responses import HostList, XRootDStatus
3023

3124

3225
class ErrorCodes(IntEnum):
@@ -73,9 +66,11 @@ async def _async_wrap(func: Callable[..., Any], *args: Any) -> Any:
7366
An asyncio future. Result is set when _handle() is called back.
7467
"""
7568
future = asyncio.get_running_loop().create_future()
76-
status = func(*args, callback=partial(_handle, future))
77-
if not status.ok:
78-
raise OSError(status.message.strip())
69+
submit_status = func(*args, callback=partial(_handle, future))
70+
if not submit_status.ok:
71+
raise OSError(
72+
f"Failed to submit {func!r} request: {submit_status.message.strip()}"
73+
)
7974
return await future
8075

8176

@@ -149,6 +144,8 @@ class XRootDFileSystem(AsyncFileSystem): # type: ignore[misc]
149144
root_marker = "/"
150145
default_timeout = 60
151146
async_impl = True
147+
default_max_num_chunks = 1024
148+
default_max_chunk_size = 2097136
152149

153150
_dataserver_info_cache: dict[str, Any] = defaultdict(dict)
154151

@@ -458,7 +455,8 @@ async def _get_file(
458455
# Close the remote file
459456
await _async_wrap(remote_file.close, self.timeout)
460457

461-
async def _get_max_chunk_info(self, file: Any) -> tuple[int, int]:
458+
@classmethod
459+
async def _get_max_chunk_info(cls, file: Any) -> tuple[int, int]:
462460
"""Queries the XRootD server for info required for pyxrootd vector_read() function.
463461
Queries for maximum number of chunks and the maximum chunk size allowed by the server.
464462
@@ -471,20 +469,31 @@ async def _get_max_chunk_info(self, file: Any) -> tuple[int, int]:
471469
Tuple of max chunk size and max number of chunks. Both ints.
472470
"""
473471
data_server = file.get_property("DataServer")
474-
if data_server not in XRootDFileSystem._dataserver_info_cache:
472+
if data_server == "":
473+
return cls.default_max_num_chunks, cls.default_max_chunk_size
474+
# Normalize to URL
475+
data_server = client.URL(data_server)
476+
data_server = f"{data_server.protocol}://{data_server.hostid}/"
477+
if data_server not in cls._dataserver_info_cache:
478+
fs = client.FileSystem(data_server)
475479
status, result = await _async_wrap(
476-
self._myclient.query, QueryCode.CONFIG, "readv_iov_max readv_ior_max"
480+
fs.query, QueryCode.CONFIG, "readv_iov_max readv_ior_max"
477481
)
478482
if not status.ok:
479483
raise OSError(
480484
f"Server query for vector read info failed: {status.message}"
481485
)
482-
max_num_chunks, max_chunk_size = map(int, result.split(b"\n", 1))
483-
XRootDFileSystem._dataserver_info_cache[data_server] = {
484-
"max_num_chunks": int(max_num_chunks),
485-
"max_chunk_size": int(max_chunk_size),
486+
try:
487+
max_num_chunks, max_chunk_size = map(int, result.split(b"\n", 1))
488+
except ValueError:
489+
raise OSError(
490+
f"Server query for vector read info failed: could not parse {result!r}"
491+
) from None
492+
cls._dataserver_info_cache[data_server] = {
493+
"max_num_chunks": max_num_chunks,
494+
"max_chunk_size": max_chunk_size,
486495
}
487-
info = XRootDFileSystem._dataserver_info_cache[data_server]
496+
info = cls._dataserver_info_cache[data_server]
488497
return (info["max_num_chunks"], info["max_chunk_size"])
489498

490499
async def _cat_vector_read(
@@ -649,10 +658,8 @@ def open(
649658
**kwargs,
650659
)
651660
if compression is not None:
652-
from fsspec.compression import compr # type: ignore[import-not-found]
653-
from fsspec.core import ( # type: ignore[import-not-found]
654-
get_compression,
655-
)
661+
from fsspec.compression import compr
662+
from fsspec.core import get_compression
656663

657664
compression = get_compression(path, compression)
658665
compress = compr[compression]
@@ -697,7 +704,7 @@ def __init__(
697704

698705
self._myFile = client.File()
699706
status, _n = self._myFile.open(
700-
fs.protocol + "://" + fs.storage_options["hostid"] + "/" + path,
707+
fs.unstrip_protocol(path),
701708
self.mode,
702709
timeout=self.timeout,
703710
)

0 commit comments

Comments
 (0)