Skip to content

Commit c6a0b7b

Browse files
connortsui20claude
andauthored
Remove len parameter from ScalarFnArray::try_new (#8378)
## Summary This didn't need to take length because it already gets it from the array children. ## API Changes Removes `len` parameter. ## Testing N/A Signed-off-by: Connor Tsui <connor.tsui20@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent 8475910 commit c6a0b7b

16 files changed

Lines changed: 196 additions & 170 deletions

File tree

encodings/datetime-parts/src/compute/rules.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,9 @@ impl ArrayParentReduceRule<DateTimeParts> for DTPComparisonPushDownRule {
133133
}
134134
}
135135

136-
let result =
137-
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, parent.len())?
138-
.into_array()
139-
.optimize()?;
136+
let result = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?
137+
.into_array()
138+
.optimize()?;
140139

141140
Ok(Some(result))
142141
}

encodings/runend/src/rules.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ impl ArrayParentReduceRule<RunEnd> for RunEndScalarFnRule {
7777
}
7878

7979
let new_values =
80-
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
81-
.into_array();
80+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?.into_array();
8281

8382
Ok(Some(
8483
unsafe {

vortex-array/src/arrays/chunked/compute/rules.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,9 @@ impl ArrayParentReduceRule<Chunked> for ChunkedUnaryScalarFnPushDownRule {
4848
let new_chunks: Vec<_> = array
4949
.iter_chunks()
5050
.map(|chunk| {
51-
ScalarFnArray::try_new(
52-
parent.scalar_fn().clone(),
53-
vec![chunk.clone()],
54-
chunk.len(),
55-
)?
56-
.into_array()
57-
.optimize()
51+
ScalarFnArray::try_new(parent.scalar_fn().clone(), vec![chunk.clone()])?
52+
.into_array()
53+
.optimize()
5854
})
5955
.try_collect()?;
6056

@@ -104,7 +100,7 @@ impl ArrayParentReduceRule<Chunked> for ChunkedConstantScalarFnPushDownRule {
104100
})
105101
.collect();
106102

107-
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())?
103+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?
108104
.into_array()
109105
.optimize()
110106
})

vortex-array/src/arrays/dict/compute/rules.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,9 @@ impl ArrayParentReduceRule<Dict> for DictionaryScalarFnValuesPushDownRule {
126126
}
127127
}
128128

129-
let new_values =
130-
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
131-
.into_array()
132-
.optimize()?;
129+
let new_values = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?
130+
.into_array()
131+
.optimize()?;
133132

134133
// We can only push down null-sensitive functions when we have all-valid codes.
135134
// In these cases, we cannot have the codes influence the nullability of the output DType.
@@ -192,13 +191,9 @@ impl ArrayParentReduceRule<Dict> for DictionaryScalarFnCodesPullUpRule {
192191
}
193192
}
194193

195-
let new_values = ScalarFnArray::try_new(
196-
parent.scalar_fn().clone(),
197-
new_children,
198-
array.values().len(),
199-
)?
200-
.into_array()
201-
.optimize()?;
194+
let new_values = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?
195+
.into_array()
196+
.optimize()?;
202197

203198
let new_dict =
204199
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array();

vortex-array/src/arrays/fixed_size_list/tests/nested.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ fn test_fsl_of_fsl_with_nulls() {
270270

271271
#[test]
272272
fn test_deeply_nested_fsl() {
273-
let _len = 2;
274273
let list_size = 2;
275274

276275
// Create a 3-level nested FSL: FSL[FSL[FSL[i32]]].

vortex-array/src/arrays/scalar_fn/array.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::fmt::Formatter;
66

77
use vortex_error::VortexExpect;
88
use vortex_error::VortexResult;
9+
use vortex_error::vortex_bail;
910
use vortex_error::vortex_ensure;
1011

1112
use crate::ArrayRef;
@@ -30,19 +31,6 @@ impl Display for ScalarFnData {
3031
}
3132

3233
impl ScalarFnData {
33-
/// Create a new ScalarFnArray from a scalar function and its children.
34-
pub fn build(
35-
scalar_fn: ScalarFnRef,
36-
children: Vec<ArrayRef>,
37-
len: usize,
38-
) -> VortexResult<Self> {
39-
vortex_ensure!(
40-
children.iter().all(|c| c.len() == len),
41-
"ScalarFnArray must have children equal to the array length"
42-
);
43-
Ok(Self { scalar_fn })
44-
}
45-
4634
/// Get the scalar function bound to this array.
4735
#[inline(always)]
4836
pub fn scalar_fn(&self) -> &ScalarFnRef {
@@ -85,14 +73,26 @@ impl<T: TypedArrayRef<ScalarFn>> ScalarFnArrayExt for T {}
8573

8674
impl Array<ScalarFn> {
8775
/// Create a new ScalarFnArray from a scalar function and its children.
88-
pub fn try_new(
76+
pub fn try_new(scalar_fn: ScalarFnRef, children: Vec<ArrayRef>) -> VortexResult<Self> {
77+
let len = Self::infer_len(&children)?;
78+
Self::try_new_with_len(scalar_fn, children, len)
79+
}
80+
81+
/// Create a new ScalarFnArray from a scalar function, children, and an explicit length.
82+
///
83+
/// This is needed for zero-child scalar functions and deserialization paths where there is no
84+
/// child array to infer the length from.
85+
pub fn try_new_with_len(
8986
scalar_fn: ScalarFnRef,
9087
children: Vec<ArrayRef>,
9188
len: usize,
9289
) -> VortexResult<Self> {
90+
Self::validate_children_len(&children, len)?;
9391
let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
9492
let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
95-
let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?;
93+
let data = ScalarFnData {
94+
scalar_fn: scalar_fn.clone(),
95+
};
9696
let vtable = ScalarFn { id: scalar_fn.id() };
9797
Ok(unsafe {
9898
Array::from_parts_unchecked(
@@ -101,4 +101,19 @@ impl Array<ScalarFn> {
101101
)
102102
})
103103
}
104+
105+
fn infer_len(children: &[ArrayRef]) -> VortexResult<usize> {
106+
let Some(child) = children.first() else {
107+
vortex_bail!("ScalarFnArray length cannot be inferred without children");
108+
};
109+
Ok(child.len())
110+
}
111+
112+
fn validate_children_len(children: &[ArrayRef], len: usize) -> VortexResult<()> {
113+
vortex_ensure!(
114+
children.iter().all(|c| c.len() == len),
115+
"ScalarFnArray must have children equal to the array length"
116+
);
117+
Ok(())
118+
}
104119
}

vortex-array/src/arrays/scalar_fn/plugin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl<V: ScalarFnVTable + ScalarFnArrayVTable> ArrayPlugin for ScalarFnArrayPlugi
8282
let parts = <V as ScalarFnArrayVTable>::deserialize(
8383
&self.0, dtype, len, metadata, children, session,
8484
)?;
85-
Ok(ScalarFnArray::try_new(
85+
Ok(ScalarFnArray::try_new_with_len(
8686
TypedScalarFnInstance::new(self.0.clone(), parts.options).erased(),
8787
parts.children,
8888
len,

vortex-array/src/arrays/scalar_fn/rules.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ impl ArrayParentReduceRule<ScalarFn> for ScalarFnSliceReduceRule {
8484
.collect::<VortexResult<_>>()?;
8585

8686
Ok(Some(
87-
ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(),
87+
ScalarFnArray::try_new_with_len(array.scalar_fn().clone(), children, range.len())?
88+
.into_array(),
8889
))
8990
}
9091
}
@@ -142,7 +143,7 @@ impl ReduceCtx for ArrayReduceCtx {
142143
children: &[ReduceNodeRef],
143144
) -> VortexResult<ReduceNodeRef> {
144145
Ok(Arc::new(
145-
ScalarFnArray::try_new(
146+
ScalarFnArray::try_new_with_len(
146147
scalar_fn,
147148
children
148149
.iter()
@@ -191,8 +192,7 @@ impl ArrayParentReduceRule<ScalarFn> for ScalarFnUnaryFilterPushDownRule {
191192
.try_collect()?;
192193

193194
let new_array =
194-
ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())?
195-
.into_array();
195+
ScalarFnArray::try_new(child.scalar_fn().clone(), new_children)?.into_array();
196196

197197
return Ok(Some(new_array));
198198
}

vortex-array/src/arrays/scalar_fn/vtable/operations.rs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ mod tests {
6565
use crate::arrays::BoolArray;
6666
use crate::arrays::PrimitiveArray;
6767
use crate::arrays::ScalarFnArray;
68+
use crate::arrays::scalar_fn::ScalarFnArrayExt;
6869
use crate::assert_arrays_eq;
70+
use crate::scalar::Scalar;
6971
use crate::scalar_fn::TypedScalarFnInstance;
7072
use crate::scalar_fn::fns::binary::Binary;
73+
use crate::scalar_fn::fns::literal::Literal;
7174
use crate::scalar_fn::fns::operators::Operator;
7275
use crate::validity::Validity;
7376

@@ -77,7 +80,9 @@ mod tests {
7780
let rhs = buffer![10i32, 20, 30].into_array();
7881

7982
let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
80-
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
83+
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
84+
85+
assert_eq!(scalar_fn_array.len(), 3);
8186

8287
let result = scalar_fn_array
8388
.into_array()
@@ -89,13 +94,47 @@ mod tests {
8994
Ok(())
9095
}
9196

97+
#[test]
98+
fn test_scalar_fn_inferred_len_rejects_mismatched_children() {
99+
let lhs = buffer![1i32, 2, 3].into_array();
100+
let rhs = buffer![10i32, 20].into_array();
101+
102+
let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
103+
let err = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])
104+
.expect_err("ScalarFnArray::try_new must reject mismatched child lengths");
105+
106+
assert!(
107+
err.to_string()
108+
.contains("ScalarFnArray must have children equal to the array length")
109+
);
110+
}
111+
112+
#[test]
113+
fn test_scalar_fn_without_children_requires_explicit_len() -> VortexResult<()> {
114+
let scalar_fn = TypedScalarFnInstance::new(Literal, Scalar::from(1i32)).erased();
115+
116+
let Err(err) = ScalarFnArray::try_new(scalar_fn.clone(), vec![]) else {
117+
panic!("ScalarFnArray::try_new should reject zero children");
118+
};
119+
assert!(
120+
err.to_string()
121+
.contains("ScalarFnArray length cannot be inferred without children")
122+
);
123+
124+
let scalar_fn_array = ScalarFnArray::try_new_with_len(scalar_fn, vec![], 3)?;
125+
assert_eq!(scalar_fn_array.len(), 3);
126+
assert_eq!(scalar_fn_array.child_count(), 0);
127+
128+
Ok(())
129+
}
130+
92131
#[test]
93132
fn test_scalar_fn_mul() -> VortexResult<()> {
94133
let lhs = buffer![2i32, 3, 4].into_array();
95134
let rhs = buffer![5i32, 6, 7].into_array();
96135

97136
let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Mul).erased();
98-
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
137+
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
99138

100139
let result = scalar_fn_array
101140
.into_array()
@@ -117,7 +156,7 @@ mod tests {
117156
.into_array();
118157

119158
let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
120-
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
159+
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
121160

122161
let result = scalar_fn_array
123162
.into_array()
@@ -139,7 +178,7 @@ mod tests {
139178
let rhs = buffer![2i32, 5, 1].into_array();
140179

141180
let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Eq).erased();
142-
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
181+
let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
143182

144183
let result = scalar_fn_array
145184
.into_array()

vortex-array/src/expression.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ impl ArrayRef {
3535

3636
// And wrap the scalar function up in an array.
3737
let array =
38-
ScalarFnArray::try_new(expr.scalar_fn().clone(), children, self.len())?.into_array();
38+
ScalarFnArray::try_new_with_len(expr.scalar_fn().clone(), children, self.len())?
39+
.into_array();
3940

4041
// Optimize the resulting array's root.
4142
array.optimize()

0 commit comments

Comments
 (0)