Skip to content

Commit 20363bc

Browse files
authored
Add extension storage scalar function (#8540)
This function allows you to extract the storage value from an extension type. It is safe for us to generically implement this, but for casting storage -> extension type it should be custom kernels to perform validation etc. Signed-off-by: "Nicholas Gates" <nick@nickgates.com>
1 parent e7a7ad9 commit 20363bc

5 files changed

Lines changed: 208 additions & 1 deletion

File tree

vortex-array/src/expr/exprs.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use crate::scalar_fn::fns::cast::Cast;
2929
use crate::scalar_fn::fns::dynamic::DynamicComparison;
3030
use crate::scalar_fn::fns::dynamic::DynamicComparisonExpr;
3131
use crate::scalar_fn::fns::dynamic::Rhs;
32+
use crate::scalar_fn::fns::ext_storage::ExtStorage;
3233
use crate::scalar_fn::fns::fill_null::FillNull;
3334
use crate::scalar_fn::fns::get_item::GetItem;
3435
use crate::scalar_fn::fns::is_not_null::IsNotNull;
@@ -737,3 +738,15 @@ pub fn list_contains(list: Expression, value: Expression) -> Expression {
737738
pub fn byte_length(input: Expression) -> Expression {
738739
ByteLength.new_expr(EmptyOptions, [input])
739740
}
741+
742+
// ---- ExtStorage ----
743+
744+
/// Creates an expression that extracts the storage values from an extension array.
745+
///
746+
/// ```rust
747+
/// # use vortex_array::expr::{ext_storage, root};
748+
/// let expr = ext_storage(root());
749+
/// ```
750+
pub fn ext_storage(input: Expression) -> Expression {
751+
ExtStorage.new_expr(EmptyOptions, [input])
752+
}
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_error::vortex_bail;
6+
use vortex_session::VortexSession;
7+
use vortex_session::registry::CachedId;
8+
9+
use crate::ArrayRef;
10+
use crate::ExecutionCtx;
11+
use crate::IntoArray;
12+
use crate::arrays::ConstantArray;
13+
use crate::arrays::ExtensionArray;
14+
use crate::arrays::extension::ExtensionArrayExt;
15+
use crate::dtype::DType;
16+
use crate::expr::Expression;
17+
use crate::scalar_fn::Arity;
18+
use crate::scalar_fn::ChildName;
19+
use crate::scalar_fn::EmptyOptions;
20+
use crate::scalar_fn::ExecutionArgs;
21+
use crate::scalar_fn::ScalarFnId;
22+
use crate::scalar_fn::ScalarFnVTable;
23+
24+
/// Extract the storage values from an extension array.
25+
#[derive(Clone)]
26+
pub struct ExtStorage;
27+
28+
impl ScalarFnVTable for ExtStorage {
29+
type Options = EmptyOptions;
30+
31+
fn id(&self) -> ScalarFnId {
32+
static ID: CachedId = CachedId::new("vortex.ext.storage");
33+
*ID
34+
}
35+
36+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
37+
Ok(Some(vec![]))
38+
}
39+
40+
fn deserialize(
41+
&self,
42+
_metadata: &[u8],
43+
_session: &VortexSession,
44+
) -> VortexResult<Self::Options> {
45+
Ok(EmptyOptions)
46+
}
47+
48+
fn arity(&self, _options: &Self::Options) -> Arity {
49+
Arity::Exact(1)
50+
}
51+
52+
fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
53+
match child_idx {
54+
0 => ChildName::from("input"),
55+
_ => unreachable!("Invalid child index {child_idx} for ext_storage()"),
56+
}
57+
}
58+
59+
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
60+
let DType::Extension(ext_dtype) = &arg_dtypes[0] else {
61+
vortex_bail!("ext_storage() requires Extension, got {}", arg_dtypes[0]);
62+
};
63+
64+
Ok(ext_dtype.storage_dtype().clone())
65+
}
66+
67+
fn execute(
68+
&self,
69+
_options: &Self::Options,
70+
args: &dyn ExecutionArgs,
71+
ctx: &mut ExecutionCtx,
72+
) -> VortexResult<ArrayRef> {
73+
let input = args.get(0)?;
74+
75+
if !matches!(input.dtype(), DType::Extension(_)) {
76+
vortex_bail!("ext_storage() requires Extension, got {}", input.dtype());
77+
}
78+
79+
if let Some(scalar) = input.as_constant() {
80+
let storage_scalar = scalar.as_extension().to_storage_scalar();
81+
return Ok(ConstantArray::new(storage_scalar, args.row_count()).into_array());
82+
}
83+
84+
let input = input.execute::<ExtensionArray>(ctx)?;
85+
Ok(input.storage_array().clone())
86+
}
87+
88+
fn validity(
89+
&self,
90+
_options: &Self::Options,
91+
expression: &Expression,
92+
) -> VortexResult<Option<Expression>> {
93+
Ok(Some(expression.child(0).validity()?))
94+
}
95+
96+
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
97+
false
98+
}
99+
100+
fn is_fallible(&self, _options: &Self::Options) -> bool {
101+
false
102+
}
103+
}
104+
105+
#[cfg(test)]
106+
mod tests {
107+
use vortex_buffer::buffer;
108+
use vortex_error::VortexResult;
109+
110+
use crate::IntoArray;
111+
use crate::arrays::ConstantArray;
112+
use crate::arrays::ExtensionArray;
113+
use crate::arrays::PrimitiveArray;
114+
use crate::assert_arrays_eq;
115+
use crate::dtype::DType;
116+
use crate::dtype::Nullability;
117+
use crate::dtype::PType;
118+
use crate::dtype::extension::ExtDTypeRef;
119+
use crate::expr::ext_storage;
120+
use crate::expr::root;
121+
use crate::extension::datetime::TimeUnit;
122+
use crate::extension::datetime::Timestamp;
123+
use crate::scalar::Scalar;
124+
125+
fn ext_dtype(nullability: Nullability) -> ExtDTypeRef {
126+
Timestamp::new(TimeUnit::Nanoseconds, nullability).erased()
127+
}
128+
129+
#[test]
130+
fn extracts_extension_storage_array() -> VortexResult<()> {
131+
let storage = buffer![2i64, 4, 6].into_array();
132+
let array =
133+
ExtensionArray::new(ext_dtype(Nullability::NonNullable), storage.clone()).into_array();
134+
135+
let result = array.apply(&ext_storage(root()))?;
136+
137+
assert_eq!(
138+
result.dtype(),
139+
&DType::Primitive(PType::I64, Nullability::NonNullable)
140+
);
141+
assert_arrays_eq!(result, storage);
142+
Ok(())
143+
}
144+
145+
#[test]
146+
fn extracts_nullable_extension_storage_array() -> VortexResult<()> {
147+
let storage = PrimitiveArray::from_option_iter([Some(2i64), None, Some(6)]).into_array();
148+
let array =
149+
ExtensionArray::new(ext_dtype(Nullability::Nullable), storage.clone()).into_array();
150+
151+
let result = array.apply(&ext_storage(root()))?;
152+
153+
assert_eq!(
154+
result.dtype(),
155+
&DType::Primitive(PType::I64, Nullability::Nullable)
156+
);
157+
assert_arrays_eq!(result, storage);
158+
Ok(())
159+
}
160+
161+
#[test]
162+
fn extracts_constant_extension_storage_scalar() -> VortexResult<()> {
163+
let storage_scalar = Scalar::primitive(4i64, Nullability::NonNullable);
164+
let scalar =
165+
Scalar::extension_ref(ext_dtype(Nullability::NonNullable), storage_scalar.clone());
166+
let array = ConstantArray::new(scalar, 3).into_array();
167+
168+
let result = array.apply(&ext_storage(root()))?;
169+
170+
assert_eq!(
171+
result.dtype(),
172+
&DType::Primitive(PType::I64, Nullability::NonNullable)
173+
);
174+
assert_arrays_eq!(result, ConstantArray::new(storage_scalar, 3));
175+
Ok(())
176+
}
177+
178+
#[test]
179+
fn rejects_non_extension_input() {
180+
let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
181+
let err = ext_storage(root()).return_dtype(&dtype).unwrap_err();
182+
assert!(err.to_string().contains("requires Extension"));
183+
}
184+
185+
#[test]
186+
fn test_display() {
187+
assert_eq!(ext_storage(root()).to_string(), "vortex.ext.storage($)");
188+
}
189+
}

vortex-array/src/scalar_fn/fns/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub mod byte_length;
77
pub mod case_when;
88
pub mod cast;
99
pub mod dynamic;
10+
pub mod ext_storage;
1011
pub mod fill_null;
1112
pub mod get_item;
1213
pub mod is_not_null;

vortex-array/src/scalar_fn/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use vortex_session::registry::Id;
1111

1212
use crate::scalar_fn::fns::byte_length::ByteLength;
13+
use crate::scalar_fn::fns::ext_storage::ExtStorage;
1314
use crate::scalar_fn::fns::get_item::GetItem;
1415
use crate::scalar_fn::fns::literal::Literal;
1516

@@ -56,13 +57,14 @@ mod sealed {
5657
/// A scalar function has a negative cost if applying it to an array and
5758
/// canonicalizing is cheaper than canonicalizing an array and applying it.
5859
///
59-
/// Example of negative cost expressions are byte_length() and get_item() since
60+
/// Example of negative cost expressions are byte_length(), ext_storage(), and get_item() since
6061
/// they don't depend on input size.
6162
///
6263
/// Example of non-negative cost expression is like() as it's linear over
6364
/// individual input.
6465
pub fn is_negative_cost(id: ScalarFnId) -> bool {
6566
id == ScalarFnVTable::id(&ByteLength)
67+
|| id == ScalarFnVTable::id(&ExtStorage)
6668
|| id == ScalarFnVTable::id(&GetItem)
6769
|| id == ScalarFnVTable::id(&Literal)
6870
}

vortex-array/src/scalar_fn/session.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::scalar_fn::ScalarFnVTable;
1313
use crate::scalar_fn::fns::between::Between;
1414
use crate::scalar_fn::fns::binary::Binary;
1515
use crate::scalar_fn::fns::cast::Cast;
16+
use crate::scalar_fn::fns::ext_storage::ExtStorage;
1617
use crate::scalar_fn::fns::fill_null::FillNull;
1718
use crate::scalar_fn::fns::get_item::GetItem;
1819
use crate::scalar_fn::fns::is_not_null::IsNotNull;
@@ -59,6 +60,7 @@ impl Default for ScalarFnSession {
5960
this.register(Between);
6061
this.register(Binary);
6162
this.register(Cast);
63+
this.register(ExtStorage);
6264
this.register(FillNull);
6365
this.register(GetItem);
6466
this.register(IsNotNull);

0 commit comments

Comments
 (0)