|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +"""Mojo SIMD run-end decode kernels. |
| 5 | +
|
| 6 | +Provides 8 exports: 4 with u32 ends and 4 with u64 ends, for {1,2,4,8}-byte |
| 7 | +value widths. Each uses a 4x-unrolled SIMD broadcast fill loop. |
| 8 | +""" |
| 9 | + |
| 10 | +from std.memory import UnsafePointer |
| 11 | + |
| 12 | +# SIMD lane counts per value width (256-bit register) |
| 13 | +alias W1 = 32 # 1-byte values |
| 14 | +alias W2 = 16 # 2-byte values |
| 15 | +alias W4 = 8 # 4-byte values |
| 16 | +alias W8 = 4 # 8-byte values |
| 17 | + |
| 18 | + |
| 19 | +fn _runend_decode[VT: DType, ET: DType, W: Int]( |
| 20 | + ends_addr: Int, |
| 21 | + vals_addr: Int, |
| 22 | + dst_addr: Int, |
| 23 | + n_runs: Int, |
| 24 | + out_len: Int, |
| 25 | +): |
| 26 | + """Decode run-end encoded data by broadcast-filling each run. |
| 27 | +
|
| 28 | + `ends` contains `n_runs` monotonically increasing end positions (exclusive). |
| 29 | + `vals` contains `n_runs` values, one per run. |
| 30 | + Fills `dst` with `out_len` decoded elements. |
| 31 | + """ |
| 32 | + var _anchor_v: Scalar[VT] = 0 |
| 33 | + comptime VP = type_of(UnsafePointer(to=_anchor_v)) |
| 34 | + var vals = VP(unsafe_from_address=vals_addr) |
| 35 | + var dst = VP(unsafe_from_address=dst_addr) |
| 36 | + |
| 37 | + var _anchor_e: Scalar[ET] = 0 |
| 38 | + comptime EP = type_of(UnsafePointer(to=_anchor_e)) |
| 39 | + var ends = EP(unsafe_from_address=ends_addr) |
| 40 | + |
| 41 | + var pos = 0 |
| 42 | + for run in range(n_runs): |
| 43 | + var end = Int((ends + run).load()) |
| 44 | + if end > out_len: |
| 45 | + end = out_len |
| 46 | + var val = (vals + run).load() |
| 47 | + var splat = SIMD[VT, W](val) |
| 48 | + |
| 49 | + # Number of elements to fill for this run |
| 50 | + var run_len = end - pos |
| 51 | + var filled = 0 |
| 52 | + var run4 = (run_len // (4 * W)) * (4 * W) |
| 53 | + |
| 54 | + # 4x unrolled SIMD broadcast fill |
| 55 | + while filled < run4: |
| 56 | + (dst + pos + filled).store(splat) |
| 57 | + (dst + pos + filled + W).store(splat) |
| 58 | + (dst + pos + filled + 2 * W).store(splat) |
| 59 | + (dst + pos + filled + 3 * W).store(splat) |
| 60 | + filled += 4 * W |
| 61 | + |
| 62 | + # Scalar remainder |
| 63 | + while filled < run_len: |
| 64 | + (dst + pos + filled).store(val) |
| 65 | + filled += 1 |
| 66 | + |
| 67 | + pos = end |
| 68 | + |
| 69 | + |
| 70 | +# =========================================================================== |
| 71 | +# Exports with u32 ends |
| 72 | +# =========================================================================== |
| 73 | + |
| 74 | +@export("vortex_runend_decode_1byte") |
| 75 | +fn runend_decode_1byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 76 | + _runend_decode[DType.uint8, DType.uint32, W1](ends, vals, dst, n_runs, out_len) |
| 77 | + |
| 78 | +@export("vortex_runend_decode_2byte") |
| 79 | +fn runend_decode_2byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 80 | + _runend_decode[DType.uint16, DType.uint32, W2](ends, vals, dst, n_runs, out_len) |
| 81 | + |
| 82 | +@export("vortex_runend_decode_4byte") |
| 83 | +fn runend_decode_4byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 84 | + _runend_decode[DType.uint32, DType.uint32, W4](ends, vals, dst, n_runs, out_len) |
| 85 | + |
| 86 | +@export("vortex_runend_decode_8byte") |
| 87 | +fn runend_decode_8byte(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 88 | + _runend_decode[DType.uint64, DType.uint32, W8](ends, vals, dst, n_runs, out_len) |
| 89 | + |
| 90 | +# =========================================================================== |
| 91 | +# Exports with u64 ends |
| 92 | +# =========================================================================== |
| 93 | + |
| 94 | +@export("vortex_runend_decode_1byte_u64ends") |
| 95 | +fn runend_decode_1byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 96 | + _runend_decode[DType.uint8, DType.uint64, W1](ends, vals, dst, n_runs, out_len) |
| 97 | + |
| 98 | +@export("vortex_runend_decode_2byte_u64ends") |
| 99 | +fn runend_decode_2byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 100 | + _runend_decode[DType.uint16, DType.uint64, W2](ends, vals, dst, n_runs, out_len) |
| 101 | + |
| 102 | +@export("vortex_runend_decode_4byte_u64ends") |
| 103 | +fn runend_decode_4byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 104 | + _runend_decode[DType.uint32, DType.uint64, W4](ends, vals, dst, n_runs, out_len) |
| 105 | + |
| 106 | +@export("vortex_runend_decode_8byte_u64ends") |
| 107 | +fn runend_decode_8byte_u64ends(ends: Int, vals: Int, dst: Int, n_runs: Int, out_len: Int): |
| 108 | + _runend_decode[DType.uint64, DType.uint64, W8](ends, vals, dst, n_runs, out_len) |
0 commit comments