Skip to content

Commit a6c95f5

Browse files
committed
fix(scrapping): fix embedding model
1 parent 814bafb commit a6c95f5

10 files changed

Lines changed: 79 additions & 232 deletions

File tree

server/.env.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# OpenAI API Configuration
2+
OPENAI_API_KEY=your_openai_api_key_here
3+
4+
# Optional: GitHub Token for higher rate limits
5+
GITHUB_TOKEN=your_github_token_here

server/README.md

Lines changed: 0 additions & 207 deletions
This file was deleted.

server/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ class ServerConfig:
2323

2424
watch_interval_seconds: int = 300
2525

26-
use_dummy_embeddings: bool = True
27-
embedding_model: str = "all-MiniLM-L6-v2"
28-
2926
log_level: str = "INFO"
3027

3128
scrapers: Dict[str, ScraperConfig] = None
@@ -66,13 +63,11 @@ def from_file(cls, filepath: str) -> "ServerConfig":
6663

6764
DEV_CONFIG = ServerConfig(
6865
db_path="veille_technique_dev.db",
69-
use_dummy_embeddings=True,
7066
watch_interval_seconds=60,
7167
)
7268

7369
PROD_CONFIG = ServerConfig(
7470
db_path="veille_technique.db",
75-
use_dummy_embeddings=False,
7671
watch_interval_seconds=600,
7772
log_level="WARNING",
7873
)

server/embeddings.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import List, Optional
55
from abc import ABC, abstractmethod
66
import numpy as np
7+
import os
8+
from pathlib import Path
79

810

911
class EmbeddingProvider(ABC):
@@ -82,6 +84,75 @@ def get_name(self) -> str:
8284
return f"sentence-transformers-{self.model_name}"
8385

8486

87+
class OpenAIEmbeddingProvider(EmbeddingProvider):
88+
"""Embedding provider using OpenAI API."""
89+
90+
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None):
91+
"""
92+
Initialize OpenAI embedding provider.
93+
94+
Args:
95+
model: OpenAI embedding model to use (text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002)
96+
api_key: OpenAI API key (defaults to .env file or OPENAI_API_KEY env var)
97+
"""
98+
try:
99+
from openai import OpenAI
100+
except ImportError:
101+
raise ImportError(
102+
"openai package is required. Install it with: "
103+
"pip install openai"
104+
)
105+
106+
self.model = model
107+
108+
if api_key:
109+
self.api_key = api_key
110+
else:
111+
self.api_key = self._load_api_key_from_env()
112+
113+
if not self.api_key:
114+
raise ValueError(
115+
"OpenAI API key is required. Add OPENAI_API_KEY to .env file "
116+
"or set OPENAI_API_KEY environment variable."
117+
)
118+
119+
self.client = OpenAI(api_key=self.api_key)
120+
121+
def _load_api_key_from_env(self) -> Optional[str]:
122+
"""Load API key from .env file or environment variable."""
123+
api_key = os.getenv("OPENAI_API_KEY")
124+
if api_key:
125+
return api_key
126+
127+
env_path = Path(__file__).parent / ".env"
128+
if env_path.exists():
129+
with open(env_path, 'r') as f:
130+
for line in f:
131+
line = line.strip()
132+
if line.startswith('OPENAI_API_KEY='):
133+
return line.split('=', 1)[1].strip().strip('"').strip("'")
134+
135+
return None
136+
137+
def embed(self, text: str) -> np.ndarray:
138+
"""Generate embedding with OpenAI API."""
139+
max_chars = 30000
140+
if len(text) > max_chars:
141+
text = text[:max_chars]
142+
143+
response = self.client.embeddings.create(
144+
input=text,
145+
model=self.model
146+
)
147+
148+
embedding = np.array(response.data[0].embedding, dtype=np.float32)
149+
return embedding
150+
151+
def get_name(self) -> str:
152+
"""Return provider name."""
153+
return f"openai-{self.model}"
154+
155+
85156
class EmbeddingManager:
86157
"""Manage embeddings for articles."""
87158

server/examples.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python3
21
"""
32
Usage examples for the watch server.
43
"""
@@ -16,7 +15,6 @@ def example_backfill():
1615

1716
server = WatchServer(
1817
db_path="test_backfill.db",
19-
use_dummy_embeddings=True,
2018
check_interval=300
2119
)
2220

@@ -34,7 +32,6 @@ def example_watch_limited():
3432

3533
server = WatchServer(
3634
db_path="test_watch.db",
37-
use_dummy_embeddings=True,
3835
check_interval=10
3936
)
4037

@@ -96,7 +93,6 @@ def example_custom_config():
9693
custom_config = ServerConfig(
9794
db_path="test_custom.db",
9895
watch_interval_seconds=120,
99-
use_dummy_embeddings=True,
10096
scrapers={
10197
"arxiv": ScraperConfig(enabled=True, limit_latest=10, limit_all=30),
10298
"github": ScraperConfig(enabled=True, limit_latest=15, limit_all=50),
@@ -109,7 +105,6 @@ def example_custom_config():
109105
print(f"\n✓ Custom config created")
110106
print(f" DB : {custom_config.db_path}")
111107
print(f" Interval : {custom_config.watch_interval_seconds}s")
112-
print(f" Dummy embeddings : {custom_config.use_dummy_embeddings}")
113108

114109
print(f"\n✓ Enabled scrapers :")
115110
for name, cfg in custom_config.scrapers.items():

0 commit comments

Comments
 (0)