Skip to content

Commit cb7358e

Browse files
gHashTagona-agent
andcommitted
Add chat template (ChatML) and streaming output
- ChatML format for TinyLlama: <|system|>, <|user|>, <|assistant|> - Streaming output: tokens printed as generated - Proper prefill + generation loop - System prompt support Note: Base model doesn't follow instructions well. Need instruction-tuned model for proper Q&A. Co-authored-by: Ona <no-reply@ona.com>
1 parent de84bc3 commit cb7358e

2 files changed

Lines changed: 106 additions & 59 deletions

File tree

bin/vibee

15.8 KB
Binary file not shown.

src/vibeec/gguf_chat.zig

Lines changed: 106 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@ const model_mod = @import("gguf_model.zig");
99
const tokenizer_mod = @import("gguf_tokenizer.zig");
1010
const inference = @import("gguf_inference.zig");
1111

12+
// Chat template for formatting prompts
13+
const ChatTemplate = tokenizer_mod.ChatTemplate;
14+
1215
// Entry point for CLI chat command
1316
pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32) !void {
14-
_ = initial_prompt;
15-
_ = max_tokens;
17+
const stdout = std.io.getStdOut().writer();
1618

17-
std.debug.print("\n", .{});
18-
std.debug.print("╔══════════════════════════════════════════════════════════════╗\n", .{});
19-
std.debug.print("║ TRINITY CHAT - SIMD Optimized LLM ║\n", .{});
20-
std.debug.print("║ phi^2 + 1/phi^2 = 3 = TRINITY \n", .{});
21-
std.debug.print("╚══════════════════════════════════════════════════════════════╝\n", .{});
22-
std.debug.print("\n", .{});
19+
try stdout.print("\n", .{});
20+
try stdout.print("╔══════════════════════════════════════════════════════════════╗\n", .{});
21+
try stdout.print("║ TRINITY CHAT - SIMD Optimized LLM ║\n", .{});
22+
try stdout.print("║ Chat Template + Streaming Output\n", .{});
23+
try stdout.print("╚══════════════════════════════════════════════════════════════╝\n", .{});
24+
try stdout.print("\n", .{});
2325

2426
// Load model
2527
std.debug.print("Loading model: {s}\n", .{model_path});
@@ -48,76 +50,121 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
4850
};
4951
defer tokenizer.deinit();
5052

51-
std.debug.print("Ready! Type your message (or 'quit' to exit):\n\n", .{});
53+
// Use TinyLlama chat template
54+
const template = ChatTemplate.TINYLLAMA;
55+
const system_prompt = "You are a helpful AI assistant.";
56+
57+
std.debug.print("Chat template: TinyLlama (ChatML format)\n", .{});
58+
std.debug.print("System: {s}\n", .{system_prompt});
59+
std.debug.print("\nReady! Type your message (or 'quit' to exit):\n\n", .{});
60+
61+
// Handle initial prompt if provided
62+
if (initial_prompt) |prompt| {
63+
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, prompt, max_tokens);
64+
}
5265

5366
// Interactive loop
5467
const stdin = std.io.getStdIn().reader();
5568
var buf: [1024]u8 = undefined;
5669

5770
while (true) {
58-
std.debug.print("User: ", .{});
71+
try stdout.print("User: ", .{});
5972
const line = stdin.readUntilDelimiter(&buf, '\n') catch break;
6073
const trimmed = std.mem.trim(u8, line, " \t\r\n");
6174

6275
if (trimmed.len == 0) continue;
6376
if (std.mem.eql(u8, trimmed, "quit") or std.mem.eql(u8, trimmed, "exit")) break;
6477

65-
// Generate response using full transformer forward pass
66-
std.debug.print("Assistant: ", .{});
67-
var gen_timer = try std.time.Timer.start();
68-
69-
const tokens = tokenizer.encode(allocator, trimmed) catch {
70-
std.debug.print("[tokenization error]\n", .{});
71-
continue;
72-
};
73-
defer allocator.free(tokens);
74-
75-
// Real generation with transformer
76-
var generated: u32 = 0;
77-
var current_tokens = std.ArrayList(u32).init(allocator);
78-
defer current_tokens.deinit();
79-
for (tokens) |t| try current_tokens.append(t);
80-
81-
const max_gen: u32 = 50;
82-
while (generated < max_gen) : (generated += 1) {
83-
// Forward pass for last token
84-
const pos = current_tokens.items.len - 1;
85-
const last_token = current_tokens.items[pos];
86-
87-
const logits = model.forward(last_token, pos) catch {
88-
std.debug.print("[forward error]", .{});
89-
break;
90-
};
91-
defer allocator.free(logits);
92-
93-
// Sample next token (greedy)
94-
var max_idx: u32 = 0;
95-
var max_val: f32 = logits[0];
96-
for (logits[1..], 1..) |l, i| {
97-
if (l > max_val) {
98-
max_val = l;
99-
max_idx = @intCast(i);
100-
}
101-
}
78+
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, trimmed, max_tokens);
79+
}
10280

103-
// Check for EOS
104-
if (max_idx == tokenizer.eos_token) break;
81+
try stdout.print("Goodbye!\n", .{});
82+
}
10583

106-
// Decode and print
107-
const decoded = tokenizer.decode(allocator, &[_]u32{max_idx}) catch " ";
108-
defer if (decoded.len > 0) allocator.free(decoded);
109-
std.debug.print("{s}", .{decoded});
84+
// Generate response with chat template and streaming output
85+
fn generateWithTemplate(
86+
allocator: std.mem.Allocator,
87+
writer: anytype,
88+
model: *model_mod.FullModel,
89+
tokenizer: *tokenizer_mod.Tokenizer,
90+
template: *const ChatTemplate,
91+
system: []const u8,
92+
user_input: []const u8,
93+
max_tokens: u32,
94+
) !void {
95+
// Format prompt with chat template
96+
const formatted = try template.formatPrompt(allocator, system, user_input);
97+
defer allocator.free(formatted);
98+
99+
try writer.print("Assistant: ", .{});
100+
var gen_timer = try std.time.Timer.start();
101+
102+
// Tokenize formatted prompt
103+
const tokens = tokenizer.encode(allocator, formatted) catch {
104+
try writer.print("[tokenization error]\n", .{});
105+
return;
106+
};
107+
defer allocator.free(tokens);
108+
109+
// Reset KV cache for new conversation
110+
model.resetKVCache();
111+
112+
// Process prompt tokens (prefill) - build up KV cache
113+
var last_logits: ?[]f32 = null;
114+
for (tokens, 0..) |token, pos| {
115+
if (last_logits) |l| allocator.free(l);
116+
last_logits = model.forward(token, pos) catch {
117+
try writer.print("[forward error]\n", .{});
118+
return;
119+
};
120+
}
110121

111-
try current_tokens.append(max_idx);
122+
// Generate tokens with streaming output
123+
var generated: u32 = 0;
124+
var current_pos = tokens.len;
125+
126+
// Use logits from last prefill token for first generation
127+
var current_logits = last_logits orelse return;
128+
var last_token: u32 = 0;
129+
130+
while (generated < max_tokens) : (generated += 1) {
131+
// Sample next token (greedy)
132+
var max_idx: u32 = 0;
133+
var max_val: f32 = current_logits[0];
134+
for (current_logits[1..], 1..) |l, i| {
135+
if (l > max_val) {
136+
max_val = l;
137+
max_idx = @intCast(i);
138+
}
112139
}
113-
std.debug.print("\n", .{});
114140

115-
const gen_time = gen_timer.read();
116-
const tok_per_sec = @as(f64, @floatFromInt(generated)) / (@as(f64, @floatFromInt(gen_time)) / 1e9);
117-
std.debug.print("[{d} tokens, {d:.1} tok/s]\n\n", .{ generated, tok_per_sec });
141+
// Free current logits
142+
allocator.free(current_logits);
143+
144+
// Check for EOS
145+
if (max_idx == tokenizer.eos_token) break;
146+
147+
// Decode and stream output immediately
148+
const decoded = tokenizer.decode(allocator, &[_]u32{max_idx}) catch " ";
149+
defer if (decoded.len > 0) allocator.free(decoded);
150+
151+
// Stream: print immediately without buffering
152+
try writer.print("{s}", .{decoded});
153+
154+
// Check for </s> or end markers in decoded text
155+
if (std.mem.indexOf(u8, decoded, "</s>") != null) break;
156+
if (std.mem.indexOf(u8, decoded, "<|") != null) break;
157+
158+
// Get next logits
159+
last_token = max_idx;
160+
current_logits = model.forward(last_token, current_pos) catch break;
161+
current_pos += 1;
118162
}
163+
try writer.print("\n", .{});
119164

120-
std.debug.print("Goodbye!\n", .{});
165+
const gen_time = gen_timer.read();
166+
const tok_per_sec = @as(f64, @floatFromInt(generated)) / (@as(f64, @floatFromInt(gen_time)) / 1e9);
167+
try writer.print("[{d} tokens, {d:.1} tok/s]\n\n", .{ generated, tok_per_sec });
121168
}
122169

123170
pub fn main() !void {

0 commit comments

Comments
 (0)