|
| 1 | +//! **LAB-ONLY.** D1.1 — `CodecKernelCache`: JIT kernel cache keyed by |
| 2 | +//! `CodecParams::kernel_signature()`. |
| 3 | +//! |
| 4 | +//! The structural layer of Phase 1 — independent of the underlying |
| 5 | +//! Cranelift / jitson implementation. This module defines the cache |
| 6 | +//! semantics; D1.2 (rotation primitives), D1.3 (residual composition), |
| 7 | +//! and D1.1b (actual Cranelift IR emission) plug into it. |
| 8 | +//! |
| 9 | +//! The insight this module captures: **kernel signature and sweep grid |
| 10 | +//! axis are the same object viewed from two sides** (EPIPHANIES 2026-04-20 |
| 11 | +//! "D0.3 sweep grid IS the JIT cache warmer"). Every unique |
| 12 | +//! `(subspaces, centroids, residual_depth, rotation_kind, distance, |
| 13 | +//! lane_width)` tuple maps to exactly one `kernel_signature()` — so the |
| 14 | +//! grid traversal order determines how fast the cache warms. |
| 15 | +//! |
| 16 | +//! ## Design — generic over handle type |
| 17 | +//! |
| 18 | +//! `CodecKernelCache<H>` is generic over `H: Clone` so this scaffold can |
| 19 | +//! host: |
| 20 | +//! |
| 21 | +//! - **Production:** `H = KernelHandle` from `lance-graph-contract::jit` |
| 22 | +//! (raw function pointer to Cranelift-emitted code). |
| 23 | +//! - **Stub / testing:** `H = StubKernel` (deterministic fake — what the |
| 24 | +//! kernel WOULD be, without compilation). |
| 25 | +//! - **Future variants:** e.g., a GPU-kernel handle when/if that lands. |
| 26 | +//! |
| 27 | +//! The cache itself doesn't know or care what a kernel IS — it only |
| 28 | +//! manages the `kernel_signature() → H` map with concurrent read-many / |
| 29 | +//! single-writer semantics. Per ndarray/.claude/rules/data-flow.md: |
| 30 | +//! "No `&mut self` during computation" — cache uses interior mutability. |
| 31 | +
|
| 32 | +use lance_graph_contract::cam::{CodecParams, CodecParamsError}; |
| 33 | +use std::collections::HashMap; |
| 34 | +use std::sync::RwLock; |
| 35 | + |
| 36 | +/// JIT kernel cache keyed by `CodecParams::kernel_signature()`. |
| 37 | +/// |
| 38 | +/// Generic over kernel handle type. Concurrent-safe via `RwLock`; multiple |
| 39 | +/// readers can hit cache simultaneously; exactly one writer at a time for |
| 40 | +/// insert. |
| 41 | +pub struct CodecKernelCache<H: Clone> { |
| 42 | + cache: RwLock<HashMap<u64, H>>, |
| 43 | + compile_count: RwLock<u64>, |
| 44 | + hit_count: RwLock<u64>, |
| 45 | +} |
| 46 | + |
| 47 | +impl<H: Clone> CodecKernelCache<H> { |
| 48 | + /// Create an empty cache. |
| 49 | + pub fn new() -> Self { |
| 50 | + Self { |
| 51 | + cache: RwLock::new(HashMap::new()), |
| 52 | + compile_count: RwLock::new(0), |
| 53 | + hit_count: RwLock::new(0), |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + /// Get the kernel for `params`, compiling if missing. |
| 58 | + /// |
| 59 | + /// The `compile` closure runs **only on cache miss**; for the typical |
| 60 | + /// sweep where overlapping grid tuples share a kernel signature, most |
| 61 | + /// calls are zero-cost cache reads. |
| 62 | + /// |
| 63 | + /// Returns a cloned handle — the caller drives the kernel; the cache |
| 64 | + /// retains its own copy indefinitely. |
| 65 | + pub fn get_or_compile<F>(&self, params: &CodecParams, compile: F) -> H |
| 66 | + where |
| 67 | + F: FnOnce() -> H, |
| 68 | + { |
| 69 | + let sig = params.kernel_signature(); |
| 70 | + // Fast path: read-lock check for cache hit. |
| 71 | + if let Some(h) = self.cache.read().unwrap().get(&sig).cloned() { |
| 72 | + *self.hit_count.write().unwrap() += 1; |
| 73 | + return h; |
| 74 | + } |
| 75 | + // Slow path: compile + insert. Double-check inside write-lock to |
| 76 | + // prevent duplicate compilation under concurrent misses. |
| 77 | + let mut w = self.cache.write().unwrap(); |
| 78 | + if let Some(h) = w.get(&sig).cloned() { |
| 79 | + *self.hit_count.write().unwrap() += 1; |
| 80 | + return h; |
| 81 | + } |
| 82 | + let h = compile(); |
| 83 | + w.insert(sig, h.clone()); |
| 84 | + *self.compile_count.write().unwrap() += 1; |
| 85 | + h |
| 86 | + } |
| 87 | + |
| 88 | + /// Same as `get_or_compile` but with a fallible compile closure. |
| 89 | + pub fn try_get_or_compile<F>( |
| 90 | + &self, |
| 91 | + params: &CodecParams, |
| 92 | + compile: F, |
| 93 | + ) -> Result<H, CodecParamsError> |
| 94 | + where |
| 95 | + F: FnOnce() -> Result<H, CodecParamsError>, |
| 96 | + { |
| 97 | + let sig = params.kernel_signature(); |
| 98 | + if let Some(h) = self.cache.read().unwrap().get(&sig).cloned() { |
| 99 | + *self.hit_count.write().unwrap() += 1; |
| 100 | + return Ok(h); |
| 101 | + } |
| 102 | + let mut w = self.cache.write().unwrap(); |
| 103 | + if let Some(h) = w.get(&sig).cloned() { |
| 104 | + *self.hit_count.write().unwrap() += 1; |
| 105 | + return Ok(h); |
| 106 | + } |
| 107 | + let h = compile()?; |
| 108 | + w.insert(sig, h.clone()); |
| 109 | + *self.compile_count.write().unwrap() += 1; |
| 110 | + Ok(h) |
| 111 | + } |
| 112 | + |
| 113 | + /// Number of unique kernels in the cache (= unique signatures seen). |
| 114 | + pub fn len(&self) -> usize { |
| 115 | + self.cache.read().unwrap().len() |
| 116 | + } |
| 117 | + |
| 118 | + /// `true` if the cache is empty. |
| 119 | + pub fn is_empty(&self) -> bool { |
| 120 | + self.len() == 0 |
| 121 | + } |
| 122 | + |
| 123 | + /// Number of `compile()` invocations — one per unique signature. |
| 124 | + pub fn compile_count(&self) -> u64 { |
| 125 | + *self.compile_count.read().unwrap() |
| 126 | + } |
| 127 | + |
| 128 | + /// Number of cache hits (compile closure NOT invoked). |
| 129 | + pub fn hit_count(&self) -> u64 { |
| 130 | + *self.hit_count.read().unwrap() |
| 131 | + } |
| 132 | + |
| 133 | + /// Cache hit ratio: `hit_count / (hit_count + compile_count)`. |
| 134 | + /// Returns 0.0 when no calls have been made. |
| 135 | + pub fn hit_ratio(&self) -> f64 { |
| 136 | + let hits = self.hit_count() as f64; |
| 137 | + let compiles = self.compile_count() as f64; |
| 138 | + let total = hits + compiles; |
| 139 | + if total < 0.5 { 0.0 } else { hits / total } |
| 140 | + } |
| 141 | + |
| 142 | + /// Check whether a specific signature is cached without calling compile. |
| 143 | + pub fn has_signature(&self, signature: u64) -> bool { |
| 144 | + self.cache.read().unwrap().contains_key(&signature) |
| 145 | + } |
| 146 | + |
| 147 | + /// Clear the cache (and reset counters). Useful for test isolation. |
| 148 | + pub fn clear(&self) { |
| 149 | + self.cache.write().unwrap().clear(); |
| 150 | + *self.compile_count.write().unwrap() = 0; |
| 151 | + *self.hit_count.write().unwrap() = 0; |
| 152 | + } |
| 153 | +} |
| 154 | + |
| 155 | +impl<H: Clone> Default for CodecKernelCache<H> { |
| 156 | + fn default() -> Self { |
| 157 | + Self::new() |
| 158 | + } |
| 159 | +} |
| 160 | + |
| 161 | +/// Deterministic stub kernel handle — for testing the cache without |
| 162 | +/// invoking the real Cranelift / jitson compilation path. |
| 163 | +/// |
| 164 | +/// Captures what the kernel WOULD be (the signature it was compiled for + |
| 165 | +/// whether AMX would be used). D1.1b's Cranelift path replaces the |
| 166 | +/// stub with a real `KernelHandle`. |
| 167 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 168 | +pub struct StubKernel { |
| 169 | + /// `CodecParams::kernel_signature()` this stub represents. |
| 170 | + pub signature: u64, |
| 171 | + /// `params.is_matmul_heavy()` at compile time — drives Tier-1 AMX dispatch. |
| 172 | + pub is_matmul_heavy: bool, |
| 173 | + /// SIMD tier name this stub claims ("amx" | "vnni" | "avx512" | "avx2"). |
| 174 | + /// Never "scalar" on a SoA path — iron rule. |
| 175 | + pub backend: &'static str, |
| 176 | +} |
| 177 | + |
| 178 | +impl StubKernel { |
| 179 | + /// Build a stub from the current `CodecParams`, selecting a tier label |
| 180 | + /// under the assumption that AMX is available for matmul-heavy paths. |
| 181 | + /// The actual per-process capability query is |
| 182 | + /// `ndarray::simd_amx::amx_available()`; this stub pretends it's true. |
| 183 | + pub fn from_params(params: &CodecParams) -> Self { |
| 184 | + Self { |
| 185 | + signature: params.kernel_signature(), |
| 186 | + is_matmul_heavy: params.is_matmul_heavy(), |
| 187 | + backend: if params.is_matmul_heavy() { "amx" } else { "avx512" }, |
| 188 | + } |
| 189 | + } |
| 190 | +} |
| 191 | + |
| 192 | +#[cfg(test)] |
| 193 | +mod tests { |
| 194 | + use super::*; |
| 195 | + use lance_graph_contract::cam::{CodecParamsBuilder, LaneWidth, Rotation}; |
| 196 | + |
| 197 | + #[test] |
| 198 | + fn cache_starts_empty() { |
| 199 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 200 | + assert_eq!(c.len(), 0); |
| 201 | + assert!(c.is_empty()); |
| 202 | + assert_eq!(c.compile_count(), 0); |
| 203 | + assert_eq!(c.hit_count(), 0); |
| 204 | + assert_eq!(c.hit_ratio(), 0.0); |
| 205 | + } |
| 206 | + |
| 207 | + #[test] |
| 208 | + fn first_call_compiles_second_is_cache_hit() { |
| 209 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 210 | + let p = CodecParamsBuilder::new().centroids(1024).build().unwrap(); |
| 211 | + |
| 212 | + let k1 = c.get_or_compile(&p, || StubKernel::from_params(&p)); |
| 213 | + let k2 = c.get_or_compile(&p, || panic!("must not recompile on cache hit")); |
| 214 | + |
| 215 | + assert_eq!(k1, k2); |
| 216 | + assert_eq!(c.compile_count(), 1); |
| 217 | + assert_eq!(c.hit_count(), 1); |
| 218 | + assert_eq!(c.len(), 1); |
| 219 | + assert_eq!(c.hit_ratio(), 0.5); |
| 220 | + } |
| 221 | + |
| 222 | + #[test] |
| 223 | + fn different_params_produce_different_kernels() { |
| 224 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 225 | + let p1 = CodecParamsBuilder::new().centroids(256).build().unwrap(); |
| 226 | + let p2 = CodecParamsBuilder::new().centroids(1024).build().unwrap(); |
| 227 | + |
| 228 | + let k1 = c.get_or_compile(&p1, || StubKernel::from_params(&p1)); |
| 229 | + let k2 = c.get_or_compile(&p2, || StubKernel::from_params(&p2)); |
| 230 | + |
| 231 | + assert_ne!(k1.signature, k2.signature); |
| 232 | + assert_eq!(c.compile_count(), 2); |
| 233 | + assert_eq!(c.hit_count(), 0); |
| 234 | + assert_eq!(c.len(), 2); |
| 235 | + } |
| 236 | + |
| 237 | + #[test] |
| 238 | + fn seed_changes_do_not_invalidate_cache() { |
| 239 | + // CodecParams::kernel_signature() excludes `seed` (PR #225). |
| 240 | + // Same IR-shaping fields → same signature → cache hit. |
| 241 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 242 | + let p1 = CodecParamsBuilder::new().seed(1).build().unwrap(); |
| 243 | + let p2 = CodecParamsBuilder::new().seed(2).build().unwrap(); |
| 244 | + |
| 245 | + let k1 = c.get_or_compile(&p1, || StubKernel::from_params(&p1)); |
| 246 | + let k2 = c.get_or_compile(&p2, || panic!("seed change must not invalidate cache")); |
| 247 | + |
| 248 | + assert_eq!(k1, k2); |
| 249 | + assert_eq!(c.compile_count(), 1); |
| 250 | + assert_eq!(c.hit_count(), 1); |
| 251 | + } |
| 252 | + |
| 253 | + #[test] |
| 254 | + fn matmul_heavy_params_select_amx_backend_in_stub() { |
| 255 | + let opq = CodecParamsBuilder::new() |
| 256 | + .lane_width(LaneWidth::BF16x32) |
| 257 | + .rotation(Rotation::Opq { matrix_blob_id: 42, dim: 4096 }) |
| 258 | + .build() |
| 259 | + .unwrap(); |
| 260 | + let identity = CodecParamsBuilder::new().build().unwrap(); |
| 261 | + |
| 262 | + let k_opq = StubKernel::from_params(&opq); |
| 263 | + let k_id = StubKernel::from_params(&identity); |
| 264 | + |
| 265 | + assert_eq!(k_opq.backend, "amx"); |
| 266 | + assert!(k_opq.is_matmul_heavy); |
| 267 | + assert_eq!(k_id.backend, "avx512"); |
| 268 | + assert!(!k_id.is_matmul_heavy); |
| 269 | + } |
| 270 | + |
| 271 | + #[test] |
| 272 | + fn clear_resets_cache_and_counters() { |
| 273 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 274 | + let p = CodecParamsBuilder::new().build().unwrap(); |
| 275 | + c.get_or_compile(&p, || StubKernel::from_params(&p)); |
| 276 | + c.get_or_compile(&p, || StubKernel::from_params(&p)); |
| 277 | + |
| 278 | + assert_eq!(c.len(), 1); |
| 279 | + assert_eq!(c.compile_count(), 1); |
| 280 | + assert_eq!(c.hit_count(), 1); |
| 281 | + |
| 282 | + c.clear(); |
| 283 | + |
| 284 | + assert_eq!(c.len(), 0); |
| 285 | + assert_eq!(c.compile_count(), 0); |
| 286 | + assert_eq!(c.hit_count(), 0); |
| 287 | + assert!(c.is_empty()); |
| 288 | + } |
| 289 | + |
| 290 | + #[test] |
| 291 | + fn try_get_or_compile_propagates_errors() { |
| 292 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 293 | + let p = CodecParamsBuilder::new().build().unwrap(); |
| 294 | + let result: Result<StubKernel, _> = c.try_get_or_compile(&p, || { |
| 295 | + Err(CodecParamsError::ZeroDimension { field: "test" }) |
| 296 | + }); |
| 297 | + assert!(result.is_err()); |
| 298 | + // Failed compile doesn't populate cache. |
| 299 | + assert_eq!(c.len(), 0); |
| 300 | + } |
| 301 | + |
| 302 | + #[test] |
| 303 | + fn has_signature_checks_without_compiling() { |
| 304 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 305 | + let p = CodecParamsBuilder::new().centroids(512).build().unwrap(); |
| 306 | + let sig = p.kernel_signature(); |
| 307 | + |
| 308 | + assert!(!c.has_signature(sig)); |
| 309 | + c.get_or_compile(&p, || StubKernel::from_params(&p)); |
| 310 | + assert!(c.has_signature(sig)); |
| 311 | + } |
| 312 | + |
| 313 | + #[test] |
| 314 | + fn sweep_grid_warms_cache_deterministically() { |
| 315 | + // Simulate the D0.3 insight: a sweep grid with 4 distinct kernel |
| 316 | + // signatures + 1 repeat (seed difference) compiles exactly 4 kernels. |
| 317 | + let c: CodecKernelCache<StubKernel> = CodecKernelCache::new(); |
| 318 | + let candidates: Vec<CodecParams> = vec![ |
| 319 | + CodecParamsBuilder::new().centroids(256).build().unwrap(), |
| 320 | + CodecParamsBuilder::new().centroids(512).build().unwrap(), |
| 321 | + CodecParamsBuilder::new().centroids(1024).build().unwrap(), |
| 322 | + CodecParamsBuilder::new().centroids(256).seed(999).build().unwrap(), // same sig as first |
| 323 | + CodecParamsBuilder::new() |
| 324 | + .lane_width(LaneWidth::BF16x32) |
| 325 | + .rotation(Rotation::Opq { matrix_blob_id: 1, dim: 4096 }) |
| 326 | + .build().unwrap(), |
| 327 | + ]; |
| 328 | + |
| 329 | + for p in &candidates { |
| 330 | + c.get_or_compile(p, || StubKernel::from_params(p)); |
| 331 | + } |
| 332 | + |
| 333 | + // 4 unique signatures (seed=999 collides with the first). |
| 334 | + assert_eq!(c.len(), 4); |
| 335 | + assert_eq!(c.compile_count(), 4); |
| 336 | + assert_eq!(c.hit_count(), 1); |
| 337 | + assert!((c.hit_ratio() - 0.2).abs() < 1e-9); |
| 338 | + } |
| 339 | +} |
0 commit comments