Skip to content

Commit e8265c9

Browse files
gHashTagona-agent
andcommitted
Add temperature + top-p (nucleus) sampling
- Temperature scaling for logits diversity - Top-p nucleus sampling (sorted probability cutoff) - CLI params: --temperature (default 0.7), --top-p (default 0.9) - SamplingParams struct with temperature, top_p, top_k, repeat_penalty Result: Model now gives diverse, relevant responses instead of repetitive 'You are a person...' output. Co-authored-by: Ona <no-reply@ona.com>
1 parent cb7358e commit e8265c9

4 files changed

Lines changed: 168 additions & 17 deletions

File tree

bin/vibee

190 KB
Binary file not shown.

src/vibeec/gen_cmd.zig

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ pub fn main() !void {
4444
var model_path: ?[]const u8 = null;
4545
var prompt: ?[]const u8 = null;
4646
var max_tokens: u32 = 100;
47+
var temperature: f32 = 0.7;
48+
var top_p: f32 = 0.9;
4749

4850
var i: usize = 2;
4951
while (i < args.len) : (i += 1) {
@@ -56,6 +58,12 @@ pub fn main() !void {
5658
} else if (std.mem.eql(u8, args[i], "--max-tokens") and i + 1 < args.len) {
5759
max_tokens = std.fmt.parseInt(u32, args[i + 1], 10) catch 100;
5860
i += 1;
61+
} else if (std.mem.eql(u8, args[i], "--temperature") and i + 1 < args.len) {
62+
temperature = std.fmt.parseFloat(f32, args[i + 1]) catch 0.7;
63+
i += 1;
64+
} else if (std.mem.eql(u8, args[i], "--top-p") and i + 1 < args.len) {
65+
top_p = std.fmt.parseFloat(f32, args[i + 1]) catch 0.9;
66+
i += 1;
5967
}
6068
}
6169

@@ -64,7 +72,7 @@ pub fn main() !void {
6472
return;
6573
}
6674

67-
try gguf_chat.runChat(allocator, model_path.?, prompt, max_tokens);
75+
try gguf_chat.runChat(allocator, model_path.?, prompt, max_tokens, temperature, top_p);
6876
} else if (std.mem.eql(u8, command, "help") or std.mem.eql(u8, command, "--help")) {
6977
printUsage();
7078
} else {
@@ -86,6 +94,8 @@ fn printUsage() void {
8694
\\ vibeec chat --model <path.gguf> [options] Chat with GGUF model (SIMD optimized)
8795
\\ --prompt "text" Initial prompt
8896
\\ --max-tokens N Max tokens to generate (default: 100)
97+
\\ --temperature F Sampling temperature (default: 0.7)
98+
\\ --top-p F Top-p nucleus sampling (default: 0.9)
8999
\\ vibeec help Show this help
90100
\\
91101
, .{});

src/vibeec/gguf_chat.zig

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,28 @@ const inference = @import("gguf_inference.zig");
1212
// Chat template for formatting prompts
1313
const ChatTemplate = tokenizer_mod.ChatTemplate;
1414

15+
// Sampling parameters struct
16+
const SamplingParams = inference.SamplingParams;
17+
1518
// Entry point for CLI chat command
16-
pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32) !void {
19+
pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_prompt: ?[]const u8, max_tokens: u32, temperature: f32, top_p: f32) !void {
1720
const stdout = std.io.getStdOut().writer();
1821

1922
try stdout.print("\n", .{});
2023
try stdout.print("╔══════════════════════════════════════════════════════════════╗\n", .{});
2124
try stdout.print("║ TRINITY CHAT - SIMD Optimized LLM ║\n", .{});
22-
try stdout.print("║ Chat Template + Streaming Output\n", .{});
25+
try stdout.print("║ Temperature + Top-p Sampling \n", .{});
2326
try stdout.print("╚══════════════════════════════════════════════════════════════╝\n", .{});
2427
try stdout.print("\n", .{});
2528

29+
// Create sampling params
30+
const sampling_params = SamplingParams{
31+
.temperature = temperature,
32+
.top_p = top_p,
33+
.top_k = 40,
34+
.repeat_penalty = 1.1,
35+
};
36+
2637
// Load model
2738
std.debug.print("Loading model: {s}\n", .{model_path});
2839
var model = model_mod.FullModel.init(allocator, model_path) catch |err| {
@@ -56,11 +67,12 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
5667

5768
std.debug.print("Chat template: TinyLlama (ChatML format)\n", .{});
5869
std.debug.print("System: {s}\n", .{system_prompt});
70+
std.debug.print("Sampling: temperature={d:.2}, top_p={d:.2}\n", .{sampling_params.temperature, sampling_params.top_p});
5971
std.debug.print("\nReady! Type your message (or 'quit' to exit):\n\n", .{});
6072

6173
// Handle initial prompt if provided
6274
if (initial_prompt) |prompt| {
63-
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, prompt, max_tokens);
75+
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, prompt, max_tokens, sampling_params);
6476
}
6577

6678
// Interactive loop
@@ -75,7 +87,7 @@ pub fn runChat(allocator: std.mem.Allocator, model_path: []const u8, initial_pro
7587
if (trimmed.len == 0) continue;
7688
if (std.mem.eql(u8, trimmed, "quit") or std.mem.eql(u8, trimmed, "exit")) break;
7789

78-
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, trimmed, max_tokens);
90+
try generateWithTemplate(allocator, stdout, &model, &tokenizer, &template, system_prompt, trimmed, max_tokens, sampling_params);
7991
}
8092

8193
try stdout.print("Goodbye!\n", .{});
@@ -91,6 +103,7 @@ fn generateWithTemplate(
91103
system: []const u8,
92104
user_input: []const u8,
93105
max_tokens: u32,
106+
params: SamplingParams,
94107
) !void {
95108
// Format prompt with chat template
96109
const formatted = try template.formatPrompt(allocator, system, user_input);
@@ -128,24 +141,28 @@ fn generateWithTemplate(
128141
var last_token: u32 = 0;
129142

130143
while (generated < max_tokens) : (generated += 1) {
131-
// Sample next token (greedy)
132-
var max_idx: u32 = 0;
133-
var max_val: f32 = current_logits[0];
134-
for (current_logits[1..], 1..) |l, i| {
135-
if (l > max_val) {
136-
max_val = l;
137-
max_idx = @intCast(i);
144+
// Sample next token with temperature + top-p
145+
const sampled_token = inference.sampleWithParams(allocator, current_logits, params) catch blk: {
146+
// Fallback to greedy on error
147+
var max_idx: u32 = 0;
148+
var max_val: f32 = current_logits[0];
149+
for (current_logits[1..], 1..) |l, i| {
150+
if (l > max_val) {
151+
max_val = l;
152+
max_idx = @intCast(i);
153+
}
138154
}
139-
}
155+
break :blk max_idx;
156+
};
140157

141158
// Free current logits
142159
allocator.free(current_logits);
143160

144161
// Check for EOS
145-
if (max_idx == tokenizer.eos_token) break;
162+
if (sampled_token == tokenizer.eos_token) break;
146163

147164
// Decode and stream output immediately
148-
const decoded = tokenizer.decode(allocator, &[_]u32{max_idx}) catch " ";
165+
const decoded = tokenizer.decode(allocator, &[_]u32{sampled_token}) catch " ";
149166
defer if (decoded.len > 0) allocator.free(decoded);
150167

151168
// Stream: print immediately without buffering
@@ -156,7 +173,7 @@ fn generateWithTemplate(
156173
if (std.mem.indexOf(u8, decoded, "<|") != null) break;
157174

158175
// Get next logits
159-
last_token = max_idx;
176+
last_token = sampled_token;
160177
current_logits = model.forward(last_token, current_pos) catch break;
161178
current_pos += 1;
162179
}

src/vibeec/gguf_inference.zig

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ pub fn softmax(output: []f32, input: []const f32) void {
144144
}
145145
}
146146

147-
// Sample from probability distribution
147+
// Sample from probability distribution (basic)
148148
pub fn sample(probs: []const f32, temperature: f32) u32 {
149149
if (temperature == 0.0) {
150150
// Greedy sampling
@@ -174,6 +174,130 @@ pub fn sample(probs: []const f32, temperature: f32) u32 {
174174
return @intCast(probs.len - 1);
175175
}
176176

177+
// ═══════════════════════════════════════════════════════════════════════════════
178+
// ADVANCED SAMPLING - Temperature + Top-p (Nucleus) Sampling
179+
// ═══════════════════════════════════════════════════════════════════════════════
180+
181+
/// Sampling parameters
182+
pub const SamplingParams = struct {
183+
temperature: f32 = 0.7,
184+
top_p: f32 = 0.9,
185+
top_k: u32 = 40,
186+
repeat_penalty: f32 = 1.1,
187+
};
188+
189+
/// Apply temperature scaling to logits
190+
pub fn applyTemperature(logits: []f32, temperature: f32) void {
191+
if (temperature <= 0.0 or temperature == 1.0) return;
192+
193+
const inv_temp = 1.0 / temperature;
194+
for (logits) |*l| {
195+
l.* *= inv_temp;
196+
}
197+
}
198+
199+
/// Sample with temperature and top-p (nucleus sampling)
200+
/// Returns token index
201+
pub fn sampleWithParams(allocator: std.mem.Allocator, logits: []f32, params: SamplingParams) !u32 {
202+
const n = logits.len;
203+
204+
// Apply temperature
205+
if (params.temperature > 0.0 and params.temperature != 1.0) {
206+
applyTemperature(logits, params.temperature);
207+
}
208+
209+
// Greedy if temperature is 0
210+
if (params.temperature == 0.0) {
211+
var max_idx: u32 = 0;
212+
var max_val: f32 = logits[0];
213+
for (logits[1..], 1..) |l, i| {
214+
if (l > max_val) {
215+
max_val = l;
216+
max_idx = @intCast(i);
217+
}
218+
}
219+
return max_idx;
220+
}
221+
222+
// Convert to probabilities with softmax
223+
var max_logit: f32 = logits[0];
224+
for (logits[1..]) |l| {
225+
if (l > max_logit) max_logit = l;
226+
}
227+
228+
var sum: f32 = 0.0;
229+
for (logits) |*l| {
230+
l.* = @exp(l.* - max_logit);
231+
sum += l.*;
232+
}
233+
234+
const inv_sum = 1.0 / sum;
235+
for (logits) |*l| {
236+
l.* *= inv_sum;
237+
}
238+
239+
// Top-p (nucleus) sampling
240+
if (params.top_p < 1.0) {
241+
// Create index array for sorting
242+
const indices = try allocator.alloc(u32, n);
243+
defer allocator.free(indices);
244+
for (indices, 0..) |*idx, i| {
245+
idx.* = @intCast(i);
246+
}
247+
248+
// Sort indices by probability (descending)
249+
std.mem.sort(u32, indices, logits, struct {
250+
fn lessThan(probs: []f32, a: u32, b: u32) bool {
251+
return probs[a] > probs[b]; // Descending
252+
}
253+
}.lessThan);
254+
255+
// Find cutoff for top-p
256+
var cumsum: f32 = 0.0;
257+
var cutoff_idx: usize = n;
258+
for (indices, 0..) |idx, i| {
259+
cumsum += logits[idx];
260+
if (cumsum >= params.top_p) {
261+
cutoff_idx = i + 1;
262+
break;
263+
}
264+
}
265+
266+
// Zero out tokens below cutoff
267+
for (indices[cutoff_idx..]) |idx| {
268+
logits[idx] = 0.0;
269+
}
270+
271+
// Renormalize
272+
sum = 0.0;
273+
for (logits) |l| {
274+
sum += l;
275+
}
276+
if (sum > 0.0) {
277+
const inv = 1.0 / sum;
278+
for (logits) |*l| {
279+
l.* *= inv;
280+
}
281+
}
282+
}
283+
284+
// Sample from distribution
285+
var prng = std.Random.DefaultPrng.init(@intCast(std.time.milliTimestamp()));
286+
const random = prng.random();
287+
const r = random.float(f32);
288+
289+
var cumsum: f32 = 0.0;
290+
for (logits, 0..) |p, i| {
291+
cumsum += p;
292+
if (r < cumsum) {
293+
return @intCast(i);
294+
}
295+
}
296+
297+
// Fallback to last token
298+
return @intCast(n - 1);
299+
}
300+
177301
// GGUF Model for inference
178302
pub const GGUFModel = struct {
179303
allocator: std.mem.Allocator,

0 commit comments

Comments
 (0)