Skip to content

Commit f7a1163

Browse files
committed
feat: Enabling calls to the rag_api simple reranker.
1 parent bc74b6c commit f7a1163

4 files changed

Lines changed: 510 additions & 5 deletions

File tree

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
"cheerio": "^1.0.0",
118118
"dotenv": "^16.4.7",
119119
"https-proxy-agent": "^7.0.6",
120+
"jsonwebtoken": "^9.0.2",
120121
"mathjs": "^15.1.0",
121122
"nanoid": "^3.3.7",
122123
"openai": "5.8.2"
@@ -135,6 +136,7 @@
135136
"@rollup/plugin-typescript": "^12.1.2",
136137
"@swc/core": "^1.6.13",
137138
"@types/jest": "^30.0.0",
139+
"@types/jsonwebtoken": "^9.0.10",
138140
"@types/node": "^20.14.11",
139141
"@types/node-fetch": "^2.6.13",
140142
"@types/yargs-parser": "^21.0.3",

src/tools/search/rerankers.ts

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import axios from 'axios';
22
import type * as t from './types';
33
import { createDefaultLogger } from './utils';
4+
import jwt from 'jsonwebtoken';
5+
import { nanoid } from 'nanoid';
46

57
export abstract class BaseReranker {
68
protected apiKey: string | undefined;
9+
protected instanceUrl: string | undefined;
710
protected logger: t.Logger;
811

912
constructor(logger?: t.Logger) {
@@ -27,6 +30,85 @@ export abstract class BaseReranker {
2730
}
2831
}
2932

33+
export class SimpleReranker extends BaseReranker {
34+
constructor({ logger }: { logger?: t.Logger }) {
35+
super(logger);
36+
if (
37+
process.env.RAG_API_URL !== undefined &&
38+
process.env.RAG_API_URL !== ''
39+
) {
40+
this.instanceUrl = process.env.RAG_API_URL + '/rerank';
41+
}
42+
}
43+
44+
async rerank(
45+
query: string,
46+
documents: string[],
47+
topK: number = 5
48+
): Promise<t.Highlight[]> {
49+
this.logger.debug(
50+
`Reranking ${documents.length} chunks with SimpleReranker`
51+
);
52+
53+
if (this.instanceUrl === undefined || this.instanceUrl === '') {
54+
this.logger.warn('RAG_API_URL is not set. Using default ranking.');
55+
return this.getDefaultRanking(documents, topK);
56+
}
57+
58+
try {
59+
const requestData = {
60+
query: query,
61+
docs: documents,
62+
k: topK,
63+
};
64+
65+
const statePayload = {
66+
nonce: nanoid(),
67+
};
68+
69+
const jwtSecret = process.env.JWT_SECRET;
70+
71+
if (jwtSecret === undefined || jwtSecret === '') {
72+
this.logger.warn('JWT_SECRET is not set. Using default ranking.');
73+
return this.getDefaultRanking(documents, topK);
74+
}
75+
76+
const stateToken = jwt.sign(statePayload, jwtSecret, {
77+
expiresIn: '10m',
78+
});
79+
80+
const resp = await axios.post<t.SimpleRerankerResponse | undefined>(
81+
this.instanceUrl,
82+
requestData,
83+
{
84+
headers: {
85+
'Content-Type': 'application/json',
86+
Authorization: 'Bearer ' + stateToken,
87+
},
88+
}
89+
);
90+
91+
if (resp.data && Array.isArray(resp.data) && resp.data.length > 0) {
92+
const isValid = resp.data.every(
93+
(item: t.SimpleRerankerResponse) =>
94+
typeof item.text === 'string' && typeof item.score === 'number'
95+
);
96+
if (isValid) {
97+
return resp.data;
98+
}
99+
this.logger.warn(
100+
'Unexpected response format from Simple reranker. Using default ranking.'
101+
);
102+
}
103+
return this.getDefaultRanking(documents, topK);
104+
} catch (error) {
105+
this.logger.error('Error using Simple reranker:', error);
106+
// Fallback to default ranking on error
107+
return this.getDefaultRanking(documents, topK);
108+
}
109+
}
110+
}
111+
30112
export class JinaReranker extends BaseReranker {
31113
private apiUrl: string;
32114

@@ -49,7 +131,9 @@ export class JinaReranker extends BaseReranker {
49131
documents: string[],
50132
topK: number = 5
51133
): Promise<t.Highlight[]> {
52-
this.logger.debug(`Reranking ${documents.length} chunks with Jina using API URL: ${this.apiUrl}`);
134+
this.logger.debug(
135+
`Reranking ${documents.length} chunks with Jina using API URL: ${this.apiUrl}`
136+
);
53137

54138
try {
55139
if (this.apiKey == null || this.apiKey === '') {
@@ -217,22 +301,34 @@ export const createReranker = (config: {
217301

218302
switch (rerankerType.toLowerCase()) {
219303
case 'jina':
220-
return new JinaReranker({ apiKey: jinaApiKey, apiUrl: jinaApiUrl, logger: defaultLogger });
304+
return new JinaReranker({
305+
apiKey: jinaApiKey,
306+
apiUrl: jinaApiUrl,
307+
logger: defaultLogger,
308+
});
221309
case 'cohere':
222310
return new CohereReranker({
223311
apiKey: cohereApiKey,
224312
logger: defaultLogger,
225313
});
314+
case 'simple':
315+
return new SimpleReranker({
316+
logger: defaultLogger,
317+
});
226318
case 'infinity':
227319
return new InfinityReranker(defaultLogger);
228320
case 'none':
229321
defaultLogger.debug('Skipping reranking as reranker is set to "none"');
230322
return undefined;
231323
default:
232324
defaultLogger.warn(
233-
`Unknown reranker type: ${rerankerType}. Defaulting to InfinityReranker.`
325+
`Unknown reranker type: ${rerankerType}. Defaulting to JinaReranker.`
234326
);
235-
return new JinaReranker({ apiKey: jinaApiKey, apiUrl: jinaApiUrl, logger: defaultLogger });
327+
return new JinaReranker({
328+
apiKey: jinaApiKey,
329+
apiUrl: jinaApiUrl,
330+
logger: defaultLogger,
331+
});
236332
}
237333
};
238334

0 commit comments

Comments
 (0)