Skip to content

Commit 4faead0

Browse files
authored
fix: fit into the latest cocoindex API (#25)
1 parent 02e5282 commit 4faead0

3 files changed

Lines changed: 7 additions & 8 deletions

File tree

src/cocoindex_code/embedder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING, Any
77

88
import cocoindex as coco
9-
import cocoindex.asyncio as coco_aio
109
import numpy as np
1110
from cocoindex.resources import schema as _schema
1211
from numpy.typing import NDArray
@@ -74,7 +73,7 @@ def _get_model(self) -> SentenceTransformer:
7473
)
7574
return self._model
7675

77-
@coco_aio.function(batching=True, runner=coco.GPU, memo=True, max_batch_size=_config.batch_size)
76+
@coco.fn.as_async(batching=True, runner=coco.GPU, memo=True, max_batch_size=_config.batch_size)
7877
def embed(self, texts: list[str]) -> list[NDArray[np.float32]]:
7978
"""Embed a batch of texts into float32 vectors."""
8079
model = self._get_model()
@@ -85,7 +84,7 @@ def embed(self, texts: list[str]) -> list[NDArray[np.float32]]:
8584
) # type: ignore[assignment]
8685
return list(embeddings)
8786

88-
@coco_aio.function(batching=True, runner=coco.GPU, memo=True, max_batch_size=_config.batch_size)
87+
@coco.fn.as_async(batching=True, runner=coco.GPU, memo=True, max_batch_size=_config.batch_size)
8988
def embed_query(self, texts: list[str]) -> list[NDArray[np.float32]]:
9089
"""Embed query texts, applying query_prompt_name if configured."""
9190
model = self._get_model()
@@ -97,7 +96,7 @@ def embed_query(self, texts: list[str]) -> list[NDArray[np.float32]]:
9796
) # type: ignore[assignment]
9897
return list(embeddings)
9998

100-
@coco_aio.function(runner=coco.GPU, memo=True)
99+
@coco.fn.as_async(runner=coco.GPU, memo=True)
101100
def __coco_vector_schema__(self) -> _schema.VectorSchema:
102101
"""Return the vector schema (dimension + dtype) for this model."""
103102
model = self._get_model()

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterator
66
from pathlib import Path
77

8-
import cocoindex.asyncio as coco_aio
8+
import cocoindex as coco
99
import pytest
1010
import pytest_asyncio
1111

@@ -29,5 +29,5 @@ async def coco_runtime() -> AsyncIterator[None]:
2929
Uses session-scoped event loop to ensure CocoIndex environment
3030
persists across all tests.
3131
"""
32-
async with coco_aio.runtime():
32+
async with coco.runtime():
3333
yield

tests/test_embedder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_embed_query_forwards_prompt_name(self) -> None:
5151
embedder._get_model()
5252
mock_model.encode.reset_mock()
5353
# Call the underlying method directly, bypassing the CocoIndex batching decorator.
54-
LocalEmbedder.embed_query.__wrapped__(embedder, ["find functions that embed text"])
54+
LocalEmbedder.embed_query.__wrapped__(embedder, ["find functions that embed text"]) # type: ignore
5555
_, kwargs = mock_model.encode.call_args
5656
assert kwargs.get("prompt_name") == "query"
5757

@@ -61,7 +61,7 @@ def test_embed_query_no_prompt_when_unset(self) -> None:
6161
embedder = LocalEmbedder("some-model", device="cpu", query_prompt_name=None)
6262
embedder._get_model()
6363
mock_model.encode.reset_mock()
64-
LocalEmbedder.embed_query.__wrapped__(embedder, ["some query"])
64+
LocalEmbedder.embed_query.__wrapped__(embedder, ["some query"]) # type: ignore
6565
_, kwargs = mock_model.encode.call_args
6666
assert kwargs.get("prompt_name") is None
6767

0 commit comments

Comments
 (0)