This repository was archived by the owner on May 27, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 422
Expand file tree
/
Copy pathgraph.py
More file actions
105 lines (94 loc) · 3.58 KB
/
graph.py
File metadata and controls
105 lines (94 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import traceback
from io import BytesIO
import networkx as nx
from fastapi import (
APIRouter,
Depends,
HTTPException,
status,
)
from fastapi.responses import StreamingResponse
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.typing.models import GraphDataResponse
from graphrag_app.utils.azure_clients import AzureClientManager
from graphrag_app.utils.common import (
sanitize_name,
subscription_key_check,
validate_index_file_exist,
)
graph_route = APIRouter(
prefix="/graph",
tags=["Graph Operations"],
)
if os.getenv("KUBERNETES_SERVICE_HOST"):
graph_route.dependencies.append(Depends(subscription_key_check))
@graph_route.get(
"/graphml/{container_name}",
summary="Retrieve a GraphML file of the knowledge graph",
response_description="GraphML file successfully downloaded",
status_code=status.HTTP_200_OK,
)
async def get_graphml_file(
container_name, sanitized_container_name: str = Depends(sanitize_name)
):
logger = load_pipeline_logger()
# validate graphml file existence
azure_client_manager = AzureClientManager()
graphml_filename = "graph.graphml"
blob_filepath = f"output/{graphml_filename}" # expected file location of the graph based on the workflow
validate_index_file_exist(sanitized_container_name, blob_filepath)
try:
blob_client = azure_client_manager.get_blob_service_client().get_blob_client(
container=sanitized_container_name, blob=blob_filepath
)
blob_stream = blob_client.download_blob().chunks()
return StreamingResponse(
blob_stream,
media_type="application/octet-stream",
headers={
"Content-Disposition": f"attachment; filename={graphml_filename}",
"filename": graphml_filename,
},
)
except Exception as e:
logger.error(
message="Could not fetch graphml file",
cause=e,
stack=traceback.format_exc(),
)
raise HTTPException(
status_code=500,
detail=f"Could not fetch graphml file for '{container_name}'.",
)
@graph_route.get(
"/stats/{index_name}",
summary="Retrieve basic graph statistics, number of nodes and edges",
response_model=GraphDataResponse,
responses={200: {"model": GraphDataResponse}},
response_description="Retrieve the number of nodes and edges from the index graph",
)
async def retrieve_graph_stats(index_name: str):
logger = load_pipeline_logger()
# validate index_name and graphml file existence
sanitized_index_name = sanitize_name(index_name)
graphml_filename = "graph.graphml"
graphml_filepath = f"output/{graphml_filename}" # expected file location of the graph based on the workflow
validate_index_file_exist(sanitized_index_name, graphml_filepath)
try:
azure_client_manager = AzureClientManager()
storage_client = azure_client_manager.get_blob_service_client().get_blob_client(
container=sanitized_index_name, blob=graphml_filepath
)
blob_data = storage_client.download_blob().readall()
bytes_io = BytesIO(blob_data)
g = nx.read_graphml(bytes_io)
return GraphDataResponse(nodes=len(g.nodes), edges=len(g.edges))
except Exception:
logger.error("Could not retrieve graph data file")
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graph statistics for index '{index_name}'.",
)