Skip to content

Commit 33f47db

Browse files
committed
fix(lib): handle embedding dimension mismatch in Unlimiformer retrieval
Skip mismatched key/query lengths instead of throwing from cosineSimilarity; attentionScore returns NaN for length mismatch; extend tests.
1 parent 071b9d6 commit 33f47db

2 files changed

Lines changed: 71 additions & 10 deletions

File tree

packages/lib/unlimiformer.test.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,27 @@ describe("topKAttentionKeys", () => {
3939
const top = topKAttentionKeys(q, keys, 10)
4040
expect(top).toHaveLength(2)
4141
})
42+
43+
test("skips keys whose dimension does not match the query (no throw)", () => {
44+
const q = unit(1, 0, 0)
45+
const keys = [unit(1, 0, 0), [1, 0], unit(0, 1, 0)]
46+
const top = topKAttentionKeys(q, keys, 5)
47+
expect(top.map((t) => t.index)).toEqual([0, 2])
48+
})
49+
50+
test("returns empty when no key matches query dimension", () => {
51+
const q = unit(1, 0, 0)
52+
expect(
53+
topKAttentionKeys(
54+
q,
55+
[
56+
[1, 0],
57+
[0, 1],
58+
],
59+
3,
60+
),
61+
).toEqual([])
62+
})
4263
})
4364

4465
describe("attentionScores", () => {
@@ -48,6 +69,14 @@ describe("attentionScores", () => {
4869
const scores = attentionScores(q, keys)
4970
expect(scores).toHaveLength(2)
5071
})
72+
73+
test("uses NaN when key dimension mismatches query", () => {
74+
const q = unit(1, 0, 0)
75+
const scores = attentionScores(q, [unit(1, 0, 0), [1, 0]])
76+
expect(scores).toHaveLength(2)
77+
expect(Number.isFinite(scores[0] ?? Number.NaN)).toBe(true)
78+
expect(scores[1]).toBeNaN()
79+
})
5180
})
5281

5382
describe("topKAttentionKeysMultiHead", () => {
@@ -73,4 +102,15 @@ describe("rankItemsByAttentionTopK", () => {
73102
expect(ranked[0]?.item.id).toBe("c")
74103
expect(ranked[0]?.originalIndex).toBe(2)
75104
})
105+
106+
test("skips items whose embedding length does not match the query", () => {
107+
const items = [
108+
{ id: "wide", e: [0.1, 0.2, 0.3, 0.4] },
109+
{ id: "ok", e: unit(0, 1, 0) },
110+
]
111+
const q = unit(0, 1, 0)
112+
const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2)
113+
expect(ranked).toHaveLength(1)
114+
expect(ranked[0]?.item.id).toBe("ok")
115+
})
76116
})

packages/lib/unlimiformer.ts

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,21 @@ export type AttentionTopK = {
1717
/**
1818
* Dot-product attention score between one query vector and one key vector.
1919
* For normalized embeddings this matches cosine similarity.
20+
*
21+
* Returns `NaN` when `query` and `key` have different lengths (e.g. mixed embedding
22+
* models) so callers can avoid throwing from `cosineSimilarity`.
2023
*/
21-
export const attentionScore = (query: number[], key: number[]): number =>
22-
cosineSimilarity(query, key)
24+
export const attentionScore = (query: number[], key: number[]): number => {
25+
if (query.length !== key.length) {
26+
return Number.NaN
27+
}
28+
return cosineSimilarity(query, key)
29+
}
2330

2431
/**
25-
* Attention scores for `query` against every row in `keys` (same dimension as query).
32+
* Attention scores for `query` against every row in `keys`, aligned by index.
33+
* Entries are `NaN` when a key length does not match the query (same embedding model
34+
* is required for a meaningful score).
2635
*/
2736
export const attentionScores = (query: number[], keys: number[][]): number[] =>
2837
keys.map((key) => attentionScore(query, key))
@@ -40,12 +49,19 @@ export const topKAttentionKeys = (
4049
return []
4150
}
4251

43-
const effectiveK = Math.min(k, keys.length)
44-
const scored: AttentionTopK[] = keys.map((key, index) => ({
45-
index,
46-
score: attentionScore(query, key),
47-
}))
52+
const scored: AttentionTopK[] = keys.flatMap((key, index) => {
53+
const score = attentionScore(query, key)
54+
if (!Number.isFinite(score)) {
55+
return []
56+
}
57+
return [{ index, score }]
58+
})
59+
60+
if (scored.length === 0) {
61+
return []
62+
}
4863

64+
const effectiveK = Math.min(k, scored.length)
4965
scored.sort((a, b) => b.score - a.score)
5066
return scored.slice(0, effectiveK)
5167
}
@@ -68,7 +84,8 @@ export type RankedItem<T> = {
6884

6985
/**
7086
* Rank arbitrary items that carry embeddings, returning the top-k by attention score.
71-
* Items with missing or empty embeddings are skipped.
87+
* Items with missing or empty embeddings, or embeddings whose length does not match
88+
* `queryEmbedding`, are skipped.
7289
*/
7390
export const rankItemsByAttentionTopK = <T>(
7491
queryEmbedding: number[],
@@ -87,7 +104,11 @@ export const rankItemsByAttentionTopK = <T>(
87104
const item = items[i]
88105
if (item === undefined) continue
89106
const embedding = getEmbedding(item)
90-
if (embedding && embedding.length > 0) {
107+
if (
108+
embedding &&
109+
embedding.length > 0 &&
110+
embedding.length === queryEmbedding.length
111+
) {
91112
packed.push({ item, originalIndex: i, embedding })
92113
}
93114
}

0 commit comments

Comments
 (0)