Skip to content

Commit 096ba50

Browse files
gHashTagona-agent
andcommitted
Integrate ternary inference into forward pass
- --ternary flag enables BitNet mode - 16x memory savings (quantize to {-1, 0, +1}) - matVecAuto selects float/ternary automatically - Memory stats printed on conversion Note: Speed is slower without SIMD ternary ops. Quality degrades for non-BitNet trained models. Co-authored-by: Ona <no-reply@ona.com>
1 parent 9747e0d commit 096ba50

4 files changed

Lines changed: 110 additions & 15 deletions

File tree

bin/vibee

69.8 KB
Binary file not shown.

src/vibeec/gen_cmd.zig

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub fn main() !void {
4747
var max_tokens: u32 = 100;
4848
var temperature: f32 = 0.7;
4949
var top_p: f32 = 0.9;
50+
var use_ternary: bool = false;
5051

5152
var i: usize = 2;
5253
while (i < args.len) : (i += 1) {
@@ -65,6 +66,8 @@ pub fn main() !void {
6566
} else if (std.mem.eql(u8, args[i], "--top-p") and i + 1 < args.len) {
6667
top_p = std.fmt.parseFloat(f32, args[i + 1]) catch 0.9;
6768
i += 1;
69+
} else if (std.mem.eql(u8, args[i], "--ternary")) {
70+
use_ternary = true;
6871
}
6972
}
7073

@@ -73,7 +76,7 @@ pub fn main() !void {
7376
return;
7477
}
7578

76-
try gguf_chat.runChat(allocator, model_path.?, prompt, max_tokens, temperature, top_p);
79+
try gguf_chat.runChatWithTernary(allocator, model_path.?, prompt, max_tokens, temperature, top_p, use_ternary);
7780
} else if (std.mem.eql(u8, command, "serve")) {
7881
// HTTP API server
7982
var model_path: ?[]const u8 = null;
@@ -119,6 +122,7 @@ fn printUsage() void {
119122
\\ --max-tokens N Max tokens to generate (default: 100)
120123
\\ --temperature F Sampling temperature (default: 0.7)
121124
\\ --top-p F Top-p nucleus sampling (default: 0.9)
125+
\\ --ternary Enable BitNet ternary mode (16x memory savings)
122126
\\ vibeec serve --model <path.gguf> [options] HTTP API server (OpenAI compatible)
123127
\\ --port N Port to listen on (default: 8080)
124128
\\ vibeec help Show this help

src/vibeec/gguf_chat.zig

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,17 @@ const ConversationHistory = struct {
109109
}
110110
};
111111

112-
// Entry point for CLI chat command
112+
// Entry point for CLI chat command (with ternary support)
113+
pub fn runChatWithTernary(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32, temperature: f32, top_p: f32, use_ternary: bool) !void {
114+
return runChatInternal(allocator, model_path, initial_prompt, max_tokens, temperature, top_p, use_ternary);
115+
}
116+
117+
// Entry point for CLI chat command (backward compatible)
113118
pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32, temperature: f32, top_p: f32) !void {
119+
return runChatInternal(allocator, model_path, initial_prompt, max_tokens, temperature, top_p, false);
120+
}
121+
122+
fn runChatInternal(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32, temperature: f32, top_p: f32, use_ternary: bool) !void {
114123
const stdout = std.io.getStdOut().writer();
115124

116125
try stdout.print("\n", .{});
@@ -147,6 +156,14 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
147156
const load_time = timer.read();
148157
std.debug.print("Weights loaded in {d:.2} seconds\n", .{@as(f64, @floatFromInt(load_time)) / 1e9});
149158

159+
// Enable ternary mode if requested (BitNet {-1, 0, +1})
160+
if (use_ternary) {
161+
std.debug.print("\nEnabling ternary mode (BitNet weights)...\n", .{});
162+
model.enableTernaryMode() catch |err| {
163+
std.debug.print("Warning: Could not enable ternary mode: {}\n", .{err});
164+
};
165+
}
166+
150167
// Initialize tokenizer
151168
std.debug.print("\nInitializing tokenizer...\n", .{});
152169
var tokenizer = tokenizer_mod.Tokenizer.init(allocator, &model.reader) catch |err| {

src/vibeec/gguf_model.zig

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,24 @@ const gguf = @import("gguf_reader.zig");
66
const inference = @import("gguf_inference.zig");
77
const transformer = @import("gguf_transformer.zig");
88
const simd = @import("simd_matmul.zig");
9+
const ternary = @import("ternary_weights.zig");
910

1011
pub const FullModel = struct {
1112
allocator: std.mem.Allocator,
1213
reader: gguf.GGUFReader,
1314
config: inference.ModelConfig,
1415

16+
// Ternary mode flag
17+
use_ternary: bool = false,
18+
1519
// Core weights
1620
token_embedding: []f32,
1721
output_weight: []f32,
1822
output_norm: []f32,
1923

24+
// Ternary weights (optional - for BitNet models)
25+
ternary_output_weight: ?[]u8 = null,
26+
2027
// Per-layer weights
2128
layers: []LayerWeights,
2229

@@ -48,6 +55,15 @@ pub const FullModel = struct {
4855
w_gate: []f32,
4956
w_up: []f32,
5057
w_down: []f32,
58+
59+
// Ternary versions (optional)
60+
ternary_wq: ?[]u8 = null,
61+
ternary_wk: ?[]u8 = null,
62+
ternary_wv: ?[]u8 = null,
63+
ternary_wo: ?[]u8 = null,
64+
ternary_w_gate: ?[]u8 = null,
65+
ternary_w_up: ?[]u8 = null,
66+
ternary_w_down: ?[]u8 = null,
5167
};
5268

5369
pub fn init(allocator: std.mem.Allocator, path: []const u8) !FullModel {
@@ -263,6 +279,64 @@ pub const FullModel = struct {
263279
}
264280
}
265281

282+
/// Enable ternary mode - quantize all weights to {-1, 0, +1}
283+
/// This provides 16x memory savings and faster inference on CPU
284+
pub fn enableTernaryMode(self: *FullModel) !void {
285+
if (self.use_ternary) return; // Already enabled
286+
287+
std.debug.print("\nConverting to ternary weights...\n", .{});
288+
const stats = ternary.MemoryStats.calculate(self.countParameters());
289+
stats.print();
290+
291+
// Convert output weights
292+
const threshold = ternary.calculateThreshold(self.output_weight);
293+
self.ternary_output_weight = try ternary.quantizeToTernary(self.allocator, self.output_weight, threshold);
294+
295+
// Convert layer weights
296+
for (self.layers) |*layer| {
297+
const t_wq = ternary.calculateThreshold(layer.wq);
298+
const t_wk = ternary.calculateThreshold(layer.wk);
299+
const t_wv = ternary.calculateThreshold(layer.wv);
300+
const t_wo = ternary.calculateThreshold(layer.wo);
301+
const t_gate = ternary.calculateThreshold(layer.w_gate);
302+
const t_up = ternary.calculateThreshold(layer.w_up);
303+
const t_down = ternary.calculateThreshold(layer.w_down);
304+
305+
layer.ternary_wq = try ternary.quantizeToTernary(self.allocator, layer.wq, t_wq);
306+
layer.ternary_wk = try ternary.quantizeToTernary(self.allocator, layer.wk, t_wk);
307+
layer.ternary_wv = try ternary.quantizeToTernary(self.allocator, layer.wv, t_wv);
308+
layer.ternary_wo = try ternary.quantizeToTernary(self.allocator, layer.wo, t_wo);
309+
layer.ternary_w_gate = try ternary.quantizeToTernary(self.allocator, layer.w_gate, t_gate);
310+
layer.ternary_w_up = try ternary.quantizeToTernary(self.allocator, layer.w_up, t_up);
311+
layer.ternary_w_down = try ternary.quantizeToTernary(self.allocator, layer.w_down, t_down);
312+
}
313+
314+
self.use_ternary = true;
315+
std.debug.print("Ternary mode enabled!\n", .{});
316+
}
317+
318+
/// Count total parameters
319+
fn countParameters(self: *const FullModel) usize {
320+
var count: usize = self.token_embedding.len + self.output_weight.len + self.output_norm.len;
321+
for (self.layers) |layer| {
322+
count += layer.wq.len + layer.wk.len + layer.wv.len + layer.wo.len;
323+
count += layer.w_gate.len + layer.w_up.len + layer.w_down.len;
324+
count += layer.attn_norm.len + layer.ffn_norm.len;
325+
}
326+
return count;
327+
}
328+
329+
/// Matrix-vector multiply with automatic ternary/float selection
330+
fn matVecAuto(self: *const FullModel, output: []f32, weights_f32: []const f32, weights_ternary: ?[]const u8, input: []const f32, rows: usize, cols: usize) void {
331+
if (self.use_ternary) {
332+
if (weights_ternary) |tw| {
333+
ternary.ternaryMatVec(output, tw, input, rows, cols);
334+
return;
335+
}
336+
}
337+
inference.matVec(output, weights_f32, input, rows, cols);
338+
}
339+
266340
// Forward pass for single token - OPTIMIZED with pre-allocated buffers
267341
pub fn forward(self: *FullModel, token: u32, pos: usize) ![]f32 {
268342
const hidden_size = self.config.hidden_size;
@@ -280,9 +354,9 @@ pub const FullModel = struct {
280354
// Final RMS norm
281355
inference.rmsNorm(self.buf_temp, self.buf_hidden, self.output_norm, self.config.rms_norm_eps);
282356

283-
// Output projection (only allocation is for return value)
357+
// Output projection (only allocation is for return value) - with ternary support
284358
const logits = try self.allocator.alloc(f32, self.config.vocab_size);
285-
inference.matVec(logits, self.output_weight, self.buf_temp, self.config.vocab_size, hidden_size);
359+
self.matVecAuto(logits, self.output_weight, self.ternary_output_weight, self.buf_temp, self.config.vocab_size, hidden_size);
286360

287361
return logits;
288362
}
@@ -420,10 +494,10 @@ pub const FullModel = struct {
420494
// Pre-attention norm (use buf_normed)
421495
inference.rmsNorm(self.buf_normed, input, layer.attn_norm, rms_eps);
422496

423-
// Compute Q, K, V (use buf_q, buf_k, buf_v)
424-
inference.matVec(self.buf_q, layer.wq, self.buf_normed, num_heads * head_dim, hidden_size);
425-
inference.matVec(self.buf_k, layer.wk, self.buf_normed, num_kv_heads * head_dim, hidden_size);
426-
inference.matVec(self.buf_v, layer.wv, self.buf_normed, num_kv_heads * head_dim, hidden_size);
497+
// Compute Q, K, V (use buf_q, buf_k, buf_v) - with ternary support
498+
self.matVecAuto(self.buf_q, layer.wq, layer.ternary_wq, self.buf_normed, num_heads * head_dim, hidden_size);
499+
self.matVecAuto(self.buf_k, layer.wk, layer.ternary_wk, self.buf_normed, num_kv_heads * head_dim, hidden_size);
500+
self.matVecAuto(self.buf_v, layer.wv, layer.ternary_wv, self.buf_normed, num_kv_heads * head_dim, hidden_size);
427501

428502
// Apply RoPE
429503
for (0..num_heads) |h| {
@@ -471,8 +545,8 @@ pub const FullModel = struct {
471545
}
472546
}
473547

474-
// Output projection (use buf_attn_proj)
475-
inference.matVec(self.buf_attn_proj, layer.wo, self.buf_attn_out, hidden_size, num_heads * head_dim);
548+
// Output projection (use buf_attn_proj) - with ternary support
549+
self.matVecAuto(self.buf_attn_proj, layer.wo, layer.ternary_wo, self.buf_attn_out, hidden_size, num_heads * head_dim);
476550

477551
// Residual
478552
for (0..hidden_size) |i| {
@@ -482,17 +556,17 @@ pub const FullModel = struct {
482556
// Pre-FFN norm
483557
inference.rmsNorm(self.buf_normed, output, layer.ffn_norm, rms_eps);
484558

485-
// FFN with SwiGLU (use buf_ffn_gate, buf_ffn_up)
486-
inference.matVec(self.buf_ffn_gate, layer.w_gate, self.buf_normed, intermediate_size, hidden_size);
487-
inference.matVec(self.buf_ffn_up, layer.w_up, self.buf_normed, intermediate_size, hidden_size);
559+
// FFN with SwiGLU (use buf_ffn_gate, buf_ffn_up) - with ternary support
560+
self.matVecAuto(self.buf_ffn_gate, layer.w_gate, layer.ternary_w_gate, self.buf_normed, intermediate_size, hidden_size);
561+
self.matVecAuto(self.buf_ffn_up, layer.w_up, layer.ternary_w_up, self.buf_normed, intermediate_size, hidden_size);
488562

489563
// SwiGLU
490564
for (0..intermediate_size) |i| {
491565
self.buf_ffn_gate[i] = inference.silu(self.buf_ffn_gate[i]) * self.buf_ffn_up[i];
492566
}
493567

494-
// Down projection (use buf_ffn_out)
495-
inference.matVec(self.buf_ffn_out, layer.w_down, self.buf_ffn_gate, hidden_size, intermediate_size);
568+
// Down projection (use buf_ffn_out) - with ternary support
569+
self.matVecAuto(self.buf_ffn_out, layer.w_down, layer.ternary_w_down, self.buf_ffn_gate, hidden_size, intermediate_size);
496570

497571
// Residual
498572
for (0..hidden_size) |i| {

0 commit comments

Comments
 (0)