Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/graphrag/graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ async def build_index(
input_documents=input_documents,
):
outputs.append(output)
if output.errors and len(output.errors) > 0:
if output.error is not None:
logger.error("Workflow %s completed with errors", output.workflow)
workflow_callbacks.pipeline_error(output.error)
else:
logger.info("Workflow %s completed successfully", output.workflow)
logger.debug(str(output.result))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def workflow_end(self, name: str, instance: object) -> None:
if self._verbose:
print(instance)

def pipeline_error(self, error: BaseException) -> None:
"""Execute this callback when an error occurs in the pipeline."""
print(f"Pipeline error: {error}")

def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
complete = progress.completed_items or 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ def workflow_end(self, name: str, instance: object) -> None:

def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""

def pipeline_error(self, error: BaseException) -> None:
"""Execute this callback when an error occurs in the pipeline."""
4 changes: 4 additions & 0 deletions packages/graphrag/graphrag/callbacks/workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ def workflow_end(self, name: str, instance: object) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
...

def pipeline_error(self, error: BaseException) -> None:
"""Execute this callback when an error occurs in the pipeline."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ def progress(self, progress: Progress) -> None:
for callback in self._callbacks:
if hasattr(callback, "progress"):
callback.progress(progress)

def pipeline_error(self, error: BaseException) -> None:
"""Execute this callback when an error occurs in the pipeline."""
for callback in self._callbacks:
if hasattr(callback, "pipeline_error"):
callback.pipeline_error(error)
11 changes: 1 addition & 10 deletions packages/graphrag/graphrag/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,6 @@ def _run_index(
verbose=verbose,
)
)
encountered_errors = any(
output.errors and len(output.errors) > 0 for output in outputs
)

if encountered_errors:
logger.error(
"Errors occurred during the pipeline run, see logs for more details."
)
else:
logger.info("All workflows completed successfully.")
encountered_errors = any(output.error is not None for output in outputs)

Comment thread
natoverse marked this conversation as resolved.
sys.exit(1 if encountered_errors else 0)
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,11 @@

import logging

import networkx as nx
import pandas as pd

from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
from graphrag.index.operations.extract_graph.typing import (
Document,
EntityExtractionResult,
EntityTypes,
)
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.language_model.protocol.base import ChatModel

Expand All @@ -42,14 +36,15 @@ async def run_strategy(row):
text = row[text_column]
id = row[id_column]
result = await run_extract_graph(
[Document(text=text, id=id)],
entity_types,
model,
prompt,
max_gleanings,
text=text,
source_id=id,
entity_types=entity_types,
model=model,
prompt=prompt,
max_gleanings=max_gleanings,
)
num_started += 1
return [result.entities, result.relationships, result.graph]
return result

results = await derive_from_rows(
text_units,
Expand All @@ -64,8 +59,8 @@ async def run_strategy(row):
relationship_dfs = []
for result in results:
if result:
entity_dfs.append(pd.DataFrame(result[0]))
relationship_dfs.append(pd.DataFrame(result[1]))
entity_dfs.append(result[0])
relationship_dfs.append(result[1])

entities = _merge_entities(entity_dfs)
relationships = _merge_relationships(relationship_dfs)
Expand All @@ -74,12 +69,13 @@ async def run_strategy(row):


async def run_extract_graph(
docs: list[Document],
entity_types: EntityTypes,
text: str,
source_id: str,
entity_types: list[str],
model: ChatModel,
prompt: str,
max_gleanings: int,
) -> EntityExtractionResult:
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Run the graph intelligence entity extraction strategy."""
extractor = GraphExtractor(
model=model,
Expand All @@ -89,36 +85,15 @@ async def run_extract_graph(
"Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d}
),
)
text_list = [doc.text.strip() for doc in docs]
text = text.strip()

results = await extractor(
list(text_list),
entities_df, relationships_df = await extractor(
text,
entity_types=entity_types,
source_id=source_id,
)

graph = results.output
# Map the "source_id" back to the "id" field
for _, node in graph.nodes(data=True): # type: ignore
if node is not None:
node["source_id"] = ",".join(
docs[int(id)].id for id in node["source_id"].split(",")
)

for _, _, edge in graph.edges(data=True): # type: ignore
if edge is not None:
edge["source_id"] = ",".join(
docs[int(id)].id for id in edge["source_id"].split(",")
)

entities = [
({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True)
if item is not None
]

relationships = nx.to_pandas_edgelist(graph)

return EntityExtractionResult(entities, relationships, graph)
return (entities_df, relationships_df)


def _merge_entities(entity_dfs) -> pd.DataFrame:
Expand Down
Loading
Loading