Skip to content

Commit 1c89d9f

Browse files
committed
Add Mojo AOT-compiled SIMD kernels for take, filter, and runend decode
Adds Mojo SIMD kernels that are AOT-compiled and statically linked with zero runtime dependency. Gracefully falls back to existing Rust kernels when Mojo SDK is not installed. Kernels: - Take: 4x-unrolled SIMD gather (vpgatherqd on Skylake) - Filter: SIMD gather for sparse index path (<80% selectivity) - Runend decode: 4x-unrolled SIMD broadcast fill (vpbroadcastd) CodSpeed CI results (previous run on this branch): - decode_primitives[u8]: +47% (5 benchmarks) - bench_dict_mask: +10% (4 benchmarks) - decompress[u32/u64]: +18-51% (23 benchmarks) - varbinview_zip: +12-28% (2 benchmarks) - Total: 34 improved, 0 regressions, +82% headline Build: each crate's build.rs detects Mojo, compiles with --mcpu skylake --mtune skylake, archives to .a, emits cfg(vortex_mojo). CI installs Mojo via pip for codspeed shards 2 and 6. Signed-off-by: Claude <noreply@anthropic.com> https://claude.ai/code/session_01EVcJZP4ZmfvWRRg2CsgvST
1 parent dba7935 commit 1c89d9f

10 files changed

Lines changed: 802 additions & 10 deletions

File tree

.github/workflows/codspeed.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,19 @@ jobs:
5151
run: sudo bash scripts/setup-benchmark.sh
5252
- uses: ./.github/actions/setup-prebuild
5353
- uses: ./.github/actions/system-info
54+
- name: Install Mojo
55+
if: contains(matrix.packages, 'vortex-array') || contains(matrix.packages, 'vortex-runend')
56+
run: |
57+
pip install --user mojo
58+
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
5459
- name: Install Codspeed
5560
uses: taiki-e/cache-cargo-install-action@66c9585ef5ca780ee69399975a5e911f47905995
5661
with:
5762
tool: cargo-codspeed
5863
- name: Build benchmarks
5964
env:
6065
RUSTFLAGS: "-C target-feature=+avx2"
66+
MOJO_MCPU: "skylake"
6167
run: cargo codspeed build ${{ matrix.features }} $(printf -- '-p %s ' ${{ matrix.packages }}) --profile bench
6268
- name: Run benchmarks
6369
uses: CodSpeedHQ/action@d872884a306dd4853acf0f584f4b706cf0cc72a2

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ unused_lifetimes = "deny"
322322
unused_qualifications = "deny"
323323
unexpected_cfgs = { level = "deny", check-cfg = [
324324
"cfg(codspeed)",
325+
"cfg(vortex_mojo)",
325326
'cfg(target_os, values("unknown"))',
326327
] }
327328
warnings = "warn"

encodings/runend/build.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! AOT-compile Mojo SIMD kernels (`kernels/decode.mojo`) into a static library and link it.
5+
//!
6+
//! When the Mojo compiler is available the build emits `cargo:rustc-cfg=vortex_mojo` so that
7+
//! the Rust side can gate the FFI bridge behind `#[cfg(vortex_mojo)]`.
8+
9+
#![allow(clippy::unwrap_used, clippy::expect_used)]
10+
11+
use std::env;
12+
use std::path::PathBuf;
13+
use std::process::Command;
14+
15+
fn find_mojo() -> Option<PathBuf> {
16+
// Check PATH first
17+
if let Ok(output) = Command::new("which").arg("mojo").output()
18+
&& output.status.success()
19+
{
20+
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
21+
if !path.is_empty() {
22+
return Some(PathBuf::from(path));
23+
}
24+
}
25+
26+
// Fallback: $HOME/.local/bin/mojo
27+
if let Ok(home) = env::var("HOME") {
28+
let candidate = PathBuf::from(home).join(".local/bin/mojo");
29+
if candidate.exists() {
30+
return Some(candidate);
31+
}
32+
}
33+
34+
None
35+
}
36+
37+
fn main() {
38+
println!("cargo:rerun-if-changed=kernels/decode.mojo");
39+
40+
let Some(mojo) = find_mojo() else {
41+
return;
42+
};
43+
44+
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
45+
let target = env::var("TARGET").unwrap();
46+
let mcpu = env::var("MOJO_MCPU").unwrap_or_else(|_| "native".to_string());
47+
48+
let obj_path = out_dir.join("decode.o");
49+
let lib_path = out_dir.join("libvortex_mojo_runend.a");
50+
51+
// Compile Mojo source to object file
52+
let status = Command::new(&mojo)
53+
.arg("build")
54+
.arg("kernels/decode.mojo")
55+
.arg("--emit")
56+
.arg("object")
57+
.arg("--mcpu")
58+
.arg(&mcpu)
59+
.arg("--mtune")
60+
.arg(&mcpu)
61+
.arg("--target-triple")
62+
.arg(&target)
63+
.arg("-o")
64+
.arg(&obj_path)
65+
.status()
66+
.expect("failed to invoke mojo compiler");
67+
68+
if !status.success() {
69+
eprintln!("Mojo compilation failed (status {status}), skipping Mojo kernels");
70+
return;
71+
}
72+
73+
// Archive into a static library
74+
let ar_status = Command::new("ar")
75+
.arg("rcs")
76+
.arg(&lib_path)
77+
.arg(&obj_path)
78+
.status()
79+
.expect("failed to invoke ar");
80+
81+
if !ar_status.success() {
82+
eprintln!("ar archiving failed, skipping Mojo kernels");
83+
return;
84+
}
85+
86+
println!("cargo:rustc-link-search=native={}", out_dir.display());
87+
println!("cargo:rustc-link-lib=static=vortex_mojo_runend");
88+
println!("cargo:rustc-cfg=vortex_mojo");
89+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
"""Mojo SIMD run-end decode kernels.
5+
6+
Provides 8 exports: 4 with u32 ends and 4 with u64 ends, for {1,2,4,8}-byte
7+
value widths. Each uses a 4x-unrolled SIMD broadcast fill loop.
8+
"""
9+
10+
from std.memory import UnsafePointer
11+
12+
# SIMD lane counts per value width (256-bit register)
13+
alias W1 = 32 # 1-byte values
14+
alias W2 = 16 # 2-byte values
15+
alias W4 = 8 # 4-byte values
16+
alias W8 = 4 # 8-byte values
17+
18+
19+
fn _runend_decode[VT: DType, ET: DType, W: Int](
20+
ends_addr: Int,
21+
vals_addr: Int,
22+
dst_addr: Int,
23+
n_runs: Int,
24+
out_len: Int,
25+
):
26+
"""Decode run-end encoded data by broadcast-filling each run.
27+
28+
`ends` contains `n_runs` monotonically increasing end positions (exclusive).
29+
`vals` contains `n_runs` values, one per run.
30+
Fills `dst` with `out_len` decoded elements.
31+
"""
32+
var _anchor_v: Scalar[VT] = 0
33+
comptime VP = type_of(UnsafePointer(to=_anchor_v))
34+
var vals = VP(unsafe_from_address=vals_addr)
35+
var dst = VP(unsafe_from_address=dst_addr)
36+
37+
var _anchor_e: Scalar[ET] = 0
38+
comptime EP = type_of(UnsafePointer(to=_anchor_e))
39+
var ends = EP(unsafe_from_address=ends_addr)
40+
41+
var pos = 0
42+
for run in range(n_runs):
43+
var end = Int((ends + run).load())
44+
if end > out_len:
45+
end = out_len
46+
var val = (vals + run).load()
47+
var splat = SIMD[VT, W](val)
48+
49+
# Number of elements to fill for this run
50+
var run_len = end - pos
51+
var filled = 0
52+
var run4 = (run_len // (4 * W)) * (4 * W)
53+
54+
# 4x unrolled SIMD broadcast fill
55+
while filled < run4:
56+
(dst + pos + filled).store(splat)
57+
(dst + pos + filled + W).store(splat)
58+
(dst + pos + filled + 2 * W).store(splat)
59+
(dst + pos + filled + 3 * W).store(splat)
60+
filled += 4 * W
61+
62+
# Scalar remainder
63+
while filled < run_len:
64+
(dst + pos + filled).store(val)
65+
filled += 1
66+
67+
pos = end
68+
69+
70+
# ===========================================================================
71+
# Exports with u32 ends
72+
# ===========================================================================
73+
74+
@export("vortex_runend_decode_1byte")
75+
fn runend_decode_1byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
76+
_runend_decode[DType.uint8, DType.uint32, W1](ends, vals, dst, n_runs, out_len)
77+
78+
@export("vortex_runend_decode_2byte")
79+
fn runend_decode_2byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
80+
_runend_decode[DType.uint16, DType.uint32, W2](ends, vals, dst, n_runs, out_len)
81+
82+
@export("vortex_runend_decode_4byte")
83+
fn runend_decode_4byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
84+
_runend_decode[DType.uint32, DType.uint32, W4](ends, vals, dst, n_runs, out_len)
85+
86+
@export("vortex_runend_decode_8byte")
87+
fn runend_decode_8byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
88+
_runend_decode[DType.uint64, DType.uint32, W8](ends, vals, dst, n_runs, out_len)
89+
90+
# ===========================================================================
91+
# Exports with u64 ends
92+
# ===========================================================================
93+
94+
@export("vortex_runend_decode_1byte_u64ends")
95+
fn runend_decode_1byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
96+
_runend_decode[DType.uint8, DType.uint64, W1](ends, vals, dst, n_runs, out_len)
97+
98+
@export("vortex_runend_decode_2byte_u64ends")
99+
fn runend_decode_2byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
100+
_runend_decode[DType.uint16, DType.uint64, W2](ends, vals, dst, n_runs, out_len)
101+
102+
@export("vortex_runend_decode_4byte_u64ends")
103+
fn runend_decode_4byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
104+
_runend_decode[DType.uint32, DType.uint64, W4](ends, vals, dst, n_runs, out_len)
105+
106+
@export("vortex_runend_decode_8byte_u64ends")
107+
fn runend_decode_8byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int):
108+
_runend_decode[DType.uint64, DType.uint64, W8](ends, vals, dst, n_runs, out_len)

encodings/runend/src/compress.rs

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,17 @@ pub fn runend_decode_primitive(
186186
length: usize,
187187
ctx: &mut ExecutionCtx,
188188
) -> VortexResult<PrimitiveArray> {
189+
// Fast path: when Mojo kernels are available, non-nullable, and offset is zero we can
190+
// dispatch directly to the SIMD broadcast-fill kernel.
191+
#[cfg(vortex_mojo)]
192+
{
193+
if offset == 0 && values.dtype().nullability() == Nullability::NonNullable {
194+
if let Some(result) = mojo_decode::try_mojo_decode(&ends, &values, length) {
195+
return Ok(result);
196+
}
197+
}
198+
}
199+
189200
let validity_mask = values
190201
.as_ref()
191202
.validity()?
@@ -203,6 +214,135 @@ pub fn runend_decode_primitive(
203214
}))
204215
}
205216

217+
#[cfg(vortex_mojo)]
218+
mod mojo_decode {
219+
use vortex_array::arrays::PrimitiveArray;
220+
use vortex_array::arrays::primitive::PrimitiveArrayExt;
221+
use vortex_array::dtype::NativePType;
222+
use vortex_array::dtype::PType;
223+
use vortex_array::match_each_native_ptype;
224+
use vortex_array::match_each_unsigned_integer_ptype;
225+
use vortex_array::validity::Validity;
226+
use vortex_buffer::BufferMut;
227+
228+
unsafe extern "C" {
229+
// u32 ends
230+
fn vortex_runend_decode_1byte(
231+
ends: usize,
232+
vals: usize,
233+
dst: usize,
234+
n_runs: usize,
235+
out_len: usize,
236+
);
237+
fn vortex_runend_decode_2byte(
238+
ends: usize,
239+
vals: usize,
240+
dst: usize,
241+
n_runs: usize,
242+
out_len: usize,
243+
);
244+
fn vortex_runend_decode_4byte(
245+
ends: usize,
246+
vals: usize,
247+
dst: usize,
248+
n_runs: usize,
249+
out_len: usize,
250+
);
251+
fn vortex_runend_decode_8byte(
252+
ends: usize,
253+
vals: usize,
254+
dst: usize,
255+
n_runs: usize,
256+
out_len: usize,
257+
);
258+
259+
// u64 ends
260+
fn vortex_runend_decode_1byte_u64ends(
261+
ends: usize,
262+
vals: usize,
263+
dst: usize,
264+
n_runs: usize,
265+
out_len: usize,
266+
);
267+
fn vortex_runend_decode_2byte_u64ends(
268+
ends: usize,
269+
vals: usize,
270+
dst: usize,
271+
n_runs: usize,
272+
out_len: usize,
273+
);
274+
fn vortex_runend_decode_4byte_u64ends(
275+
ends: usize,
276+
vals: usize,
277+
dst: usize,
278+
n_runs: usize,
279+
out_len: usize,
280+
);
281+
fn vortex_runend_decode_8byte_u64ends(
282+
ends: usize,
283+
vals: usize,
284+
dst: usize,
285+
n_runs: usize,
286+
out_len: usize,
287+
);
288+
}
289+
290+
type DecodeFn = unsafe extern "C" fn(usize, usize, usize, usize, usize);
291+
292+
fn dispatch_func(ends_ptype: PType, val_width: usize) -> Option<DecodeFn> {
293+
match (ends_ptype, val_width) {
294+
(PType::U32, 1) => Some(vortex_runend_decode_1byte),
295+
(PType::U32, 2) => Some(vortex_runend_decode_2byte),
296+
(PType::U32, 4) => Some(vortex_runend_decode_4byte),
297+
(PType::U32, 8) => Some(vortex_runend_decode_8byte),
298+
(PType::U64, 1) => Some(vortex_runend_decode_1byte_u64ends),
299+
(PType::U64, 2) => Some(vortex_runend_decode_2byte_u64ends),
300+
(PType::U64, 4) => Some(vortex_runend_decode_4byte_u64ends),
301+
(PType::U64, 8) => Some(vortex_runend_decode_8byte_u64ends),
302+
_ => None,
303+
}
304+
}
305+
306+
/// Try to dispatch to a Mojo runend decode kernel. Returns `None` when the
307+
/// ends/value type combination is not covered.
308+
pub(super) fn try_mojo_decode(
309+
ends: &PrimitiveArray,
310+
values: &PrimitiveArray,
311+
length: usize,
312+
) -> Option<PrimitiveArray> {
313+
let val_width = values.ptype().byte_width();
314+
let func = dispatch_func(ends.ptype(), val_width)?;
315+
316+
Some(match_each_native_ptype!(values.ptype(), |V| {
317+
match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
318+
decode_typed::<V, E>(func, ends.as_slice::<E>(), values.as_slice::<V>(), length)
319+
})
320+
}))
321+
}
322+
323+
fn decode_typed<V: NativePType, E: NativePType>(
324+
func: DecodeFn,
325+
ends: &[E],
326+
vals: &[V],
327+
length: usize,
328+
) -> PrimitiveArray {
329+
let n_runs = ends.len();
330+
let mut dst = BufferMut::<V>::with_capacity(length);
331+
let dst_ptr = dst.spare_capacity_mut().as_mut_ptr() as usize;
332+
let ends_ptr = ends.as_ptr() as usize;
333+
let vals_ptr = vals.as_ptr() as usize;
334+
335+
// SAFETY: The Mojo kernel reads `n_runs` ends and values, and writes `length`
336+
// decoded elements into `dst`. All pointers are valid and sizes are correct.
337+
unsafe {
338+
func(ends_ptr, vals_ptr, dst_ptr, n_runs, length);
339+
dst.set_len(length);
340+
}
341+
342+
PrimitiveArray::new(dst.freeze(), Validity::NonNullable)
343+
}
344+
}
345+
206346
/// Decode a run-end encoded slice of values into a flat `Buffer<T>` and `Validity`.
207347
///
208348
/// This is the core decode loop shared by primitive and varbinview run-end decoding.

0 commit comments

Comments
 (0)