Skip to content

Commit 632f898

Browse files
committed
fix reduction rule
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 1ece694 commit 632f898

File tree

2 files changed

+62
-62
lines changed

2 files changed

+62
-62
lines changed

vortex-array/src/arrays/extension/compute/rules.rs

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,52 +14,46 @@ use crate::arrays::Filter;
1414
use crate::arrays::extension::ExtensionArrayExt;
1515
use crate::arrays::filter::FilterReduceAdaptor;
1616
use crate::arrays::slice::SliceReduceAdaptor;
17-
use crate::matcher::AnyArray;
1817
use crate::optimizer::rules::ArrayParentReduceRule;
18+
use crate::optimizer::rules::ArrayReduceRule;
1919
use crate::optimizer::rules::ParentRuleSet;
20+
use crate::optimizer::rules::ReduceRuleSet;
2021
use crate::scalar::Scalar;
2122
use crate::scalar_fn::fns::cast::CastReduceAdaptor;
2223
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
2324

24-
pub(crate) const PARENT_RULES: ParentRuleSet<Extension> = ParentRuleSet::new(&[
25-
ParentRuleSet::lift(&ExtensionConstantParentRule),
26-
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
27-
ParentRuleSet::lift(&CastReduceAdaptor(Extension)),
28-
ParentRuleSet::lift(&FilterReduceAdaptor(Extension)),
29-
ParentRuleSet::lift(&MaskReduceAdaptor(Extension)),
30-
ParentRuleSet::lift(&SliceReduceAdaptor(Extension)),
31-
]);
25+
pub(crate) const RULES: ReduceRuleSet<Extension> = ReduceRuleSet::new(&[&ExtensionConstantRule]);
3226

3327
/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`.
3428
#[derive(Debug)]
35-
struct ExtensionConstantParentRule;
36-
37-
impl ArrayParentReduceRule<Extension> for ExtensionConstantParentRule {
38-
type Parent = AnyArray;
29+
struct ExtensionConstantRule;
3930

40-
fn reduce_parent(
41-
&self,
42-
child: ArrayView<'_, Extension>,
43-
parent: &ArrayRef,
44-
child_idx: usize,
45-
) -> VortexResult<Option<ArrayRef>> {
46-
let Some(const_array) = child.storage_array().as_opt::<Constant>() else {
31+
impl ArrayReduceRule<Extension> for ExtensionConstantRule {
32+
fn reduce(&self, array: ArrayView<'_, Extension>) -> VortexResult<Option<ArrayRef>> {
33+
let Some(const_array) = array.storage_array().as_opt::<Constant>() else {
34+
println!("not constant");
4735
return Ok(None);
4836
};
37+
println!("reducing");
4938

5039
let storage_scalar = const_array.scalar().clone();
51-
let ext_scalar = Scalar::extension_ref(child.ext_dtype().clone(), storage_scalar);
40+
let ext_scalar = Scalar::extension_ref(array.ext_dtype().clone(), storage_scalar);
5241

5342
let constant_with_extension_scalar =
54-
ConstantArray::new(ext_scalar, child.len()).into_array();
43+
ConstantArray::new(ext_scalar, array.len()).into_array();
5544

56-
parent
57-
.clone()
58-
.with_slot(child_idx, constant_with_extension_scalar)
59-
.map(Some)
45+
Ok(Some(constant_with_extension_scalar.into_array()))
6046
}
6147
}
6248

49+
pub(crate) const PARENT_RULES: ParentRuleSet<Extension> = ParentRuleSet::new(&[
50+
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
51+
ParentRuleSet::lift(&CastReduceAdaptor(Extension)),
52+
ParentRuleSet::lift(&FilterReduceAdaptor(Extension)),
53+
ParentRuleSet::lift(&MaskReduceAdaptor(Extension)),
54+
ParentRuleSet::lift(&SliceReduceAdaptor(Extension)),
55+
]);
56+
6357
/// Push filter operations into the storage array of an extension array.
6458
#[derive(Debug)]
6559
struct ExtensionFilterPushDownRule;
@@ -99,6 +93,7 @@ mod tests {
9993
use crate::arrays::ExtensionArray;
10094
use crate::arrays::FilterArray;
10195
use crate::arrays::PrimitiveArray;
96+
use crate::arrays::ScalarFnVTable;
10297
use crate::arrays::extension::ExtensionArrayExt;
10398
use crate::arrays::scalar_fn::ScalarFnArrayExt;
10499
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
@@ -227,8 +222,8 @@ mod tests {
227222
.try_new_array(3, Operator::Lt, [constant_ext, ext_array])
228223
.unwrap();
229224

230-
let optimized = scalar_fn_array.optimize().unwrap();
231-
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
225+
let optimized = scalar_fn_array.optimize_recursive().unwrap();
226+
let scalar_fn = optimized.as_opt::<ScalarFnVTable>().unwrap();
232227
let children = scalar_fn.children();
233228
let constant = children[0]
234229
.as_opt::<Constant>()
@@ -291,7 +286,7 @@ mod tests {
291286
let optimized = scalar_fn_array.optimize().unwrap();
292287

293288
// The first child should still be an ExtensionArray (no pushdown happened)
294-
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
289+
let scalar_fn = optimized.as_opt::<ScalarFnVTable>().unwrap();
295290
assert!(
296291
scalar_fn.children()[0].as_opt::<Extension>().is_some(),
297292
"Expected first child to remain ExtensionArray when ext types differ"
@@ -316,7 +311,7 @@ mod tests {
316311
let optimized = scalar_fn_array.optimize().unwrap();
317312

318313
// No pushdown should happen because sibling is not a constant
319-
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
314+
let scalar_fn = optimized.as_opt::<ScalarFnVTable>().unwrap();
320315
assert!(
321316
scalar_fn.children()[0].as_opt::<Extension>().is_some(),
322317
"Expected first child to remain ExtensionArray when sibling is not constant"
@@ -339,7 +334,7 @@ mod tests {
339334
let optimized = scalar_fn_array.optimize().unwrap();
340335

341336
// No pushdown should happen because constant is not an extension scalar
342-
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
337+
let scalar_fn = optimized.as_opt::<ScalarFnVTable>().unwrap();
343338
assert!(
344339
scalar_fn.children()[0].as_opt::<Extension>().is_some(),
345340
"Expected first child to remain ExtensionArray when constant is not extension"

vortex-array/src/arrays/extension/vtable/mod.rs

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@ use crate::arrays::extension::ExtensionData;
3131
use crate::arrays::extension::array::SLOT_NAMES;
3232
use crate::arrays::extension::array::STORAGE_SLOT;
3333
use crate::arrays::extension::compute::rules::PARENT_RULES;
34+
use crate::arrays::extension::compute::rules::RULES;
3435
use crate::buffer::BufferHandle;
3536
use crate::dtype::DType;
3637
use crate::serde::ArrayChildren;
3738

39+
#[derive(Clone, Debug)]
40+
pub struct Extension;
41+
3842
/// A [`Extension`]-encoded Vortex array.
3943
pub type ExtensionArray = Array<Extension>;
4044

@@ -59,29 +63,6 @@ impl VTable for Extension {
5963
*ID
6064
}
6165

62-
fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
63-
0
64-
}
65-
66-
fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
67-
vortex_panic!("ExtensionArray buffer index {idx} out of bounds")
68-
}
69-
70-
fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
71-
None
72-
}
73-
74-
fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
75-
SLOT_NAMES[idx].to_string()
76-
}
77-
78-
fn serialize(
79-
_array: ArrayView<'_, Self>,
80-
_session: &VortexSession,
81-
) -> VortexResult<Option<Vec<u8>>> {
82-
Ok(Some(vec![]))
83-
}
84-
8566
fn validate(
8667
&self,
8768
data: &ExtensionData,
@@ -111,6 +92,25 @@ impl VTable for Extension {
11192
Ok(())
11293
}
11394

95+
fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
96+
0
97+
}
98+
99+
fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
100+
vortex_panic!("ExtensionArray buffer index {idx} out of bounds")
101+
}
102+
103+
fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
104+
None
105+
}
106+
107+
fn serialize(
108+
_array: ArrayView<'_, Self>,
109+
_session: &VortexSession,
110+
) -> VortexResult<Option<Vec<u8>>> {
111+
Ok(Some(vec![]))
112+
}
113+
114114
fn deserialize(
115115
&self,
116116
dtype: &DType,
@@ -143,27 +143,32 @@ impl VTable for Extension {
143143
.with_slots(vec![Some(storage)]))
144144
}
145145

146+
fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
147+
SLOT_NAMES[idx].to_string()
148+
}
149+
146150
fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
147151
Ok(ExecutionResult::done(array))
148152
}
149153

150-
fn reduce_parent(
154+
fn execute_parent(
151155
array: ArrayView<'_, Self>,
152156
parent: &ArrayRef,
153157
child_idx: usize,
158+
ctx: &mut ExecutionCtx,
154159
) -> VortexResult<Option<ArrayRef>> {
155-
PARENT_RULES.evaluate(array, parent, child_idx)
160+
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
156161
}
157162

158-
fn execute_parent(
163+
fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
164+
RULES.evaluate(array)
165+
}
166+
167+
fn reduce_parent(
159168
array: ArrayView<'_, Self>,
160169
parent: &ArrayRef,
161170
child_idx: usize,
162-
ctx: &mut ExecutionCtx,
163171
) -> VortexResult<Option<ArrayRef>> {
164-
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
172+
PARENT_RULES.evaluate(array, parent, child_idx)
165173
}
166174
}
167-
168-
#[derive(Clone, Debug)]
169-
pub struct Extension;

0 commit comments

Comments
 (0)