Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 75e16cf

Browse files
Added cross encoder re-ranker (#2)
* Added cross encoder re-ranker * Update README.md * Update README.md
1 parent a5acfc4 commit 75e16cf

6 files changed

Lines changed: 89 additions & 32 deletions

File tree

README.md

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ Deploying these techniques, the pipeline for building a semantic vertical search
4040
4141
In some situations, an organization may also wish to fine-tune and retrain the pre-trained model with their own specialized dataset in order to improve the performance of the model to documents that an organization may have. For example, if an organization's documents are largely financial in nature, it could be useful to fine-tune these models so that they become aware of domain-specific jargon related to financial transactions or common phrases. In this reference kit, we do not demonstrate this process but more information on training and transfer learning techniques can be found at https://www.sbert.net/examples/training/sts/README.html.
4242

43+
Moreover, if companies aim to enhance capabilities centered around the vertical search engine, it can serve as a retreiver for custom documentation. The results from this retreiver can subsequently be input into a large language model, enabling context-aware responses to build a high quality chatbot.
44+
45+
4346
### Re-ranking
4447

45-
In this reference kit, we focus on the document retrieval aspect of building a vertical search engine to obtain an initial list of the top-K most similar documents in the corpus for a given query. Often times, this is sufficient for building a feature rich system. However, in some situations, a 3rd component, the re-ranker, which is not included in this reference kit, could be added to the search pipeline to improve results. In this architecture, for a given query, the *document retrieval* step will use one model to rapidly obtain a list of the top-K documents (as shown in this reference kit), followed by a *re-ranking* step which will use a different model to re-order the list of K retrieved documents before returning to the user. The second re-ranking refinement step has been shown to improve user satisfaction, especially when fine-tuned on a custom corpus, but may be unnecessary as a starting point for building a functional vertical search engine. To extend this reference implementation with re-ranking, we direct you to https://www.sbert.net/examples/applications/retrieve_rerank/README.html for further details on implementation where Intel® oneAPI optimizations can also be applied to speed up re-ranking models.
48+
In this reference kit, we focus on the document retrieval aspect of building a vertical search engine to obtain an initial list of the top-K most similar documents in the corpus for a given query. Often times, this is sufficient for building a feature rich system. However, in some situations, a 3rd component, the re-ranker, could be added to the search pipeline to improve results. In this architecture, for a given query, the *document retrieval* step will use one model to rapidly obtain a list of the top-K documents, followed by a *re-ranking* step which will use a different model to re-order the list of K retrieved documents before returning to the user. The second re-ranking refinement step has been shown to improve user satisfaction, especially when fine-tuned on a custom corpus, but may be unnecessary as a starting point for building a functional vertical search engine. To know more about re-ranker, we direct you to https://www.sbert.net/examples/applications/retrieve_rerank/README.html for further details. In this reference kit we use `cross-encoder/ms-marco-MiniLM-L-6-v2` model as re-ranker. For more details about different re-ranker models visit https://www.sbert.net/docs/pretrained-models/ce-msmarco.html.
4649

4750
### Key Implementation Details
4851

@@ -55,7 +58,7 @@ The reference kit implementation is a reference solution to the described use ca
5558

5659
### E2E Architecture
5760

58-
![Use_case_flow](assets/e2e-embedding-original.png)
61+
![Use_case_flow](assets/e2e-embedding-reranking.png)
5962

6063
### Expected Input-Output
6164

@@ -204,14 +207,16 @@ optional arguments:
204207
--benchmark_mode toggle to benchmark embedding
205208
--n_runs N_RUNS number of iterations to benchmark embedding
206209
--intel use intel pytorch extension to optimize model
210+
--use_re_ranker toggle to use cross encoder re-ranker model
211+
--input_corpus INPUT_CORPUS path to corpus to embed
207212
```
208213

209214
To perform realtime query search using the above set of saved corpus embeddings and the provided configuration file, which points to the saved embeddings file, we can run the commands:
210215

211216
```shell
212217
cd src
213218
conda activate vse_stock
214-
python run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --output_file ../saved_output/rankings.json
219+
python run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --output_file ../saved_output/rankings.json --use_re_ranker --input_corpus ../data/corpus_abbreviated.csv
215220
cd ..
216221
```
217222

@@ -256,7 +261,7 @@ This reference kit extends to demonstrate the advantages of using the Intel® Ex
256261

257262
![Model Quantization](assets/embedding-optimized.png)
258263

259-
### IIntel® Optimized Offline Realtime Query Search Decision Flow
264+
### Intel® Optimized Offline Realtime Query Search Decision Flow
260265

261266
![Optimized Execution](assets/realtime-search-optimized.png)
262267

@@ -322,7 +327,7 @@ To perform query searches with these additional optimizations and the `ipexrun`
322327
```shell
323328
cd src
324329
conda activate vse_intel
325-
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --output_file ../saved_output/rankings.json --intel
330+
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --output_file ../saved_output/rankings.json --intel --use_re_ranker --input_corpus ../data/corpus_abbreviated.csv
326331
cd ..
327332
```
328333

@@ -348,15 +353,15 @@ optional arguments:
348353
--batch_size BATCH_SIZE batch size to use. Defaults to 32.
349354
--save_model_dir SAVE_MODEL_DIR directory to save the quantized model to
350355
--inc_config_file INC_CONFIG_FILE INC conf yaml
351-
356+
--use_re_ranker toggle to use cross encoder re-ranker model
352357
```
353358

354359
which can be used for our models as follows:
355360

356361
```shell
357362
cd src
358363
conda activate vse_intel
359-
python run_quantize_inc.py --query_file ../data/quant_queries.csv --corpus_file ../data/corpus_quantization.csv --ground_truth_file ../data/ground_truth_quant.csv --vse_config configs/vse_config_inc.yml --save_model_dir ../saved_models/inc_int8 --inc_config_file conf.yml
364+
python run_quantize_inc.py --query_file ../data/quant_queries.csv --corpus_file ../data/corpus_quantization.csv --ground_truth_file ../data/ground_truth_quant.csv --vse_config configs/vse_config_inc.yml --save_model_dir ../saved_models/inc_int8 --inc_config_file conf.yml --use_re_ranker
360365
cd ..
361366
```
362367

@@ -398,7 +403,7 @@ To do realtime query searching, we can run the commands:
398403
```shell
399404
cd src
400405
conda activate vse_intel
401-
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_inc.yml --input_queries ../data/test_queries.csv --output_file ../saved_output/rankings.json --intel
406+
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_inc.yml --input_queries ../data/test_queries.csv --output_file ../saved_output/rankings.json --intel --use_re_ranker --input_corpus ../data/corpus_abbreviated.csv
402407
cd ..
403408
```
404409

@@ -518,7 +523,7 @@ To replicate the performance experiments described above, do the following:
518523
python run_document_embedder.py --vse_config configs/vse_config_base.yml --input_corpus ../data/corpus_abbreviated.csv --output_file ../saved_output/embeddings.pkl --batch_size 64
519524
520525
# Run benchmarks on single query search
521-
python run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --benchmark_mode --n_runs 10000 --batch_size 1 --logfile ../logs/stock.log
526+
python run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --benchmark_mode --n_runs 10000 --batch_size 1 --logfile ../logs/stock.log --input_corpus ../data/corpus_abbreviated.csv
522527
```
523528

524529
6. For the intel environment, run the following to run and log results to the ../logs/intel.log file
@@ -540,7 +545,7 @@ To replicate the performance experiments described above, do the following:
540545
ipexrun --use_logical_core --enable_tcmalloc run_document_embedder.py --vse_config configs/vse_config_base.yml --input_corpus ../data/corpus_abbreviated.csv --output_file ../saved_output/embeddings.pkl --batch_size 64 --intel
541546
542547
# Run single query search experiments using IPEX
543-
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --benchmark_mode --n_runs 10000 --batch_size 1 --logfile ../logs/intel.log --intel
548+
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_base.yml --input_queries ../data/test_queries.csv --benchmark_mode --n_runs 10000 --batch_size 1 --logfile ../logs/intel.log --intel --input_corpus ../data/corpus_abbreviated.csv
544549
545550
# Quantize the model using INC (long run time!)
546551
python run_quantize_inc.py --query_file ../data/quant_queries.csv --corpus_file ../data/corpus_quantization.csv --ground_truth_file ../data/ground_truth_quant.csv --vse_config configs/vse_config_inc.yml --save_model_dir ../saved_models/inc_int8 --inc_config_file conf.yml
@@ -553,7 +558,7 @@ To replicate the performance experiments described above, do the following:
553558
ipexrun --use_logical_core --enable_tcmalloc run_document_embedder.py --vse_config configs/vse_config_inc.yml --input_corpus ../data/corpus_abbreviated.csv --logfile ../logs/intel_inc_int8.log --batch_size 128 --benchmark_mode --intel
554559
555560
# Run single query search experiments using INC INT8
556-
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_inc.yml --input_queries ../data/test_queries.csv --benchmark_mode --n_runs 10000 --batch_size 1 --logfile ../logs/intel_inc_int8.log --intel
561+
ipexrun --use_logical_core --enable_tcmalloc run_query_search.py --vse_config configs/vse_config_inc.yml --input_queries ../data/test_queries.csv --benchmark_mode --n_runs 10000 --batch_size 1 --logfile ../logs/intel_inc_int8.log --intel --input_corpus ../data/corpus_abbreviated.csv
557562
558563
```
559564

@@ -570,4 +575,4 @@ To replicate the performance experiments described above, do the following:
570575
```bash
571576
apt install libgl1-mesa-glx
572577
```
573-
578+

assets/e2e-embedding-reranking.png

83.2 KB
Loading

src/configs/vse_config_base.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ version: 1.0
55
model:
66
format: default # default, inc, pt
77
pretrained_model: sentence-transformers/msmarco-distilbert-base-tas-b
8+
cross_encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
89
max_seq_length: 128
910

1011
# inference config
11-
inference:
12+
inference:
1213
top_k : 5
1314
score_function : dot # cos_sim, dot
14-
corpus_embeddings_path : ../saved_output/embeddings.pkl
15-
15+
corpus_embeddings_path : ../saved_output/embeddings.pkl

src/configs/vse_config_inc.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ version: 1.0
55
model:
66
format: inc # default, inc, pt
77
pretrained_model: sentence-transformers/msmarco-distilbert-base-tas-b
8+
cross_encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
89
max_seq_length: 128
910

1011
# inc required config parameters
1112
path: ../saved_models/inc_int8
1213

1314
# inference config
14-
inference:
15+
inference:
1516
top_k : 5
1617
score_function : dot # cos_sim, dot
1718
corpus_embeddings_path : ../saved_output/embeddings.pkl

src/run_quantize_inc.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
PreTrainedTokenizer,
2626
PreTrainedModel
2727
)
28-
28+
from sentence_transformers.cross_encoder import CrossEncoder
2929
from utils.dataloader import (
3030
load_queries, load_corpus, QueryDataset, CorpusDataset
3131
)
@@ -35,11 +35,13 @@
3535
def quantize_model(
3636
tokenizer: PreTrainedTokenizer,
3737
embedder: PreTrainedModel,
38+
cross_encoder: CrossEncoder,
3839
queries: QueryDataset,
3940
corpus: CorpusDataset,
4041
inc_config_file: str,
4142
score_func,
4243
gt,
44+
use_re_ranker: bool = True,
4345
top_k: int = 5,
4446
max_seq_length: int = 128,
4547
batch_size: int = 64):
@@ -93,7 +95,16 @@ def evaluate_contains_top_entries(model_q) -> float:
9395
corpus_embeddings,
9496
top_k=top_k,
9597
score_function=score_func)
96-
98+
99+
### Added Re-Ranker
100+
if use_re_ranker:
101+
for i in range(len(res)):
102+
cross_inp = [[queries[i], corpus[entry['corpus_id']]] for entry in res[i]]
103+
cross_scores = cross_encoder.predict(cross_inp)
104+
for idx in range(len(cross_scores)):
105+
res[i][idx]['cross-score'] = float(cross_scores[idx])
106+
res[i] = sorted(res[i], key=lambda x: x['cross-score'], reverse=True)
107+
#######
97108
correct = 0
98109
for idx, query_ranking in enumerate(res):
99110
matches = []
@@ -170,7 +181,7 @@ def main(flags) -> None:
170181
conf['model']['pretrained_model'])
171182
embedder = AutoModel.from_pretrained(conf['model']['pretrained_model'])
172183
embedder.eval()
173-
184+
cross_encoder = CrossEncoder(conf['model']['cross_encoder'])
174185
score_func = util.cos_sim
175186
if conf['inference']['score_function'] == 'dot':
176187
score_func = util.dot_score
@@ -181,14 +192,16 @@ def main(flags) -> None:
181192
quantized_model = quantize_model(
182193
tokenizer,
183194
embedder,
195+
cross_encoder,
184196
query_dataset,
185197
corpus_dataset,
186198
flags.inc_config_file,
187199
score_func,
188200
ground_truth,
189201
top_k=conf['inference']['top_k'],
190202
max_seq_length=conf["model"]["max_seq_length"],
191-
batch_size=64)
203+
batch_size=64,
204+
use_re_ranker=flags.use_re_ranker)
192205
quantized_model.save(flags.save_model_dir)
193206

194207

@@ -230,6 +243,13 @@ def main(flags) -> None:
230243
required=True
231244
)
232245

246+
parser.add_argument('--use_re_ranker',
247+
required=False,
248+
help="Use cross encoder re-ranking",
249+
action="store_true",
250+
default=True
251+
)
252+
233253
FLAGS = parser.parse_args()
234254

235-
main(FLAGS)
255+
main(FLAGS)

src/run_query_search.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,23 @@
2929
PreTrainedTokenizer,
3030
PreTrainedModel
3131
)
32-
33-
from utils.dataloader import load_queries
34-
from utils.embed import encode, batch_encode
35-
32+
from sentence_transformers.cross_encoder import CrossEncoder
33+
from utils.dataloader import load_queries, load_corpus, CorpusDataset
34+
from utils.embed import encode, batch_encode
3635
random.seed(0)
3736

3837

3938
def search_query(
4039
logger: logging.Logger,
4140
tokenizer: PreTrainedTokenizer,
4241
embedder: PreTrainedModel,
42+
cross_encoder: CrossEncoder,
4343
queries: List[str],
4444
corpus_embeddings: np.ndarray,
4545
idx_to_ids: Dict[int, str],
46+
corpus: CorpusDataset,
4647
top_k: int = 5,
48+
use_re_ranker: bool=True,
4749
max_sequence_length: int = 128,
4850
batch_size: int = 1,
4951
n_runs: int = 100,
@@ -65,8 +67,10 @@ def search_query(
6567
Pre-embedded corpus of documents.
6668
idx_to_ids: (Dict[int, int]):
6769
Map of embedding index to document ids.
70+
corpus (CorpusDataset):
71+
CorpusDataset to embed.
6872
top_k (int, optional):
69-
Number of entries similar corpus documents to return.
73+
Number of entries similar corpus documents to return.
7074
max_sequence_length (int, optional):
7175
max sequence length. Defaults to 128.
7276
batch_size (int, optional):
@@ -139,10 +143,19 @@ def search_query(
139143
corpus_embeddings,
140144
top_k=top_k,
141145
score_function=score_func)
142-
146+
147+
if use_re_ranker and cross_encoder!=None:
148+
### Prepare inp for cross_encoder using output of bi_encoder
149+
for i in range(len(out)):
150+
cross_inp = [[queries[i], corpus[entry['corpus_id']]] for entry in out[i]]
151+
cross_scores = cross_encoder.predict(cross_inp)
152+
for idx in range(len(cross_scores)):
153+
out[i][idx]['cross-score'] = float(cross_scores[idx])
154+
out[i] = sorted(out[i], key=lambda x: x['cross-score'], reverse=True)
155+
143156
# map index based ids to raw corpus_ids
144-
for res in out:
145-
for entry in res:
157+
for i in range(len(out)):
158+
for entry in out[i]:
146159
entry['corpus_id'] = idx_to_ids[entry['corpus_id']]
147160

148161
if output_file is not None:
@@ -186,6 +199,7 @@ def main(flags):
186199

187200
# load the pretrained embedding model
188201
embedder = AutoModel.from_pretrained(conf['model']['pretrained_model'])
202+
cross_encoder = CrossEncoder(conf['model']['cross_encoder'])
189203

190204
elif conf["model"]["format"] == "inc":
191205

@@ -194,7 +208,7 @@ def main(flags):
194208

195209
embedder = AutoModel.from_pretrained(conf['model']['pretrained_model'])
196210
embedder = load(conf["model"]["path"], embedder)
197-
211+
cross_encoder = CrossEncoder(conf['model']['cross_encoder'])
198212
# re-establish logger because it breaks from above
199213
logging.getLogger().handlers.clear()
200214

@@ -222,7 +236,6 @@ def main(flags):
222236
if flags.intel:
223237
import intel_extension_for_pytorch as ipex
224238
embedder = ipex.optimize(embedder, dtype=torch.float32)
225-
226239
sample_inputs = tokenizer.batch_decode([
227240
random.sample(
228241
range(tokenizer.vocab_size), max_sequence_length) for
@@ -252,16 +265,21 @@ def main(flags):
252265
corpus_embeddings = saved_embeddings['embeddings']
253266
idx_to_ids = dict(enumerate(ids))
254267

268+
# read in corpus dataset
269+
corpus = load_corpus(flags.input_corpus)
255270
search_query(
256271
logger=logger,
257272
tokenizer=tokenizer,
258273
embedder=embedder,
274+
cross_encoder=cross_encoder,
259275
queries=input_file.queries,
260276
corpus_embeddings=corpus_embeddings,
261277
idx_to_ids=idx_to_ids,
278+
corpus=corpus,
262279
top_k=conf['inference']['top_k'],
263280
max_sequence_length=max_sequence_length,
264281
batch_size=flags.batch_size,
282+
use_re_ranker=flags.use_re_ranker,
265283
n_runs=flags.n_runs,
266284
score=conf['inference']['score_function'],
267285
output_file=flags.output_file,
@@ -289,6 +307,12 @@ def main(flags):
289307
type=str
290308
)
291309

310+
parser.add_argument('--input_corpus',
311+
required=True,
312+
help="path to corpus to embed",
313+
type=str
314+
)
315+
292316
parser.add_argument('--output_file',
293317
required=False,
294318
help="file to output top k documents to",
@@ -310,6 +334,13 @@ def main(flags):
310334
default=False
311335
)
312336

337+
parser.add_argument('--use_re_ranker',
338+
required=False,
339+
help="Use cross encoder reranking",
340+
action="store_true",
341+
default=False
342+
)
343+
313344
parser.add_argument('--n_runs',
314345
required=False,
315346
help="number of iterations to benchmark embedding",
@@ -326,4 +357,4 @@ def main(flags):
326357

327358
FLAGS = parser.parse_args()
328359

329-
main(FLAGS)
360+
main(FLAGS)

0 commit comments

Comments
 (0)