Skip to content

Commit b75d042

Browse files
fix: adjust ray import
1 parent 8fa5bc8 commit b75d042

2 files changed

Lines changed: 8 additions & 65 deletions

File tree

graphgen/common/init_storage.py

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
1+
from typing import Any, Dict, List, Set, Union
22

3-
from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
3+
import ray
44

5-
if TYPE_CHECKING:
6-
import ray
5+
from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
76

87

98
class KVStorageActor:
@@ -152,186 +151,137 @@ def __init__(self, actor_handle: "ray.actor.ActorHandle"):
152151
self.actor = actor_handle
153152

154153
def data(self) -> Dict[str, Any]:
155-
import ray
156-
157154
return ray.get(self.actor.data.remote())
158155

159156
def all_keys(self) -> list[str]:
160-
import ray
161-
162157
return ray.get(self.actor.all_keys.remote())
163158

164159
def index_done_callback(self):
165-
import ray
166-
167160
return ray.get(self.actor.index_done_callback.remote())
168161

169162
def get_by_id(self, id: str) -> Union[Any, None]:
170-
import ray
171-
172163
return ray.get(self.actor.get_by_id.remote(id))
173164

174165
def get_by_ids(self, ids: list[str], fields=None) -> list[Any]:
175-
import ray
176-
177166
return ray.get(self.actor.get_by_ids.remote(ids, fields))
178167

179168
def get_all(self) -> Dict[str, Any]:
180-
import ray
181-
182169
return ray.get(self.actor.get_all.remote())
183170

184171
def filter_keys(self, data: list[str]) -> set[str]:
185-
import ray
186-
187172
return ray.get(self.actor.filter_keys.remote(data))
188173

189174
def upsert(self, data: Dict[str, Any]):
190-
import ray
191-
192175
return ray.get(self.actor.upsert.remote(data))
193176

194177
def update(self, data: Dict[str, Any]):
195-
import ray
196-
197178
return ray.get(self.actor.update.remote(data))
198179

199180
def delete(self, ids: list[str]):
200-
import ray
201-
202181
return ray.get(self.actor.delete.remote(ids))
203182

204183
def drop(self):
205-
import ray
206-
207184
return ray.get(self.actor.drop.remote())
208185

209186
def reload(self):
210-
import ray
211-
212187
return ray.get(self.actor.reload.remote())
213188

214189

215190
class RemoteGraphStorageProxy(BaseGraphStorage):
216-
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
191+
def __init__(self, actor_handle: ray.actor.ActorHandle):
217192
super().__init__()
218193
self.actor = actor_handle
219194

220195
def index_done_callback(self):
221-
import ray
222-
223196
return ray.get(self.actor.index_done_callback.remote())
224197

225198
def is_directed(self) -> bool:
226-
import ray
227-
228199
return ray.get(self.actor.is_directed.remote())
229200

230201
def get_all_node_degrees(self) -> Dict[str, int]:
231-
import ray
232-
233202
return ray.get(self.actor.get_all_node_degrees.remote())
234203

235204
def get_node_count(self) -> int:
236-
import ray
237205

238206
return ray.get(self.actor.get_node_count.remote())
239207

240208
def get_edge_count(self) -> int:
241-
import ray
242209

243210
return ray.get(self.actor.get_edge_count.remote())
244211

245212
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
246-
import ray
247213

248214
return ray.get(self.actor.get_connected_components.remote(undirected))
249215

250216
def has_node(self, node_id: str) -> bool:
251-
import ray
252217

253218
return ray.get(self.actor.has_node.remote(node_id))
254219

255220
def has_edge(self, source_node_id: str, target_node_id: str):
256-
import ray
257221

258222
return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
259223

260224
def node_degree(self, node_id: str) -> int:
261-
import ray
262225

263226
return ray.get(self.actor.node_degree.remote(node_id))
264227

265228
def edge_degree(self, src_id: str, tgt_id: str) -> int:
266-
import ray
267229

268230
return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
269231

270232
def get_node(self, node_id: str) -> Any:
271-
import ray
272233

273234
return ray.get(self.actor.get_node.remote(node_id))
274235

275236
def update_node(self, node_id: str, node_data: dict[str, str]):
276-
import ray
277237

278238
return ray.get(self.actor.update_node.remote(node_id, node_data))
279239

280240
def get_all_nodes(self) -> Any:
281-
import ray
282241

283242
return ray.get(self.actor.get_all_nodes.remote())
284243

285244
def get_edge(self, source_node_id: str, target_node_id: str):
286-
import ray
287245

288246
return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
289247

290248
def update_edge(
291249
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
292250
):
293-
import ray
294251

295252
return ray.get(
296253
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
297254
)
298255

299256
def get_all_edges(self) -> Any:
300-
import ray
301257

302258
return ray.get(self.actor.get_all_edges.remote())
303259

304260
def get_node_edges(self, source_node_id: str) -> Any:
305-
import ray
306261

307262
return ray.get(self.actor.get_node_edges.remote(source_node_id))
308263

309264
def upsert_node(self, node_id: str, node_data: dict[str, str]):
310-
import ray
311265

312266
return ray.get(self.actor.upsert_node.remote(node_id, node_data))
313267

314268
def upsert_edge(
315269
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
316270
):
317-
import ray
318271

319272
return ray.get(
320273
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
321274
)
322275

323276
def delete_node(self, node_id: str):
324-
import ray
325277

326278
return ray.get(self.actor.delete_node.remote(node_id))
327279

328280
def get_neighbors(self, node_id: str) -> List[str]:
329-
import ray
330281

331282
return ray.get(self.actor.get_neighbors.remote(node_id))
332283

333284
def reload(self):
334-
import ray
335285

336286
return ray.get(self.actor.reload.remote())
337287

@@ -343,7 +293,6 @@ class StorageFactory:
343293

344294
@staticmethod
345295
def create_storage(backend: str, working_dir: str, namespace: str):
346-
import ray
347296

348297
if backend in ["json_kv", "rocksdb"]:
349298
actor_name = f"Actor_KV_{namespace}"

graphgen/storage/graph/networkx_storage.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import html
22
import os
33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast
4+
from typing import Any, Dict, List, Optional, Set, Union, cast
55

6-
from graphgen.bases.base_storage import BaseGraphStorage
6+
import networkx as nx
77

8-
if TYPE_CHECKING:
9-
import networkx as nx
8+
from graphgen.bases.base_storage import BaseGraphStorage
109

1110

1211
@dataclass
@@ -27,7 +26,6 @@ def get_edge_count(self) -> int:
2726
return self._graph.number_of_edges()
2827

2928
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
30-
import networkx as nx
3129

3230
graph = self._graph
3331

@@ -40,15 +38,13 @@ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
4038

4139
@staticmethod
4240
def load_nx_graph(file_name) -> Optional["nx.Graph"]:
43-
import networkx as nx
4441

4542
if os.path.exists(file_name):
4643
return nx.read_graphml(file_name)
4744
return None
4845

4946
@staticmethod
5047
def write_nx_graph(graph: "nx.Graph", file_name):
51-
import networkx as nx
5248

5349
nx.write_graphml(graph, file_name)
5450

@@ -57,7 +53,7 @@ def stable_largest_connected_component(graph: "nx.Graph") -> "nx.Graph":
5753
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
5854
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
5955
"""
60-
import networkx as nx
56+
6157
from graspologic.utils import largest_connected_component
6258

6359
graph = graph.copy()
@@ -74,7 +70,6 @@ def _stabilize_graph(graph: "nx.Graph") -> "nx.Graph":
7470
Ensure an undirected graph with the same relationships will always be read the same way.
7571
通过对节点和边进行排序来实现
7672
"""
77-
import networkx as nx
7873

7974
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
8075

@@ -107,7 +102,6 @@ def __post_init__(self):
107102
Initialize the NetworkX graph storage by loading an existing graph from a GraphML file,
108103
if it exists, or creating a new empty graph otherwise.
109104
"""
110-
import networkx as nx
111105

112106
self._graphml_xml_file = os.path.join(
113107
self.working_dir, f"{self.namespace}.graphml"

0 commit comments

Comments
 (0)