Skip to content

Commit eb85701

Browse files
Merge pull request #7 from PoCInnovation/feat/cross_encoder
feat: add cross-encoder and update prompt for primary and secondary s…
2 parents bd91a33 + ed2e8e2 commit eb85701

7 files changed

Lines changed: 591 additions & 70 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.env
Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,87 @@
11
from typing import Dict, List
22

33
def compute_entity_similarity(a: Dict, b: Dict) -> float:
4+
"""
5+
Compute entity similarity between two articles with primary/secondary importance weighting.
6+
7+
Args:
8+
a: Entity dict with primary_subject, secondary_subject, primary_orgs, secondary_orgs, primary_event, secondary_event
9+
b: Entity dict with same structure
10+
11+
Returns:
12+
Similarity score (0.0 to 1.0+)
13+
"""
414
score = 0.0
515

6-
if a["subject"] and a["subject"] == b["subject"]:
16+
# Primary subject match (highest weight)
17+
if a.get("primary_subject") and a["primary_subject"] == b.get("primary_subject"):
718
score += 1.0
19+
# Secondary subject match with lower weight
20+
elif a.get("secondary_subject") and a["secondary_subject"] == b.get("secondary_subject"):
21+
score += 0.3
22+
# Cross-match (primary vs secondary)
23+
elif (a.get("primary_subject") and a["primary_subject"] == b.get("secondary_subject")) or \
24+
(a.get("secondary_subject") and a["secondary_subject"] == b.get("primary_subject")):
25+
score += 0.2
826

9-
if a["event"] and a["event"] == b["event"]:
27+
# Primary event match (high weight)
28+
if a.get("primary_event") and a["primary_event"] == b.get("primary_event"):
1029
score += 0.5
30+
# Secondary event match (lower weight)
31+
elif a.get("secondary_event") and a["secondary_event"] == b.get("secondary_event"):
32+
score += 0.2
33+
# Cross-match
34+
elif (a.get("primary_event") and a["primary_event"] == b.get("secondary_event")) or \
35+
(a.get("secondary_event") and a["secondary_event"] == b.get("primary_event")):
36+
score += 0.15
1137

12-
orgs_a = set(a.get("orgs", []))
13-
orgs_b = set(b.get("orgs", []))
38+
# Organization matching with primary/secondary distinction
39+
primary_orgs_a = set(a.get("primary_orgs", []))
40+
primary_orgs_b = set(b.get("primary_orgs", []))
41+
secondary_orgs_a = set(a.get("secondary_orgs", []))
42+
secondary_orgs_b = set(b.get("secondary_orgs", []))
1443

15-
if orgs_a and orgs_b:
16-
score += 0.2 * len(orgs_a & orgs_b)
44+
# Primary org matches (higher weight)
45+
if primary_orgs_a and primary_orgs_b:
46+
score += 0.3 * len(primary_orgs_a & primary_orgs_b)
47+
48+
# Secondary org matches (lower weight)
49+
if secondary_orgs_a and secondary_orgs_b:
50+
score += 0.1 * len(secondary_orgs_a & secondary_orgs_b)
51+
52+
# Cross-org matches (primary <-> secondary)
53+
if primary_orgs_a and secondary_orgs_b:
54+
score += 0.1 * len(primary_orgs_a & secondary_orgs_b)
55+
if secondary_orgs_a and primary_orgs_b:
56+
score += 0.1 * len(secondary_orgs_a & primary_orgs_b)
1757

1858
return score
1959

2060
def compute_final_score(
2161
semantic_score: float,
2262
entity_score: float,
23-
w_sem: float = 0.6,
63+
cross_score: float = 0.5,
64+
w_sem: float = 0.3,
2465
w_ent: float = 0.4,
66+
w_cross: float = 0.3,
2567
) -> float:
26-
return w_sem * semantic_score + w_ent * entity_score
68+
"""
69+
Compute final clustering score combining multiple signals.
70+
71+
Args:
72+
semantic_score: Embedding-based similarity (0.0-1.0)
73+
entity_score: Entity matching score (0.0-1.0+)
74+
cross_score: Cross-encoder score (0.0-1.0)
75+
w_sem: Weight for semantic similarity
76+
w_ent: Weight for entity similarity
77+
w_cross: Weight for cross-encoder score
78+
79+
Returns:
80+
Final combined score
81+
"""
82+
# Normalize entity score to [0, 1] range
83+
normalized_entity = min(entity_score / 2.0, 1.0)
84+
85+
return (w_sem * semantic_score +
86+
w_ent * normalized_entity +
87+
w_cross * cross_score)

server/cross_encoder.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import logging
2+
from typing import Dict, List, Any, Tuple
3+
from sentence_transformers import CrossEncoder
4+
import numpy as np
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class CrossEncoderManager:
10+
"""Manages cross-encoder for computing semantic relevance scores between articles."""
11+
12+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
13+
"""
14+
Initialize the cross encoder.
15+
16+
Args:
17+
model_name: HuggingFace model identifier for cross-encoder
18+
Default: ms-marco-MiniLM-L-6-v2 (efficient and accurate for relevance)
19+
"""
20+
self.model_name = model_name
21+
try:
22+
self.model = CrossEncoder(model_name)
23+
logger.info(f"Cross-encoder loaded: {model_name}")
24+
except Exception as e:
25+
logger.error(f"Failed to load cross-encoder: {e}")
26+
self.model = None
27+
28+
def compute_relevance_score(
29+
self,
30+
query_article: Dict[str, Any],
31+
candidate_article: Dict[str, Any]
32+
) -> float:
33+
"""
34+
Compute semantic relevance score between two articles.
35+
36+
Args:
37+
query_article: Source article dict with title, description, full_content
38+
candidate_article: Target article dict for comparison
39+
40+
Returns:
41+
Relevance score between 0 and 1
42+
"""
43+
if self.model is None:
44+
logger.warning("Cross-encoder model not loaded, returning 0.5")
45+
return 0.5
46+
47+
try:
48+
query_text = self._build_article_text(query_article)
49+
candidate_text = self._build_article_text(candidate_article)
50+
scores = self.model.predict([
51+
[query_text, candidate_text]
52+
])
53+
relevance_score = self._sigmoid(scores[0])
54+
55+
return float(relevance_score)
56+
57+
except Exception as e:
58+
logger.error(f"Error computing relevance score: {e}")
59+
return 0.5
60+
61+
def compute_batch_relevance_scores(
62+
self,
63+
query_article: Dict[str, Any],
64+
candidate_articles: List[Dict[str, Any]]
65+
) -> List[float]:
66+
"""
67+
Compute relevance scores between one query article and multiple candidates.
68+
69+
Args:
70+
query_article: Source article
71+
candidate_articles: List of candidate articles
72+
73+
Returns:
74+
List of relevance scores
75+
"""
76+
if self.model is None or not candidate_articles:
77+
return [0.5] * len(candidate_articles)
78+
79+
try:
80+
query_text = self._build_article_text(query_article)
81+
82+
pairs = [
83+
[query_text, self._build_article_text(candidate)]
84+
for candidate in candidate_articles
85+
]
86+
scores = self.model.predict(pairs)
87+
normalized_scores = [float(self._sigmoid(score)) for score in scores]
88+
89+
return normalized_scores
90+
91+
except Exception as e:
92+
logger.error(f"Error computing batch relevance scores: {e}")
93+
return [0.5] * len(candidate_articles)
94+
95+
def _build_article_text(self, article: Dict[str, Any]) -> str:
96+
"""
97+
Build a text representation of an article for cross-encoder.
98+
99+
Args:
100+
article: Article dictionary
101+
102+
Returns:
103+
Combined text of title and description
104+
"""
105+
title = article.get("title", "").strip()
106+
description = article.get("description", "").strip()
107+
108+
if title and description:
109+
return f"{title} {description}"
110+
elif title:
111+
return title
112+
elif description:
113+
return description
114+
else:
115+
return ""
116+
117+
@staticmethod
118+
def _sigmoid(x: float) -> float:
119+
"""Apply sigmoid function to normalize cross-encoder output."""
120+
import math
121+
try:
122+
return 1.0 / (1.0 + math.exp(-x))
123+
except OverflowError:
124+
return 0.0 if x < 0 else 1.0

0 commit comments

Comments
 (0)