@@ -15,6 +15,100 @@ const ChatTemplate = tokenizer_mod.ChatTemplate;
1515// Sampling parameters struct
1616const SamplingParams = inference .SamplingParams ;
1717
18+ // ═══════════════════════════════════════════════════════════════════════════════
19+ // CONVERSATION HISTORY
20+ // ═══════════════════════════════════════════════════════════════════════════════
21+
22+ const Message = struct {
23+ role : Role ,
24+ content : []const u8 ,
25+
26+ const Role = enum { system , user , assistant };
27+ };
28+
29+ const ConversationHistory = struct {
30+ allocator : std.mem.Allocator ,
31+ messages : std .ArrayList (Message ),
32+ max_messages : usize ,
33+
34+ pub fn init (allocator : std.mem.Allocator , max_messages : usize ) ConversationHistory {
35+ return .{
36+ .allocator = allocator ,
37+ .messages = std .ArrayList (Message ).init (allocator ),
38+ .max_messages = max_messages ,
39+ };
40+ }
41+
42+ pub fn deinit (self : * ConversationHistory ) void {
43+ for (self .messages .items ) | msg | {
44+ self .allocator .free (msg .content );
45+ }
46+ self .messages .deinit ();
47+ }
48+
49+ pub fn addMessage (self : * ConversationHistory , role : Message.Role , content : []const u8 ) ! void {
50+ // Copy content
51+ const content_copy = try self .allocator .dupe (u8 , content );
52+
53+ // Truncate old messages if needed (keep system + last N)
54+ while (self .messages .items .len >= self .max_messages ) {
55+ // Keep first message if it's system
56+ const start_idx : usize = if (self .messages .items .len > 0 and
57+ self .messages .items [0 ].role == .system ) 1 else 0 ;
58+
59+ if (self .messages .items .len > start_idx ) {
60+ const removed = self .messages .orderedRemove (start_idx );
61+ self .allocator .free (removed .content );
62+ } else {
63+ break ;
64+ }
65+ }
66+
67+ try self .messages .append (.{ .role = role , .content = content_copy });
68+ }
69+
70+ pub fn formatForModel (self : * const ConversationHistory , allocator : std.mem.Allocator , template : * const ChatTemplate ) ! []u8 {
71+ var result = std .ArrayList (u8 ).init (allocator );
72+ errdefer result .deinit ();
73+
74+ for (self .messages .items ) | msg | {
75+ switch (msg .role ) {
76+ .system = > {
77+ try result .appendSlice (template .system_prefix );
78+ try result .appendSlice (msg .content );
79+ try result .appendSlice (template .system_suffix );
80+ },
81+ .user = > {
82+ try result .appendSlice (template .user_prefix );
83+ try result .appendSlice (msg .content );
84+ try result .appendSlice (template .user_suffix );
85+ },
86+ .assistant = > {
87+ try result .appendSlice (template .assistant_prefix );
88+ try result .appendSlice (msg .content );
89+ try result .appendSlice (template .assistant_suffix );
90+ },
91+ }
92+ }
93+
94+ // Add assistant prefix for generation
95+ try result .appendSlice (template .assistant_prefix );
96+
97+ return result .toOwnedSlice ();
98+ }
99+
100+ pub fn getMessageCount (self : * const ConversationHistory ) usize {
101+ return self .messages .items .len ;
102+ }
103+
104+ pub fn clear (self : * ConversationHistory ) void {
105+ for (self .messages .items ) | msg | {
106+ self .allocator .free (msg .content );
107+ }
108+ self .messages .clearRetainingCapacity ();
109+ }
110+ };
111+
18112// Entry point for CLI chat command
19113pub fn runChat (allocator : std.mem.Allocator , model_path : []const u8 , initial_prompt : ? []const u8 , max_tokens : u32 , temperature : f32 , top_p : f32 ) ! void {
20114 const stdout = std .io .getStdOut ().writer ();
@@ -65,14 +159,28 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
65159 const template = ChatTemplate .TINYLLAMA ;
66160 const system_prompt = "You are a helpful AI assistant." ;
67161
162+ // Initialize conversation history (keep last 10 messages + system)
163+ var history = ConversationHistory .init (allocator , 12 );
164+ defer history .deinit ();
165+
166+ // Add system message
167+ try history .addMessage (.system , system_prompt );
168+
68169 std .debug .print ("Chat template: TinyLlama (ChatML format)\n " , .{});
69170 std .debug .print ("System: {s}\n " , .{system_prompt });
70171 std .debug .print ("Sampling: temperature={d:.2}, top_p={d:.2}\n " , .{sampling_params .temperature , sampling_params .top_p });
71- std .debug .print ("\n Ready! Type your message (or 'quit' to exit):\n\n " , .{});
172+ std .debug .print ("History: enabled (last 10 messages)\n " , .{});
173+ std .debug .print ("\n Commands: 'quit' to exit, '/clear' to reset history\n " , .{});
174+ std .debug .print ("\n Ready! Type your message:\n\n " , .{});
72175
73176 // Handle initial prompt if provided
74177 if (initial_prompt ) | prompt | {
75- try generateWithTemplate (allocator , stdout , & model , & tokenizer , & template , system_prompt , prompt , max_tokens , sampling_params );
178+ try history .addMessage (.user , prompt );
179+ const response = try generateWithHistory (allocator , stdout , & model , & tokenizer , & template , & history , max_tokens , sampling_params );
180+ if (response ) | resp | {
181+ try history .addMessage (.assistant , resp );
182+ allocator .free (resp );
183+ }
76184 }
77185
78186 // Interactive loop
@@ -87,13 +195,139 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
87195 if (trimmed .len == 0 ) continue ;
88196 if (std .mem .eql (u8 , trimmed , "quit" ) or std .mem .eql (u8 , trimmed , "exit" )) break ;
89197
90- try generateWithTemplate (allocator , stdout , & model , & tokenizer , & template , system_prompt , trimmed , max_tokens , sampling_params );
198+ // Handle commands
199+ if (std .mem .eql (u8 , trimmed , "/clear" )) {
200+ history .clear ();
201+ try history .addMessage (.system , system_prompt );
202+ model .resetKVCache ();
203+ try stdout .print ("[History cleared]\n\n " , .{});
204+ continue ;
205+ }
206+
207+ if (std .mem .eql (u8 , trimmed , "/history" )) {
208+ try stdout .print ("[{d} messages in history]\n\n " , .{history .getMessageCount ()});
209+ continue ;
210+ }
211+
212+ // Add user message to history
213+ try history .addMessage (.user , trimmed );
214+
215+ // Generate response with full history
216+ const response = try generateWithHistory (allocator , stdout , & model , & tokenizer , & template , & history , max_tokens , sampling_params );
217+
218+ // Add assistant response to history
219+ if (response ) | resp | {
220+ try history .addMessage (.assistant , resp );
221+ allocator .free (resp );
222+ }
91223 }
92224
93225 try stdout .print ("Goodbye!\n " , .{});
94226}
95227
96- // Generate response with chat template and streaming output
228+ // Generate response with conversation history
229+ fn generateWithHistory (
230+ allocator : std.mem.Allocator ,
231+ writer : anytype ,
232+ model : * model_mod.FullModel ,
233+ tokenizer : * tokenizer_mod.Tokenizer ,
234+ template : * const ChatTemplate ,
235+ history : * const ConversationHistory ,
236+ max_tokens : u32 ,
237+ params : SamplingParams ,
238+ ) ! ? []u8 {
239+ // Format full conversation history
240+ const formatted = try history .formatForModel (allocator , template );
241+ defer allocator .free (formatted );
242+
243+ try writer .print ("Assistant: " , .{});
244+ var gen_timer = try std .time .Timer .start ();
245+
246+ // Tokenize formatted prompt
247+ const tokens = tokenizer .encode (allocator , formatted ) catch {
248+ try writer .print ("[tokenization error]\n " , .{});
249+ return null ;
250+ };
251+ defer allocator .free (tokens );
252+
253+ // Check context length
254+ if (tokens .len > model .config .context_length - max_tokens ) {
255+ try writer .print ("[context too long, use /clear]\n " , .{});
256+ return null ;
257+ }
258+
259+ // Reset KV cache for full history processing
260+ model .resetKVCache ();
261+
262+ // Process all history tokens (prefill)
263+ var last_logits : ? []f32 = null ;
264+ for (tokens , 0.. ) | token , pos | {
265+ if (last_logits ) | l | allocator .free (l );
266+ last_logits = model .forward (token , pos ) catch {
267+ try writer .print ("[forward error]\n " , .{});
268+ return null ;
269+ };
270+ }
271+
272+ // Generate tokens with streaming output
273+ var generated : u32 = 0 ;
274+ var current_pos = tokens .len ;
275+ var current_logits = last_logits orelse return null ;
276+ var last_token : u32 = 0 ;
277+
278+ // Collect response for history
279+ var response = std .ArrayList (u8 ).init (allocator );
280+ errdefer response .deinit ();
281+
282+ while (generated < max_tokens ) : (generated += 1 ) {
283+ // Sample next token
284+ const sampled_token = inference .sampleWithParams (allocator , current_logits , params ) catch blk : {
285+ var max_idx : u32 = 0 ;
286+ var max_val : f32 = current_logits [0 ];
287+ for (current_logits [1.. ], 1.. ) | l , i | {
288+ if (l > max_val ) {
289+ max_val = l ;
290+ max_idx = @intCast (i );
291+ }
292+ }
293+ break :blk max_idx ;
294+ };
295+
296+ allocator .free (current_logits );
297+
298+ // Check for EOS
299+ if (sampled_token == tokenizer .eos_token ) break ;
300+
301+ // Decode token
302+ const decoded = tokenizer .decode (allocator , &[_ ]u32 {sampled_token }) catch " " ;
303+ defer if (decoded .len > 0 ) allocator .free (decoded );
304+
305+ // Stream output
306+ try writer .print ("{s}" , .{decoded });
307+
308+ // Collect for history
309+ try response .appendSlice (decoded );
310+
311+ // Check for end markers
312+ if (std .mem .indexOf (u8 , decoded , "</s>" ) != null ) break ;
313+ if (std .mem .indexOf (u8 , decoded , "<|" ) != null ) break ;
314+
315+ // Get next logits
316+ last_token = sampled_token ;
317+ current_logits = model .forward (last_token , current_pos ) catch break ;
318+ current_pos += 1 ;
319+ }
320+ try writer .print ("\n " , .{});
321+
322+ const gen_time = gen_timer .read ();
323+ const tok_per_sec = @as (f64 , @floatFromInt (generated )) / (@as (f64 , @floatFromInt (gen_time )) / 1e9 );
324+ try writer .print ("[{d} tokens, {d:.1} tok/s, history: {d} msgs]\n\n " , .{ generated , tok_per_sec , history .getMessageCount () });
325+
326+ const result = try response .toOwnedSlice ();
327+ return result ;
328+ }
329+
330+ // Generate response with chat template and streaming output (legacy single-turn)
97331fn generateWithTemplate (
98332 allocator : std.mem.Allocator ,
99333 writer : anytype ,
0 commit comments