@@ -185,6 +185,15 @@ pub const HttpServer = struct {
185185 }
186186
187187 fn handleChatCompletion (self : * HttpServer , connection : * std.net.Server.Connection , body : []const u8 , model : * FullModel , tokenizer : * Tokenizer ) ! void {
188+ // Check if streaming is requested
189+ const is_streaming = std .mem .indexOf (u8 , body , "\" stream\" :true" ) != null or
190+ std .mem .indexOf (u8 , body , "\" stream\" : true" ) != null ;
191+
192+ if (is_streaming ) {
193+ try self .handleStreamingCompletion (connection , body , model , tokenizer );
194+ return ;
195+ }
196+
188197 // Extract prompt from JSON body
189198 var prompt : []const u8 = "Hello" ;
190199
@@ -201,7 +210,7 @@ pub const HttpServer = struct {
201210
202211 std .debug .print (" Prompt: {s}\n " , .{prompt });
203212
204- // Generate response
213+ // Generate response with system prompt for better quality
205214 const sampling = SamplingParams {
206215 .temperature = 0.7 ,
207216 .top_p = 0.9 ,
@@ -213,8 +222,16 @@ pub const HttpServer = struct {
213222 var generated : ? []u8 = null ;
214223 defer if (generated ) | g | self .allocator .free (g );
215224
225+ // Build full prompt with system instruction
226+ const system_prompt = "You are TRINITY, a helpful AI assistant. Be concise and direct." ;
227+ const full_prompt = std .fmt .allocPrint (self .allocator ,
228+ "<|system|>\n {s}<|end|>\n <|user|>\n {s}<|end|>\n <|assistant|>\n " ,
229+ .{system_prompt , prompt }
230+ ) catch prompt ;
231+ defer if (full_prompt .ptr != prompt .ptr ) self .allocator .free (full_prompt );
232+
216233 // Tokenize and generate
217- const tokens = tokenizer .encode (self .allocator , prompt ) catch null ;
234+ const tokens = tokenizer .encode (self .allocator , full_prompt ) catch null ;
218235 defer if (tokens ) | t | self .allocator .free (t );
219236
220237 if (tokens ) | toks | {
@@ -282,6 +299,105 @@ pub const HttpServer = struct {
282299 try connection .stream .writeAll (json_body );
283300 std .debug .print (" Sent: {d} bytes\n " , .{json_body .len });
284301 }
302+
303+ /// Handle streaming chat completion (SSE)
304+ fn handleStreamingCompletion (self : * HttpServer , connection : * std.net.Server.Connection , body : []const u8 , model : * FullModel , tokenizer : * Tokenizer ) ! void {
305+ // Extract prompt
306+ var prompt : []const u8 = "Hello" ;
307+ if (std .mem .lastIndexOf (u8 , body , "\" content\" " )) | idx | {
308+ const after_key = body [idx + 10.. ];
309+ if (std .mem .indexOf (u8 , after_key , "\" " )) | start | {
310+ const content_start = after_key [start + 1.. ];
311+ if (std .mem .indexOf (u8 , content_start , "\" " )) | end | {
312+ prompt = content_start [0.. end ];
313+ }
314+ }
315+ }
316+
317+ std .debug .print (" Streaming prompt: {s}\n " , .{prompt });
318+
319+ // Send SSE headers
320+ const sse_header =
321+ "HTTP/1.1 200 OK\r \n " ++
322+ "Content-Type: text/event-stream\r \n " ++
323+ "Cache-Control: no-cache\r \n " ++
324+ "Access-Control-Allow-Origin: *\r \n " ++
325+ "Connection: keep-alive\r \n \r \n " ;
326+ try connection .stream .writeAll (sse_header );
327+
328+ // Build prompt with system instruction
329+ const system_prompt = "You are TRINITY, a helpful AI assistant. Be concise and direct." ;
330+ const full_prompt = std .fmt .allocPrint (self .allocator ,
331+ "<|system|>\n {s}<|end|>\n <|user|>\n {s}<|end|>\n <|assistant|>\n " ,
332+ .{system_prompt , prompt }
333+ ) catch prompt ;
334+ defer if (full_prompt .ptr != prompt .ptr ) self .allocator .free (full_prompt );
335+
336+ // Tokenize
337+ const tokens = tokenizer .encode (self .allocator , full_prompt ) catch null ;
338+ defer if (tokens ) | t | self .allocator .free (t );
339+
340+ const sampling = SamplingParams {
341+ .temperature = 0.7 ,
342+ .top_p = 0.9 ,
343+ .top_k = 40 ,
344+ .repeat_penalty = 1.1 ,
345+ };
346+
347+ if (tokens ) | toks | {
348+ // Process input tokens
349+ var pos : usize = 0 ;
350+ for (toks ) | tok | {
351+ _ = model .forward (tok , pos ) catch null ;
352+ pos += 1 ;
353+ }
354+
355+ // Generate and stream tokens
356+ var last_token : u32 = if (toks .len > 0 ) toks [toks .len - 1 ] else 0 ;
357+ var i : usize = 0 ;
358+ while (i < 100 ) : (i += 1 ) {
359+ const logits = model .forward (last_token , pos ) catch break ;
360+ const next_token = inference .sampleWithParams (self .allocator , @constCast (logits ), sampling ) catch break ;
361+
362+ if (next_token == tokenizer .eos_token ) break ;
363+
364+ // Decode single token
365+ const token_arr = [_ ]u32 {next_token };
366+ const token_text = tokenizer .decode (self .allocator , & token_arr ) catch null ;
367+ defer if (token_text ) | t | self .allocator .free (t );
368+
369+ if (token_text ) | text | {
370+ // Escape for JSON
371+ var escaped = std .ArrayList (u8 ).init (self .allocator );
372+ defer escaped .deinit ();
373+ for (text ) | c | {
374+ switch (c ) {
375+ '"' = > escaped .appendSlice ("\\ \" " ) catch break ,
376+ '\\ ' = > escaped .appendSlice ("\\\\ " ) catch break ,
377+ '\n ' = > escaped .appendSlice ("\\ n" ) catch break ,
378+ '\r ' = > escaped .appendSlice ("\\ r" ) catch break ,
379+ else = > escaped .append (c ) catch break ,
380+ }
381+ }
382+
383+ // Send SSE event
384+ const event = std .fmt .allocPrint (self .allocator ,
385+ "data: {{\" choices\" :[{{\" delta\" :{{\" content\" :\" {s}\" }},\" index\" :0}}]}}\n\n "
386+ , .{escaped .items }) catch continue ;
387+ defer self .allocator .free (event );
388+
389+ connection .stream .writeAll (event ) catch break ;
390+ }
391+
392+ last_token = next_token ;
393+ pos += 1 ;
394+ }
395+ }
396+
397+ // Send done event
398+ try connection .stream .writeAll ("data: [DONE]\n\n " );
399+ std .debug .print (" Streaming complete\n " , .{});
400+ }
285401};
286402
287403// ═══════════════════════════════════════════════════════════════════════════════
0 commit comments