Skip to content

Commit 631f45f

Browse files
authored
Merge pull request #2 from Andrian0s/add_feature/custom_refixes_adapters_ndcg
Extend support for HuggingFace models and add NDCG@k metric
2 parents b83a0bb + ff17d31 commit 631f45f

14 files changed

Lines changed: 1637 additions & 1392 deletions

File tree

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- name: Set up Python
1717
uses: actions/setup-python@v5
1818
with:
19-
python-version: "3.11"
19+
python-version: "3.13"
2020

2121
- name: Install uv
2222
run: pip3 install uv

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- name: Set up Python
1717
uses: actions/setup-python@v5
1818
with:
19-
python-version: '3.11'
19+
python-version: '3.13'
2020

2121
- name: Install uv
2222
run: pip3 install uv

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.11
1+
3.13

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
**A framework for evaluating semantic search across custom datasets, metrics, and embedding backends.**
44

55
![GitHub License](https://img.shields.io/github/license/machinelearningZH/semantic-search-eval)
6-
[![PyPI - Python](https://img.shields.io/badge/python-v3.11+-blue.svg)](https://github.com/machinelearningZH/semantic-search-eval)
6+
[![PyPI - Python](https://img.shields.io/badge/python-v3.13+-blue.svg)](https://github.com/machinelearningZH/semantic-search-eval)
77
[![GitHub Stars](https://img.shields.io/github/stars/machinelearningZH/semantic-search-eval.svg)](https://github.com/machinelearningZH/semantic-search-eval/stargazers)
88
[![GitHub Issues](https://img.shields.io/github/issues/machinelearningZH/semantic-search-eval.svg)](https://github.com/machinelearningZH/semantic-search-eval/issues)
99
[![GitHub Pull Requests](https://img.shields.io/github/issues-pr/machinelearningZH/semantic-search-eval.svg)](https://img.shields.io/github/issues-pr/machinelearningZH/semantic-search-eval)
@@ -33,7 +33,7 @@
3333
## Features
3434
- **Flexible model integration**: HuggingFace, OpenAI, BM25, and more.
3535
- **Simple YAML-based configuration**.
36-
- **Custom evaluation metrics**: e.g., Accuracy@k, Latency.
36+
- **Custom evaluation metrics**: e.g., Accuracy@k, NDCG@k, Latency.
3737
- **Integrated visualizations** via `seaborn`/`matplotlib`.
3838

3939
## Installation
@@ -67,6 +67,9 @@ You need two input files:
6767
> [!NOTE]
6868
> Embedding models have a maximum input length - more on this in the next section. If your documents exceed this length, they should be split into smaller chunks before evaluation to ensure compatibility with the models. All preprocessing (e.g., cleaning, tokenization) should be completed before evaluation, as it is not (yet) supported in this toolkit.
6969
70+
> [!NOTE]
71+
> The current implementation of this toolkit assumes **exactly one relevant document per query**. All metrics (Accuracy@k, NDCG@k, etc.) are designed for this single-answer evaluation scenario. If you have multiple relevant documents per query, the current metrics will not produce meaningful results.
72+
7073
### Configuration
7174
Create a YAML config to define datasets, models, and metrics. Use [`configs/example.yaml`](configs/example.yaml) as a template.
7275

@@ -76,7 +79,10 @@ Key fields:
7679
- `docs` and `queries`: paths to your documents and queries in CSV or parquet format
7780
- `is-public-data`: set to true to use OpenAI query creator if data is public
7881
- `max-len`: set to the **shortest model limit** to ensure fair evaluation with same input text length for all models
79-
- `models`: define model backends and options
82+
- `models`: define model backends and options. Supported backends are `huggingface`, `lexical`, and `open-ai`. HuggingFace models support the following optional parameters (check the usage examples on HuggingFace to see whether a model uses either of these parameters):
83+
- `set_builtin_query_prompt` / `set_builtin_passage_prompt`: use a model's built-in named prompt for queries/passages
84+
- `set_query_task_prompt` / `set_passage_task_prompt`: pass a task string to the encoder (e.g. for Jina models)
85+
- `set_custom_query_prefix` / `set_custom_passage_prefix`: prepend a custom string to each query/passage at inference time (mutually exclusive with the built-in prompt options)
8086

8187
### OpenAI Key
8288
To use OpenAI-based features:

configs/example.yaml

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ metrics:
1515
- accuracy@1 # Top-1 accuracy: checks if gold doc is rank 1
1616
- accuracy@5 # Top-5 accuracy: checks if gold doc is in top 5
1717
- accuracy@10 # Top-10 accuracy: same logic for top 10
18+
- ndcg@10 # NDCG@10: rewards higher-ranked gold docs (position-aware)
1819
- latency # Time taken for full inference per model
1920

2021
# Global max token length for truncation (needs to be based on smallest model max len for fair comparison)
@@ -26,24 +27,40 @@ models:
2627
lexical:
2728
bm25: de_core_news_sm # uses spacy model to lemmatize before indexing
2829

29-
intfloat:
30-
intfloat-small: intfloat/multilingual-e5-small # max-len 512
31-
intfloat-base: intfloat/multilingual-e5-base # max-len 512
32-
33-
intfloat-instruct:
34-
intfloat-instruct: intfloat/multilingual-e5-large-instruct # max-len 512
35-
3630
huggingface:
3731
jina-v2: jinaai/jina-embeddings-v2-base-de # max-len 8192
3832
all-MiniLM-v2: sentence-transformers/all-MiniLM-L6-v2 # max-len 512
3933
granite: ibm-granite/granite-embedding-278m-multilingual # max-len 512
4034
nomic:
4135
model: nomic-ai/nomic-embed-text-v2-moe # max-len 512
42-
use_query_prompt: true
43-
use_passage_prompt: true
36+
set_builtin_query_prompt: query
37+
set_builtin_passage_prompt: passage
4438
snowflake:
4539
model: Snowflake/snowflake-arctic-embed-l-v2.0 # max-len 8192
46-
use_query_prompt: true
40+
set_builtin_query_prompt: query
41+
jina-v3:
42+
model: jinaai/jina-embeddings-v3 # max-len 8192
43+
set_builtin_query_prompt: retrieval.query
44+
set_builtin_passage_prompt: retrieval.passage
45+
set_passage_task_prompt: retrieval.passage
46+
set_query_task_prompt: retrieval.query
47+
jina-v5-small:
48+
model: jinaai/jina-embeddings-v5-text-small # max-len 32768
49+
set_query_task_prompt: retrieval
50+
set_passage_task_prompt: retrieval
51+
set_builtin_passage_prompt: document
52+
set_builtin_query_prompt: query
53+
intfloat-small:
54+
model: intfloat/multilingual-e5-small # max-len 512
55+
set_custom_query_prefix: "query: "
56+
set_custom_passage_prefix: "passage: "
57+
intfloat-base:
58+
model: intfloat/multilingual-e5-base # max-len 512
59+
set_custom_query_prefix: "query: "
60+
set_custom_passage_prefix: "passage: "
61+
intfloat-instruct:
62+
model: intfloat/multilingual-e5-large-instruct # max-len 512
63+
set_custom_query_prefix: "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: "
4764

4865
# open-ai:
4966
# open-ai-3-small: text-embedding-3-small # max-len 8191

pyproject.toml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@ authors = [
77
]
88
license = { file = "LICENSE" }
99
readme = "README.md"
10-
requires-python = ">=3.11"
10+
requires-python = ">=3.13"
1111
dependencies = [
12-
"accelerate>=1.6.0",
13-
"einops>=0.8.1",
12+
"accelerate>=1.13.0",
13+
"einops>=0.8.2",
1414
"de-core-news-sm",
1515
"openai>=1.77.0",
1616
"ordered-set>=4.1.0",
17-
"polars>=1.29.0",
18-
"pytest-mock>=3.14.0",
17+
"polars>=1.40.1",
18+
"pytest-mock>=3.15.1",
1919
"rank-bm25>=0.2.2",
20-
"ruamel-yaml>=0.18.10",
20+
"ruamel-yaml>=0.19.1",
2121
"seaborn>=0.13.2",
22-
"sentence-transformers>=4.1.0",
23-
"spacy>=3.8.5",
24-
"tiktoken>=0.9.0",
22+
"sentence-transformers<=5.4.1",
23+
"transformers<5.0.0",
24+
"spacy>=3.8.14",
25+
"tiktoken>=0.12.0",
2526
"dotenv>=0.9.9",
27+
"peft>=0.19.1",
2628
]
2729

2830
[dependency-groups]

semsearcheval/constants.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,23 @@
44

55
from typing import Dict, Type
66

7-
from semsearcheval.metrics import Accuracy, Latency, Metric
7+
from semsearcheval.metrics import NDCG, Accuracy, Latency, Metric
88
from semsearcheval.models import (
99
BM25Model,
1010
HuggingFaceModel,
11-
IntFloatInstructModel,
12-
IntFloatModel,
1311
Model,
1412
OpenAIModel,
1513
)
1614

1715

1816
model_registry: Dict[str, Type[Model]] = {
1917
"huggingface": HuggingFaceModel,
20-
"intfloat": IntFloatModel,
21-
"intfloat-instruct": IntFloatInstructModel,
2218
"lexical": BM25Model,
2319
"open-ai": OpenAIModel,
2420
}
2521

26-
metric_registry: Dict[str, Type[Metric]] = {"accuracy": Accuracy, "latency": Latency}
22+
metric_registry: Dict[str, Type[Metric]] = {
23+
"accuracy": Accuracy,
24+
"latency": Latency,
25+
"ndcg": NDCG,
26+
}

semsearcheval/metrics.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ class Metric(ABC):
1717
def __init__(self, name: str) -> None:
1818
self.name = name
1919

20+
def _parse_k(self, name: str) -> None:
21+
"""Extracts the top-k cutoff from the metric name."""
22+
if "@" not in name:
23+
raise ValueError(f"Invalid metric name: {name}. Expected format: metric@k")
24+
k = int(name.split("@")[1])
25+
if k <= 0:
26+
raise ValueError(f"Invalid k value: {k}. Must be a positive integer.")
27+
return k
28+
2029
@abstractmethod
2130
def compute(self, result: Result) -> Tuple[float, str]:
2231
pass
@@ -34,24 +43,18 @@ def __init__(self, name: str) -> None:
3443
super().__init__(name)
3544
self.k = self._parse_k(name)
3645

37-
def _parse_k(self, name: str) -> None:
38-
"""Extracts the top-k cutoff from the metric name."""
39-
if "@" not in name:
40-
raise ValueError(f"Invalid metric name: {name}. Expected format: accuracy@k")
41-
k = int(name.split("@")[1])
42-
if k <= 0:
43-
raise ValueError(f"Invalid k value: {k}. Must be a positive integer.")
44-
return k
45-
4646
def compute(self, result: Result) -> Tuple[float, str]:
4747
"""
4848
Compute top-k accuracy: proportion of queries where the gold document index
4949
appears in the top-k predicted documents (by similarity score).
5050
"""
5151
correct_retrieved = 0
5252
for q_sims, gold_index in zip(result.similarity, result.gold_indices):
53-
# Get indices of top-k most similar documents (descending order as higher similarity is better)
53+
# Sort similarity scores in descending order (highest similarity first)
54+
# and select the indices of the top-k documents
5455
top_k = np.argsort(q_sims)[::-1][: self.k]
56+
57+
# Check if the gold (correct) document is among the top-k predictions
5558
if gold_index in top_k:
5659
correct_retrieved += 1
5760

@@ -68,3 +71,44 @@ class Latency(Metric):
6871

6972
def compute(self, result: Result) -> Tuple[float, str]:
7073
return result.time, "s"
74+
75+
76+
class NDCG(Metric):
77+
"""
78+
Computes NDCG@k (Normalized Discounted Cumulative Gain) over all queries.
79+
80+
With a single relevant document per query the ideal DCG is always 1.0,
81+
so NDCG simplifies to 1/log2(rank+1) if the gold doc is within the top k,
82+
and 0 otherwise. The result is averaged across all queries.
83+
The metric name should be in the form: ndcg@k.
84+
"""
85+
86+
def __init__(self, name: str) -> None:
87+
super().__init__(name)
88+
self.k = self._parse_k(name)
89+
90+
def compute(self, result: Result) -> Tuple[float, str]:
91+
"""
92+
Compute NDCG@k. For each query, find the rank of the gold document
93+
within the top k. Score is 1/log2(rank+1) if found, 0 otherwise.
94+
"""
95+
total_ndcg = 0.0
96+
for q_sims, gold_index in zip(result.similarity, result.gold_indices):
97+
# Sort similarity scores in descending order and select top-k document indices
98+
ranked = np.argsort(q_sims)[::-1][: self.k]
99+
100+
# Find if and where the gold document appears in the top-k ranked results
101+
positions = np.where(ranked == gold_index)[0]
102+
103+
if len(positions) > 0:
104+
# Convert 0-based position to 1-based rank
105+
rank = positions[0] + 1
106+
107+
# Apply logarithmic discount: 1/log2(rank+1)
108+
# This rewards higher-ranked results more than lower-ranked ones
109+
total_ndcg += 1.0 / np.log2(rank + 1)
110+
# If gold doc not in top-k, contributes 0 to the score
111+
112+
# Calculate average NDCG across all queries and convert to percentage
113+
avg_ndcg = total_ndcg / result.similarity.shape[0] * 100
114+
return avg_ndcg, "%"

0 commit comments

Comments
 (0)