|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +use vortex_buffer::BitBufferMut; |
| 5 | +use vortex_error::VortexExpect; |
| 6 | +use vortex_error::VortexResult; |
| 7 | + |
| 8 | +use crate::ArrayRef; |
| 9 | +use crate::Canonical; |
| 10 | +use crate::ExecutionCtx; |
| 11 | +use crate::IntoArray; |
| 12 | +use crate::arrays::BoolArray; |
| 13 | +use crate::arrays::ConstantArray; |
| 14 | +use crate::arrays::Patched; |
| 15 | +use crate::arrays::PrimitiveArray; |
| 16 | +use crate::arrays::bool::BoolArrayParts; |
| 17 | +use crate::arrays::primitive::NativeValue; |
| 18 | +use crate::builtins::ArrayBuiltins; |
| 19 | +use crate::dtype::NativePType; |
| 20 | +use crate::match_each_native_ptype; |
| 21 | +use crate::scalar_fn::fns::binary::CompareKernel; |
| 22 | +use crate::scalar_fn::fns::operators::CompareOperator; |
| 23 | + |
| 24 | +impl CompareKernel for Patched { |
| 25 | + fn compare( |
| 26 | + lhs: &Self::Array, |
| 27 | + rhs: &ArrayRef, |
| 28 | + operator: CompareOperator, |
| 29 | + ctx: &mut ExecutionCtx, |
| 30 | + ) -> VortexResult<Option<ArrayRef>> { |
| 31 | + // We only accelerate comparisons for primitives |
| 32 | + if !lhs.dtype().is_primitive() { |
| 33 | + return Ok(None); |
| 34 | + } |
| 35 | + |
| 36 | + // We only accelerate comparisons against constants |
| 37 | + let Some(constant) = rhs.as_constant() else { |
| 38 | + return Ok(None); |
| 39 | + }; |
| 40 | + |
| 41 | + // NOTE: due to offset, it's possible that the inner.len != array.len. |
| 42 | + // We slice the inner before performing the comparison. |
| 43 | + let result = lhs |
| 44 | + .base_array() |
| 45 | + .binary( |
| 46 | + ConstantArray::new(constant.clone(), lhs.len()).into_array(), |
| 47 | + operator.into(), |
| 48 | + )? |
| 49 | + .execute::<Canonical>(ctx)? |
| 50 | + .into_bool(); |
| 51 | + |
| 52 | + let BoolArrayParts { |
| 53 | + bits, |
| 54 | + offset, |
| 55 | + len, |
| 56 | + validity, |
| 57 | + } = result.into_parts(); |
| 58 | + |
| 59 | + let mut bits = BitBufferMut::from_buffer(bits.unwrap_host().into_mut(), offset, len); |
| 60 | + |
| 61 | + let lane_offsets = lhs.lane_offsets().clone().execute::<PrimitiveArray>(ctx)?; |
| 62 | + let indices = lhs.patch_indices().clone().execute::<PrimitiveArray>(ctx)?; |
| 63 | + let values = lhs.patch_values().clone().execute::<PrimitiveArray>(ctx)?; |
| 64 | + let n_lanes = lhs.n_lanes; |
| 65 | + |
| 66 | + match_each_native_ptype!(values.ptype(), |V| { |
| 67 | + let offset = lhs.offset; |
| 68 | + let indices = indices.as_slice::<u16>(); |
| 69 | + let values = values.as_slice::<V>(); |
| 70 | + let constant = constant |
| 71 | + .as_primitive() |
| 72 | + .as_::<V>() |
| 73 | + .vortex_expect("compare constant not null"); |
| 74 | + |
| 75 | + let apply_patches = ApplyPatches { |
| 76 | + bits: &mut bits, |
| 77 | + offset, |
| 78 | + n_lanes, |
| 79 | + lane_offsets: lane_offsets.as_slice::<u32>(), |
| 80 | + indices, |
| 81 | + values, |
| 82 | + constant, |
| 83 | + }; |
| 84 | + |
| 85 | + match operator { |
| 86 | + CompareOperator::Eq => { |
| 87 | + apply_patches.apply(|l, r| NativeValue(l) == NativeValue(r))?; |
| 88 | + } |
| 89 | + CompareOperator::NotEq => { |
| 90 | + apply_patches.apply(|l, r| NativeValue(l) != NativeValue(r))?; |
| 91 | + } |
| 92 | + CompareOperator::Gt => { |
| 93 | + apply_patches.apply(|l, r| NativeValue(l) > NativeValue(r))?; |
| 94 | + } |
| 95 | + CompareOperator::Gte => { |
| 96 | + apply_patches.apply(|l, r| NativeValue(l) >= NativeValue(r))?; |
| 97 | + } |
| 98 | + CompareOperator::Lt => { |
| 99 | + apply_patches.apply(|l, r| NativeValue(l) < NativeValue(r))?; |
| 100 | + } |
| 101 | + CompareOperator::Lte => { |
| 102 | + apply_patches.apply(|l, r| NativeValue(l) <= NativeValue(r))?; |
| 103 | + } |
| 104 | + } |
| 105 | + }); |
| 106 | + |
| 107 | + let result = BoolArray::new(bits.freeze(), validity); |
| 108 | + Ok(Some(result.into_array())) |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +struct ApplyPatches<'a, V: NativePType> { |
| 113 | + bits: &'a mut BitBufferMut, |
| 114 | + offset: usize, |
| 115 | + n_lanes: usize, |
| 116 | + lane_offsets: &'a [u32], |
| 117 | + indices: &'a [u16], |
| 118 | + values: &'a [V], |
| 119 | + constant: V, |
| 120 | +} |
| 121 | + |
| 122 | +impl<V: NativePType> ApplyPatches<'_, V> { |
| 123 | + fn apply<F>(self, cmp: F) -> VortexResult<()> |
| 124 | + where |
| 125 | + F: Fn(V, V) -> bool, |
| 126 | + { |
| 127 | + for index in 0..(self.lane_offsets.len() - 1) { |
| 128 | + let chunk = index / self.n_lanes; |
| 129 | + |
| 130 | + let lane_start = self.lane_offsets[index] as usize; |
| 131 | + let lane_end = self.lane_offsets[index + 1] as usize; |
| 132 | + |
| 133 | + for (&patch_index, &patch_value) in std::iter::zip( |
| 134 | + &self.indices[lane_start..lane_end], |
| 135 | + &self.values[lane_start..lane_end], |
| 136 | + ) { |
| 137 | + let bit_index = chunk * 1024 + patch_index as usize; |
| 138 | + // Skip any indices < the offset. |
| 139 | + if bit_index < self.offset { |
| 140 | + continue; |
| 141 | + } |
| 142 | + let bit_index = bit_index - self.offset; |
| 143 | + if bit_index >= self.bits.len() { |
| 144 | + break; |
| 145 | + } |
| 146 | + if cmp(patch_value, self.constant) { |
| 147 | + self.bits.set(bit_index) |
| 148 | + } else { |
| 149 | + self.bits.unset(bit_index) |
| 150 | + } |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + Ok(()) |
| 155 | + } |
| 156 | +} |
| 157 | + |
| 158 | +#[cfg(test)] |
| 159 | +mod tests { |
| 160 | + use vortex_buffer::buffer; |
| 161 | + use vortex_error::VortexResult; |
| 162 | + |
| 163 | + use crate::DynArray; |
| 164 | + use crate::ExecutionCtx; |
| 165 | + use crate::IntoArray; |
| 166 | + use crate::LEGACY_SESSION; |
| 167 | + use crate::arrays::BoolArray; |
| 168 | + use crate::arrays::ConstantArray; |
| 169 | + use crate::arrays::Patched; |
| 170 | + use crate::arrays::PatchedArray; |
| 171 | + use crate::arrays::PrimitiveArray; |
| 172 | + use crate::assert_arrays_eq; |
| 173 | + use crate::optimizer::ArrayOptimizer; |
| 174 | + use crate::patches::Patches; |
| 175 | + use crate::scalar_fn::fns::binary::CompareKernel; |
| 176 | + use crate::scalar_fn::fns::operators::CompareOperator; |
| 177 | + use crate::validity::Validity; |
| 178 | + |
| 179 | + #[test] |
| 180 | + fn test_basic() { |
| 181 | + let lhs = PrimitiveArray::from_iter(0u32..512).into_array(); |
| 182 | + let patches = Patches::new( |
| 183 | + 512, |
| 184 | + 0, |
| 185 | + buffer![509u16, 510, 511].into_array(), |
| 186 | + buffer![u32::MAX; 3].into_array(), |
| 187 | + None, |
| 188 | + ) |
| 189 | + .unwrap(); |
| 190 | + |
| 191 | + let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone()); |
| 192 | + |
| 193 | + let lhs = PatchedArray::from_array_and_patches(lhs, &patches, &mut ctx).unwrap(); |
| 194 | + |
| 195 | + let rhs = ConstantArray::new(u32::MAX, 512).into_array(); |
| 196 | + |
| 197 | + let result = <Patched as CompareKernel>::compare(&lhs, &rhs, CompareOperator::Eq, &mut ctx) |
| 198 | + .unwrap() |
| 199 | + .unwrap(); |
| 200 | + |
| 201 | + let expected = |
| 202 | + BoolArray::from_indices(512, [509, 510, 511], Validity::NonNullable).into_array(); |
| 203 | + |
| 204 | + assert_arrays_eq!(expected, result); |
| 205 | + } |
| 206 | + |
| 207 | + #[test] |
| 208 | + fn test_with_offset() { |
| 209 | + let lhs = PrimitiveArray::from_iter(0u32..512).into_array(); |
| 210 | + let patches = Patches::new( |
| 211 | + 512, |
| 212 | + 0, |
| 213 | + buffer![5u16, 510, 511].into_array(), |
| 214 | + buffer![u32::MAX; 3].into_array(), |
| 215 | + None, |
| 216 | + ) |
| 217 | + .unwrap(); |
| 218 | + |
| 219 | + let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone()); |
| 220 | + |
| 221 | + let lhs = PatchedArray::from_array_and_patches(lhs, &patches, &mut ctx).unwrap(); |
| 222 | + // Slice the array so that the first patch should be skipped. |
| 223 | + let lhs = lhs |
| 224 | + .slice(10..lhs.len()) |
| 225 | + .unwrap() |
| 226 | + .optimize() |
| 227 | + .unwrap() |
| 228 | + .try_into::<Patched>() |
| 229 | + .unwrap(); |
| 230 | + |
| 231 | + assert_eq!(lhs.len(), 502); |
| 232 | + |
| 233 | + let rhs = ConstantArray::new(u32::MAX, lhs.len()).into_array(); |
| 234 | + |
| 235 | + let result = <Patched as CompareKernel>::compare(&lhs, &rhs, CompareOperator::Eq, &mut ctx) |
| 236 | + .unwrap() |
| 237 | + .unwrap(); |
| 238 | + |
| 239 | + let expected = BoolArray::from_indices(502, [500, 501], Validity::NonNullable).into_array(); |
| 240 | + |
| 241 | + assert_arrays_eq!(expected, result); |
| 242 | + } |
| 243 | + |
| 244 | + #[test] |
| 245 | + fn test_subnormal_f32() -> VortexResult<()> { |
| 246 | + // Subnormal f32 values are smaller than f32::MIN_POSITIVE but greater than 0 |
| 247 | + let subnormal: f32 = f32::MIN_POSITIVE / 2.0; |
| 248 | + assert!(subnormal > 0.0 && subnormal < f32::MIN_POSITIVE); |
| 249 | + |
| 250 | + let lhs = PrimitiveArray::from_iter((0..512).map(|i| i as f32)).into_array(); |
| 251 | + |
| 252 | + let patches = Patches::new( |
| 253 | + 512, |
| 254 | + 0, |
| 255 | + buffer![509u16, 510, 511].into_array(), |
| 256 | + buffer![f32::NAN, subnormal, f32::NEG_INFINITY].into_array(), |
| 257 | + None, |
| 258 | + )?; |
| 259 | + |
| 260 | + let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone()); |
| 261 | + let lhs = PatchedArray::from_array_and_patches(lhs, &patches, &mut ctx)?; |
| 262 | + |
| 263 | + let rhs = ConstantArray::new(subnormal, 512).into_array(); |
| 264 | + |
| 265 | + let result = |
| 266 | + <Patched as CompareKernel>::compare(&lhs, &rhs, CompareOperator::Eq, &mut ctx)? |
| 267 | + .unwrap(); |
| 268 | + |
| 269 | + let expected = BoolArray::from_indices(512, [510], Validity::NonNullable).into_array(); |
| 270 | + |
| 271 | + assert_arrays_eq!(expected, result); |
| 272 | + Ok(()) |
| 273 | + } |
| 274 | + |
| 275 | + #[test] |
| 276 | + fn test_pos_neg_zero() -> VortexResult<()> { |
| 277 | + let lhs = PrimitiveArray::from_iter([-0.0f32; 10]).into_array(); |
| 278 | + |
| 279 | + let patches = Patches::new( |
| 280 | + 10, |
| 281 | + 0, |
| 282 | + buffer![5u16, 6, 7, 8, 9].into_array(), |
| 283 | + buffer![f32::NAN, f32::NEG_INFINITY, 0f32, -0.0f32, f32::INFINITY].into_array(), |
| 284 | + None, |
| 285 | + )?; |
| 286 | + |
| 287 | + let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone()); |
| 288 | + let lhs = PatchedArray::from_array_and_patches(lhs, &patches, &mut ctx)?; |
| 289 | + |
| 290 | + let rhs = ConstantArray::new(0.0f32, 10).into_array(); |
| 291 | + |
| 292 | + let result = |
| 293 | + <Patched as CompareKernel>::compare(&lhs, &rhs, CompareOperator::Eq, &mut ctx)? |
| 294 | + .unwrap(); |
| 295 | + |
| 296 | + let expected = BoolArray::from_indices(10, [7], Validity::NonNullable).into_array(); |
| 297 | + |
| 298 | + assert_arrays_eq!(expected, result); |
| 299 | + |
| 300 | + Ok(()) |
| 301 | + } |
| 302 | +} |
0 commit comments