Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions graphgen/bases/base_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union

import numpy as np
if TYPE_CHECKING:
import numpy as np


class BaseFilter(ABC):
Expand All @@ -15,7 +16,7 @@ def filter(self, data: Any) -> bool:

class BaseValueFilter(BaseFilter, ABC):
@abstractmethod
def filter(self, data: Union[int, float, np.number]) -> bool:
def filter(self, data: Union[int, float, "np.number"]) -> bool:
"""
Filter the numeric value and return True if it passes the filter, False otherwise.
"""
Expand Down
24 changes: 17 additions & 7 deletions graphgen/bases/base_operator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import inspect
import os
from abc import ABC, abstractmethod
from typing import Iterable, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Tuple, Union

import numpy as np
import pandas as pd
import ray
if TYPE_CHECKING:
import numpy as np
import pandas as pd


def convert_to_serializable(obj):
import numpy as np

if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
Expand Down Expand Up @@ -40,6 +44,8 @@ def __init__(
)

try:
import ray

ctx = ray.get_runtime_context()
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
worker_id_short = worker_id[-6:] if worker_id else "driver"
Expand All @@ -62,9 +68,11 @@ def __init__(
)

def __call__(
self, batch: pd.DataFrame
) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
self, batch: "pd.DataFrame"
) -> Union["pd.DataFrame", Iterable["pd.DataFrame"]]:
# lazy import to avoid circular import
import pandas as pd

from graphgen.utils import CURRENT_LOGGER_VAR

logger_token = CURRENT_LOGGER_VAR.set(self.logger)
Expand Down Expand Up @@ -106,14 +114,16 @@ def get_trace_id(self, content: dict) -> str:

return compute_dict_hash(content, prefix=f"{self.op_name}-")

def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]:
"""
Split the input batch into to_process & processed based on _meta data in KV_storage
:param batch
:return:
to_process: DataFrame of documents to be chunked
recovered: Result DataFrame of already chunked documents
"""
import pandas as pd

meta_forward = self.get_meta_forward()
meta_ids = set(meta_forward.keys())
mask = batch["_trace_id"].isin(meta_ids)
Expand Down
11 changes: 8 additions & 3 deletions graphgen/bases/base_reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Union

import pandas as pd
import requests
from ray.data import Dataset

if TYPE_CHECKING:
import pandas as pd
from ray.data import Dataset


class BaseReader(ABC):
Expand Down Expand Up @@ -51,6 +55,7 @@ def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
"""
Validate data format.
"""

if "type" not in batch.columns:
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")

Expand Down
11 changes: 7 additions & 4 deletions graphgen/common/init_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Any, Dict, Optional

import ray
from typing import TYPE_CHECKING, Any, Dict, Optional

from graphgen.bases import BaseLLMWrapper
from graphgen.models import Tokenizer

if TYPE_CHECKING:
import ray


class LLMServiceActor:
"""
Expand Down Expand Up @@ -73,7 +74,7 @@ class LLMServiceProxy(BaseLLMWrapper):
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
"""

def __init__(self, actor_handle: ray.actor.ActorHandle):
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
super().__init__()
self.actor_handle = actor_handle
self._create_local_tokenizer()
Expand Down Expand Up @@ -120,6 +121,8 @@ class LLMFactory:
def create_llm(
model_type: str, backend: str, config: Dict[str, Any]
) -> BaseLLMWrapper:
import ray

if not config:
raise ValueError(
f"No configuration provided for LLM {model_type} with backend {backend}."
Expand Down
81 changes: 76 additions & 5 deletions graphgen/common/init_storage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Dict, List, Set, Union

import ray
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union

from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage

if TYPE_CHECKING:
import ray


class KVStorageActor:
def __init__(self, backend: str, working_dir: str, namespace: str):
Expand Down Expand Up @@ -146,124 +147,192 @@ def ready(self) -> bool:


class RemoteKVStorageProxy(BaseKVStorage):
def __init__(self, actor_handle: ray.actor.ActorHandle):
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
super().__init__()
self.actor = actor_handle

def data(self) -> Dict[str, Any]:
import ray
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing ray in every method of RemoteKVStorageProxy and RemoteGraphStorageProxy introduces significant code duplication. Since these classes are tightly coupled with Ray, consider importing ray once at the top of the file. This would simplify the code and improve maintainability. Python's import caching mechanism ensures that subsequent imports are very fast, so there's minimal performance overhead.


return ray.get(self.actor.data.remote())

def all_keys(self) -> list[str]:
import ray

return ray.get(self.actor.all_keys.remote())

def index_done_callback(self):
import ray

return ray.get(self.actor.index_done_callback.remote())

def get_by_id(self, id: str) -> Union[Any, None]:
import ray

return ray.get(self.actor.get_by_id.remote(id))

def get_by_ids(self, ids: list[str], fields=None) -> list[Any]:
import ray

return ray.get(self.actor.get_by_ids.remote(ids, fields))

def get_all(self) -> Dict[str, Any]:
import ray

return ray.get(self.actor.get_all.remote())

def filter_keys(self, data: list[str]) -> set[str]:
import ray

return ray.get(self.actor.filter_keys.remote(data))

def upsert(self, data: Dict[str, Any]):
import ray

return ray.get(self.actor.upsert.remote(data))

def update(self, data: Dict[str, Any]):
import ray

return ray.get(self.actor.update.remote(data))

def delete(self, ids: list[str]):
import ray

return ray.get(self.actor.delete.remote(ids))

def drop(self):
import ray

return ray.get(self.actor.drop.remote())

def reload(self):
import ray

return ray.get(self.actor.reload.remote())


class RemoteGraphStorageProxy(BaseGraphStorage):
def __init__(self, actor_handle: ray.actor.ActorHandle):
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
super().__init__()
self.actor = actor_handle

def index_done_callback(self):
import ray

return ray.get(self.actor.index_done_callback.remote())

def is_directed(self) -> bool:
import ray

return ray.get(self.actor.is_directed.remote())

def get_all_node_degrees(self) -> Dict[str, int]:
import ray

return ray.get(self.actor.get_all_node_degrees.remote())

def get_node_count(self) -> int:
import ray

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This blank line is unnecessary and harms readability by making the method less compact. This applies to many other simple proxy methods in this class as well (e.g., get_edge_count, has_node). For simple one-line proxy methods, it's best to keep them compact.

return ray.get(self.actor.get_node_count.remote())

def get_edge_count(self) -> int:
import ray

return ray.get(self.actor.get_edge_count.remote())

def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
import ray

return ray.get(self.actor.get_connected_components.remote(undirected))

def has_node(self, node_id: str) -> bool:
import ray

return ray.get(self.actor.has_node.remote(node_id))

def has_edge(self, source_node_id: str, target_node_id: str):
import ray

return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))

def node_degree(self, node_id: str) -> int:
import ray

return ray.get(self.actor.node_degree.remote(node_id))

def edge_degree(self, src_id: str, tgt_id: str) -> int:
import ray

return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))

def get_node(self, node_id: str) -> Any:
import ray

return ray.get(self.actor.get_node.remote(node_id))

def update_node(self, node_id: str, node_data: dict[str, str]):
import ray

return ray.get(self.actor.update_node.remote(node_id, node_data))

def get_all_nodes(self) -> Any:
import ray

return ray.get(self.actor.get_all_nodes.remote())

def get_edge(self, source_node_id: str, target_node_id: str):
import ray

return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
import ray

return ray.get(
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
)

def get_all_edges(self) -> Any:
import ray

return ray.get(self.actor.get_all_edges.remote())

def get_node_edges(self, source_node_id: str) -> Any:
import ray

return ray.get(self.actor.get_node_edges.remote(source_node_id))

def upsert_node(self, node_id: str, node_data: dict[str, str]):
import ray

return ray.get(self.actor.upsert_node.remote(node_id, node_data))

def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
import ray

return ray.get(
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
)

def delete_node(self, node_id: str):
import ray

return ray.get(self.actor.delete_node.remote(node_id))

def get_neighbors(self, node_id: str) -> List[str]:
import ray

return ray.get(self.actor.get_neighbors.remote(node_id))

def reload(self):
import ray

return ray.get(self.actor.reload.remote())


Expand All @@ -274,6 +343,8 @@ class StorageFactory:

@staticmethod
def create_storage(backend: str, working_dir: str, namespace: str):
import ray

if backend in ["json_kv", "rocksdb"]:
actor_name = f"Actor_KV_{namespace}"
actor_class = KVStorageActor
Expand Down
Loading
Loading