Skip to content

Commit 41992c2

Browse files
Session D: Interpreter ↔ codegen JIT dispatch wiring
End-to-end: parse OMC source → compile to bytecode → JIT eligible fns in dual-band mode → register dispatch hook on Interpreter → JIT-compiled native code runs in place of tree-walk for those fns. The mechanism is proven via integration tests that drive the full pipeline. The CLI-level `OMC_HBIT_JIT=1` env var is one small follow-up away (extracting omnimcode-cli to a new package — see deferred Session D.5 task; the workspace-cycle constraint prevents adding codegen as a dep of core's existing binary). Codegen API additions: - `JitContext::jit_module(&Module) -> HashMap<String, JittedFn>`: walks user fns, attempts dual-band lowering on each, returns the successes as arity-tagged raw fn pointers. Failed fns are silently skipped — and crucially, their partial LLVM IR is now erased from the module so JIT finalization for the rest of the module doesn't crash on broken IR (`function.delete()` after the lowering errors). - `JittedFn { arity, fn_ptr: *const () }` with `call(&[i64])` dispatcher for arities 0..=4. Uses transmute to cast the raw pointer to the right `unsafe extern "C" fn(...) -> i64` signature based on arity. - `extract_raw_fn_ptr` helper: typed `engine.get_function::<F>` per arity, then `JitFunction::into_raw()` for storage. (Discovered the bug: `engine.get_function_address` returned garbage where the typed `get_function::<F>` produces correct fn pointers — sticking with the typed path.) Codegen lowerer fix (applies to BOTH scalar and dual-band): - `bind_params_into_locals()` now runs at fn entry. The OMC bytecode compiler emits `Op::LoadVar("x")` for parameter access in fn bodies (treating params as already-bound locals); the bytecode VM and tree-walk both pre-populate these bindings before executing the body. Without this fix, the JIT was reading uninitialized allocas — returning pointer-looking garbage instead of the param value (the symptom that broke the first integration test before it was diagnosed). Interpreter additions: - `JitDispatch = Rc<dyn Fn(&str, &[Value]) -> Option<Result<Value, String>>>` type alias for the dispatch closure. - `Interpreter::jit_dispatch: Option<JitDispatch>` field. - `Interpreter::set_jit_dispatch(Some(hook))` setter (the embedder API the future omnimcode-cli will call). - `invoke_user_function_at` consults the hook BEFORE running the tree-walk body. `Some(_)` return short-circuits tree-walk; `None` falls through. Marshalling Value↔i64 happens in the hook closure (decoupled from core). Tests (5 new integration tests, all passing): - jit_dispatch_routes_simple_int_fn: `double(21) == 42` via JIT - jit_module_returns_callable_fn_directly: isolation — call JIT'd fn without going through Interpreter (proves jit_module's fn-ptr extraction) - jit_dispatch_matches_tree_walk_factorial: factorial(10) == 3.6M, matches tree-walk-only run - jit_dispatch_matches_tree_walk_sum_loop: while loop with two locals matches across both code paths - jit_dispatch_falls_through_on_unsupported_fn: a string-using fn silently skipped by JIT (delete() of partial IR), int fn still routes through JIT, both produce correct outputs Workspace state: 19/19 codegen tests green, 149/149 omnimcode-core Rust unit tests green, 18/18 OMC harmonic-lib tests green. Deferred to Session D.5: extract omnimcode-cli, wire OMC_HBIT_JIT to populate set_jit_dispatch from inside main.rs. Deferred to Session E: AVX-512 path (`<8 x i64>` carrier with explicit LLVM intrinsics), benchmark harness measuring actual speedup vs tree-walk and bytecode VM. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent e9f0fbc commit 41992c2

4 files changed

Lines changed: 429 additions & 0 deletions

File tree

omnimcode-codegen/src/dual_band.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,41 @@ impl<'ctx, 'a> DualBandLowerer<'ctx, 'a> {
103103

104104
self.collect_leaders()?;
105105
self.collect_cleanup_pops();
106+
self.bind_params_into_locals()?;
106107
self.emit_body()?;
107108
Ok(self.function)
108109
}
109110

111+
/// Bind each fn parameter into a named local-variable slot. The
112+
/// OMC bytecode compiler emits `LoadVar("x")` for parameter access
113+
/// in fn bodies; we mirror what the bytecode VM does at fn entry
114+
/// and pre-populate each parameter into a `<2 x i64>` alloca slot
115+
/// keyed by the parameter name. β = α at entry (matched bands);
116+
/// later sessions add explicit phi-shadow ops that diverge β.
117+
fn bind_params_into_locals(&mut self) -> Result<(), CodegenError> {
118+
for (i, pname) in self.f.params.clone().iter().enumerate() {
119+
let param = self
120+
.function
121+
.get_nth_param(i as u32)
122+
.ok_or_else(|| format!("hbit bind_params: no param at slot {}", i))?;
123+
let iv = match param {
124+
BasicValueEnum::IntValue(iv) => iv,
125+
_ => {
126+
return Err(format!(
127+
"hbit bind_params: non-int param at slot {}",
128+
i
129+
))
130+
}
131+
};
132+
let v = self.splat(iv, &format!("{}_init", pname))?;
133+
let slot = self.get_or_create_slot(pname)?;
134+
self.builder
135+
.build_store(slot, v)
136+
.map_err(|e| format!("hbit bind_params store {}: {}", pname, e))?;
137+
}
138+
Ok(())
139+
}
140+
110141
fn collect_leaders(&mut self) -> Result<(), CodegenError> {
111142
let mut leaders: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
112143
leaders.insert(0);

omnimcode-codegen/src/lib.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,71 @@ pub struct JitContext<'ctx> {
5656
/// Error type for codegen failures. Keeps it simple — just a String.
5757
pub type CodegenError = String;
5858

59+
/// A successfully JIT'd OMC function, presented as an arity-tagged
60+
/// raw function pointer. Callable via `JittedFn::call(args)` for
61+
/// the supported arities (0..=4); larger arities should be folded
62+
/// down via a future uniform-arg-array calling convention.
63+
///
64+
/// SAFETY: the underlying machine code is owned by the
65+
/// `JitContext`/`ExecutionEngine` that produced this struct. Calling
66+
/// after that JitContext is dropped is undefined behavior. In the
67+
/// current Session D design, the main CLI keeps the JitContext
68+
/// alive for the entire program duration (Box::leak), so the
69+
/// invariant holds for normal use.
70+
#[derive(Clone, Copy, Debug)]
71+
pub struct JittedFn {
72+
pub arity: usize,
73+
/// Erased fn pointer. Cast to the right `unsafe extern "C" fn`
74+
/// signature at call time based on `arity`.
75+
pub fn_ptr: *const (),
76+
}
77+
78+
// SAFETY: a raw function pointer is `Send + Sync` — it's plain data.
79+
// The LLVM-generated machine code is read-only and re-entrant.
80+
unsafe impl Send for JittedFn {}
81+
unsafe impl Sync for JittedFn {}
82+
83+
impl JittedFn {
84+
/// Call this JITted fn with i64 args. Returns `Some(result)` when
85+
/// arity matches a supported overload, `None` otherwise. Caller is
86+
/// responsible for keeping the producing JitContext alive — that's
87+
/// the unsafe invariant this method enforces minimally (it's
88+
/// "safe" because we trust the pointer, but a use-after-free of
89+
/// the JitContext would crash here).
90+
pub fn call(&self, args: &[i64]) -> Option<i64> {
91+
if args.len() != self.arity {
92+
return None;
93+
}
94+
unsafe {
95+
match self.arity {
96+
0 => {
97+
let f: unsafe extern "C" fn() -> i64 = std::mem::transmute(self.fn_ptr);
98+
Some(f())
99+
}
100+
1 => {
101+
let f: unsafe extern "C" fn(i64) -> i64 = std::mem::transmute(self.fn_ptr);
102+
Some(f(args[0]))
103+
}
104+
2 => {
105+
let f: unsafe extern "C" fn(i64, i64) -> i64 = std::mem::transmute(self.fn_ptr);
106+
Some(f(args[0], args[1]))
107+
}
108+
3 => {
109+
let f: unsafe extern "C" fn(i64, i64, i64) -> i64 =
110+
std::mem::transmute(self.fn_ptr);
111+
Some(f(args[0], args[1], args[2]))
112+
}
113+
4 => {
114+
let f: unsafe extern "C" fn(i64, i64, i64, i64) -> i64 =
115+
std::mem::transmute(self.fn_ptr);
116+
Some(f(args[0], args[1], args[2], args[3]))
117+
}
118+
_ => None,
119+
}
120+
}
121+
}
122+
}
123+
59124
impl<'ctx> JitContext<'ctx> {
60125
pub fn new(context: &'ctx Context) -> Result<Self, CodegenError> {
61126
let module = context.create_module("omc_jit");
@@ -107,6 +172,88 @@ impl<'ctx> JitContext<'ctx> {
107172
lowerer.lower()
108173
}
109174

175+
/// Try to JIT every user function in a bytecode `Module` in dual-band
176+
/// mode. Functions whose bodies use ops the codegen layer doesn't
177+
/// yet support (strings, dicts, builtins, cross-fn calls, etc.)
178+
/// are silently skipped — they stay routed through the tree-walk
179+
/// interpreter at runtime.
180+
///
181+
/// Returns a map of `fn_name -> JittedFn` for every fn that did
182+
/// lower successfully. The native function pointers inside
183+
/// `JittedFn` are owned by `self` (the underlying ExecutionEngine);
184+
/// callers must not invoke the returned fns after `self` is dropped.
185+
///
186+
/// The returned name uses the ORIGINAL (un-suffixed) bytecode-side
187+
/// fn name; under the hood the LLVM module sees `<name>_hbit` per
188+
/// the dual-band lowerer's naming convention.
189+
///
190+
/// Session D scope: every user fn is attempted. Sessions later
191+
/// add explicit `@hbit` pragma filtering so non-tagged fns aren't
192+
/// JIT'd even if they could be.
193+
pub fn jit_module(
194+
&self,
195+
module: &omnimcode_core::bytecode::Module,
196+
) -> Result<HashMap<String, JittedFn>, CodegenError> {
197+
let mut out: HashMap<String, JittedFn> = HashMap::new();
198+
for (name, cf) in &module.functions {
199+
let suffixed = format!("{}_hbit", name);
200+
match self.lower_function_dual_band(cf) {
201+
Ok(_) => {
202+
// get_function::<F> triggers JIT finalization and
203+
// returns a JitFunction wrapping the raw pointer.
204+
// We dispatch on arity to pick the right F so we
205+
// can extract the raw fn pointer for storage.
206+
let arity = cf.params.len();
207+
let fn_ptr = unsafe { self.extract_raw_fn_ptr(&suffixed, arity)? };
208+
out.insert(
209+
name.clone(),
210+
JittedFn { arity, fn_ptr },
211+
);
212+
}
213+
Err(_) => {
214+
// Lowering failed mid-emission. The LLVM module
215+
// now contains a partial / broken function — leaving
216+
// it in place would corrupt JIT finalization for
217+
// every subsequent fn (and crash on first call).
218+
// Erase it so the rest of the module stays valid.
219+
if let Some(broken) = self.module.get_function(&suffixed) {
220+
unsafe { broken.delete() };
221+
}
222+
}
223+
}
224+
}
225+
Ok(out)
226+
}
227+
228+
/// Erase a typed JitFunction down to a `*const ()` pointer for
229+
/// arity-tagged storage in `JittedFn`. Internal helper for
230+
/// `jit_module`; the caller is responsible for not invoking the
231+
/// returned pointer after `self` is dropped.
232+
unsafe fn extract_raw_fn_ptr(
233+
&self,
234+
name: &str,
235+
arity: usize,
236+
) -> Result<*const (), CodegenError> {
237+
macro_rules! by_arity {
238+
($t:ty) => {{
239+
let jf: JitFunction<'ctx, $t> = self
240+
.engine
241+
.get_function(name)
242+
.map_err(|e| format!("get_function({}): {:?}", name, e))?;
243+
jf.into_raw() as *const ()
244+
}};
245+
}
246+
let ptr = match arity {
247+
0 => by_arity!(unsafe extern "C" fn() -> i64),
248+
1 => by_arity!(unsafe extern "C" fn(i64) -> i64),
249+
2 => by_arity!(unsafe extern "C" fn(i64, i64) -> i64),
250+
3 => by_arity!(unsafe extern "C" fn(i64, i64, i64) -> i64),
251+
4 => by_arity!(unsafe extern "C" fn(i64, i64, i64, i64) -> i64),
252+
_ => return Err(format!("arity {} not supported in Session D jit_module", arity)),
253+
};
254+
Ok(ptr)
255+
}
256+
110257
/// JIT-lookup helper for single-arg i64 functions.
111258
pub unsafe fn get_i64_i64(
112259
&self,
@@ -187,10 +334,36 @@ impl<'ctx, 'a> FunctionLowerer<'ctx, 'a> {
187334

188335
self.collect_leaders()?;
189336
self.collect_cleanup_pops();
337+
self.bind_params_into_locals()?;
190338
self.emit_body()?;
191339
Ok(self.function)
192340
}
193341

342+
/// Bind each fn parameter into a named local-variable slot.
343+
/// The OMC bytecode compiler emits `Op::LoadVar("x")` for parameter
344+
/// access in the body (treating params as locals already in scope).
345+
/// The bytecode VM and tree-walk interpreter both pre-populate
346+
/// these bindings before executing the body; we mirror that here
347+
/// so LoadVar resolves to the actual parameter value rather than
348+
/// reading from an uninitialized alloca.
349+
fn bind_params_into_locals(&mut self) -> Result<(), CodegenError> {
350+
for (i, pname) in self.f.params.clone().iter().enumerate() {
351+
let param = self
352+
.function
353+
.get_nth_param(i as u32)
354+
.ok_or_else(|| format!("bind_params: no param at slot {}", i))?;
355+
let iv = match param {
356+
BasicValueEnum::IntValue(iv) => iv,
357+
_ => return Err(format!("bind_params: non-int param at slot {}", i)),
358+
};
359+
let slot = self.get_or_create_slot(pname)?;
360+
self.builder
361+
.build_store(slot, iv)
362+
.map_err(|e| format!("bind_params store {}: {}", pname, e))?;
363+
}
364+
Ok(())
365+
}
366+
194367
/// First pass: find op-indices that begin a new basic block. An
195368
/// op-index is a leader if:
196369
/// - it's 0 (entry)

0 commit comments

Comments
 (0)