Skip to content

Commit 01a12ef

Browse files
fix: fix ray redundant execution
1 parent 1f13f44 commit 01a12ef

4 files changed

Lines changed: 56 additions & 51 deletions

File tree

graphgen/engine.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,38 @@
88
import ray
99
import ray.data
1010
from ray.data import DataContext
11+
from ray.data.block import Block
12+
from ray.data.datasource.filename_provider import FilenameProvider
1113

1214
from graphgen.bases import Config, Node
1315
from graphgen.common import init_llm, init_storage
1416
from graphgen.utils import logger
1517

1618

19+
class NodeFilenameProvider(FilenameProvider):
20+
def __init__(self, node_id: str):
21+
self.node_id = node_id
22+
23+
def get_filename_for_block(
24+
self, block: Block, write_uuid: str, task_index: int, block_index: int
25+
) -> str:
26+
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
27+
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
28+
29+
def get_filename_for_row(
30+
self,
31+
row: Dict[str, Any],
32+
write_uuid: str,
33+
task_index: int,
34+
block_index: int,
35+
row_index: int,
36+
) -> str:
37+
raise NotImplementedError(
38+
f"Row-based filenames are not supported by write_json. "
39+
f"Node: {self.node_id}, write_uuid: {write_uuid}"
40+
)
41+
42+
1743
class Engine:
1844
def __init__(
1945
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
@@ -263,13 +289,28 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
263289
f"Unsupported node type {node.type} for node {node.id}"
264290
)
265291

266-
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
292+
def execute(self, initial_ds: ray.data.Dataset, output_dir: str):
267293
sorted_nodes = self._topo_sort(self.config.nodes)
268294

269295
for node in sorted_nodes:
296+
logger.info("Executing node %s of type %s", node.id, node.type)
270297
self._execute_node(node, initial_ds)
271298
if getattr(node, "save_output", False):
272-
self.datasets[node.id] = self.datasets[node.id].materialize()
299+
node_output_path = os.path.join(output_dir, f"{node.id}")
300+
os.makedirs(node_output_path, exist_ok=True)
301+
logger.info("Saving output of node %s to %s", node.id, node_output_path)
302+
303+
ds = self.datasets[node.id]
304+
ds.write_json(
305+
node_output_path,
306+
filename_provider=NodeFilenameProvider(node.id),
307+
pandas_json_args_fn=lambda: {
308+
"orient": "records",
309+
"lines": True,
310+
"force_ascii": False,
311+
},
312+
)
313+
logger.info("Node %s output saved to %s", node.id, node_output_path)
273314

274-
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
275-
return {node.id: self.datasets[node.id] for node in output_nodes}
315+
# ray will lazy read the dataset
316+
self.datasets[node.id] = ray.data.read_json(node_output_path)

graphgen/operators/generate/generate_service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
import pandas as pd
24

35
from graphgen.bases import BaseLLMWrapper, BaseOperator
@@ -85,7 +87,10 @@ def generate(self, items: list[dict]) -> list[dict]:
8587
:return: QA pairs
8688
"""
8789
logger.info("[Generation] mode: %s, batches: %d", self.method, len(items))
88-
items = [(item["nodes"], item["edges"]) for item in items]
90+
# items = [(item["nodes"], item["edges"]) for item in items]
91+
items = [
92+
(json.loads(item["nodes"]), json.loads(item["edges"])) for item in items
93+
]
8994
results = run_concurrent(
9095
self.generator.generate,
9196
items,

graphgen/operators/partition/partition_service.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ def partition(self) -> Iterable[pd.DataFrame]:
8989

9090
yield pd.DataFrame(
9191
{
92-
"nodes": [batch[0]],
93-
"edges": [batch[1]],
94-
}
92+
"nodes": json.dumps(batch[0]),
93+
"edges": json.dumps(batch[1]),
94+
},
95+
index=[0],
9596
)
9697
logger.info("Total communities partitioned: %d", count)
9798

graphgen/run.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
import os
33
import time
44
from importlib import resources
5-
from typing import Any, Dict
65

76
import ray
87
import yaml
98
from dotenv import load_dotenv
10-
from ray.data.block import Block
11-
from ray.data.datasource.filename_provider import FilenameProvider
129

1310
from graphgen.engine import Engine
1411
from graphgen.operators import operators
@@ -32,30 +29,6 @@ def save_config(config_path, global_config):
3229
)
3330

3431

35-
class NodeFilenameProvider(FilenameProvider):
36-
def __init__(self, node_id: str):
37-
self.node_id = node_id
38-
39-
def get_filename_for_block(
40-
self, block: Block, write_uuid: str, task_index: int, block_index: int
41-
) -> str:
42-
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
43-
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
44-
45-
def get_filename_for_row(
46-
self,
47-
row: Dict[str, Any],
48-
write_uuid: str,
49-
task_index: int,
50-
block_index: int,
51-
row_index: int,
52-
) -> str:
53-
raise NotImplementedError(
54-
f"Row-based filenames are not supported by write_json. "
55-
f"Node: {self.node_id}, write_uuid: {write_uuid}"
56-
)
57-
58-
5932
def main():
6033
parser = argparse.ArgumentParser()
6134
parser.add_argument(
@@ -91,22 +64,7 @@ def main():
9164

9265
engine = Engine(config, operators)
9366
ds = ray.data.from_items([])
94-
results = engine.execute(ds)
95-
96-
for node_id, dataset in results.items():
97-
logger.info("Saving results for node %s", node_id)
98-
node_output_path = os.path.join(output_path, f"{node_id}")
99-
os.makedirs(node_output_path, exist_ok=True)
100-
dataset.write_json(
101-
node_output_path,
102-
filename_provider=NodeFilenameProvider(node_id),
103-
pandas_json_args_fn=lambda: {
104-
"force_ascii": False,
105-
"orient": "records",
106-
"lines": True,
107-
},
108-
)
109-
logger.info("Node %s results saved to %s", node_id, node_output_path)
67+
engine.execute(ds, output_dir=output_path)
11068

11169
save_config(os.path.join(output_path, "config.yaml"), config)
11270
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

0 commit comments

Comments
 (0)