@@ -17,12 +17,21 @@ export type AttentionTopK = {
1717/**
1818 * Dot-product attention score between one query vector and one key vector.
1919 * For normalized embeddings this matches cosine similarity.
20+ *
21+ * Returns `NaN` when `query` and `key` have different lengths (e.g. mixed embedding
22+ * models) so callers can avoid throwing from `cosineSimilarity`.
2023 */
21- export const attentionScore = ( query : number [ ] , key : number [ ] ) : number =>
22- cosineSimilarity ( query , key )
24+ export const attentionScore = ( query : number [ ] , key : number [ ] ) : number => {
25+ if ( query . length !== key . length ) {
26+ return Number . NaN
27+ }
28+ return cosineSimilarity ( query , key )
29+ }
2330
2431/**
25- * Attention scores for `query` against every row in `keys` (same dimension as query).
32+ * Attention scores for `query` against every row in `keys`, aligned by index.
33+ * Entries are `NaN` when a key length does not match the query (same embedding model
34+ * is required for a meaningful score).
2635 */
2736export const attentionScores = ( query : number [ ] , keys : number [ ] [ ] ) : number [ ] =>
2837 keys . map ( ( key ) => attentionScore ( query , key ) )
@@ -40,12 +49,19 @@ export const topKAttentionKeys = (
4049 return [ ]
4150 }
4251
43- const effectiveK = Math . min ( k , keys . length )
44- const scored : AttentionTopK [ ] = keys . map ( ( key , index ) => ( {
45- index,
46- score : attentionScore ( query , key ) ,
47- } ) )
52+ const scored : AttentionTopK [ ] = keys . flatMap ( ( key , index ) => {
53+ const score = attentionScore ( query , key )
54+ if ( ! Number . isFinite ( score ) ) {
55+ return [ ]
56+ }
57+ return [ { index, score } ]
58+ } )
59+
60+ if ( scored . length === 0 ) {
61+ return [ ]
62+ }
4863
64+ const effectiveK = Math . min ( k , scored . length )
4965 scored . sort ( ( a , b ) => b . score - a . score )
5066 return scored . slice ( 0 , effectiveK )
5167}
@@ -68,7 +84,8 @@ export type RankedItem<T> = {
6884
6985/**
7086 * Rank arbitrary items that carry embeddings, returning the top-k by attention score.
71- * Items with missing or empty embeddings are skipped.
87+ * Items with missing or empty embeddings, or embeddings whose length does not match
88+ * `queryEmbedding`, are skipped.
7289 */
7390export const rankItemsByAttentionTopK = < T > (
7491 queryEmbedding : number [ ] ,
@@ -87,7 +104,11 @@ export const rankItemsByAttentionTopK = <T>(
87104 const item = items [ i ]
88105 if ( item === undefined ) continue
89106 const embedding = getEmbedding ( item )
90- if ( embedding && embedding . length > 0 ) {
107+ if (
108+ embedding &&
109+ embedding . length > 0 &&
110+ embedding . length === queryEmbedding . length
111+ ) {
91112 packed . push ( { item, originalIndex : i , embedding } )
92113 }
93114 }
0 commit comments