Skip to content

Commit 223d2ee

Browse files
committed
fixes
Signed-off-by: Baris Palaska <barispalaska@gmail.com>
1 parent c03c8e8 commit 223d2ee

6 files changed

Lines changed: 136 additions & 42 deletions

File tree

vortex-array/public-api.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8888,6 +8888,8 @@ pub use vortex_array::dtype::half
88888888

88898889
pub mod vortex_array::dtype::arrow
88908890

8891+
pub const vortex_array::dtype::arrow::ARROW_EXT_NAME_VARIANT: &str
8892+
88918893
pub trait vortex_array::dtype::arrow::FromArrowType<T>: core::marker::Sized
88928894

88938895
pub fn vortex_array::dtype::arrow::FromArrowType::from_arrow(value: T) -> Self

vortex-array/src/arrow/convert.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,7 @@ impl FromArrowArray<&dyn ArrowArray> for ArrayRef {
589589
DataType::Dictionary(key_type, _) => {
590590
Ok(dict_from_arrow_with_session(array, key_type, nullable, session)?.into_array())
591591
}
592-
// Other arrays don't carry child Fields, so the session-aware path is
593-
// equivalent to the legacy one.
592+
// Leaves: no child Fields to thread session through.
594593
_ => Self::from_arrow(array, nullable),
595594
}
596595
}
@@ -744,8 +743,8 @@ impl FromArrowArray<&RecordBatch> for ArrayRef {
744743
}
745744
}
746745

747-
/// Inverse of `field_from_dtype` (in `dtype/arrow.rs`): rewrap `storage` as `ExtensionArray`
748-
/// when `field` carries `ARROW:extension:name` metadata for a registered extension.
746+
/// Rewrap `storage` as `ExtensionArray` if `field` carries `ARROW:extension:name`
747+
/// for a registered extension. Inverse of `field_from_dtype` in `dtype/arrow.rs`.
749748
fn wrap_extension_if_field_has_metadata(
750749
storage: ArrayRef,
751750
field: &Field,

vortex-array/src/arrow/executor/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@ impl ArrowArrayExecutor for ArrayRef {
9999
None => preferred_arrow_type(&self)?,
100100
};
101101

102-
// Extensions with a native Arrow mapping (temporal) keep their wrapper so
103-
// `to_arrow_temporal` can read the metadata. Other extensions carry identity in
104-
// Field metadata, so dispatch on the storage array. Mirror the discriminator used
105-
// by `field_from_dtype` / `native_arrow_dtype_for_extension` in `dtype/arrow.rs`.
102+
// Temporal extensions keep their wrapper so `to_arrow_temporal` can read the
103+
// metadata. Other extensions unwrap to storage; their identity rides on Field
104+
// metadata.
106105
if let DType::Extension(ext) = self.dtype()
107106
&& ext.metadata_opt::<AnyTemporal>().is_none()
108107
{

vortex-array/src/dtype/arrow.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ use crate::extension::datetime::Time;
5353
use crate::extension::datetime::TimeUnit;
5454
use crate::extension::datetime::Timestamp;
5555

56-
const ARROW_EXT_NAME_VARIANT: &str = "arrow.parquet.variant";
56+
/// Canonical Arrow extension name for Parquet Variant — handled as `DType::Variant` rather
57+
/// than going through the extension registry.
58+
pub const ARROW_EXT_NAME_VARIANT: &str = "arrow.parquet.variant";
5759

5860
/// Trait for converting Arrow types to Vortex types.
5961
pub trait FromArrowType<T>: Sized {
@@ -251,8 +253,7 @@ impl FromArrowType<&Field> for DType {
251253
}
252254
}
253255

254-
/// Convert an Arrow Field to a [`DType`] with `dtypes` already borrowed from the session,
255-
/// so the handle is acquired once per schema rather than once per field.
256+
/// Convert a Field to a [`DType`]. Takes `dtypes` borrowed once per schema (not per field).
256257
fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType {
257258
if field
258259
.extension_type_name()
@@ -268,10 +269,8 @@ fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType {
268269
}
269270
}
270271

271-
/// Resolve the [`ExtDTypeRef`] for an Arrow Field whose `ARROW:extension:name` metadata names
272-
/// a registered Vortex extension. Returns `None` for unregistered extensions, malformed
273-
/// metadata, or fields with no extension name; callers fall back to the storage representation
274-
/// and `tracing::warn!` reports the anomaly.
272+
/// Resolve the [`ExtDTypeRef`] for a Field whose `ARROW:extension:name` names a registered
273+
/// Vortex extension. Returns `None` for missing/unregistered/malformed metadata.
275274
pub(crate) fn resolve_extension_dtype(
276275
field: &Field,
277276
dtypes: &DTypeSession,
@@ -377,10 +376,8 @@ impl DType {
377376

378377
/// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
379378
///
380-
/// Extensions without a native Arrow mapping (e.g. user-registered extensions like
381-
/// `Vector`) degrade to their storage `DataType`; extension identity only survives when
382-
/// emitted onto an Arrow `Field` (see [`Self::to_arrow_schema`]). Callers that must
383-
/// reject non-temporal extensions should match on `DType::Extension` themselves.
379+
/// Extensions without a native Arrow mapping degrade to their storage `DataType`;
380+
/// identity only survives via `Field` metadata (see [`Self::to_arrow_schema`]).
384381
pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
385382
arrow_dtype_from_dtype(self)
386383
}
@@ -460,8 +457,7 @@ fn arrow_dtype_from_dtype(dtype: &DType) -> VortexResult<DataType> {
460457
})
461458
}
462459

463-
/// Build an Arrow [`Field`], attaching `ARROW:extension:name` and, when present,
464-
/// `ARROW:extension:metadata` for extensions and Variant that have no native Arrow mapping.
460+
/// Build a Field, attaching extension/Variant metadata when there's no native Arrow mapping.
465461
fn field_from_dtype(name: &str, dtype: &DType) -> VortexResult<Field> {
466462
if dtype.is_variant() {
467463
let storage = DataType::Struct(variant_storage_fields_minimal());

vortex-python/src/arrays/from_arrow.rs

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ use vortex::array::arrays::ChunkedArray;
1717
use vortex::array::arrays::ExtensionArray;
1818
use vortex::array::arrow::FromArrowArray;
1919
use vortex::dtype::DType;
20+
use vortex::dtype::arrow::ARROW_EXT_NAME_VARIANT;
2021
use vortex::dtype::arrow::FromArrowType;
2122
use vortex::dtype::extension::ExtId;
2223
use vortex::dtype::session::DTypeSessionExt;
2324
use vortex::error::VortexError;
2425
use vortex::error::VortexResult;
25-
use vortex::error::vortex_err;
2626
use vortex::session::VortexSession;
2727

2828
use crate::SESSION;
@@ -37,8 +37,7 @@ use crate::error::PyVortexResult;
3737

3838
/// Convert a Python `pyarrow` array (including `pa.ExtensionArray`) into a Vortex array.
3939
///
40-
/// Arrow's C ABI strips extension identity from the array layer — it lives on `Field`
41-
/// metadata, and a leaf `pa.ExtensionArray` has no enclosing field. We recover it from the
40+
/// The Arrow C ABI strips extension identity from leaf arrays; we recover it from the
4241
/// Python object via `extension_name` and `__arrow_ext_serialize__`.
4342
pub trait FromPyArrowArray: Sized {
4443
/// Convert a Python `pyarrow` array to a Vortex array.
@@ -78,22 +77,27 @@ impl FromPyArrowArray for ArrayRef {
7877
}
7978
}
8079

81-
/// `__arrow_ext_serialize__` returns raw bytes; we pass them straight to the plugin.
82-
/// Base64 is only the encoding used in the Arrow Field-metadata string channel — going
83-
/// directly Python → registry skips that hop.
80+
/// Raw bytes from `__arrow_ext_serialize__` — no base64 (that's only for the
81+
/// Arrow Field-metadata string channel). Variant short-circuits to `None` so it surfaces
82+
/// as `DType::Variant` via the storage path, mirroring `dtype/arrow.rs::dtype_from_field`.
8483
fn extract_extension_info(py_array: &Bound<'_, PyAny>) -> PyResult<Option<(String, Vec<u8>)>> {
8584
let py = py_array.py();
8685
let py_type = py_array.getattr(intern!(py, "type"))?;
8786
if !py_type.is_instance(extension_type_class(py)?)? {
8887
return Ok(None);
8988
}
9089
let ext_name: String = py_type.getattr(intern!(py, "extension_name"))?.extract()?;
90+
if ext_name == ARROW_EXT_NAME_VARIANT {
91+
return Ok(None);
92+
}
9193
let ext_meta_bytes: Vec<u8> = py_type
9294
.call_method0(intern!(py, "__arrow_ext_serialize__"))?
9395
.extract()?;
9496
Ok(Some((ext_name, ext_meta_bytes)))
9597
}
9698

99+
/// Soft fallback to storage on registry miss or malformed metadata, mirroring
100+
/// `dtype/arrow.rs::resolve_extension_dtype`.
97101
fn wrap_with_extension(
98102
storage: ArrayRef,
99103
ext_name: &str,
@@ -102,11 +106,20 @@ fn wrap_with_extension(
102106
) -> VortexResult<ArrayRef> {
103107
let ext_id = ExtId::new(ext_name);
104108
let dtypes = session.dtypes();
105-
let plugin = dtypes
106-
.registry()
107-
.find(&ext_id)
108-
.ok_or_else(|| vortex_err!("extension `{ext_name}` is not registered on the session"))?;
109-
let ext_dtype = plugin.deserialize(ext_meta_bytes, storage.dtype().clone())?;
109+
let Some(plugin) = dtypes.registry().find(&ext_id) else {
110+
log::warn!("pyarrow extension {ext_name:?} not registered on session; using storage dtype");
111+
return Ok(storage);
112+
};
113+
let ext_dtype = match plugin.deserialize(ext_meta_bytes, storage.dtype().clone()) {
114+
Ok(dt) => dt,
115+
Err(e) => {
116+
log::warn!(
117+
"pyarrow extension {ext_name:?} failed to deserialize metadata ({e}); \
118+
using storage dtype",
119+
);
120+
return Ok(storage);
121+
}
122+
};
110123
Ok(ExtensionArray::try_new(ext_dtype, storage)?.into_array())
111124
}
112125

@@ -135,28 +148,49 @@ pub(super) fn from_arrow(obj: &Borrowed<'_, '_, PyAny>) -> PyVortexResult<PyArra
135148
Ok(PyArrayRef::from(enc_array))
136149
} else if obj.is_instance(chunked_array)? {
137150
let chunks: Vec<Bound<PyAny>> = obj.getattr(intern!(py, "chunks"))?.extract()?;
151+
// ChunkedArray has a uniform type — peek extension identity once and reuse.
152+
let bound = obj.to_owned();
153+
let ext_info = extract_extension_info(&bound)?;
138154
let encoded_chunks = chunks
139155
.iter()
140-
.map(|a| {
141-
let arrow_array = ArrowArrayData::from_pyarrow(&a.as_borrowed()).map(make_array)?;
142-
ArrayRef::from_arrow(arrow_array.as_ref(), false).map_err(PyVortexError::from)
156+
.map(|chunk| {
157+
let arrow_array =
158+
ArrowArrayData::from_pyarrow(&chunk.as_borrowed()).map(make_array)?;
159+
let storage = ArrayRef::from_arrow_with_session(
160+
arrow_array.as_ref(),
161+
arrow_array.is_nullable(),
162+
&SESSION,
163+
)
164+
.map_err(PyVortexError::from)?;
165+
match &ext_info {
166+
None => Ok(storage),
167+
Some((name, meta)) => wrap_with_extension(storage, name, meta, &SESSION)
168+
.map_err(|e| PyVortexError::from(e).into()),
169+
}
143170
})
144-
.collect::<PyVortexResult<Vec<_>>>()?;
145-
let dtype: DType = obj
146-
.getattr(intern!(py, "type"))
147-
.and_then(|v| DataType::from_pyarrow(&v.as_borrowed()))
148-
.map(|dt| DType::from_arrow(&Field::new("_", dt, false)))?;
171+
.collect::<PyResult<Vec<_>>>()?;
172+
let dtype: DType = if let Some(first) = encoded_chunks.first() {
173+
first.dtype().clone()
174+
} else {
175+
// Empty array: `obj.type` over the C ABI loses extension metadata, so we
176+
// recover only the storage dtype.
177+
obj.getattr(intern!(py, "type"))
178+
.and_then(|v| DataType::from_pyarrow(&v.as_borrowed()))
179+
.map(|dt| DType::from_arrow_with_session(&Field::new("_", dt, false), &SESSION))?
180+
};
149181
Ok(PyArrayRef::from(
150182
ChunkedArray::try_new(encoded_chunks, dtype)?.into_array(),
151183
))
152184
} else if obj.is_instance(table)? {
185+
// The C ABI Stream carries Field metadata on the schema — session-aware
186+
// conversion recovers extensions directly, no Python peek needed.
153187
let array_stream = ArrowArrayStreamReader::from_pyarrow(&obj.as_borrowed())?;
154-
let dtype = DType::from_arrow(array_stream.schema());
188+
let dtype = DType::from_arrow_with_session(array_stream.schema(), &SESSION);
155189
let chunks = array_stream
156190
.into_iter()
157191
.map(|b| {
158192
b.map_err(VortexError::from)
159-
.and_then(|b| ArrayRef::from_arrow(b, false))
193+
.and_then(|b| ArrayRef::from_arrow_with_session(b, false, &SESSION))
160194
})
161195
.collect::<VortexResult<Vec<_>>>()?;
162196
Ok(PyArrayRef::from(
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
"""Round-trip tests for `vx.array` against pyarrow inputs carrying Vortex
5+
extension identity over the Arrow C ABI."""
6+
7+
from __future__ import annotations
8+
9+
import base64
10+
11+
import pyarrow as pa
12+
13+
import vortex as vx
14+
15+
# vortex.timestamp wire format: u8 unit_tag + u16 LE tz_len (us=1, no tz).
16+
# See vortex-array/src/extension/datetime/timestamp.rs::serialize_metadata.
17+
_TIMESTAMP_US_METADATA = bytes([1, 0, 0])
18+
19+
20+
class VortexTimestampType(pa.ExtensionType):
21+
"""A pyarrow `ExtensionType` matching Vortex's `vortex.timestamp` extension."""
22+
23+
def __init__(self, unit: str = "us"):
24+
# pyarrow calls `__arrow_ext_serialize__` in __init__, so set `_unit` first.
25+
self._unit = unit
26+
pa.ExtensionType.__init__(self, pa.int64(), "vortex.timestamp")
27+
28+
def __arrow_ext_serialize__(self) -> bytes:
29+
unit_tag = {"ns": 0, "us": 1, "ms": 2, "s": 3}[self._unit]
30+
return bytes([unit_tag, 0, 0])
31+
32+
@classmethod
33+
def __arrow_ext_deserialize__(cls, storage_type, serialized): # noqa: ARG003
34+
unit_tag = serialized[0]
35+
unit = {0: "ns", 1: "us", 2: "ms", 3: "s"}[unit_tag]
36+
return cls(unit)
37+
38+
39+
def test_chunked_extension_array_uses_session_for_leaf_extension_type():
40+
ext_type = VortexTimestampType()
41+
storage = pa.array([1, 2, 3], type=pa.int64())
42+
arrow = pa.chunked_array([pa.ExtensionArray.from_storage(ext_type, storage)])
43+
array = vx.array(arrow)
44+
assert isinstance(array, vx.ChunkedArray)
45+
assert repr(array.dtype) == repr(vx.timestamp("us"))
46+
assert repr(array.chunks()[0].dtype) == repr(vx.timestamp("us"))
47+
48+
49+
def test_table_uses_session_for_extension_field_metadata():
50+
field = pa.field("ts", pa.int64(), nullable=False).with_metadata(
51+
{
52+
b"ARROW:extension:name": b"vortex.timestamp",
53+
b"ARROW:extension:metadata": base64.b64encode(_TIMESTAMP_US_METADATA),
54+
}
55+
)
56+
table = pa.Table.from_arrays(
57+
[pa.array([1, 2, 3], type=pa.int64())],
58+
schema=pa.schema([field]),
59+
)
60+
array = vx.array(table)
61+
expected = vx.struct({"ts": vx.timestamp("us")})
62+
assert isinstance(array, vx.ChunkedArray)
63+
assert repr(array.dtype) == repr(expected)
64+
assert repr(array.chunks()[0].dtype) == repr(expected)

0 commit comments

Comments
 (0)