Skip to content

Commit 424fb24

Browse files
authored
Fix stop words (#22)
* Updated README * Updated chat to reflect llama.cpp approach (stop conditions) * Update pkg/llamacpp/chat.go
1 parent bbffeb6 commit 424fb24

6 files changed

Lines changed: 287 additions & 28 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ go.work.sum
2626

2727
# Editor/IDE
2828
.vscode/
29+
.claude
2930

3031
# Build artifacts
3132
build/
3233

3334
# Other
34-
/gollama
35+
/go-llama
3536
.DS*

pkg/llamacpp/chat.go

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -281,25 +281,23 @@ var defaultStopSequences = []string{
281281

282282
type stopMarkerFilter struct {
283283
stops []string
284-
maxLen int
285284
buffer string
286285
stopped bool
287286
}
288287

289288
func newStopMarkerFilter(stops []string) *stopMarkerFilter {
290-
maxLen := 0
291-
for _, s := range stops {
292-
if len(s) > maxLen {
293-
maxLen = len(s)
294-
}
295-
}
296-
return &stopMarkerFilter{stops: stops, maxLen: maxLen}
289+
return &stopMarkerFilter{stops: stops}
297290
}
298291

299292
func (f *stopMarkerFilter) Stopped() bool {
300293
return f.stopped
301294
}
302295

296+
// Process handles incoming tokens using llama.cpp's two-phase approach:
297+
// 1. Check for full stop sequences
298+
// 2. If no full stop, check if text ends with a prefix of any stop word (partial match)
299+
//
300+
// Returns the text safe to send and whether generation should stop.
303301
func (f *stopMarkerFilter) Process(text string) (string, bool) {
304302
if f.stopped {
305303
return "", true
@@ -309,25 +307,28 @@ func (f *stopMarkerFilter) Process(text string) (string, bool) {
309307
}
310308

311309
combined := f.buffer + text
312-
idx, found := indexAnyStop(combined, f.stops)
313-
if found {
310+
311+
// Phase 1: Check for full stop sequence
312+
if idx, found := indexAnyStop(combined, f.stops); found {
314313
f.stopped = true
314+
f.buffer = ""
315315
return combined[:idx], true
316316
}
317317

318-
if f.maxLen <= 1 {
319-
f.buffer = ""
320-
return combined, false
321-
}
322-
keep := f.maxLen - 1
323-
if len(combined) <= keep {
324-
f.buffer = combined
325-
return "", false
318+
// Phase 2: Check if text ends with a partial stop sequence
319+
// If so, withhold the partial match in the buffer
320+
if partialPos := findPartialStop(combined, f.stops); partialPos != -1 {
321+
// Text ends with a prefix of a stop word - withhold from partialPos onward
322+
f.buffer = combined[partialPos:]
323+
if partialPos == 0 {
324+
return "", false
325+
}
326+
return combined[:partialPos], false
326327
}
327-
cut := len(combined) - keep
328-
out := combined[:cut]
329-
f.buffer = combined[cut:]
330-
return out, false
328+
329+
// No full or partial match - safe to send everything
330+
f.buffer = ""
331+
return combined, false
331332
}
332333

333334
func (f *stopMarkerFilter) Flush() string {
@@ -339,6 +340,49 @@ func (f *stopMarkerFilter) Flush() string {
339340
return out
340341
}
341342

343+
// findPartialStop checks if s ends with a prefix of any stop word.
344+
// Returns the position where the partial match starts, or -1 if no partial match.
345+
// This mirrors llama.cpp's string_find_partial_stop function.
346+
func findPartialStop(s string, stops []string) int {
347+
if len(s) == 0 {
348+
return -1
349+
}
350+
351+
lastChar := s[len(s)-1]
352+
353+
for _, stop := range stops {
354+
if len(stop) == 0 {
355+
continue
356+
}
357+
358+
// Check each possible prefix of the stop word, starting from longest.
359+
maxPrefixLen := 0
360+
361+
for _, stop := range stops {
362+
if len(stop) == 0 {
363+
continue
364+
}
365+
366+
// Check each possible prefix of the stop word, starting from longest.
367+
// We look for prefixes that end with the last character of s.
368+
for prefixLen := len(stop); prefixLen >= 1; prefixLen-- {
369+
if stop[prefixLen-1] != lastChar {
370+
continue
371+
}
372+
373+
prefix := stop[:prefixLen]
374+
if strings.HasSuffix(s, prefix) && prefixLen > maxPrefixLen {
375+
maxPrefixLen = prefixLen
376+
}
377+
}
378+
}
379+
380+
if maxPrefixLen > 0 {
381+
return len(s) - maxPrefixLen
382+
}
383+
return -1
384+
}
385+
342386
func indexAnyStop(s string, stops []string) (int, bool) {
343387
idx := -1
344388
for _, stop := range stops {

pkg/llamacpp/chat_test.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,197 @@ func TestChatLogsOutput(t *testing.T) {
150150
t.Logf("chat response: %q", resp.Message.Content)
151151
t.Logf("usage: input=%d output=%d total=%d", resp.Usage.InputTokens, resp.Usage.OutputTokens, resp.Usage.TotalTokens())
152152
}
153+
154+
func TestFindPartialStop(t *testing.T) {
155+
tests := []struct {
156+
name string
157+
text string
158+
stops []string
159+
expected int
160+
}{
161+
{
162+
name: "no partial match",
163+
text: "hello world",
164+
stops: []string{"</s>", "<|end|>"},
165+
expected: -1,
166+
},
167+
{
168+
name: "full stop not detected as partial",
169+
text: "hello</s>",
170+
stops: []string{"</s>"},
171+
expected: 5, // partial match starts at position 5
172+
},
173+
{
174+
name: "partial match single char",
175+
text: "hello<",
176+
stops: []string{"</s>"},
177+
expected: 5,
178+
},
179+
{
180+
name: "partial match two chars",
181+
text: "hello</",
182+
stops: []string{"</s>"},
183+
expected: 5,
184+
},
185+
{
186+
name: "partial match three chars",
187+
text: "hello</s",
188+
stops: []string{"</s>"},
189+
expected: 5,
190+
},
191+
{
192+
name: "partial match with pipe",
193+
text: "test<|",
194+
stops: []string{"<|end|>"},
195+
expected: 4,
196+
},
197+
{
198+
name: "partial match longer prefix",
199+
text: "test<|end",
200+
stops: []string{"<|end|>"},
201+
expected: 4,
202+
},
203+
{
204+
name: "partial match almost complete",
205+
text: "test<|end|",
206+
stops: []string{"<|end|>"},
207+
expected: 4,
208+
},
209+
{
210+
name: "multiple stops first matches",
211+
text: "hello<",
212+
stops: []string{"</s>", "<|end|>"},
213+
expected: 5,
214+
},
215+
{
216+
name: "empty text",
217+
text: "",
218+
stops: []string{"</s>"},
219+
expected: -1,
220+
},
221+
{
222+
name: "empty stops",
223+
text: "hello",
224+
stops: []string{},
225+
expected: -1,
226+
},
227+
{
228+
name: "text is just partial",
229+
text: "</",
230+
stops: []string{"</s>"},
231+
expected: 0,
232+
},
233+
{
234+
name: "no match when char not in stop",
235+
text: "hellox",
236+
stops: []string{"</s>"},
237+
expected: -1,
238+
},
239+
}
240+
241+
for _, tt := range tests {
242+
t.Run(tt.name, func(t *testing.T) {
243+
result := findPartialStop(tt.text, tt.stops)
244+
if result != tt.expected {
245+
t.Errorf("findPartialStop(%q, %v) = %d, want %d", tt.text, tt.stops, result, tt.expected)
246+
}
247+
})
248+
}
249+
}
250+
251+
func TestStopMarkerFilter(t *testing.T) {
252+
tests := []struct {
253+
name string
254+
stops []string
255+
tokens []string
256+
expectedChunks []string
257+
expectedStop bool
258+
}{
259+
{
260+
name: "no stop sequences",
261+
stops: []string{"</s>"},
262+
tokens: []string{"hello", " ", "world"},
263+
expectedChunks: []string{"hello", " ", "world"},
264+
expectedStop: false,
265+
},
266+
{
267+
name: "full stop in single token",
268+
stops: []string{"</s>"},
269+
tokens: []string{"hello", "</s>", "extra"},
270+
expectedChunks: []string{"hello", ""},
271+
expectedStop: true,
272+
},
273+
{
274+
name: "stop split across tokens",
275+
stops: []string{"</s>"},
276+
tokens: []string{"hello<", "/s>more"},
277+
expectedChunks: []string{"hello", ""},
278+
expectedStop: true,
279+
},
280+
{
281+
name: "partial withholding then release",
282+
stops: []string{"</s>"},
283+
tokens: []string{"hello<", "notastop"},
284+
expectedChunks: []string{"hello", "<notastop"},
285+
expectedStop: false,
286+
},
287+
{
288+
name: "partial withholding with longer sequence",
289+
stops: []string{"<|end|>"},
290+
tokens: []string{"test<|", "end|>done"},
291+
expectedChunks: []string{"test", ""},
292+
expectedStop: true,
293+
},
294+
{
295+
name: "buffer accumulates partial",
296+
stops: []string{"</s>"},
297+
tokens: []string{"a<", "/", "s>"},
298+
expectedChunks: []string{"a", "", ""},
299+
expectedStop: true,
300+
},
301+
{
302+
name: "stop at very beginning",
303+
stops: []string{"</s>"},
304+
tokens: []string{"</s>rest"},
305+
expectedChunks: []string{""},
306+
expectedStop: true,
307+
},
308+
}
309+
310+
for _, tt := range tests {
311+
t.Run(tt.name, func(t *testing.T) {
312+
filter := newStopMarkerFilter(tt.stops)
313+
var chunks []string
314+
315+
for _, token := range tt.tokens {
316+
chunk, stopped := filter.Process(token)
317+
chunks = append(chunks, chunk)
318+
if stopped {
319+
break
320+
}
321+
}
322+
323+
// If not stopped, flush remaining buffer
324+
if !filter.Stopped() {
325+
if tail := filter.Flush(); tail != "" {
326+
chunks[len(chunks)-1] += tail
327+
}
328+
}
329+
330+
if len(chunks) != len(tt.expectedChunks) {
331+
t.Errorf("got %d chunks %v, want %d chunks %v", len(chunks), chunks, len(tt.expectedChunks), tt.expectedChunks)
332+
return
333+
}
334+
335+
for i, chunk := range chunks {
336+
if chunk != tt.expectedChunks[i] {
337+
t.Errorf("chunk[%d] = %q, want %q", i, chunk, tt.expectedChunks[i])
338+
}
339+
}
340+
341+
if filter.Stopped() != tt.expectedStop {
342+
t.Errorf("Stopped() = %v, want %v", filter.Stopped(), tt.expectedStop)
343+
}
344+
})
345+
}
346+
}

pkg/llamacpp/completion.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,21 @@ func (l *Llama) Complete(ctx context.Context, req schema.CompletionRequest, onCh
3434

3535
opts := buildCompletionOptions(ctx, req)
3636
var callbackErr error
37+
var stopFilter *stopMarkerFilter
3738
if onChunk != nil {
39+
stopFilter = newStopMarkerFilter(opts.StopWords)
3840
opts.OnToken = func(token string) bool {
3941
if callbackErr != nil {
4042
return false
4143
}
42-
if err := onChunk(schema.CompletionChunk{Text: token}); err != nil {
43-
callbackErr = err
44+
filtered, stopped := stopFilter.Process(token)
45+
if filtered != "" {
46+
if err := onChunk(schema.CompletionChunk{Text: filtered}); err != nil {
47+
callbackErr = err
48+
return false
49+
}
50+
}
51+
if stopped {
4452
return false
4553
}
4654
return true
@@ -55,6 +63,18 @@ func (l *Llama) Complete(ctx context.Context, req schema.CompletionRequest, onCh
5563
return callbackErr
5664
}
5765

66+
// Flush any buffered content if we didn't hit a stop
67+
if onChunk != nil && stopFilter != nil && !stopFilter.Stopped() {
68+
if tail := stopFilter.Flush(); tail != "" {
69+
if err := onChunk(schema.CompletionChunk{Text: tail}); err != nil {
70+
return err
71+
}
72+
}
73+
}
74+
75+
// Trim stop sequences from final text
76+
text, _ = trimAtStop(text, opts.StopWords)
77+
5878
usage, err := completionUsage(task.Model(), req.Prompt, text)
5979
if err != nil {
6080
return err

pkg/llamacpp/httphandler/model.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,8 @@ func modelLoadUnload(w http.ResponseWriter, r *http.Request, llamaInstance *llam
143143
var req schema.LoadModelRequest
144144
if err := httprequest.Read(r, &req); err != nil {
145145
return httpresponse.Error(w, httpresponse.ErrBadRequest.With(err.Error()))
146-
} else {
147-
req.Name = r.PathValue("id")
148146
}
147+
req.Name = r.PathValue("id")
149148

150149
// Check if this is an unload request
151150
isUnload := req.Load != nil && !*req.Load

pkg/llamacpp/httphandler/model_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ func TestModelLoad_NonExistentModel(t *testing.T) {
235235
router := http.NewServeMux()
236236
RegisterModelHandlers(router, "/api", llama, noopMiddleware())
237237

238-
req := httptest.NewRequest(http.MethodPost, "/api/model/nonexistent", nil)
238+
req := httptest.NewRequest(http.MethodPost, "/api/model/nonexistent", strings.NewReader("{}"))
239+
req.Header.Set("Content-Type", "application/json")
239240
rw := httptest.NewRecorder()
240241

241242
router.ServeHTTP(rw, req)

0 commit comments

Comments
 (0)