|
| 1 | +from enum import Enum |
| 2 | + |
| 3 | +from pydantic import BaseModel, SecretStr |
| 4 | + |
| 5 | +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType |
| 6 | + |
| 7 | + |
| 8 | +class TencentElasticsearchConfig(DBConfig, BaseModel): |
| 9 | + #: Protocol in use to connect to the node |
| 10 | + scheme: str = "http" |
| 11 | + host: str = "" |
| 12 | + port: int = 9200 |
| 13 | + user: str = "elastic" |
| 14 | + password: SecretStr |
| 15 | + |
| 16 | + def to_dict(self) -> dict: |
| 17 | + return { |
| 18 | + "hosts": [{"scheme": self.scheme, "host": self.host, "port": self.port}], |
| 19 | + "basic_auth": (self.user, self.password.get_secret_value()), |
| 20 | + } |
| 21 | + |
| 22 | + |
| 23 | +class ESElementType(str, Enum): |
| 24 | + float = "float" # 4 byte |
| 25 | + byte = "byte" # 1 byte, -128 to 127 |
| 26 | + |
| 27 | + |
| 28 | +class TencentElasticsearchIndexConfig(BaseModel, DBCaseConfig): |
| 29 | + element_type: ESElementType = ESElementType.float |
| 30 | + index: IndexType = IndexType.TES_VSEARCH |
| 31 | + number_of_shards: int = 1 |
| 32 | + number_of_replicas: int = 0 |
| 33 | + refresh_interval: str = "3s" |
| 34 | + merge_max_thread_count: int = 8 |
| 35 | + use_rescore: bool = False |
| 36 | + oversample_ratio: float = 2.0 |
| 37 | + use_routing: bool = False |
| 38 | + use_force_merge: bool = True |
| 39 | + |
| 40 | + metric_type: MetricType | None = None |
| 41 | + efConstruction: int | None = None |
| 42 | + M: int | None = None |
| 43 | + num_candidates: int | None = None |
| 44 | + |
| 45 | + def __eq__(self, obj: any): |
| 46 | + return ( |
| 47 | + self.index == obj.index |
| 48 | + and self.number_of_shards == obj.number_of_shards |
| 49 | + and self.number_of_replicas == obj.number_of_replicas |
| 50 | + and self.use_routing == obj.use_routing |
| 51 | + and self.efConstruction == obj.efConstruction |
| 52 | + and self.M == obj.M |
| 53 | + ) |
| 54 | + |
| 55 | + def __hash__(self) -> int: |
| 56 | + return hash( |
| 57 | + ( |
| 58 | + self.index, |
| 59 | + self.number_of_shards, |
| 60 | + self.number_of_replicas, |
| 61 | + self.use_routing, |
| 62 | + self.efConstruction, |
| 63 | + self.M, |
| 64 | + 2, |
| 65 | + ) |
| 66 | + ) |
| 67 | + |
| 68 | + def parse_metric(self) -> str: |
| 69 | + if self.metric_type == MetricType.L2: |
| 70 | + return "l2_norm" |
| 71 | + if self.metric_type == MetricType.IP: |
| 72 | + return "dot_product" |
| 73 | + return "cosine" |
| 74 | + |
| 75 | + def index_param(self) -> dict: |
| 76 | + return { |
| 77 | + "type": "dense_vector", |
| 78 | + "index": True, |
| 79 | + "element_type": self.element_type.value, |
| 80 | + "similarity": self.parse_metric(), |
| 81 | + "index_options": { |
| 82 | + "type": self.index.value, |
| 83 | + "index": "hnsw", |
| 84 | + "m": self.M, |
| 85 | + "ef_construction": self.efConstruction, |
| 86 | + }, |
| 87 | + } |
| 88 | + |
| 89 | + def search_param(self) -> dict: |
| 90 | + return { |
| 91 | + "num_candidates": self.num_candidates, |
| 92 | + } |
0 commit comments