Skip to content

Commit 1363288

Browse files
committed
renaming async methods to use 'async' suffix + converting some methods to staticmethods
1 parent 9370805 commit 1363288

2 files changed

Lines changed: 18 additions & 15 deletions

File tree

integrations/oracle/src/haystack_integrations/components/retrievers/oracle/embedding_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def run_async(
8282
) -> dict[str, list[Document]]:
8383
"""Async variant of :meth:`run`."""
8484
merged = _merge_filters(self.filters, filters)
85-
docs = await self.document_store._async_embedding_retrieval(
85+
docs = await self.document_store._embedding_retrieval_async(
8686
query_embedding,
8787
filters=merged,
8888
top_k=top_k if top_k is not None else self.top_k,

integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def translate(
113113
raise ValueError(msg)
114114
field: str = filters["field"]
115115
value: Any = filters["value"]
116-
col = self._field_to_sql(field, value)
116+
col = _FilterTranslator._field_to_sql(field, value)
117117

118118
if op in ("in", "not in"):
119119
placeholders = []
@@ -131,7 +131,8 @@ def translate(
131131
sql_op = self._OP_MAP[op]
132132
return f"{col} {sql_op} :{pname}"
133133

134-
def _field_to_sql(self, field: str, value: Any) -> str:
134+
@staticmethod
135+
def _field_to_sql(field: str, value: Any) -> str:
135136
if field == "id":
136137
return "id"
137138
if field == "content":
@@ -285,7 +286,7 @@ def create_hnsw_index(self) -> None:
285286
cur.execute(sql)
286287
conn.commit()
287288

288-
async def acreate_hnsw_index(self) -> None:
289+
async def create_hnsw_index_async(self) -> None:
289290
await asyncio.to_thread(self.create_hnsw_index)
290291

291292
def write_documents(
@@ -304,7 +305,8 @@ def write_documents(
304305
msg = f"Unknown DuplicatePolicy: {policy}"
305306
raise ValueError(msg)
306307

307-
def _to_row(self, doc: Document) -> tuple[str, str | None, str, bytes | None]:
308+
@staticmethod
309+
def _to_row(doc: Document) -> tuple[str, str | None, str, bytes | None]:
308310
"""Convert a Document to (id, text, metadata_json, embedding_bytes).
309311
310312
Haystack IDs are stored verbatim in a VARCHAR2(64) column, so any
@@ -319,7 +321,7 @@ def _to_row(self, doc: Document) -> tuple[str, str | None, str, bytes | None]:
319321
return doc_id, text, meta, emb
320322

321323
def _to_named_row(self, doc: Document) -> dict[str, Any]:
322-
doc_id, text, meta, emb = self._to_row(doc)
324+
doc_id, text, meta, emb = OracleDocumentStore._to_row(doc)
323325
return {"doc_id": doc_id, "doc_text": text, "doc_meta": meta, "doc_emb": emb}
324326

325327
def _insert_documents(self, documents: list[Document]) -> int:
@@ -373,14 +375,15 @@ def _upsert_documents(self, documents: list[Document]) -> int:
373375
conn.commit()
374376
return written
375377

376-
async def awrite_documents(
378+
async def write_documents_async(
377379
self,
378380
documents: list[Document],
379381
policy: DuplicatePolicy = DuplicatePolicy.NONE,
380382
) -> int:
381383
return await asyncio.to_thread(self.write_documents, documents, policy)
382384

383-
def _build_where(self, filters: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
385+
@staticmethod
386+
def _build_where(filters: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
384387
if not filters:
385388
return "", {}
386389
params: dict[str, Any] = {}
@@ -389,12 +392,12 @@ def _build_where(self, filters: dict[str, Any] | None) -> tuple[str, dict[str, A
389392
return f"WHERE {fragment}", params
390393

391394
def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
392-
where, params = self._build_where(filters)
395+
where, params = OracleDocumentStore._build_where(filters)
393396
sql = f"SELECT id, text, metadata FROM {self.table_name} {where}"
394397
with self._get_connection() as conn, conn.cursor() as cur:
395398
cur.execute(sql, params)
396399
rows = cur.fetchall()
397-
return [self._row_to_document(r) for r in rows]
400+
return [OracleDocumentStore._row_to_document(r) for r in rows]
398401

399402
async def afilter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
400403
return await asyncio.to_thread(self.filter_documents, filters)
@@ -409,7 +412,7 @@ def delete_documents(self, document_ids: list[str]) -> None:
409412
cur.execute(sql, params)
410413
conn.commit()
411414

412-
async def adelete_documents(self, document_ids: list[str]) -> None:
415+
async def delete_documents_async(self, document_ids: list[str]) -> None:
413416
await asyncio.to_thread(self.delete_documents, document_ids)
414417

415418
def count_documents(self) -> int:
@@ -419,7 +422,7 @@ def count_documents(self) -> int:
419422
row = cur.fetchone()
420423
return row[0] if row else 0
421424

422-
async def acount_documents(self) -> int:
425+
async def count_documents_async(self) -> int:
423426
return await asyncio.to_thread(self.count_documents)
424427

425428
def _embedding_retrieval(
@@ -446,7 +449,7 @@ def _embedding_retrieval(
446449
rows = cur.fetchall()
447450
return [self._row_to_document(r, with_score=True) for r in rows]
448451

449-
async def _async_embedding_retrieval(
452+
async def _embedding_retrieval_async(
450453
self,
451454
query_embedding: list[float],
452455
*,
@@ -460,8 +463,8 @@ async def _async_embedding_retrieval(
460463
top_k=top_k,
461464
)
462465

463-
464-
def _row_to_document(self, row: tuple, *, with_score: bool = False) -> Document:
466+
@staticmethod
467+
def _row_to_document(row: tuple, *, with_score: bool = False) -> Document:
465468
if with_score:
466469
raw_id, text, metadata_raw, score = row
467470
else:

0 commit comments

Comments
 (0)