Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions encodings/alp/benches/alp_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use vortex_alp::ALPRDFloat;
use vortex_alp::RDEncoder;
use vortex_alp::alp_encode;
use vortex_alp::decompress_into_array;
use vortex_array::Canonical;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::PrimitiveArray;
Expand Down Expand Up @@ -153,6 +155,6 @@ fn decompress_rd<T: ALPRDFloat + NativePType>(bencher: Bencher, args: (usize, f6
let encoded = encoder.encode(primitive.as_view());

bencher
.with_inputs(|| &encoded)
.bench_refs(|encoded| encoded.to_canonical());
.with_inputs(|| (&encoded, LEGACY_SESSION.create_execution_ctx()))
.bench_refs(|(encoded, ctx)| (**encoded).clone().into_array().execute::<Canonical>(ctx));
}
4 changes: 2 additions & 2 deletions encodings/alp/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub fn vortex_alp::ALP::validate(&self, data: &vortex_alp::ALPData, dtype: &vort

impl vortex_array::array::vtable::operations::OperationsVTable<vortex_alp::ALP> for vortex_alp::ALP

pub fn vortex_alp::ALP::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_alp::ALP>, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_alp::ALP::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_alp::ALP>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

impl vortex_array::array::vtable::validity::ValidityChild<vortex_alp::ALP> for vortex_alp::ALP

Expand Down Expand Up @@ -188,7 +188,7 @@ pub fn vortex_alp::ALPRD::validate(&self, data: &vortex_alp::ALPRDData, dtype: &

impl vortex_array::array::vtable::operations::OperationsVTable<vortex_alp::ALPRD> for vortex_alp::ALPRD

pub fn vortex_alp::ALPRD::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_alp::ALPRD>, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_alp::ALPRD::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_alp::ALPRD>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

impl vortex_array::array::vtable::validity::ValidityChild<vortex_alp::ALPRD> for vortex_alp::ALPRD

Expand Down
4 changes: 2 additions & 2 deletions encodings/alp/src/alp/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ impl OperationsVTable<ALP> for ALP {
fn scalar_at(
array: ArrayView<'_, ALP>,
index: usize,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<Scalar> {
if let Some(patches) = array.patches()
&& let Some(patch) = patches.get_patched(index)?
{
return patch.cast(array.dtype());
}

let encoded_val = array.encoded().scalar_at(index)?;
let encoded_val = array.encoded().execute_scalar(index, ctx)?;

Ok(match_each_alp_float_ptype!(array.dtype().as_ptype(), |T| {
let encoded_val: <T as ALPFloat>::ALPInt =
Expand Down
11 changes: 9 additions & 2 deletions encodings/alp/src/alp_rd/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ use vortex_array::ArrayView;
use vortex_array::ExecutionCtx;
use vortex_array::ExecutionResult;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::Precision;
use vortex_array::TypedArrayRef;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::Primitive;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::buffer::BufferHandle;
Expand Down Expand Up @@ -405,7 +407,10 @@ impl ALPRDData {
) -> VortexResult<Option<Patches>> {
left_parts_patches
.map(|patches| {
if !patches.values().all_valid()? {
if !patches
.values()
.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?
{
vortex_bail!("patches must be all valid: {}", patches.values());
}
// TODO(ngates): assert the DType, don't cast it.
Expand Down Expand Up @@ -586,7 +591,9 @@ fn validate_parts(
left_parts.dtype(),
);
vortex_ensure!(
patches.values().all_valid()?,
patches
.values()
.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
"patches must be all valid: {}",
patches.values()
);
Expand Down
8 changes: 4 additions & 4 deletions encodings/alp/src/alp_rd/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl OperationsVTable<ALPRD> for ALPRD {
fn scalar_at(
array: ArrayView<'_, ALPRD>,
index: usize,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<Scalar> {
// The left value can either be a direct value, or an exception.
// The exceptions array represents exception positions with non-null values.
Expand All @@ -32,7 +32,7 @@ impl OperationsVTable<ALPRD> for ALPRD {
_ => {
let left_code: u16 = array
.left_parts()
.scalar_at(index)?
.execute_scalar(index, ctx)?
.as_primitive()
.as_::<u16>()
.vortex_expect("left_code must be non-null");
Expand All @@ -44,7 +44,7 @@ impl OperationsVTable<ALPRD> for ALPRD {
Ok(if array.dtype().as_ptype() == PType::F32 {
let right: u32 = array
.right_parts()
.scalar_at(index)?
.execute_scalar(index, ctx)?
.as_primitive()
.as_::<u32>()
.vortex_expect("non-null");
Expand All @@ -53,7 +53,7 @@ impl OperationsVTable<ALPRD> for ALPRD {
} else {
let right: u64 = array
.right_parts()
.scalar_at(index)?
.execute_scalar(index, ctx)?
.as_primitive()
.as_::<u64>()
.vortex_expect("non-null");
Expand Down
13 changes: 8 additions & 5 deletions encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ impl From<Vec<Option<bool>>> for ByteBoolData {
mod tests {
use vortex_array::ArrayContext;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::assert_arrays_eq;
use vortex_array::serde::SerializeOptions;
use vortex_array::serde::SerializedArray;
Expand All @@ -366,15 +368,16 @@ mod tests {
let arr = ByteBool::from_vec(v, Validity::AllValid);
assert_eq!(v_len, arr.len());

let mut ctx = LEGACY_SESSION.create_execution_ctx();
for idx in 0..arr.len() {
assert!(arr.is_valid(idx).unwrap());
assert!(arr.is_valid(idx, &mut ctx).unwrap());
}

let v = vec![Some(true), None, Some(false)];
let arr = ByteBool::from_option_vec(v);
assert!(arr.is_valid(0).unwrap());
assert!(!arr.is_valid(1).unwrap());
assert!(arr.is_valid(2).unwrap());
assert!(arr.is_valid(0, &mut ctx).unwrap());
assert!(!arr.is_valid(1, &mut ctx).unwrap());
assert!(arr.is_valid(2, &mut ctx).unwrap());
assert_eq!(arr.len(), 3);

let v: Vec<Option<bool>> = vec![None, None];
Expand All @@ -384,7 +387,7 @@ mod tests {
assert_eq!(v_len, arr.len());

for idx in 0..arr.len() {
assert!(!arr.is_valid(idx).unwrap());
assert!(!arr.is_valid(idx, &mut ctx).unwrap());
}
assert_eq!(arr.len(), 2);
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/datetime-parts/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub fn vortex_datetime_parts::DateTimeParts::validate(&self, _data: &Self::Array

impl vortex_array::array::vtable::operations::OperationsVTable<vortex_datetime_parts::DateTimeParts> for vortex_datetime_parts::DateTimeParts

pub fn vortex_datetime_parts::DateTimeParts::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_datetime_parts::DateTimeParts>, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_datetime_parts::DateTimeParts::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_datetime_parts::DateTimeParts>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

impl vortex_array::array::vtable::validity::ValidityChild<vortex_datetime_parts::DateTimeParts> for vortex_datetime_parts::DateTimeParts

Expand All @@ -68,7 +68,7 @@ pub fn vortex_datetime_parts::DateTimeParts::slice(array: vortex_array::array::v

impl vortex_array::scalar_fn::fns::binary::compare::CompareKernel for vortex_datetime_parts::DateTimeParts

pub fn vortex_datetime_parts::DateTimeParts::compare(lhs: vortex_array::array::view::ArrayView<'_, Self>, rhs: &vortex_array::array::erased::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>
pub fn vortex_datetime_parts::DateTimeParts::compare(lhs: vortex_array::array::view::ArrayView<'_, Self>, rhs: &vortex_array::array::erased::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>

impl vortex_array::scalar_fn::fns::cast::kernel::CastReduce for vortex_datetime_parts::DateTimeParts

Expand Down
30 changes: 17 additions & 13 deletions encodings/datetime-parts/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl CompareKernel for DateTimeParts {
lhs: ArrayView<'_, Self>,
rhs: &ArrayRef,
operator: CompareOperator,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(rhs_const) = rhs.as_constant() else {
return Ok(None);
Expand All @@ -51,17 +51,17 @@ impl CompareKernel for DateTimeParts {
let ts_parts = timestamp::split(timestamp, options.unit)?;

match operator {
CompareOperator::Eq => compare_eq(lhs, &ts_parts, nullability),
CompareOperator::NotEq => compare_ne(lhs, &ts_parts, nullability),
CompareOperator::Eq => compare_eq(lhs, &ts_parts, nullability, ctx),
CompareOperator::NotEq => compare_ne(lhs, &ts_parts, nullability, ctx),
// lt and lte have identical behavior, as we optimize
// for the case that all days on the lhs are smaller.
// If that special case is not hit, we return `Ok(None)` to
// signal that the comparison wasn't handled within dtp.
CompareOperator::Lt => compare_lt(lhs, &ts_parts, nullability),
CompareOperator::Lte => compare_lt(lhs, &ts_parts, nullability),
CompareOperator::Lt => compare_lt(lhs, &ts_parts, nullability, ctx),
CompareOperator::Lte => compare_lt(lhs, &ts_parts, nullability, ctx),
// (Like for lt, lte)
CompareOperator::Gt => compare_gt(lhs, &ts_parts, nullability),
CompareOperator::Gte => compare_gt(lhs, &ts_parts, nullability),
CompareOperator::Gt => compare_gt(lhs, &ts_parts, nullability, ctx),
CompareOperator::Gte => compare_gt(lhs, &ts_parts, nullability, ctx),
}
}
}
Expand All @@ -70,9 +70,10 @@ fn compare_eq(
lhs: ArrayView<DateTimeParts>,
ts_parts: &timestamp::TimestampParts,
nullability: Nullability,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let mut comparison = compare_dtp(lhs.days(), ts_parts.days, CompareOperator::Eq, nullability)?;
if comparison.statistics().compute_max::<bool>() == Some(false) {
if comparison.statistics().compute_max::<bool>(ctx) == Some(false) {
// All values are different.
return Ok(Some(comparison));
}
Expand All @@ -85,7 +86,7 @@ fn compare_eq(
)?
.binary(comparison, Operator::And)?;

if comparison.statistics().compute_max::<bool>() == Some(false) {
if comparison.statistics().compute_max::<bool>(ctx) == Some(false) {
// All values are different.
return Ok(Some(comparison));
}
Expand All @@ -105,14 +106,15 @@ fn compare_ne(
lhs: ArrayView<DateTimeParts>,
ts_parts: &timestamp::TimestampParts,
nullability: Nullability,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let mut comparison = compare_dtp(
lhs.days(),
ts_parts.days,
CompareOperator::NotEq,
nullability,
)?;
if comparison.statistics().compute_min::<bool>() == Some(true) {
if comparison.statistics().compute_min::<bool>(ctx) == Some(true) {
// All values are different.
return Ok(Some(comparison));
}
Expand All @@ -125,7 +127,7 @@ fn compare_ne(
)?
.binary(comparison, Operator::Or)?;

if comparison.statistics().compute_min::<bool>() == Some(true) {
if comparison.statistics().compute_min::<bool>(ctx) == Some(true) {
// All values are different.
return Ok(Some(comparison));
}
Expand All @@ -145,9 +147,10 @@ fn compare_lt(
lhs: ArrayView<DateTimeParts>,
ts_parts: &timestamp::TimestampParts,
nullability: Nullability,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let days_lt = compare_dtp(lhs.days(), ts_parts.days, CompareOperator::Lt, nullability)?;
if days_lt.statistics().compute_min::<bool>() == Some(true) {
if days_lt.statistics().compute_min::<bool>(ctx) == Some(true) {
// All values on the lhs are smaller.
return Ok(Some(days_lt));
}
Expand All @@ -159,9 +162,10 @@ fn compare_gt(
lhs: ArrayView<DateTimeParts>,
ts_parts: &timestamp::TimestampParts,
nullability: Nullability,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let days_gt = compare_dtp(lhs.days(), ts_parts.days, CompareOperator::Gt, nullability)?;
if days_gt.statistics().compute_min::<bool>() == Some(true) {
if days_gt.statistics().compute_min::<bool>(ctx) == Some(true) {
// All values on the lhs are larger.
return Ok(Some(days_gt));
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/datetime-parts/src/compute/is_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ impl DynAggregateKernel for DateTimePartsIsConstantKernel {
let result = is_constant(array.days(), ctx)?
&& is_constant(array.seconds(), ctx)?
&& is_constant(array.subseconds(), ctx)?;
Ok(Some(IsConstant::make_partial(batch, result)?))
Ok(Some(IsConstant::make_partial(batch, result, ctx)?))
}
}
10 changes: 5 additions & 5 deletions encodings/datetime-parts/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl OperationsVTable<DateTimeParts> for DateTimeParts {
fn scalar_at(
array: ArrayView<'_, DateTimeParts>,
index: usize,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<Scalar> {
let DType::Extension(ext) = array.dtype().clone() else {
vortex_panic!(
Expand All @@ -33,25 +33,25 @@ impl OperationsVTable<DateTimeParts> for DateTimeParts {
vortex_panic!(Compute: "must decode TemporalMetadata from extension metadata");
};

if !array.as_ref().is_valid(index)? {
if !array.as_ref().is_valid(index, ctx)? {
return Ok(Scalar::null(DType::Extension(ext)));
}

let days: i64 = array
.days()
.scalar_at(index)?
.execute_scalar(index, ctx)?
.as_primitive()
.as_::<i64>()
.vortex_expect("days fits in i64");
let seconds: i64 = array
.seconds()
.scalar_at(index)?
.execute_scalar(index, ctx)?
.as_primitive()
.as_::<i64>()
.vortex_expect("seconds fits in i64");
let subseconds: i64 = array
.subseconds()
.scalar_at(index)?
.execute_scalar(index, ctx)?
.as_primitive()
.as_::<i64>()
.vortex_expect("subseconds fits in i64");
Expand Down
4 changes: 2 additions & 2 deletions encodings/decimal-byte-parts/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn vortex_decimal_byte_parts::DecimalByteParts::validate(&self, _data: &Self

impl vortex_array::array::vtable::operations::OperationsVTable<vortex_decimal_byte_parts::DecimalByteParts> for vortex_decimal_byte_parts::DecimalByteParts

pub fn vortex_decimal_byte_parts::DecimalByteParts::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_decimal_byte_parts::DecimalByteParts>, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_decimal_byte_parts::DecimalByteParts::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_decimal_byte_parts::DecimalByteParts>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

impl vortex_array::array::vtable::validity::ValidityChild<vortex_decimal_byte_parts::DecimalByteParts> for vortex_decimal_byte_parts::DecimalByteParts

Expand All @@ -66,7 +66,7 @@ pub fn vortex_decimal_byte_parts::DecimalByteParts::slice(array: vortex_array::a

impl vortex_array::scalar_fn::fns::binary::compare::CompareKernel for vortex_decimal_byte_parts::DecimalByteParts

pub fn vortex_decimal_byte_parts::DecimalByteParts::compare(lhs: vortex_array::array::view::ArrayView<'_, Self>, rhs: &vortex_array::array::erased::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>
pub fn vortex_decimal_byte_parts::DecimalByteParts::compare(lhs: vortex_array::array::view::ArrayView<'_, Self>, rhs: &vortex_array::array::erased::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>

impl vortex_array::scalar_fn::fns::cast::kernel::CastReduce for vortex_decimal_byte_parts::DecimalByteParts

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl CompareKernel for DecimalByteParts {
lhs: ArrayView<'_, Self>,
rhs: &ArrayRef,
operator: CompareOperator,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(rhs_const) = rhs.as_constant() else {
return Ok(None);
Expand Down Expand Up @@ -65,7 +65,7 @@ impl CompareKernel for DecimalByteParts {
// (depending on the `sign`) than all values in MSP.
// If the LHS or the RHS contain nulls, then we must fallback to the canonicalized
// implementation which does null-checking instead.
if lhs.array().all_valid()? && rhs.all_valid()? {
if lhs.array().all_valid(ctx)? && rhs.all_valid(ctx)? {
Ok(Some(
ConstantArray::new(
unconvertible_value(sign, operator, nullability),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ impl DynAggregateKernel for DecimalBytePartsIsConstantKernel {
};

let result = is_constant(array.msp(), ctx)?;
Ok(Some(IsConstant::make_partial(batch, result)?))
Ok(Some(IsConstant::make_partial(batch, result, ctx)?))
}
}
Loading
Loading