@@ -9,17 +9,19 @@ const model_mod = @import("gguf_model.zig");
99const tokenizer_mod = @import ("gguf_tokenizer.zig" );
1010const inference = @import ("gguf_inference.zig" );
1111
12+ // Chat template for formatting prompts
13+ const ChatTemplate = tokenizer_mod .ChatTemplate ;
14+
1215// Entry point for CLI chat command
1316pub fn runChat (allocator : std.mem.Allocator , model_path : []const u8 , initial_prompt : ? []const u8 , max_tokens : u32 ) ! void {
14- _ = initial_prompt ;
15- _ = max_tokens ;
17+ const stdout = std .io .getStdOut ().writer ();
1618
17- std . debug .print ("\n " , .{});
18- std . debug .print ("╔══════════════════════════════════════════════════════════════╗\n " , .{});
19- std . debug .print ("║ TRINITY CHAT - SIMD Optimized LLM ║\n " , .{});
20- std . debug . print ("║ phi^2 + 1/phi^2 = 3 = TRINITY ║\n " , .{});
21- std . debug .print ("╚══════════════════════════════════════════════════════════════╝\n " , .{});
22- std . debug .print ("\n " , .{});
19+ try stdout .print ("\n " , .{});
20+ try stdout .print ("╔══════════════════════════════════════════════════════════════╗\n " , .{});
21+ try stdout .print ("║ TRINITY CHAT - SIMD Optimized LLM ║\n " , .{});
22+ try stdout . print ("║ Chat Template + Streaming Output ║\n " , .{});
23+ try stdout .print ("╚══════════════════════════════════════════════════════════════╝\n " , .{});
24+ try stdout .print ("\n " , .{});
2325
2426 // Load model
2527 std .debug .print ("Loading model: {s}\n " , .{model_path });
@@ -48,76 +50,121 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
4850 };
4951 defer tokenizer .deinit ();
5052
51- std .debug .print ("Ready! Type your message (or 'quit' to exit):\n\n " , .{});
53+ // Use TinyLlama chat template
54+ const template = ChatTemplate .TINYLLAMA ;
55+ const system_prompt = "You are a helpful AI assistant." ;
56+
57+ std .debug .print ("Chat template: TinyLlama (ChatML format)\n " , .{});
58+ std .debug .print ("System: {s}\n " , .{system_prompt });
59+ std .debug .print ("\n Ready! Type your message (or 'quit' to exit):\n\n " , .{});
60+
61+ // Handle initial prompt if provided
62+ if (initial_prompt ) | prompt | {
63+ try generateWithTemplate (allocator , stdout , & model , & tokenizer , & template , system_prompt , prompt , max_tokens );
64+ }
5265
5366 // Interactive loop
5467 const stdin = std .io .getStdIn ().reader ();
5568 var buf : [1024 ]u8 = undefined ;
5669
5770 while (true ) {
58- std . debug .print ("User: " , .{});
71+ try stdout .print ("User: " , .{});
5972 const line = stdin .readUntilDelimiter (& buf , '\n ' ) catch break ;
6073 const trimmed = std .mem .trim (u8 , line , " \t \r \n " );
6174
6275 if (trimmed .len == 0 ) continue ;
6376 if (std .mem .eql (u8 , trimmed , "quit" ) or std .mem .eql (u8 , trimmed , "exit" )) break ;
6477
65- // Generate response using full transformer forward pass
66- std .debug .print ("Assistant: " , .{});
67- var gen_timer = try std .time .Timer .start ();
68-
69- const tokens = tokenizer .encode (allocator , trimmed ) catch {
70- std .debug .print ("[tokenization error]\n " , .{});
71- continue ;
72- };
73- defer allocator .free (tokens );
74-
75- // Real generation with transformer
76- var generated : u32 = 0 ;
77- var current_tokens = std .ArrayList (u32 ).init (allocator );
78- defer current_tokens .deinit ();
79- for (tokens ) | t | try current_tokens .append (t );
80-
81- const max_gen : u32 = 50 ;
82- while (generated < max_gen ) : (generated += 1 ) {
83- // Forward pass for last token
84- const pos = current_tokens .items .len - 1 ;
85- const last_token = current_tokens .items [pos ];
86-
87- const logits = model .forward (last_token , pos ) catch {
88- std .debug .print ("[forward error]" , .{});
89- break ;
90- };
91- defer allocator .free (logits );
92-
93- // Sample next token (greedy)
94- var max_idx : u32 = 0 ;
95- var max_val : f32 = logits [0 ];
96- for (logits [1.. ], 1.. ) | l , i | {
97- if (l > max_val ) {
98- max_val = l ;
99- max_idx = @intCast (i );
100- }
101- }
78+ try generateWithTemplate (allocator , stdout , & model , & tokenizer , & template , system_prompt , trimmed , max_tokens );
79+ }
10280
103- // Check for EOS
104- if ( max_idx == tokenizer . eos_token ) break ;
81+ try stdout . print ( "Goodbye! \n " , .{});
82+ }
10583
106- // Decode and print
107- const decoded = tokenizer .decode (allocator , &[_ ]u32 {max_idx }) catch " " ;
108- defer if (decoded .len > 0 ) allocator .free (decoded );
109- std .debug .print ("{s}" , .{decoded });
84+ // Generate response with chat template and streaming output
85+ fn generateWithTemplate (
86+ allocator : std.mem.Allocator ,
87+ writer : anytype ,
88+ model : * model_mod.FullModel ,
89+ tokenizer : * tokenizer_mod.Tokenizer ,
90+ template : * const ChatTemplate ,
91+ system : []const u8 ,
92+ user_input : []const u8 ,
93+ max_tokens : u32 ,
94+ ) ! void {
95+ // Format prompt with chat template
96+ const formatted = try template .formatPrompt (allocator , system , user_input );
97+ defer allocator .free (formatted );
98+
99+ try writer .print ("Assistant: " , .{});
100+ var gen_timer = try std .time .Timer .start ();
101+
102+ // Tokenize formatted prompt
103+ const tokens = tokenizer .encode (allocator , formatted ) catch {
104+ try writer .print ("[tokenization error]\n " , .{});
105+ return ;
106+ };
107+ defer allocator .free (tokens );
108+
109+ // Reset KV cache for new conversation
110+ model .resetKVCache ();
111+
112+ // Process prompt tokens (prefill) - build up KV cache
113+ var last_logits : ? []f32 = null ;
114+ for (tokens , 0.. ) | token , pos | {
115+ if (last_logits ) | l | allocator .free (l );
116+ last_logits = model .forward (token , pos ) catch {
117+ try writer .print ("[forward error]\n " , .{});
118+ return ;
119+ };
120+ }
110121
111- try current_tokens .append (max_idx );
122+ // Generate tokens with streaming output
123+ var generated : u32 = 0 ;
124+ var current_pos = tokens .len ;
125+
126+ // Use logits from last prefill token for first generation
127+ var current_logits = last_logits orelse return ;
128+ var last_token : u32 = 0 ;
129+
130+ while (generated < max_tokens ) : (generated += 1 ) {
131+ // Sample next token (greedy)
132+ var max_idx : u32 = 0 ;
133+ var max_val : f32 = current_logits [0 ];
134+ for (current_logits [1.. ], 1.. ) | l , i | {
135+ if (l > max_val ) {
136+ max_val = l ;
137+ max_idx = @intCast (i );
138+ }
112139 }
113- std .debug .print ("\n " , .{});
114140
115- const gen_time = gen_timer .read ();
116- const tok_per_sec = @as (f64 , @floatFromInt (generated )) / (@as (f64 , @floatFromInt (gen_time )) / 1e9 );
117- std .debug .print ("[{d} tokens, {d:.1} tok/s]\n\n " , .{ generated , tok_per_sec });
141+ // Free current logits
142+ allocator .free (current_logits );
143+
144+ // Check for EOS
145+ if (max_idx == tokenizer .eos_token ) break ;
146+
147+ // Decode and stream output immediately
148+ const decoded = tokenizer .decode (allocator , &[_ ]u32 {max_idx }) catch " " ;
149+ defer if (decoded .len > 0 ) allocator .free (decoded );
150+
151+ // Stream: print immediately without buffering
152+ try writer .print ("{s}" , .{decoded });
153+
154+ // Check for </s> or end markers in decoded text
155+ if (std .mem .indexOf (u8 , decoded , "</s>" ) != null ) break ;
156+ if (std .mem .indexOf (u8 , decoded , "<|" ) != null ) break ;
157+
158+ // Get next logits
159+ last_token = max_idx ;
160+ current_logits = model .forward (last_token , current_pos ) catch break ;
161+ current_pos += 1 ;
118162 }
163+ try writer .print ("\n " , .{});
119164
120- std .debug .print ("Goodbye!\n " , .{});
165+ const gen_time = gen_timer .read ();
166+ const tok_per_sec = @as (f64 , @floatFromInt (generated )) / (@as (f64 , @floatFromInt (gen_time )) / 1e9 );
167+ try writer .print ("[{d} tokens, {d:.1} tok/s]\n\n " , .{ generated , tok_per_sec });
121168}
122169
123170pub fn main () ! void {
0 commit comments