Skip to content

Commit 04750a5

Browse files
authored
Merge pull request #116 from AdaWorldAPI/claude/risc-thought-engine-TCZw7
feat: pooling strategies + builder pattern + commit sinks (EmbedAnything patterns)
2 parents 40e6d92 + 5322e19 commit 04750a5

3 files changed

Lines changed: 573 additions & 0 deletions

File tree

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
//! ThinkingEngineBuilder: fluent API for engine construction.
2+
//!
3+
//! ```rust,no_run
4+
//! let engine = ThinkingEngineBuilder::new()
5+
//! .lens(Lens::Jina)
6+
//! .table_type(TableType::SignedI8)
7+
//! .pooling(Pooling::TopK(5))
8+
//! .on_commit(|bus| l4.learn_from(bus))
9+
//! .build();
10+
//! ```
11+
12+
use crate::engine::ThinkingEngine;
13+
use crate::signed_engine::SignedThinkingEngine;
14+
use crate::pooling::Pooling;
15+
16+
/// Which baked lens to use.
17+
#[derive(Clone, Debug)]
18+
pub enum Lens {
19+
Jina,
20+
BgeM3,
21+
Reranker,
22+
Custom(Vec<u8>),
23+
}
24+
25+
/// Distance table encoding.
26+
#[derive(Clone, Copy, Debug, PartialEq)]
27+
pub enum TableType {
28+
UnsignedU8,
29+
SignedI8,
30+
}
31+
32+
/// Built engine: either unsigned or signed.
33+
pub enum BuiltEngine {
34+
Unsigned(ThinkingEngine),
35+
Signed(SignedThinkingEngine),
36+
}
37+
38+
impl BuiltEngine {
39+
pub fn perturb(&mut self, indices: &[u16]) {
40+
match self {
41+
BuiltEngine::Unsigned(e) => e.perturb(indices),
42+
BuiltEngine::Signed(e) => e.perturb(indices),
43+
}
44+
}
45+
46+
pub fn reset(&mut self) {
47+
match self {
48+
BuiltEngine::Unsigned(e) => e.reset(),
49+
BuiltEngine::Signed(e) => e.reset(),
50+
}
51+
}
52+
53+
pub fn energy(&self) -> &[f32] {
54+
match self {
55+
BuiltEngine::Unsigned(e) => &e.energy,
56+
BuiltEngine::Signed(e) => &e.energy,
57+
}
58+
}
59+
60+
pub fn cycles(&self) -> u16 {
61+
match self {
62+
BuiltEngine::Unsigned(e) => e.cycles,
63+
BuiltEngine::Signed(e) => e.cycles,
64+
}
65+
}
66+
67+
pub fn size(&self) -> usize {
68+
match self {
69+
BuiltEngine::Unsigned(e) => e.size,
70+
BuiltEngine::Signed(e) => e.size,
71+
}
72+
}
73+
74+
pub fn think(&mut self, max_cycles: usize) {
75+
match self {
76+
BuiltEngine::Unsigned(e) => { e.think(max_cycles); }
77+
BuiltEngine::Signed(e) => { e.think(max_cycles); }
78+
}
79+
}
80+
}
81+
82+
/// Commit sink: where committed thoughts go.
83+
pub type CommitSink = Box<dyn Fn(&crate::dto::BusDto) + Send + Sync>;
84+
85+
/// Builder for ThinkingEngine with fluent API.
86+
pub struct ThinkingEngineBuilder {
87+
lens: Option<Lens>,
88+
table_type: TableType,
89+
pooling: Pooling,
90+
max_cycles: usize,
91+
sinks: Vec<CommitSink>,
92+
}
93+
94+
impl ThinkingEngineBuilder {
95+
pub fn new() -> Self {
96+
Self {
97+
lens: None,
98+
table_type: TableType::UnsignedU8,
99+
pooling: Pooling::ArgMax,
100+
max_cycles: 10,
101+
sinks: Vec::new(),
102+
}
103+
}
104+
105+
/// Select a baked lens.
106+
pub fn lens(mut self, lens: Lens) -> Self {
107+
self.lens = Some(lens);
108+
self
109+
}
110+
111+
/// Set table encoding type: u8 (default) or i8 signed.
112+
pub fn table_type(mut self, tt: TableType) -> Self {
113+
self.table_type = tt;
114+
self
115+
}
116+
117+
/// Set pooling strategy.
118+
pub fn pooling(mut self, p: Pooling) -> Self {
119+
self.pooling = p;
120+
self
121+
}
122+
123+
/// Set max think cycles (default: 10).
124+
pub fn max_cycles(mut self, n: usize) -> Self {
125+
self.max_cycles = n;
126+
self
127+
}
128+
129+
/// Add a commit sink (adapter pattern).
130+
/// Sinks receive the BusDto after every commit.
131+
pub fn on_commit(mut self, sink: impl Fn(&crate::dto::BusDto) + Send + Sync + 'static) -> Self {
132+
self.sinks.push(Box::new(sink));
133+
self
134+
}
135+
136+
/// Build the engine.
137+
pub fn build(self) -> Result<ConfiguredEngine, String> {
138+
let table = match self.lens {
139+
Some(Lens::Jina) => crate::jina_lens::JINA_HDR_TABLE.to_vec(),
140+
Some(Lens::BgeM3) => crate::bge_m3_lens::BGE_M3_HDR_TABLE.to_vec(),
141+
Some(Lens::Reranker) => crate::reranker_lens::RERANKER_HDR_TABLE.to_vec(),
142+
Some(Lens::Custom(t)) => t,
143+
None => return Err("no lens specified".into()),
144+
};
145+
146+
let engine = match self.table_type {
147+
TableType::UnsignedU8 => BuiltEngine::Unsigned(ThinkingEngine::new(table)),
148+
TableType::SignedI8 => BuiltEngine::Signed(
149+
crate::signed_engine::SignedThinkingEngine::from_unsigned(&table)
150+
),
151+
};
152+
153+
Ok(ConfiguredEngine {
154+
engine,
155+
pooling: self.pooling,
156+
max_cycles: self.max_cycles,
157+
sinks: self.sinks,
158+
})
159+
}
160+
}
161+
162+
impl Default for ThinkingEngineBuilder {
163+
fn default() -> Self {
164+
Self::new()
165+
}
166+
}
167+
168+
/// A fully configured engine with pooling and commit sinks.
169+
pub struct ConfiguredEngine {
170+
pub engine: BuiltEngine,
171+
pub pooling: Pooling,
172+
pub max_cycles: usize,
173+
sinks: Vec<CommitSink>,
174+
}
175+
176+
impl ConfiguredEngine {
177+
/// Full pipeline: perturb → think → pool → commit → notify sinks.
178+
pub fn process(&mut self, codebook_indices: &[u16]) -> crate::dto::BusDto {
179+
self.engine.reset();
180+
self.engine.perturb(codebook_indices);
181+
self.engine.think(self.max_cycles);
182+
183+
let bus = self.pooling.to_bus(self.engine.energy(), self.engine.cycles());
184+
185+
// Notify all sinks
186+
for sink in &self.sinks {
187+
sink(&bus);
188+
}
189+
190+
bus
191+
}
192+
193+
/// Access the underlying engine.
194+
pub fn inner(&self) -> &BuiltEngine {
195+
&self.engine
196+
}
197+
198+
/// Access the pooling strategy.
199+
pub fn pooling(&self) -> &Pooling {
200+
&self.pooling
201+
}
202+
}
203+
204+
#[cfg(test)]
205+
mod tests {
206+
use super::*;
207+
use std::sync::{Arc, atomic::{AtomicU32, Ordering}};
208+
209+
#[test]
210+
fn builder_jina_unsigned() {
211+
let engine = ThinkingEngineBuilder::new()
212+
.lens(Lens::Jina)
213+
.build()
214+
.unwrap();
215+
assert_eq!(engine.engine.size(), 256);
216+
}
217+
218+
#[test]
219+
fn builder_reranker_signed() {
220+
let engine = ThinkingEngineBuilder::new()
221+
.lens(Lens::Reranker)
222+
.table_type(TableType::SignedI8)
223+
.build()
224+
.unwrap();
225+
assert_eq!(engine.engine.size(), 256);
226+
}
227+
228+
#[test]
229+
fn builder_with_pooling() {
230+
let mut engine = ThinkingEngineBuilder::new()
231+
.lens(Lens::Jina)
232+
.pooling(Pooling::TopK(3))
233+
.build()
234+
.unwrap();
235+
236+
let bus = engine.process(&[50, 52, 54]);
237+
assert!(bus.energy > 0.0);
238+
}
239+
240+
#[test]
241+
fn builder_with_sink() {
242+
let counter = Arc::new(AtomicU32::new(0));
243+
let counter_clone = counter.clone();
244+
245+
let mut engine = ThinkingEngineBuilder::new()
246+
.lens(Lens::BgeM3)
247+
.on_commit(move |_bus| {
248+
counter_clone.fetch_add(1, Ordering::Relaxed);
249+
})
250+
.build()
251+
.unwrap();
252+
253+
engine.process(&[10, 20, 30]);
254+
engine.process(&[40, 50, 60]);
255+
256+
assert_eq!(counter.load(Ordering::Relaxed), 2);
257+
}
258+
259+
#[test]
260+
fn builder_no_lens_errors() {
261+
let result = ThinkingEngineBuilder::new().build();
262+
assert!(result.is_err());
263+
}
264+
265+
#[test]
266+
fn builder_custom_table() {
267+
let mut table = vec![128u8; 64 * 64];
268+
for i in 0..64 { table[i * 64 + i] = 255; }
269+
270+
let engine = ThinkingEngineBuilder::new()
271+
.lens(Lens::Custom(table))
272+
.table_type(TableType::SignedI8)
273+
.pooling(Pooling::Mean { threshold: 0.001 })
274+
.max_cycles(5)
275+
.build()
276+
.unwrap();
277+
278+
assert_eq!(engine.engine.size(), 64);
279+
}
280+
281+
#[test]
282+
fn builder_multiple_sinks() {
283+
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
284+
let log1 = log.clone();
285+
let log2 = log.clone();
286+
287+
let mut engine = ThinkingEngineBuilder::new()
288+
.lens(Lens::Jina)
289+
.on_commit(move |bus| {
290+
log1.lock().unwrap().push(format!("sink1:{}", bus.codebook_index));
291+
})
292+
.on_commit(move |bus| {
293+
log2.lock().unwrap().push(format!("sink2:{}", bus.codebook_index));
294+
})
295+
.build()
296+
.unwrap();
297+
298+
engine.process(&[100]);
299+
let entries = log.lock().unwrap();
300+
assert_eq!(entries.len(), 2);
301+
assert!(entries[0].starts_with("sink1:"));
302+
assert!(entries[1].starts_with("sink2:"));
303+
}
304+
}

crates/thinking-engine/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,5 @@ pub mod dual_engine;
4444
pub mod l4_bridge;
4545
pub mod composite_engine;
4646
pub mod signed_domino;
47+
pub mod pooling;
48+
pub mod builder;

0 commit comments

Comments
 (0)