This repository was archived by the owner on May 27, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 419
Expand file tree
/
Copy pathcommon.py
More file actions
208 lines (177 loc) · 7.26 KB
/
common.py
File metadata and controls
208 lines (177 loc) · 7.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import hashlib
import os
import traceback
from typing import Annotated
import pandas as pd
from azure.core.exceptions import ResourceNotFoundError
from azure.cosmos import ContainerProxy, exceptions
from azure.identity import DefaultAzureCredential
from azure.storage.blob.aio import ContainerClient
from fastapi import Header, HTTPException
from graphrag_app.logger.load_logger import load_pipeline_logger
from graphrag_app.utils.azure_clients import AzureClientManager
def get_df(
table_path: str,
) -> pd.DataFrame:
df = pd.read_parquet(
table_path,
storage_options=pandas_storage_options(),
)
return df
def pandas_storage_options() -> dict:
"""Generate the storage options required by pandas to read parquet files from Storage."""
# For more information on the options available, see: https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
azure_client_manager = AzureClientManager()
options = {
"account_name": azure_client_manager.storage_account_name,
"account_host": azure_client_manager.storage_account_hostname,
}
if os.getenv("STORAGE_CONNECTION_STRING"):
options["connection_string"] = os.getenv("STORAGE_CONNECTION_STRING")
else:
options["credential"] = DefaultAzureCredential()
return options
def delete_storage_container_if_exist(container_name: str):
"""
Delete a blob container. If it does not exist, do nothing.
If exception is raised, the calling function should catch it.
"""
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
try:
blob_service_client.delete_container(container_name)
except ResourceNotFoundError:
# do nothing if container does not exist
pass
def delete_cosmos_container_item_if_exist(container: str, item_id: str):
"""
Delete an item from a cosmosdb container. If it does not exist, do nothing.
If exception is raised, the calling function should catch it.
"""
azure_client_manager = AzureClientManager()
try:
azure_client_manager.get_cosmos_container_client(
database="graphrag", container=container
).delete_item(item_id, item_id)
except ResourceNotFoundError:
# do nothing if item does not exist
pass
def validate_index_file_exist(sanitized_container_name: str, file_name: str):
"""
Check if index exists and that the specified blob file exists.
A "valid" index is defined by having an entry in the container-store table in cosmos db.
Further checks are done to ensure the blob container and file exist.
Args:
-----
sanitized_container_name (str)
Sanitized name of a blob container.
file_name (str)
The blob file to be validated.
Raises: ValueError
"""
azure_client_manager = AzureClientManager()
original_container_name = desanitize_name(sanitized_container_name)
try:
cosmos_container_client = get_cosmos_container_store_client()
cosmos_container_client.read_item(
sanitized_container_name, sanitized_container_name
)
except Exception:
raise ValueError(f"{original_container_name} is not a valid index.")
# check for file existence
index_container_client = (
azure_client_manager.get_blob_service_client().get_container_client(
sanitized_container_name
)
)
if not index_container_client.exists():
raise ValueError(f"{original_container_name} not found.")
if not index_container_client.get_blob_client(file_name).exists():
raise ValueError(
f"File {file_name} unavailable for container {original_container_name}."
)
def get_cosmos_container_store_client() -> ContainerProxy:
try:
azure_client_manager = AzureClientManager()
return azure_client_manager.get_cosmos_container_client(
database="graphrag", container="container-store"
)
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Error fetching cosmosdb client.",
cause=e,
stack=traceback.format_exc(),
)
raise HTTPException(status_code=500, detail="Error fetching cosmosdb client.")
async def get_blob_container_client(name: str) -> ContainerClient:
try:
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client_async()
container_client = blob_service_client.get_container_client(name)
if not await container_client.exists():
await container_client.create_container()
return container_client
except Exception as e:
logger = load_pipeline_logger()
logger.error(
message="Error fetching storage client.",
cause=e,
stack=traceback.format_exc(),
)
raise HTTPException(status_code=500, detail="Error fetching storage client.")
def sanitize_name(container_name: str) -> str:
"""
Sanitize a user-provided string to be used as an Azure Storage container name.
Convert the string to a SHA256 hash, then truncate to 128 bit length to ensure
it is within the 63 character limit imposed by Azure Storage.
The sanitized name will be used to identify container names in both Azure Storage and CosmosDB.
Args:
-----
name (str)
The name to be sanitized.
Returns: str
The sanitized name.
"""
container_name = container_name.encode()
hashed_name = hashlib.sha256(container_name)
truncated_hash = hashed_name.digest()[:16] # get the first 16 bytes (128 bits)
return truncated_hash.hex()
def desanitize_name(sanitized_container_name: str) -> str | None:
"""
Reverse the sanitization process by retrieving the original user-provided name.
Args:
-----
sanitized_name (str)
The sanitized name to be converted back to the original name.
Returns: str | None
The original human-readable name or None if it does not exist.
"""
try:
container_store_client = get_cosmos_container_store_client()
try:
return container_store_client.read_item(
sanitized_container_name, sanitized_container_name
)["human_readable_name"]
except exceptions.CosmosResourceNotFoundError:
return None
except Exception:
raise HTTPException(
status_code=500, detail="Error retrieving original container name."
)
async def subscription_key_check(
Ocp_Apim_Subscription_Key: Annotated[str, Header()],
):
"""
Verifies if user has passed the Ocp_Apim_Subscription_Key (APIM subscription key) in the request header.
If it is not present, an HTTPException with a 400 status code is raised.
Note: this check is unnecessary (APIM validates subscription keys automatically), but this will add the key
as a required parameter in the swagger docs page, enabling users to send requests using the swagger docs "Try it out" feature.
"""
if not Ocp_Apim_Subscription_Key:
raise HTTPException(
status_code=400, detail="Ocp-Apim-Subscription-Key required"
)
return Ocp_Apim_Subscription_Key