Skip to content

Commit 993f676

Browse files
chandrasekharan-zipstackcoderabbitai[bot]claudeathul-rs
authored
UN-2946 [MISC] Flush embedding usage records on indexing path (#1962)
* UN-2946 [FIX] Flush embedding usage records on indexing path The deferred-batch usage refactor stopped flushing embedding-callback records: ``UsageHandler.on_event_end`` now appends to ``_pending_usage`` instead of pushing directly, but ``_handle_index`` (and ``_run_pipeline_index``) never drained the list. Embedding rows fell off ``Usage`` for every workflow / API-deployment run, so the API response's ``metadata.usage.embedding_tokens`` reported 0 despite indexing actually happening. - Add ``EmbeddingCompat.flush_pending_usage()`` mirroring the LLM shim. - ``_handle_index`` flushes embedding into ``ExecutionResult.metadata["usage_records"]`` on all success exits and attaches partial rows via ``LegacyExecutorError.partial_usage_records`` on the error path. - ``_run_pipeline_index`` now returns ``(metrics, records)`` so ``_handle_structure_pipeline`` can absorb embedding rows into ``pipeline_records``. Existing IDE-index path already absorbs via ``metadata["usage_records"]`` and starts working automatically. - Fix the ``_run_pipeline_index`` mock in test_phase5d to return the new tuple shape. * UN-2946 [FIX] Preserve mid-loop index records on LegacyExecutorError If ``_handle_index`` raises in iteration N of the per-output loop, records accumulated from iterations 1…N-1 were dropped because the exception escaped ``_run_pipeline_index`` unmodified and ``_handle_structure_pipeline``'s ``except`` branch only inherits ``e.partial_usage_records`` (the N-th iteration's partial rows) and ``pipeline_records`` (which never received the in-flight tuple). Mirror the ``_handle_ide_index`` pattern: wrap the ``_handle_index`` call in a try/except and prepend ``index_records`` to ``e.partial_usage_records`` before re-raising so the outer handler sees every row the worker had collected so far. * Update workers/executor/executors/legacy_executor.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> * UN-2946 [REFACTOR] Split _run_pipeline_index / _handle_index helpers Extract structure-pipeline indexing loop body into _index_pipeline_output and missing-param validation into _missing_index_params to drop both functions below SonarCloud's cognitive complexity threshold. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * UN-2946 [FIX] Guard embedding handler flush and propagate index failure - EmbeddingCompat.flush_pending_usage now wraps each handler call in try/except so one bad handler doesn't drop usage rows from the rest and doesn't escape into the indexing success path. - _index_pipeline_output now raises LegacyExecutorError when _handle_index returns a failure result, so the structure pipeline aborts instead of running downstream steps against an incomplete vector store. - Added regression tests asserting embedding usage rows propagate through the structure pipeline and that an indexing returned-failure short-circuits before answer_prompt. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Athul <89829560+athul-rs@users.noreply.github.com>
1 parent 8fdef4e commit 993f676

3 files changed

Lines changed: 260 additions & 88 deletions

File tree

unstract/sdk1/src/unstract/sdk1/embedding.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,22 @@ async def get_aquery_embedding(self, query: str) -> list[float]:
281281

282282
def test_connection(self) -> bool:
283283
return self._embedding_instance.test_connection()
284+
285+
def flush_pending_usage(self) -> list[dict]:
286+
"""Drain pending usage rows from registered callback handlers."""
287+
if not self.callback_manager:
288+
return []
289+
records: list[dict] = []
290+
for handler in self.callback_manager.handlers:
291+
if not hasattr(handler, "flush_pending_usage"):
292+
continue
293+
# Per-handler guard so one bad handler doesn't drop the rest.
294+
try:
295+
records.extend(handler.flush_pending_usage())
296+
except Exception:
297+
logger.warning(
298+
"Failed to flush usage from embedding handler %s",
299+
type(handler).__name__,
300+
exc_info=True,
301+
)
302+
return records

workers/executor/executors/legacy_executor.py

Lines changed: 154 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -659,13 +659,15 @@ def _failure(child_result: ExecutionResult) -> ExecutionResult:
659659
elif not is_single_pass:
660660
# ---- Step 3: Index per output with dedup ----
661661
step += 1
662-
index_metrics = self._run_pipeline_index(
662+
index_metrics, index_records = self._run_pipeline_index(
663663
context=context,
664664
index_template=index_template,
665665
answer_params=answer_params,
666666
extracted_text=extracted_text,
667667
usage_kwargs=extract_params.get("usage_kwargs", {}),
668668
)
669+
if index_records:
670+
pipeline_records.extend(index_records)
669671

670672
# ---- Step 4: Table settings injection ----
671673
if not is_single_pass:
@@ -891,99 +893,132 @@ def _run_pipeline_index(
891893
answer_params: dict,
892894
extracted_text: str,
893895
usage_kwargs: dict | None = None,
894-
) -> dict:
896+
) -> tuple[dict, list[dict]]:
895897
"""Run per-output indexing with dedup for the structure pipeline.
896898
897899
Args:
898900
usage_kwargs: Audit-tracking kwargs (``run_id``,
899901
``execution_id``, ``file_name``) propagated to the
900902
embedding adapter so its callback can record usage
901-
rows against the correct file_execution_id. Without
902-
this, embedding usage is missing from the API
903-
deployment response metadata.
903+
rows against the correct file_execution_id.
904904
905905
Returns:
906-
Dict of index metrics keyed by output name.
906+
(index_metrics, usage_records) — metrics keyed by output
907+
name and the flat list of usage rows collected across all
908+
child ``_handle_index`` calls.
907909
"""
908-
import datetime
909-
910910
tool_settings = answer_params.get("tool_settings", {})
911911
outputs = answer_params.get("outputs", [])
912-
tool_id = index_template.get("tool_id", "")
913-
file_hash = index_template.get("file_hash", "")
914-
is_highlight = index_template.get("is_highlight_enabled", False)
915-
platform_api_key = index_template.get("platform_api_key", "")
916-
extracted_file_path = index_template.get("extracted_file_path", "")
917-
usage_kwargs = usage_kwargs or {}
918-
919912
index_metrics: dict = {}
913+
index_records: list[dict] = []
920914
seen_params: set = set()
921915

922916
for output in outputs:
923-
chunk_size = output.get("chunk-size", 0)
924-
chunk_overlap = output.get("chunk-overlap", 0)
925-
vector_db = tool_settings.get("vector-db", "")
926-
embedding = tool_settings.get("embedding", "")
927-
x2text = tool_settings.get("x2text_adapter", "")
928-
929-
param_key = (
930-
f"chunk_size={chunk_size}_"
931-
f"chunk_overlap={chunk_overlap}_"
932-
f"vector_db={vector_db}_"
933-
f"embedding={embedding}_"
934-
f"x2text={x2text}"
917+
self._index_pipeline_output(
918+
context=context,
919+
output=output,
920+
index_template=index_template,
921+
tool_settings=tool_settings,
922+
extracted_text=extracted_text,
923+
usage_kwargs=usage_kwargs or {},
924+
seen_params=seen_params,
925+
index_metrics=index_metrics,
926+
index_records=index_records,
935927
)
936928

937-
if chunk_size != 0 and param_key not in seen_params:
938-
seen_params.add(param_key)
929+
return index_metrics, index_records
939930

940-
indexing_start = datetime.datetime.now()
941-
logger.info(
942-
"Pipeline indexing: chunk_size=%s chunk_overlap=%s vector_db=%s",
943-
chunk_size,
944-
chunk_overlap,
945-
vector_db,
946-
)
931+
def _index_pipeline_output(
932+
self,
933+
*,
934+
context: ExecutionContext,
935+
output: dict,
936+
index_template: dict,
937+
tool_settings: dict,
938+
extracted_text: str,
939+
usage_kwargs: dict,
940+
seen_params: set,
941+
index_metrics: dict,
942+
index_records: list[dict],
943+
) -> None:
944+
"""Index a single structure-pipeline output entry in-place."""
945+
import datetime
947946

948-
index_ctx = ExecutionContext(
949-
executor_name=context.executor_name,
950-
operation=Operation.INDEX.value,
951-
run_id=context.run_id,
952-
execution_source=context.execution_source,
953-
organization_id=context.organization_id,
954-
request_id=context.request_id,
955-
log_events_id=context.log_events_id,
956-
execution_id=context.execution_id,
957-
file_execution_id=context.file_execution_id,
958-
executor_params={
959-
"embedding_instance_id": embedding,
960-
"vector_db_instance_id": vector_db,
961-
"x2text_instance_id": x2text,
962-
"chunk_size": chunk_size,
963-
"chunk_overlap": chunk_overlap,
964-
"file_path": extracted_file_path,
965-
"reindex": True,
966-
"tool_id": tool_id,
967-
"file_hash": file_hash,
968-
"enable_highlight": is_highlight,
969-
"extracted_text": extracted_text,
970-
"platform_api_key": platform_api_key,
971-
"usage_kwargs": usage_kwargs,
972-
},
973-
)
974-
index_result = self._handle_index(index_ctx)
975-
if not index_result.success:
976-
logger.warning(
977-
"Pipeline indexing failed for %s: %s",
978-
param_key,
979-
index_result.error,
980-
)
947+
chunk_size = output.get("chunk-size", 0)
948+
if chunk_size == 0:
949+
return
981950

982-
elapsed = (datetime.datetime.now() - indexing_start).total_seconds()
983-
output_name = output.get("name", "")
984-
index_metrics[output_name] = {"indexing": {"time_taken(s)": elapsed}}
951+
chunk_overlap = output.get("chunk-overlap", 0)
952+
vector_db = tool_settings.get("vector-db", "")
953+
embedding = tool_settings.get("embedding", "")
954+
x2text = tool_settings.get("x2text_adapter", "")
955+
956+
param_key = (
957+
f"chunk_size={chunk_size}_"
958+
f"chunk_overlap={chunk_overlap}_"
959+
f"vector_db={vector_db}_"
960+
f"embedding={embedding}_"
961+
f"x2text={x2text}"
962+
)
963+
if param_key in seen_params:
964+
return
965+
seen_params.add(param_key)
985966

986-
return index_metrics
967+
indexing_start = datetime.datetime.now()
968+
logger.info(
969+
"Pipeline indexing: chunk_size=%s chunk_overlap=%s vector_db=%s",
970+
chunk_size,
971+
chunk_overlap,
972+
vector_db,
973+
)
974+
975+
index_ctx = ExecutionContext(
976+
executor_name=context.executor_name,
977+
operation=Operation.INDEX.value,
978+
run_id=context.run_id,
979+
execution_source=context.execution_source,
980+
organization_id=context.organization_id,
981+
request_id=context.request_id,
982+
log_events_id=context.log_events_id,
983+
execution_id=context.execution_id,
984+
file_execution_id=context.file_execution_id,
985+
executor_params={
986+
"embedding_instance_id": embedding,
987+
"vector_db_instance_id": vector_db,
988+
"x2text_instance_id": x2text,
989+
"chunk_size": chunk_size,
990+
"chunk_overlap": chunk_overlap,
991+
"file_path": index_template.get("extracted_file_path", ""),
992+
"reindex": True,
993+
"tool_id": index_template.get("tool_id", ""),
994+
"file_hash": index_template.get("file_hash", ""),
995+
"enable_highlight": index_template.get("is_highlight_enabled", False),
996+
"extracted_text": extracted_text,
997+
"platform_api_key": index_template.get("platform_api_key", ""),
998+
"usage_kwargs": usage_kwargs,
999+
},
1000+
)
1001+
try:
1002+
index_result = self._handle_index(index_ctx)
1003+
except LegacyExecutorError as e:
1004+
# Preserve usage rows accrued from prior iterations.
1005+
e.partial_usage_records = index_records + e.partial_usage_records
1006+
raise
1007+
if not index_result.success:
1008+
# Abort on returned-failure so downstream steps don't run
1009+
# against an incomplete vector store.
1010+
raise LegacyExecutorError(
1011+
message=f"Pipeline indexing failed for {param_key}: {index_result.error}",
1012+
code=500,
1013+
partial_usage_records=list(index_records),
1014+
)
1015+
child_records = (index_result.metadata or {}).get("usage_records") or []
1016+
if child_records:
1017+
index_records.extend(child_records)
1018+
1019+
elapsed = (datetime.datetime.now() - indexing_start).total_seconds()
1020+
output_name = output.get("name", "")
1021+
index_metrics[output_name] = {"indexing": {"time_taken(s)": elapsed}}
9871022

9881023
@staticmethod
9891024
def _merge_pipeline_metrics(metrics1: dict, metrics2: dict) -> dict:
@@ -1008,6 +1043,23 @@ def _merge_pipeline_metrics(metrics1: dict, metrics2: dict) -> dict:
10081043
# Phase 2C — Index handler
10091044
# ------------------------------------------------------------------
10101045

1046+
@staticmethod
1047+
def _missing_index_params(
1048+
*,
1049+
embedding_instance_id: str,
1050+
vector_db_instance_id: str,
1051+
x2text_instance_id: str,
1052+
file_path: str,
1053+
) -> list[str]:
1054+
"""Return required-param keys that are unset for an INDEX op."""
1055+
checks = (
1056+
(embedding_instance_id, IKeys.EMBEDDING_INSTANCE_ID),
1057+
(vector_db_instance_id, IKeys.VECTOR_DB_INSTANCE_ID),
1058+
(x2text_instance_id, IKeys.X2TEXT_INSTANCE_ID),
1059+
(file_path, IKeys.FILE_PATH),
1060+
)
1061+
return [key for value, key in checks if not value]
1062+
10111063
def _handle_index(self, context: ExecutionContext) -> ExecutionResult:
10121064
"""Handle ``Operation.INDEX`` — vector DB indexing.
10131065
@@ -1027,15 +1079,12 @@ def _handle_index(self, context: ExecutionContext) -> ExecutionResult:
10271079
extracted_text: str = params.get(IKeys.EXTRACTED_TEXT, "")
10281080
platform_api_key: str = params.get("platform_api_key", "")
10291081

1030-
missing = []
1031-
if not embedding_instance_id:
1032-
missing.append(IKeys.EMBEDDING_INSTANCE_ID)
1033-
if not vector_db_instance_id:
1034-
missing.append(IKeys.VECTOR_DB_INSTANCE_ID)
1035-
if not x2text_instance_id:
1036-
missing.append(IKeys.X2TEXT_INSTANCE_ID)
1037-
if not file_path:
1038-
missing.append(IKeys.FILE_PATH)
1082+
missing = self._missing_index_params(
1083+
embedding_instance_id=embedding_instance_id,
1084+
vector_db_instance_id=vector_db_instance_id,
1085+
x2text_instance_id=x2text_instance_id,
1086+
file_path=file_path,
1087+
)
10391088
if missing:
10401089
return ExecutionResult.failure(
10411090
error=f"Missing required params: {', '.join(missing)}"
@@ -1110,6 +1159,7 @@ def _handle_index(self, context: ExecutionContext) -> ExecutionResult:
11101159
index_cls, embedding_compat, vector_db_cls = self._get_indexing_deps()
11111160

11121161
vector_db = None
1162+
embedding = None
11131163
try:
11141164
index = index_cls(
11151165
tool=shim,
@@ -1152,7 +1202,11 @@ def _handle_index(self, context: ExecutionContext) -> ExecutionResult:
11521202
"Skipping re-index: doc_id=%s already in vector DB and reindex=False",
11531203
doc_id,
11541204
)
1155-
return ExecutionResult(success=True, data={IKeys.DOC_ID: doc_id})
1205+
return ExecutionResult(
1206+
success=True,
1207+
data={IKeys.DOC_ID: doc_id},
1208+
metadata={"usage_records": embedding.flush_pending_usage()},
1209+
)
11561210

11571211
shim.stream_log(
11581212
"Re-indexing document" if doc_id_found else "Indexing document"
@@ -1169,16 +1223,31 @@ def _handle_index(self, context: ExecutionContext) -> ExecutionResult:
11691223
Path(file_path).name,
11701224
)
11711225
shim.stream_log("Document indexing completed")
1172-
return ExecutionResult(success=True, data={IKeys.DOC_ID: doc_id})
1226+
return ExecutionResult(
1227+
success=True,
1228+
data={IKeys.DOC_ID: doc_id},
1229+
metadata={"usage_records": embedding.flush_pending_usage()},
1230+
)
11731231
except Exception as e:
11741232
logger.error(
11751233
"Indexing failed: file=%s error=%s",
11761234
Path(file_path).name,
11771235
str(e),
11781236
)
11791237
status_code = getattr(e, "status_code", 500)
1238+
partial = []
1239+
if embedding is not None:
1240+
try:
1241+
partial = list(embedding.flush_pending_usage())
1242+
except Exception:
1243+
logger.warning(
1244+
"Failed to flush embedding usage during indexing error path",
1245+
exc_info=True,
1246+
)
11801247
raise LegacyExecutorError(
1181-
message=f"Error while indexing: {e}", code=status_code
1248+
message=f"Error while indexing: {e}",
1249+
code=status_code,
1250+
partial_usage_records=partial,
11821251
) from e
11831252
finally:
11841253
if vector_db is not None:

0 commit comments

Comments
 (0)