-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathbase_operator.py
More file actions
162 lines (134 loc) · 5.37 KB
/
base_operator.py
File metadata and controls
162 lines (134 loc) · 5.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import inspect
import os
from abc import ABC, abstractmethod
from typing import Iterable, Tuple, Union
import numpy as np
import pandas as pd
import ray
def convert_to_serializable(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [convert_to_serializable(v) for v in obj]
return obj
class BaseOperator(ABC):
def __init__(
self,
working_dir: str = "cache",
kv_backend: str = "rocksdb",
op_name: str = None,
):
# lazy import to avoid circular import
from graphgen.common.init_storage import init_storage
from graphgen.utils import set_logger
log_dir = os.path.join(working_dir, "logs")
self.op_name = op_name or self.__class__.__name__
self.working_dir = working_dir
self.kv_backend = kv_backend
self.kv_storage = init_storage(
backend=kv_backend, working_dir=working_dir, namespace=self.op_name
)
try:
ctx = ray.get_runtime_context()
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
worker_id_short = worker_id[-6:] if worker_id else "driver"
except Exception as e:
print(
"Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:",
e,
)
worker_id_short = "local"
# e.g. cache/logs/ChunkService_a1b2c3.log
log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log")
self.logger = set_logger(
log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True
)
self.logger.info(
"[%s] Operator initialized on Worker %s", self.op_name, worker_id_short
)
def __call__(
self, batch: pd.DataFrame
) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
# lazy import to avoid circular import
from graphgen.utils import CURRENT_LOGGER_VAR
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
try:
self.kv_storage.reload()
to_process, recovered = self.split(batch)
# yield recovered chunks first
if not recovered.empty:
yield recovered
if to_process.empty:
return
data = to_process.to_dict(orient="records")
result, meta_update = self.process(data)
if inspect.isgenerator(result):
is_first = True
for res in result:
yield pd.DataFrame([res])
self.store([res], meta_update if is_first else {})
is_first = False
else:
yield pd.DataFrame(result)
self.store(result, meta_update)
finally:
CURRENT_LOGGER_VAR.reset(logger_token)
def get_logger(self):
return self.logger
def get_meta_forward(self):
return self.kv_storage.get_by_id("_meta_forward") or {}
def get_meta_inverse(self):
return self.kv_storage.get_by_id("_meta_inverse") or {}
def get_trace_id(self, content: dict) -> str:
from graphgen.utils import compute_dict_hash
return compute_dict_hash(content, prefix=f"{self.op_name}-")
def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Split the input batch into to_process & processed based on _meta data in KV_storage
:param batch
:return:
to_process: DataFrame of documents to be chunked
recovered: Result DataFrame of already chunked documents
"""
meta_forward = self.get_meta_forward()
meta_ids = set(meta_forward.keys())
mask = batch["_trace_id"].isin(meta_ids)
to_process = batch[~mask]
processed = batch[mask]
if processed.empty:
return to_process, pd.DataFrame()
all_ids = [
pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, [])
]
recovered_chunks = self.kv_storage.get_by_ids(all_ids)
recovered_chunks = [c for c in recovered_chunks if c is not None]
return to_process, pd.DataFrame(recovered_chunks)
def store(self, results: list, meta_update: dict):
results = convert_to_serializable(results)
meta_update = convert_to_serializable(meta_update)
batch = {res["_trace_id"]: res for res in results}
self.kv_storage.upsert(batch)
# update forward meta
forward_meta = self.get_meta_forward()
forward_meta.update(meta_update)
self.kv_storage.update({"_meta_forward": forward_meta})
# update inverse meta
inverse_meta = self.get_meta_inverse()
for k, v_list in meta_update.items():
for v in v_list:
inverse_meta[v] = k
self.kv_storage.update({"_meta_inverse": inverse_meta})
self.kv_storage.index_done_callback()
@abstractmethod
def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]:
"""
Process the input batch and return the result.
:param batch
:return:
result: DataFrame of processed documents
meta_update: dict of meta data to be updated
"""