Skip to content

Commit 3ad4a64

Browse files
committed
Add Mojo SIMD filter kernel for primitive arrays
Extends the Mojo AOT kernel with filter-by-indices support. The primitive filter path converts sparse masks (<80% selectivity) into an index array, then gathers values at those positions — identical to the take operation but with usize indices. Four new exported symbols (vortex_filter_{1,2,4,8}byte) are added to the Mojo kernel and wired into filter_slice_by_indices behind cfg(vortex_mojo). Falls back to scalar when Mojo is unavailable. All 121 existing filter tests pass with the Mojo kernel active. Signed-off-by: Claude <noreply@anthropic.com> https://claude.ai/code/session_01EVcJZP4ZmfvWRRg2CsgvST
1 parent cd086d9 commit 3ad4a64

2 files changed

Lines changed: 82 additions & 0 deletions

File tree

vortex-array/kernels/take.mojo

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,29 @@ fn take_1byte_u32idx(src: Int, idx: Int, dst: Int, n: Int):
143143
@export("vortex_take_1byte_u64idx")
144144
fn take_1byte_u64idx(src: Int, idx: Int, dst: Int, n: Int):
145145
_take[DType.uint8, DType.uint64, W1](src, idx, dst, n)
146+
147+
148+
# ---------------------------------------------------------------------------
149+
# Filter kernels (gather by usize indices from mask)
150+
#
151+
# These are used by the primitive filter path when the mask is sparse (<80%
152+
# selectivity). The Rust side converts the bitmap to a &[usize] index array
153+
# and passes it here. On x86_64 usize = u64, so these are gathers with
154+
# u64 element indices.
155+
# ---------------------------------------------------------------------------
156+
157+
@export("vortex_filter_1byte")
158+
fn filter_1byte(src: Int, idx: Int, dst: Int, n: Int):
159+
_take[DType.uint8, DType.uint64, W1](src, idx, dst, n)
160+
161+
@export("vortex_filter_2byte")
162+
fn filter_2byte(src: Int, idx: Int, dst: Int, n: Int):
163+
_take[DType.uint16, DType.uint64, W2](src, idx, dst, n)
164+
165+
@export("vortex_filter_4byte")
166+
fn filter_4byte(src: Int, idx: Int, dst: Int, n: Int):
167+
_take[DType.uint32, DType.uint64, W4](src, idx, dst, n)
168+
169+
@export("vortex_filter_8byte")
170+
fn filter_8byte(src: Int, idx: Int, dst: Int, n: Int):
171+
_take[DType.uint64, DType.uint64, W8](src, idx, dst, n)

vortex-array/src/arrays/filter/execute/slice.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//! Provides both immutable and mutable (in-place) filtering of typed slices by various mask
77
//! representations: indices and ranges (slices).
88
9+
use std::mem::size_of;
910
use std::ptr;
1011

1112
use vortex_buffer::Buffer;
@@ -37,9 +38,64 @@ pub(super) fn filter_slice_by_mask_values<T: Copy>(slice: &[T], mask: &MaskValue
3738

3839
/// Filter a slice by a set of strictly increasing indices.
3940
fn filter_slice_by_indices<T: Copy>(slice: &[T], indices: &[usize]) -> Buffer<T> {
41+
#[cfg(vortex_mojo)]
42+
{
43+
if let Some(buf) = mojo::filter_by_indices_mojo(slice, indices) {
44+
return buf;
45+
}
46+
}
47+
4048
Buffer::<T>::from_trusted_len_iter(indices.iter().map(|&idx| slice[idx]))
4149
}
4250

51+
#[cfg(vortex_mojo)]
52+
mod mojo {
53+
use vortex_buffer::Buffer;
54+
use vortex_buffer::BufferMut;
55+
56+
use super::size_of;
57+
58+
unsafe extern "C" {
59+
fn vortex_filter_1byte(src: usize, idx: usize, dst: usize, n: usize);
60+
fn vortex_filter_2byte(src: usize, idx: usize, dst: usize, n: usize);
61+
fn vortex_filter_4byte(src: usize, idx: usize, dst: usize, n: usize);
62+
fn vortex_filter_8byte(src: usize, idx: usize, dst: usize, n: usize);
63+
}
64+
65+
/// SIMD gather for the filter-by-indices path. Returns `None` for unsupported
66+
/// element sizes so the caller falls back to scalar.
67+
pub(super) fn filter_by_indices_mojo<T: Copy>(
68+
slice: &[T],
69+
indices: &[usize],
70+
) -> Option<Buffer<T>> {
71+
let kernel: unsafe extern "C" fn(usize, usize, usize, usize) = match size_of::<T>() {
72+
1 => vortex_filter_1byte,
73+
2 => vortex_filter_2byte,
74+
4 => vortex_filter_4byte,
75+
8 => vortex_filter_8byte,
76+
_ => return None,
77+
};
78+
79+
let len = indices.len();
80+
let mut buffer = BufferMut::<T>::with_capacity(len);
81+
let dst = buffer.spare_capacity_mut().as_mut_ptr().cast::<T>();
82+
83+
// SAFETY: The Mojo kernel reads `len` indices from `indices`, gathers from
84+
// `slice`, and writes `len` elements to `dst`. All pointers are valid.
85+
unsafe {
86+
kernel(
87+
slice.as_ptr() as usize,
88+
indices.as_ptr() as usize,
89+
dst as usize,
90+
len,
91+
);
92+
buffer.set_len(len);
93+
}
94+
95+
Some(buffer.freeze())
96+
}
97+
}
98+
4399
/// Filter a slice by a set of strictly increasing `(start, end)` ranges.
44100
fn filter_slice_by_slices<T: Copy>(slice: &[T], slices: &[(usize, usize)]) -> Buffer<T> {
45101
let output_len: usize = slices.iter().map(|(start, end)| end - start).sum();

0 commit comments

Comments
 (0)