Skip to content

Commit f62b212

Browse files
gHashTagona-agent
andcommitted
Add multi-turn conversation history
- ConversationHistory struct with Message storage - Automatic truncation (keeps last 10 messages + system) - formatForModel() builds full chat template - Context length check before generation - Commands: /clear (reset), /history (show count) - Response collection for history tracking History visible in output: [N tokens, X tok/s, history: M msgs] Co-authored-by: Ona <no-reply@ona.com>
1 parent 94ec90f commit f62b212

2 files changed

Lines changed: 238 additions & 4 deletions

File tree

bin/vibee

84.6 KB
Binary file not shown.

src/vibeec/gguf_chat.zig

Lines changed: 238 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,100 @@ const ChatTemplate = tokenizer_mod.ChatTemplate;
1515
// Sampling parameters struct
1616
const SamplingParams = inference.SamplingParams;
1717

18+
// ═══════════════════════════════════════════════════════════════════════════════
19+
// CONVERSATION HISTORY
20+
// ═══════════════════════════════════════════════════════════════════════════════
21+
22+
const Message = struct {
23+
role: Role,
24+
content: []const u8,
25+
26+
const Role = enum { system, user, assistant };
27+
};
28+
29+
const ConversationHistory = struct {
30+
allocator: std.mem.Allocator,
31+
messages: std.ArrayList(Message),
32+
max_messages: usize,
33+
34+
pub fn init(allocator: std.mem.Allocator, max_messages: usize) ConversationHistory {
35+
return .{
36+
.allocator = allocator,
37+
.messages = std.ArrayList(Message).init(allocator),
38+
.max_messages = max_messages,
39+
};
40+
}
41+
42+
pub fn deinit(self: *ConversationHistory) void {
43+
for (self.messages.items) |msg| {
44+
self.allocator.free(msg.content);
45+
}
46+
self.messages.deinit();
47+
}
48+
49+
pub fn addMessage(self: *ConversationHistory, role: Message.Role, content: []const u8) !void {
50+
// Copy content
51+
const content_copy = try self.allocator.dupe(u8, content);
52+
53+
// Truncate old messages if needed (keep system + last N)
54+
while (self.messages.items.len >= self.max_messages) {
55+
// Keep first message if it's system
56+
const start_idx: usize = if (self.messages.items.len > 0 and
57+
self.messages.items[0].role == .system) 1 else 0;
58+
59+
if (self.messages.items.len > start_idx) {
60+
const removed = self.messages.orderedRemove(start_idx);
61+
self.allocator.free(removed.content);
62+
} else {
63+
break;
64+
}
65+
}
66+
67+
try self.messages.append(.{ .role = role, .content = content_copy });
68+
}
69+
70+
pub fn formatForModel(self: *const ConversationHistory, allocator: std.mem.Allocator, template: *const ChatTemplate) ![]u8 {
71+
var result = std.ArrayList(u8).init(allocator);
72+
errdefer result.deinit();
73+
74+
for (self.messages.items) |msg| {
75+
switch (msg.role) {
76+
.system => {
77+
try result.appendSlice(template.system_prefix);
78+
try result.appendSlice(msg.content);
79+
try result.appendSlice(template.system_suffix);
80+
},
81+
.user => {
82+
try result.appendSlice(template.user_prefix);
83+
try result.appendSlice(msg.content);
84+
try result.appendSlice(template.user_suffix);
85+
},
86+
.assistant => {
87+
try result.appendSlice(template.assistant_prefix);
88+
try result.appendSlice(msg.content);
89+
try result.appendSlice(template.assistant_suffix);
90+
},
91+
}
92+
}
93+
94+
// Add assistant prefix for generation
95+
try result.appendSlice(template.assistant_prefix);
96+
97+
return result.toOwnedSlice();
98+
}
99+
100+
pub fn getMessageCount(self: *const ConversationHistory) usize {
101+
return self.messages.items.len;
102+
}
103+
104+
pub fn clear(self: *ConversationHistory) void {
105+
for (self.messages.items) |msg| {
106+
self.allocator.free(msg.content);
107+
}
108+
self.messages.clearRetainingCapacity();
109+
}
110+
};
111+
18112
// Entry point for CLI chat command
19113
pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32, temperature: f32, top_p: f32) !void {
20114
const stdout = std.io.getStdOut().writer();
@@ -65,14 +159,28 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
65159
const template = ChatTemplate.TINYLLAMA;
66160
const system_prompt = "You are a helpful AI assistant.";
67161

162+
// Initialize conversation history (keep last 10 messages + system)
163+
var history = ConversationHistory.init(allocator, 12);
164+
defer history.deinit();
165+
166+
// Add system message
167+
try history.addMessage(.system, system_prompt);
168+
68169
std.debug.print("Chat template: TinyLlama (ChatML format)\n", .{});
69170
std.debug.print("System: {s}\n", .{system_prompt});
70171
std.debug.print("Sampling: temperature={d:.2}, top_p={d:.2}\n", .{sampling_params.temperature, sampling_params.top_p});
71-
std.debug.print("\nReady! Type your message (or 'quit' to exit):\n\n", .{});
172+
std.debug.print("History: enabled (last 10 messages)\n", .{});
173+
std.debug.print("\nCommands: 'quit' to exit, '/clear' to reset history\n", .{});
174+
std.debug.print("\nReady! Type your message:\n\n", .{});
72175

73176
// Handle initial prompt if provided
74177
if (initial_prompt) |prompt| {
75-
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, prompt, max_tokens, sampling_params);
178+
try history.addMessage(.user, prompt);
179+
const response = try generateWithHistory(allocator, stdout, &model, &tokenizer, &template, &history, max_tokens, sampling_params);
180+
if (response) |resp| {
181+
try history.addMessage(.assistant, resp);
182+
allocator.free(resp);
183+
}
76184
}
77185

78186
// Interactive loop
@@ -87,13 +195,139 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
87195
if (trimmed.len == 0) continue;
88196
if (std.mem.eql(u8, trimmed, "quit") or std.mem.eql(u8, trimmed, "exit")) break;
89197

90-
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, trimmed, max_tokens, sampling_params);
198+
// Handle commands
199+
if (std.mem.eql(u8, trimmed, "/clear")) {
200+
history.clear();
201+
try history.addMessage(.system, system_prompt);
202+
model.resetKVCache();
203+
try stdout.print("[History cleared]\n\n", .{});
204+
continue;
205+
}
206+
207+
if (std.mem.eql(u8, trimmed, "/history")) {
208+
try stdout.print("[{d} messages in history]\n\n", .{history.getMessageCount()});
209+
continue;
210+
}
211+
212+
// Add user message to history
213+
try history.addMessage(.user, trimmed);
214+
215+
// Generate response with full history
216+
const response = try generateWithHistory(allocator, stdout, &model, &tokenizer, &template, &history, max_tokens, sampling_params);
217+
218+
// Add assistant response to history
219+
if (response) |resp| {
220+
try history.addMessage(.assistant, resp);
221+
allocator.free(resp);
222+
}
91223
}
92224

93225
try stdout.print("Goodbye!\n", .{});
94226
}
95227

96-
// Generate response with chat template and streaming output
228+
// Generate response with conversation history
229+
fn generateWithHistory(
230+
allocator: std.mem.Allocator,
231+
writer: anytype,
232+
model: *model_mod.FullModel,
233+
tokenizer: *tokenizer_mod.Tokenizer,
234+
template: *const ChatTemplate,
235+
history: *const ConversationHistory,
236+
max_tokens: u32,
237+
params: SamplingParams,
238+
) !?[]u8 {
239+
// Format full conversation history
240+
const formatted = try history.formatForModel(allocator, template);
241+
defer allocator.free(formatted);
242+
243+
try writer.print("Assistant: ", .{});
244+
var gen_timer = try std.time.Timer.start();
245+
246+
// Tokenize formatted prompt
247+
const tokens = tokenizer.encode(allocator, formatted) catch {
248+
try writer.print("[tokenization error]\n", .{});
249+
return null;
250+
};
251+
defer allocator.free(tokens);
252+
253+
// Check context length
254+
if (tokens.len > model.config.context_length - max_tokens) {
255+
try writer.print("[context too long, use /clear]\n", .{});
256+
return null;
257+
}
258+
259+
// Reset KV cache for full history processing
260+
model.resetKVCache();
261+
262+
// Process all history tokens (prefill)
263+
var last_logits: ?[]f32 = null;
264+
for (tokens, 0..) |token, pos| {
265+
if (last_logits) |l| allocator.free(l);
266+
last_logits = model.forward(token, pos) catch {
267+
try writer.print("[forward error]\n", .{});
268+
return null;
269+
};
270+
}
271+
272+
// Generate tokens with streaming output
273+
var generated: u32 = 0;
274+
var current_pos = tokens.len;
275+
var current_logits = last_logits orelse return null;
276+
var last_token: u32 = 0;
277+
278+
// Collect response for history
279+
var response = std.ArrayList(u8).init(allocator);
280+
errdefer response.deinit();
281+
282+
while (generated < max_tokens) : (generated += 1) {
283+
// Sample next token
284+
const sampled_token = inference.sampleWithParams(allocator, current_logits, params) catch blk: {
285+
var max_idx: u32 = 0;
286+
var max_val: f32 = current_logits[0];
287+
for (current_logits[1..], 1..) |l, i| {
288+
if (l > max_val) {
289+
max_val = l;
290+
max_idx = @intCast(i);
291+
}
292+
}
293+
break :blk max_idx;
294+
};
295+
296+
allocator.free(current_logits);
297+
298+
// Check for EOS
299+
if (sampled_token == tokenizer.eos_token) break;
300+
301+
// Decode token
302+
const decoded = tokenizer.decode(allocator, &[_]u32{sampled_token}) catch " ";
303+
defer if (decoded.len > 0) allocator.free(decoded);
304+
305+
// Stream output
306+
try writer.print("{s}", .{decoded});
307+
308+
// Collect for history
309+
try response.appendSlice(decoded);
310+
311+
// Check for end markers
312+
if (std.mem.indexOf(u8, decoded, "</s>") != null) break;
313+
if (std.mem.indexOf(u8, decoded, "<|") != null) break;
314+
315+
// Get next logits
316+
last_token = sampled_token;
317+
current_logits = model.forward(last_token, current_pos) catch break;
318+
current_pos += 1;
319+
}
320+
try writer.print("\n", .{});
321+
322+
const gen_time = gen_timer.read();
323+
const tok_per_sec = @as(f64, @floatFromInt(generated)) / (@as(f64, @floatFromInt(gen_time)) / 1e9);
324+
try writer.print("[{d} tokens, {d:.1} tok/s, history: {d} msgs]\n\n", .{ generated, tok_per_sec, history.getMessageCount() });
325+
326+
const result = try response.toOwnedSlice();
327+
return result;
328+
}
329+
330+
// Generate response with chat template and streaming output (legacy single-turn)
97331
fn generateWithTemplate(
98332
allocator: std.mem.Allocator,
99333
writer: anytype,

0 commit comments

Comments
 (0)