Skip to content

Commit 2d6cd80

Browse files
authored
Expand .df Array/Query accessor to allow indexing with NumPy and PyArrow arrays (#2170)
1 parent 9b6a893 commit 2d6cd80

3 files changed

Lines changed: 70 additions & 10 deletions

File tree

tiledb/array.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,9 +1039,9 @@ def multi_index(self):
10391039
"""Retrieve data cells with multi-range, domain-inclusive indexing. Returns
10401040
the cross-product of the ranges.
10411041
1042-
:param list selection: Per dimension, a scalar, ``slice``, or list of scalars
1043-
or ``slice`` objects. Scalars and ``slice`` components should match the
1044-
type of the underlying Dimension.
1042+
:param list selection: Per dimension, a scalar, ``slice``,
1043+
or a list/numpy array/pyarrow array of scalars or ``slice`` objects.
1044+
Scalars and ``slice`` components should match the type of the underlying Dimension.
10451045
:returns: dict of {'attribute': result}. Coords are included by default for
10461046
Sparse arrays only (use `Array.query(coords=<>)` to select).
10471047
:raises IndexError: invalid or unsupported index selection
@@ -1093,9 +1093,9 @@ def df(self):
10931093
"""Retrieve data cells as a Pandas dataframe, with multi-range,
10941094
domain-inclusive indexing using ``multi_index``.
10951095
1096-
:param list selection: Per dimension, a scalar, ``slice``, or list of scalars
1097-
or ``slice`` objects. Scalars and ``slice`` components should match the
1098-
type of the underlying Dimension.
1096+
:param list selection: Per dimension, a scalar, ``slice``,
1097+
or a list/numpy array/pyarrow array of scalars or ``slice`` objects.
1098+
Scalars and ``slice`` components should match the type of the underlying Dimension.
10991099
:returns: dict of {'attribute': result}. Coords are included by default for
11001100
Sparse arrays only (use `Array.query(coords=<>)` to select).
11011101
:raises IndexError: invalid or unsupported index selection

tiledb/multirange_indexing.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,14 @@
4040
# We don't want to import these eagerly since importing Pandas in particular
4141
# can add around half a second of import time even if we never use it.
4242
import pandas
43+
44+
45+
try:
4346
import pyarrow
4447

48+
has_pyarrow = True
49+
except ImportError:
50+
has_pyarrow = False
4551

4652
current_timer: ContextVar[str] = ContextVar("timer_scope")
4753

@@ -112,11 +118,15 @@ def to_scalar(obj: Any) -> Scalar:
112118
return cast(Scalar, obj)
113119
if isinstance(obj, np.ndarray) and obj.ndim == 0:
114120
return cast(Scalar, obj[()])
121+
if has_pyarrow and isinstance(obj, pyarrow.Array):
122+
return to_scalar(obj.to_numpy()[()])
123+
if has_pyarrow and isinstance(obj, pyarrow.Scalar):
124+
return cast(Scalar, obj.as_py())
115125
raise ValueError(f"Cannot convert {type(obj)} to scalar")
116126

117127

118128
def iter_ranges(
119-
sel: Union[Scalar, slice, Range, List[Scalar]],
129+
sel: Union[Scalar, slice, Range, List[Scalar], np.ndarray, "pyarrow.Array"],
120130
sparse: bool,
121131
nonempty_domain: Optional[Range] = None,
122132
) -> Iterator[Range]:
@@ -145,7 +155,9 @@ def iter_ranges(
145155
assert len(sel) == 2
146156
yield to_scalar(sel[0]), to_scalar(sel[1])
147157

148-
elif isinstance(sel, list):
158+
elif isinstance(sel, (list, np.ndarray)) or (
159+
has_pyarrow and isinstance(sel, pyarrow.Array)
160+
):
149161
for scalar in map(to_scalar, sel):
150162
yield scalar, scalar
151163

@@ -178,8 +190,6 @@ def iter_label_range(sel: Union[Scalar, slice, Range, List[Scalar]]):
178190

179191
def dim_ranges_from_selection(selection, nonempty_domain, is_sparse):
180192
# don't try to index nonempty_domain if None
181-
if isinstance(selection, np.ndarray):
182-
return selection
183193
selection = selection if isinstance(selection, list) else [selection]
184194
return tuple(
185195
rng for sel in selection for rng in iter_ranges(sel, is_sparse, nonempty_domain)

tiledb/tests/test_pandas_dataframe.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,56 @@ def try_rt(name, df, pq_args={}):
12761276
basic3 = make_dataframe_basic3()
12771277
try_rt("basic3", basic3)
12781278

1279+
@pytest.mark.parametrize(
1280+
"dim_data, attr_data, dtype, domain",
1281+
[
1282+
(pyarrow.array([1, 2, 3]), pyarrow.array([1, 2, 3]), np.int64, (1, 3)),
1283+
(pyarrow.array(["a", "b", "c"]), pyarrow.array([1, 2, 3]), "ascii", None),
1284+
],
1285+
)
1286+
def test_read_indexing_with_pyarrow_and_numpy_arrays(
1287+
self, dim_data, attr_data, dtype, domain
1288+
):
1289+
# This test is to ensure that .df can be indexed with both PyArrow and NumPy arrays.
1290+
uri = self.path("read_indexing_with_pyarrow_and_numpy_arrays")
1291+
1292+
dim = (
1293+
tiledb.Dim(name="dim_a", dtype=dtype, domain=domain)
1294+
if domain
1295+
else tiledb.Dim(name="dim_a", dtype=dtype)
1296+
)
1297+
schema = tiledb.ArraySchema(
1298+
domain=tiledb.Domain(dim),
1299+
sparse=True,
1300+
attrs=[tiledb.Attr(name="rand", dtype=np.int32)],
1301+
allows_duplicates=True,
1302+
)
1303+
tiledb.Array.create(uri, schema)
1304+
1305+
with tiledb.open(uri, "w") as arr:
1306+
arr[dim_data] = attr_data
1307+
1308+
with tiledb.open(uri, "r") as arr:
1309+
expected_df = pd.DataFrame(
1310+
{"dim_a": dim_data.tolist(), "rand": attr_data.tolist()}
1311+
)
1312+
1313+
assert_array_equal(arr.df[:], expected_df)
1314+
assert_array_equal(arr.df[pyarrow.array(dim_data)], expected_df)
1315+
assert_array_equal(arr.df[np.array(dim_data)], expected_df)
1316+
1317+
partial_dim_data = dim_data[:2]
1318+
expected_partial_df = expected_df.iloc[:2]
1319+
1320+
assert_array_equal(
1321+
arr.df[pyarrow.array(partial_dim_data)], expected_partial_df
1322+
)
1323+
assert_array_equal(arr.df[np.array(partial_dim_data)], expected_partial_df)
1324+
1325+
expected_dict = OrderedDict(
1326+
[("dim_a", dim_data.tolist()), ("rand", attr_data.tolist())]
1327+
)
1328+
12791329
def test_nullable_integers(self):
12801330
nullable_int_dtypes = (
12811331
pd.Int64Dtype(),

0 commit comments

Comments
 (0)