|
| 1 | +from typing import Annotated, TypedDict, Unpack |
| 2 | + |
| 3 | +import click |
| 4 | +from pydantic import SecretStr |
| 5 | + |
| 6 | +from ....cli.cli import ( |
| 7 | + CommonTypedDict, |
| 8 | + cli, |
| 9 | + click_parameter_decorators_from_typed_dict, |
| 10 | + get_custom_case_config, |
| 11 | + run, |
| 12 | +) |
| 13 | +from .. import DB |
| 14 | +from ..api import MetricType |
| 15 | +from .config import S3VectorsIndexConfig |
| 16 | + |
| 17 | + |
| 18 | +class S3VectorsTypedDict(TypedDict): |
| 19 | + region_name: Annotated[ |
| 20 | + str, click.option("--region", type=str, help="AWS region for S3 bucket (eg. us-east-1)", default="us-east-1") |
| 21 | + ] |
| 22 | + access_key_id: Annotated[str, click.option("--access_key_id", type=str, help="AWS access key ID", required=True)] |
| 23 | + secret_access_key: Annotated[ |
| 24 | + str, click.option("--secret_access_key", type=str, help="AWS secret access key", required=True) |
| 25 | + ] |
| 26 | + |
| 27 | + bucket: Annotated[str, click.option("--bucket", type=str, help="S3 bucket name", required=True)] |
| 28 | + index: Annotated[str, click.option("--index", type=str, help="Unique vector index name", default="vdbbench-index")] |
| 29 | + |
| 30 | + metric: Annotated[ |
| 31 | + str, |
| 32 | + click.option( |
| 33 | + "--metric", |
| 34 | + type=str, |
| 35 | + help="Distance metric for vector similarity (e.g., 'cosine', 'euclidean').", |
| 36 | + default=None, |
| 37 | + ), |
| 38 | + ] |
| 39 | + |
| 40 | + |
| 41 | +class S3VectorsIndexTypedDict(CommonTypedDict, S3VectorsTypedDict): ... |
| 42 | + |
| 43 | + |
| 44 | +@cli.command() |
| 45 | +@click_parameter_decorators_from_typed_dict(S3VectorsIndexTypedDict) |
| 46 | +def S3Vectors(**parameters: Unpack[S3VectorsIndexTypedDict]): |
| 47 | + from .config import S3VectorsConfig |
| 48 | + |
| 49 | + parameters["custom_case"] = get_custom_case_config(parameters) |
| 50 | + run( |
| 51 | + db=DB.S3Vectors, |
| 52 | + db_config=S3VectorsConfig( |
| 53 | + region_name=parameters["region"], |
| 54 | + access_key_id=SecretStr(parameters["access_key_id"]), |
| 55 | + secret_access_key=SecretStr(parameters["secret_access_key"]), |
| 56 | + bucket_name=parameters["bucket"], |
| 57 | + index_name=parameters["index"] if parameters["index"] else "vdbbench-index", |
| 58 | + ), |
| 59 | + db_case_config=S3VectorsIndexConfig( |
| 60 | + metric_type=( |
| 61 | + MetricType.COSINE |
| 62 | + if parameters["metric"] == "cosine" |
| 63 | + else MetricType.L2 if parameters["metric"] == "l2" else None |
| 64 | + ) |
| 65 | + ), |
| 66 | + **parameters, |
| 67 | + ) |
0 commit comments