-
Notifications
You must be signed in to change notification settings - Fork 102
Expand file tree
/
Copy pathshared.py
More file actions
86 lines (63 loc) · 2.71 KB
/
shared.py
File metadata and controls
86 lines (63 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Shared context keys, embedder factory, and CodeChunk schema."""
from __future__ import annotations
import logging
import pathlib
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Union
import cocoindex as coco
from cocoindex.connectors import sqlite
from numpy.typing import NDArray
if TYPE_CHECKING:
from cocoindex.ops.litellm import LiteLLMEmbedder
from cocoindex.ops.sentence_transformers import SentenceTransformerEmbedder
from .settings import EmbeddingSettings
logger = logging.getLogger(__name__)
SBERT_PREFIX = "sbert/"
# Models that define a "query" prompt for asymmetric retrieval.
_QUERY_PROMPT_MODELS = {"nomic-ai/nomic-embed-code", "nomic-ai/CodeRankEmbed"}
# Type alias
Embedder = Union["SentenceTransformerEmbedder", "LiteLLMEmbedder"]
# Context keys
EMBEDDER = coco.ContextKey[Embedder]("embedder")
SQLITE_DB = coco.ContextKey[sqlite.ManagedConnection]("index_db", tracked=False)
CODEBASE_DIR = coco.ContextKey[pathlib.Path]("codebase", tracked=False)
EXT_LANG_OVERRIDE_MAP = coco.ContextKey[dict[str, str]]("ext_lang_override_map")
# Module-level variable — set by daemon at startup (needed for CodeChunk annotation).
embedder: Embedder | None = None
# Query prompt name — set alongside embedder by create_embedder().
query_prompt_name: str | None = None
def create_embedder(settings: EmbeddingSettings) -> Embedder:
"""Create and return an embedder instance based on settings.
Also sets the module-level ``embedder`` and ``query_prompt_name`` variables.
"""
global embedder, query_prompt_name
if settings.provider == "sentence-transformers":
from cocoindex.ops.sentence_transformers import SentenceTransformerEmbedder
model_name = settings.model
# Strip the legacy sbert/ prefix if present
if model_name.startswith(SBERT_PREFIX):
model_name = model_name[len(SBERT_PREFIX) :]
query_prompt_name = "query" if model_name in _QUERY_PROMPT_MODELS else None
instance: Embedder = SentenceTransformerEmbedder(
model_name,
device=settings.device,
trust_remote_code=True,
)
logger.info("Embedding model: %s | device: %s", settings.model, settings.device)
else:
from cocoindex.ops.litellm import LiteLLMEmbedder
instance = LiteLLMEmbedder(settings.model)
query_prompt_name = None
logger.info("Embedding model (LiteLLM): %s", settings.model)
embedder = instance
return instance
@dataclass
class CodeChunk:
"""Schema for storing code chunks in SQLite."""
id: int
file_path: str
language: str
content: str
start_line: int
end_line: int
embedding: Annotated[NDArray, embedder] # type: ignore[type-arg]