Skip to content

Commit 8673eef

Browse files
committed
feat: Enabling calls to the rag_api simple reranker.
1 parent 7abb9ff commit 8673eef

4 files changed

Lines changed: 511 additions & 5 deletions

File tree

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
"cheerio": "^1.0.0",
122122
"dotenv": "^16.4.7",
123123
"https-proxy-agent": "^7.0.6",
124+
"jsonwebtoken": "^9.0.2",
124125
"mathjs": "^15.1.0",
125126
"nanoid": "^3.3.7",
126127
"openai": "5.8.2"
@@ -139,6 +140,7 @@
139140
"@rollup/plugin-typescript": "^12.1.2",
140141
"@swc/core": "^1.6.13",
141142
"@types/jest": "^30.0.0",
143+
"@types/jsonwebtoken": "^9.0.10",
142144
"@types/node": "^20.14.11",
143145
"@types/node-fetch": "^2.6.13",
144146
"@types/yargs-parser": "^21.0.3",

src/tools/search/rerankers.ts

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import axios from 'axios';
22
import type * as t from './types';
33
import { createDefaultLogger } from './utils';
4+
import jwt from 'jsonwebtoken';
5+
import { nanoid } from 'nanoid';
46

57
export abstract class BaseReranker {
68
protected apiKey: string | undefined;
@@ -27,6 +29,87 @@ export abstract class BaseReranker {
2729
}
2830
}
2931

32+
export class SimpleReranker extends BaseReranker {
33+
private instanceUrl: string | undefined;
34+
35+
constructor({ logger }: { logger?: t.Logger }) {
36+
super(logger);
37+
if (
38+
process.env.RAG_API_URL !== undefined &&
39+
process.env.RAG_API_URL !== ''
40+
) {
41+
this.instanceUrl = process.env.RAG_API_URL + '/rerank';
42+
}
43+
}
44+
45+
async rerank(
46+
query: string,
47+
documents: string[],
48+
topK: number = 5
49+
): Promise<t.Highlight[]> {
50+
this.logger.debug(
51+
`Reranking ${documents.length} chunks with SimpleReranker`
52+
);
53+
54+
if (this.instanceUrl === undefined || this.instanceUrl === '') {
55+
this.logger.warn('RAG_API_URL is not set. Using default ranking.');
56+
return this.getDefaultRanking(documents, topK);
57+
}
58+
59+
try {
60+
const requestData = {
61+
query: query,
62+
docs: documents,
63+
k: topK,
64+
};
65+
66+
const statePayload = {
67+
nonce: nanoid(),
68+
};
69+
70+
const jwtSecret = process.env.JWT_SECRET;
71+
72+
if (jwtSecret === undefined || jwtSecret === '') {
73+
this.logger.warn('JWT_SECRET is not set. Using default ranking.');
74+
return this.getDefaultRanking(documents, topK);
75+
}
76+
77+
const stateToken = jwt.sign(statePayload, jwtSecret, {
78+
expiresIn: '10m',
79+
});
80+
81+
const resp = await axios.post<t.SimpleRerankerResponse | undefined>(
82+
this.instanceUrl,
83+
requestData,
84+
{
85+
headers: {
86+
'Content-Type': 'application/json',
87+
Authorization: 'Bearer ' + stateToken,
88+
},
89+
}
90+
);
91+
92+
if (resp.data && Array.isArray(resp.data) && resp.data.length > 0) {
93+
const isValid = resp.data.every(
94+
(item: t.SimpleRerankerResponse) =>
95+
typeof item.text === 'string' && typeof item.score === 'number'
96+
);
97+
if (isValid) {
98+
return resp.data;
99+
}
100+
this.logger.warn(
101+
'Unexpected response format from Simple reranker. Using default ranking.'
102+
);
103+
}
104+
return this.getDefaultRanking(documents, topK);
105+
} catch (error) {
106+
this.logger.error('Error using Simple reranker:', error);
107+
// Fallback to default ranking on error
108+
return this.getDefaultRanking(documents, topK);
109+
}
110+
}
111+
}
112+
30113
export class JinaReranker extends BaseReranker {
31114
private apiUrl: string;
32115

@@ -49,7 +132,9 @@ export class JinaReranker extends BaseReranker {
49132
documents: string[],
50133
topK: number = 5
51134
): Promise<t.Highlight[]> {
52-
this.logger.debug(`Reranking ${documents.length} chunks with Jina using API URL: ${this.apiUrl}`);
135+
this.logger.debug(
136+
`Reranking ${documents.length} chunks with Jina using API URL: ${this.apiUrl}`
137+
);
53138

54139
try {
55140
if (this.apiKey == null || this.apiKey === '') {
@@ -217,22 +302,34 @@ export const createReranker = (config: {
217302

218303
switch (rerankerType.toLowerCase()) {
219304
case 'jina':
220-
return new JinaReranker({ apiKey: jinaApiKey, apiUrl: jinaApiUrl, logger: defaultLogger });
305+
return new JinaReranker({
306+
apiKey: jinaApiKey,
307+
apiUrl: jinaApiUrl,
308+
logger: defaultLogger,
309+
});
221310
case 'cohere':
222311
return new CohereReranker({
223312
apiKey: cohereApiKey,
224313
logger: defaultLogger,
225314
});
315+
case 'simple':
316+
return new SimpleReranker({
317+
logger: defaultLogger,
318+
});
226319
case 'infinity':
227320
return new InfinityReranker(defaultLogger);
228321
case 'none':
229322
defaultLogger.debug('Skipping reranking as reranker is set to "none"');
230323
return undefined;
231324
default:
232325
defaultLogger.warn(
233-
`Unknown reranker type: ${rerankerType}. Defaulting to InfinityReranker.`
326+
`Unknown reranker type: ${rerankerType}. Defaulting to JinaReranker.`
234327
);
235-
return new JinaReranker({ apiKey: jinaApiKey, apiUrl: jinaApiUrl, logger: defaultLogger });
328+
return new JinaReranker({
329+
apiKey: jinaApiKey,
330+
apiUrl: jinaApiUrl,
331+
logger: defaultLogger,
332+
});
236333
}
237334
};
238335

0 commit comments

Comments
 (0)