Skip to content

Commit e6f04fc

Browse files
committed
feat(engine): add LLM query-rewrite-before-retrieval strategy
Add a KeygenLLM interface (mirroring the ConsolidationLLM idiom that keeps yaad free of any direct LLM dependency) and RewriteQuery / Engine.RecallWithKeygen, which let a caller turn a natural-language query into tight search keywords with an LLM before recall. Three strategies: KeygenSynonyms (default, reuses the existing ExpandQuery synonym map, no LLM), KeygenLLMThenSynonyms (union+dedup), and KeygenLLMOnly. The step is best-effort: a nil LLM or any LLM error degrades gracefully to local synonym expansion rather than failing the recall.
1 parent ec3acdd commit e6f04fc

2 files changed

Lines changed: 177 additions & 0 deletions

File tree

engine/query_keygen.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package engine
2+
3+
import (
4+
"context"
5+
"strings"
6+
)
7+
8+
// KeygenLLM is the interface a caller implements to let yaad rewrite a natural-
9+
// language query into retrieval keywords with an LLM before searching. Like
10+
// ConsolidationLLM, it keeps yaad free of any direct LLM dependency — the engine
11+
// only ever sees this small interface.
12+
//
13+
// GenerateKeywords receives the user's raw query and returns a space-separated
14+
// (or newline-separated) set of keywords/phrases to search for. Returning an
15+
// empty string or an error signals "no rewrite"; the caller falls back to the
16+
// query unchanged (after synonym expansion).
17+
type KeygenLLM interface {
18+
GenerateKeywords(ctx context.Context, query string) (string, error)
19+
}
20+
21+
// KeygenStrategy controls how RewriteQuery turns a raw query into search terms.
22+
type KeygenStrategy int
23+
24+
const (
25+
// KeygenSynonyms uses only the built-in dependency-free synonym expansion
26+
// (ExpandQuery). This is the default and requires no LLM.
27+
KeygenSynonyms KeygenStrategy = iota
28+
// KeygenLLMThenSynonyms asks the LLM to extract keywords first, then unions
29+
// the result with synonym expansion of the original query. This mirrors the
30+
// "generate keywords before retrieval" pattern: separate the LLM-driven
31+
// keyword step from the retrieval step so a verbose question becomes tight
32+
// search terms without the retriever ever seeing the prose.
33+
KeygenLLMThenSynonyms
34+
// KeygenLLMOnly uses only the LLM keyword output, falling back to synonym
35+
// expansion if the LLM is absent or returns nothing.
36+
KeygenLLMOnly
37+
)
38+
39+
// RewriteQuery turns a raw user query into the search string used for recall,
40+
// applying the given strategy. llm may be nil; when it is, the function behaves
41+
// exactly as KeygenSynonyms regardless of strategy, so callers can wire an LLM
42+
// optionally without branching. The returned string is always non-empty as long
43+
// as query is non-empty.
44+
//
45+
// Defaults here are yaad's own — deliberately not borrowed from any external
46+
// framework's RAG config — and the whole step is best-effort: any LLM failure
47+
// degrades gracefully to local synonym expansion rather than erroring out of a
48+
// recall.
49+
func RewriteQuery(ctx context.Context, llm KeygenLLM, strategy KeygenStrategy, query string) string {
50+
q := strings.TrimSpace(query)
51+
if q == "" {
52+
return ""
53+
}
54+
55+
synonyms := ExpandQuery(q)
56+
57+
if llm == nil || strategy == KeygenSynonyms {
58+
return synonyms
59+
}
60+
61+
kw, err := llm.GenerateKeywords(ctx, q)
62+
kw = strings.TrimSpace(kw)
63+
if err != nil || kw == "" {
64+
// Best-effort: fall back to local expansion.
65+
return synonyms
66+
}
67+
68+
if strategy == KeygenLLMOnly {
69+
return kw
70+
}
71+
72+
// KeygenLLMThenSynonyms: union the LLM keywords with synonym expansion,
73+
// de-duplicated, preserving first-seen order (LLM terms first).
74+
return unionTerms(kw, synonyms)
75+
}
76+
77+
// RecallWithKeygen rewrites opts.Query via RewriteQuery and then runs Recall on
78+
// the rewritten query, leaving every other RecallOpts field untouched. It is a
79+
// thin convenience over Recall for callers that have an LLM available; callers
80+
// without one can keep using Recall directly.
81+
func (e *Engine) RecallWithKeygen(ctx context.Context, llm KeygenLLM, strategy KeygenStrategy, opts RecallOpts) (*RecallResult, error) {
82+
if opts.Query != "" {
83+
opts.Query = RewriteQuery(ctx, llm, strategy, opts.Query)
84+
}
85+
return e.Recall(ctx, opts)
86+
}
87+
88+
// unionTerms merges two whitespace-separated term lists, lower-casing for the
89+
// dedup key but emitting the first-seen surface form, preserving order.
90+
func unionTerms(a, b string) string {
91+
seen := make(map[string]bool)
92+
var out []string
93+
for _, field := range strings.Fields(a + " " + b) {
94+
key := strings.ToLower(field)
95+
if seen[key] {
96+
continue
97+
}
98+
seen[key] = true
99+
out = append(out, field)
100+
}
101+
return strings.Join(out, " ")
102+
}

engine/query_keygen_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package engine
2+
3+
import (
4+
"context"
5+
"errors"
6+
"strings"
7+
"testing"
8+
)
9+
10+
type fakeKeygen struct {
11+
out string
12+
err error
13+
}
14+
15+
func (f fakeKeygen) GenerateKeywords(_ context.Context, _ string) (string, error) {
16+
return f.out, f.err
17+
}
18+
19+
func TestRewriteQuery_NilLLMUsesSynonyms(t *testing.T) {
20+
got := RewriteQuery(context.Background(), nil, KeygenLLMThenSynonyms, "auth bug")
21+
// ExpandQuery expands "auth" → adds authentication/authorize/login.
22+
if !strings.Contains(got, "auth") || !strings.Contains(got, "authentication") {
23+
t.Errorf("expected synonym expansion, got %q", got)
24+
}
25+
}
26+
27+
func TestRewriteQuery_EmptyQuery(t *testing.T) {
28+
if got := RewriteQuery(context.Background(), fakeKeygen{out: "x"}, KeygenLLMOnly, " "); got != "" {
29+
t.Errorf("empty query should rewrite to empty, got %q", got)
30+
}
31+
}
32+
33+
func TestRewriteQuery_LLMOnly(t *testing.T) {
34+
got := RewriteQuery(context.Background(), fakeKeygen{out: "login session cookie"}, KeygenLLMOnly, "why am I logged out?")
35+
if got != "login session cookie" {
36+
t.Errorf("LLMOnly should return raw keywords, got %q", got)
37+
}
38+
}
39+
40+
func TestRewriteQuery_LLMErrorFallsBack(t *testing.T) {
41+
got := RewriteQuery(context.Background(), fakeKeygen{err: errors.New("boom")}, KeygenLLMOnly, "db config")
42+
// Falls back to synonym expansion of the original query.
43+
if !strings.Contains(got, "database") || !strings.Contains(got, "configuration") {
44+
t.Errorf("expected fallback to synonyms, got %q", got)
45+
}
46+
}
47+
48+
func TestRewriteQuery_LLMThenSynonymsUnionDedup(t *testing.T) {
49+
// LLM returns "auth token"; synonym expansion of "auth" also yields auth/...
50+
got := RewriteQuery(context.Background(), fakeKeygen{out: "auth token"}, KeygenLLMThenSynonyms, "auth")
51+
fields := strings.Fields(got)
52+
seen := map[string]int{}
53+
for _, f := range fields {
54+
seen[strings.ToLower(f)]++
55+
}
56+
for term, n := range seen {
57+
if n > 1 {
58+
t.Errorf("term %q appears %d times, want deduped: %q", term, n, got)
59+
}
60+
}
61+
// LLM terms come first.
62+
if fields[0] != "auth" {
63+
t.Errorf("expected LLM terms first, got %q", got)
64+
}
65+
if !strings.Contains(got, "token") {
66+
t.Errorf("expected LLM keyword 'token' retained, got %q", got)
67+
}
68+
}
69+
70+
func TestRewriteQuery_SynonymsStrategyIgnoresLLM(t *testing.T) {
71+
got := RewriteQuery(context.Background(), fakeKeygen{out: "SHOULD_NOT_APPEAR"}, KeygenSynonyms, "config")
72+
if strings.Contains(got, "SHOULD_NOT_APPEAR") {
73+
t.Errorf("KeygenSynonyms must not call the LLM, got %q", got)
74+
}
75+
}

0 commit comments

Comments
 (0)