Skip to content

Commit 9e994a4

Browse files
Fix duplicate flush bug (#185)
1 parent 8da1261 commit 9e994a4

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

graphgen/bases/base_operator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,11 @@ def __call__(
9292
is_first = True
9393
for res in result:
9494
yield pd.DataFrame([res])
95-
self.store([res], meta_update if is_first else {})
95+
self.store(
96+
[res], meta_update if is_first else {}, flush=False
97+
)
9698
is_first = False
99+
self.kv_storage.index_done_callback()
97100
else:
98101
yield pd.DataFrame(result)
99102
self.store(result, meta_update)
@@ -141,7 +144,7 @@ def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]:
141144
recovered_chunks = [c for c in recovered_chunks if c is not None]
142145
return to_process, pd.DataFrame(recovered_chunks)
143146

144-
def store(self, results: list, meta_update: dict):
147+
def store(self, results: list, meta_update: dict, flush: bool = True):
145148
results = convert_to_serializable(results)
146149
meta_update = convert_to_serializable(meta_update)
147150

@@ -159,7 +162,8 @@ def store(self, results: list, meta_update: dict):
159162
for v in v_list:
160163
inverse_meta[v] = k
161164
self.kv_storage.update({"_meta_inverse": inverse_meta})
162-
self.kv_storage.index_done_callback()
165+
if flush:
166+
self.kv_storage.index_done_callback()
163167

164168
@abstractmethod
165169
def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]:

graphgen/storage/kv/rocksdb_storage.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
from dataclasses import dataclass
34
from typing import Any, Dict, List, Set
@@ -8,6 +9,8 @@
89

910
from graphgen.bases.base_storage import BaseKVStorage
1011

12+
logger = logging.getLogger(__name__)
13+
1114

1215
@dataclass
1316
class RocksDBKVStorage(BaseKVStorage):
@@ -17,8 +20,10 @@ class RocksDBKVStorage(BaseKVStorage):
1720
def __post_init__(self):
1821
self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db")
1922
self._db = Rdict(self._db_path)
20-
print(
21-
f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'"
23+
logger.debug(
24+
"RocksDBKVStorage initialized for namespace '%s' at '%s'",
25+
self.namespace,
26+
self._db_path,
2227
)
2328

2429
@property
@@ -30,7 +35,7 @@ def all_keys(self) -> List[str]:
3035

3136
def index_done_callback(self):
3237
self._db.flush()
33-
print(f"RocksDB flushed for {self.namespace}")
38+
logger.debug("RocksDB flushed for %s", self.namespace)
3439

3540
def get_by_id(self, id: str) -> Any:
3641
return self._db.get(id, None)

0 commit comments

Comments
 (0)