Skip to content

Commit a623a8c

Browse files
author
Aegis-AI
committed
feat(turboquant): implement turbo_encode_k/v CPU encode path
Adds the missing mlx::core::fast::turbo_encode_k() and turbo_encode_v() functions that the C API stub was placeholding. Algorithm (from turbo_quant.h): turbo_encode_k: 3-bit PolarQuant (WHT rotation + Lloyd-Max centroids) + 1-bit QJL residual — K cache, 68 bytes/token turbo_encode_v: 3-bit PolarQuant only — V cache, 50 bytes/token Buffer layout per token: K: indices[48] | qjl_signs[16] | norm_fp16[2] | rnorm_fp16[2] V: indices[48] | norm_fp16[2] Layout matches the Metal decompression path in sdpa_vector.h which already implements turbo_dequant_k/v for on-the-fly decode during SDPA. The encode path is CPU-side (eval + iterate), which is appropriate since compression runs once per appended KV token, not in the hot forward pass. Files changed: fast.h — declare turbo_encode_k/v in namespace mlx::core::fast fast.cpp — implement using turbo_quant.h primitives mlx-c fast.cpp — replace runtime_error stub with real call
1 parent a60235f commit a623a8c

3 files changed

Lines changed: 136 additions & 3 deletions

File tree

LocalPackages/mlx-swift/Source/Cmlx/mlx-c/mlx/c/fast.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -789,9 +789,24 @@ extern "C" int mlx_fast_turbo_encode(
789789
int k_bits,
790790
const mlx_stream s) {
791791
try {
792-
// TurboQuant C++ core not yet implemented — stub returns error.
793-
// Will be wired up when mlx::core::fast::turbo_encode() is available.
794-
throw std::runtime_error("turbo_encode: not yet implemented in this build");
792+
// Encode K: 3-bit PolarQuant + 1-bit QJL, packed into [.., 68] uint8
793+
mlx_array_set_(
794+
*res_polar_k,
795+
mlx::core::fast::turbo_encode_k(
796+
mlx_array_get_(keys),
797+
mlx_stream_get_(s)));
798+
799+
// Encode V: 3-bit PolarQuant only, packed into [.., 50] uint8
800+
mlx_array_set_(
801+
*res_polar_v,
802+
mlx::core::fast::turbo_encode_v(
803+
mlx_array_get_(values),
804+
mlx_stream_get_(s)));
805+
806+
// Metadata is packed inline — residual arrays are unused but must be
807+
// valid (non-null ctx) so the Swift bridge can call mlx_array_free on them.
808+
*res_residual_k = mlx_array_new();
809+
*res_residual_v = mlx_array_new();
795810
} catch (std::exception& e) {
796811
mlx_error(e.what());
797812
return 1;

LocalPackages/mlx-swift/Source/Cmlx/mlx/mlx/fast.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,9 +933,109 @@ std::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {
933933
}
934934
}
935935

936+
936937
bool ConvertFP8::is_equivalent(const Primitive& other) const {
937938
const ConvertFP8& a_other = static_cast<const ConvertFP8&>(other);
938939
return to_fp8_ == a_other.to_fp8_;
939940
}
940941

941942
} // namespace mlx::core::fast
943+
944+
// ---------------------------------------------------------------------------
945+
// TurboQuant KV-cache compression — CPU encode path
946+
// ---------------------------------------------------------------------------
947+
// turbo_quant.h opens its own namespace mlx::core::fast so it must be
948+
// included OUTSIDE our namespace block to avoid double-nesting.
949+
//
950+
// K record layout (68 bytes per head_dim=128 vector):
951+
// [ 0.. 47] indices[48] — 3-bit PolarQuant indices, LSB-packed
952+
// [ 48.. 63] qjl_signs[16] — 1-bit QJL sign bits
953+
// [ 64.. 65] norm_fp16[2] — original L2 norm as fp16
954+
// [ 66.. 67] rnorm_fp16[2] — residual L2 norm as fp16
955+
//
956+
// V record layout (50 bytes per head_dim=128 vector):
957+
// [ 0.. 47] indices[48] — 3-bit PolarQuant indices, LSB-packed
958+
// [ 48.. 49] norm_fp16[2] — corrected L2 norm as fp16
959+
960+
#include "mlx/fast/turbo_quant.h" // brings in mlx::core::fast types
961+
962+
namespace {
963+
// Record byte sizes — must match sdpa_vector.h (Metal decompression kernel).
964+
static constexpr int TURBO_K_RECORD = 68;
965+
static constexpr int TURBO_V_RECORD = 50;
966+
} // anonymous namespace
967+
968+
namespace mlx::core::fast {
969+
970+
// Helper: materialise the array as float32 on the CPU and return a raw ptr.
971+
// The returned array object must stay alive for the duration of the loop.
972+
static std::pair<mlx::core::array, const float*>
973+
turbo_to_f32(const mlx::core::array& x, mlx::core::StreamOrDevice s) {
974+
auto x_f32 = mlx::core::astype(x, mlx::core::float32, s);
975+
mlx::core::eval(x_f32);
976+
return {x_f32, x_f32.data<float>()};
977+
}
978+
979+
array turbo_encode_k(const array& keys, StreamOrDevice s_) {
980+
auto s = to_stream(s_);
981+
982+
if (keys.shape(-1) != ::mlx::core::fast::TURBO_D) {
983+
throw std::invalid_argument(
984+
"[turbo_encode_k] last dim (head_dim) must be " +
985+
std::to_string(::mlx::core::fast::TURBO_D) + " but got " +
986+
std::to_string(keys.shape(-1)));
987+
}
988+
989+
auto [keys_f32, src] = turbo_to_f32(keys, s);
990+
const int N = static_cast<int>(keys_f32.size() / ::mlx::core::fast::TURBO_D);
991+
992+
std::vector<uint8_t> buf(static_cast<size_t>(N) * TURBO_K_RECORD, 0u);
993+
994+
for (int i = 0; i < N; ++i) {
995+
::mlx::core::fast::TurboQuantK rec =
996+
::mlx::core::fast::turbo_quantize_k(
997+
src + i * ::mlx::core::fast::TURBO_D,
998+
::mlx::core::fast::TURBO_D);
999+
uint8_t* dst = buf.data() + i * TURBO_K_RECORD;
1000+
std::memcpy(dst, rec.indices, 48);
1001+
std::memcpy(dst + 48, rec.qjl_signs, 16);
1002+
std::memcpy(dst + 64, &rec.norm_fp16, 2);
1003+
std::memcpy(dst + 66, &rec.rnorm_fp16, 2);
1004+
}
1005+
1006+
Shape out_shape = keys.shape();
1007+
out_shape.back() = TURBO_K_RECORD;
1008+
return array(buf.data(), out_shape, uint8);
1009+
}
1010+
1011+
array turbo_encode_v(const array& values, StreamOrDevice s_) {
1012+
auto s = to_stream(s_);
1013+
1014+
if (values.shape(-1) != ::mlx::core::fast::TURBO_D) {
1015+
throw std::invalid_argument(
1016+
"[turbo_encode_v] last dim (head_dim) must be " +
1017+
std::to_string(::mlx::core::fast::TURBO_D) + " but got " +
1018+
std::to_string(values.shape(-1)));
1019+
}
1020+
1021+
auto [vals_f32, src] = turbo_to_f32(values, s);
1022+
const int N = static_cast<int>(vals_f32.size() / ::mlx::core::fast::TURBO_D);
1023+
1024+
std::vector<uint8_t> buf(static_cast<size_t>(N) * TURBO_V_RECORD, 0u);
1025+
1026+
for (int i = 0; i < N; ++i) {
1027+
::mlx::core::fast::TurboQuantV rec =
1028+
::mlx::core::fast::turbo_quantize_v(
1029+
src + i * ::mlx::core::fast::TURBO_D,
1030+
::mlx::core::fast::TURBO_D);
1031+
uint8_t* dst = buf.data() + i * TURBO_V_RECORD;
1032+
std::memcpy(dst, rec.indices, 48);
1033+
std::memcpy(dst + 48, &rec.norm_fp16, 2);
1034+
}
1035+
1036+
Shape out_shape = values.shape();
1037+
out_shape.back() = TURBO_V_RECORD;
1038+
return array(buf.data(), out_shape, uint8);
1039+
}
1040+
1041+
} // namespace mlx::core::fast

LocalPackages/mlx-swift/Source/Cmlx/mlx/mlx/fast.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,22 @@ MLX_API std::vector<array> precompiled_cuda_kernel(
100100
bool ensure_row_contiguous = false,
101101
StreamOrDevice s = {});
102102

103+
/**
104+
* Compress a K-cache tensor to TurboQuant format (3-bit PolarQuant + 1-bit QJL).
105+
*
106+
* keys: [batch, heads, seq, 128] — fp16 / bf16 / fp32
107+
* returns: uint8 array with the same leading dims and last dim = 68
108+
* Layout per token: indices[48] | qjl_signs[16] | norm_fp16[2] | rnorm_fp16[2]
109+
*/
110+
MLX_API array turbo_encode_k(const array& keys, StreamOrDevice s = {});
111+
112+
/**
113+
* Compress a V-cache tensor to TurboQuant format (3-bit PolarQuant only).
114+
*
115+
* values: [batch, heads, seq, 128] — fp16 / bf16 / fp32
116+
* returns: uint8 array with the same leading dims and last dim = 50
117+
* Layout per token: indices[48] | norm_fp16[2]
118+
*/
119+
MLX_API array turbo_encode_v(const array& values, StreamOrDevice s = {});
120+
103121
} // namespace mlx::core::fast

0 commit comments

Comments
 (0)