Skip to content

Commit 1c85757

Browse files
authored
refactor:refine embedding and cli (#33)
* Refine embedding and cli * fix lint
1 parent d6fd022 commit 1c85757

16 files changed

Lines changed: 1591 additions & 1077 deletions

File tree

python/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ dependencies = [
77
"fastapi>=0.104.0",
88
"uvicorn>=0.24.0",
99
"pydantic>=2.0.0",
10+
"openai>=1.52.0",
11+
"lance>=0.17.0",
1012
]
1113
description = "Python bindings for the lance-graph Cypher engine"
1214
authors = [{ name = "Lance Devs", email = "dev@lancedb.com" }]
@@ -43,8 +45,6 @@ build-backend = "maturin"
4345
[project.optional-dependencies]
4446
tests = ["pytest", "pyarrow>=14", "pandas", "ruff"]
4547
dev = ["ruff", "pyright"]
46-
llm = ["openai>=1.52.0"]
47-
lance-storage = ["lance>=0.17.0"]
4848

4949
[project.scripts]
5050
knowledge_graph = "knowledge_graph.main:main"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .ingest import extract_and_add, preview_extraction
2+
from .interactive import list_datasets, run_interactive
3+
4+
__all__ = [
5+
"run_interactive",
6+
"list_datasets",
7+
"preview_extraction",
8+
"extract_and_add",
9+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Shared helpers for preparing rows and embeddings (internal)."""
2+
3+
from __future__ import annotations
4+
5+
from .ingest import (
6+
_assign_embeddings,
7+
_format_entity_embedding_input,
8+
_format_relationship_embedding_input,
9+
_prepare_entity_rows,
10+
_prepare_relationship_rows,
11+
)
12+
13+
__all__ = [
14+
"_assign_embeddings",
15+
"_format_entity_embedding_input",
16+
"_format_relationship_embedding_input",
17+
"_prepare_entity_rows",
18+
"_prepare_relationship_rows",
19+
]
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""Extraction preview and ingest helpers for the knowledge graph CLI."""
2+
3+
from __future__ import annotations
4+
5+
import hashlib
6+
import json
7+
import logging
8+
from dataclasses import asdict, is_dataclass
9+
from pathlib import Path
10+
from typing import TYPE_CHECKING, Any, Callable, Mapping
11+
12+
from .. import extraction as kg_extraction
13+
14+
if TYPE_CHECKING:
15+
from ..embeddings import EmbeddingGenerator
16+
from ..service import LanceKnowledgeGraph
17+
18+
LOGGER = logging.getLogger(__name__)
19+
20+
21+
def preview_extraction(source: str, extractor: kg_extraction.BaseExtractor) -> None:
22+
"""Preview extracted knowledge from a text source or inline text."""
23+
text = _resolve_text_input(source)
24+
result = kg_extraction.preview_extraction(text, extractor=extractor)
25+
print(json.dumps(_result_to_dict(result), indent=2))
26+
27+
28+
def extract_and_add(
29+
source: str,
30+
service: LanceKnowledgeGraph,
31+
extractor: kg_extraction.BaseExtractor,
32+
*,
33+
embedding_generator: EmbeddingGenerator | None = None,
34+
) -> None:
35+
"""Extract knowledge and append it to the backing graph."""
36+
import pyarrow as pa
37+
38+
text = _resolve_text_input(source)
39+
result = kg_extraction.preview_extraction(text, extractor=extractor)
40+
entity_rows, name_to_id = _prepare_entity_rows(
41+
result.entities, embedding_generator=embedding_generator
42+
)
43+
relationships = result.relationships
44+
45+
if not entity_rows and not relationships:
46+
print("No candidate entities or relationships detected.")
47+
return
48+
49+
if entity_rows:
50+
entity_table = pa.Table.from_pylist(entity_rows)
51+
service.upsert_table("Entity", entity_table, merge=True)
52+
message = f"Upserted {entity_table.num_rows} entity rows into dataset 'Entity'."
53+
print(message)
54+
55+
relationship_rows = _prepare_relationship_rows(
56+
relationships,
57+
name_to_id,
58+
embedding_generator=embedding_generator,
59+
)
60+
if relationship_rows:
61+
rel_table = pa.Table.from_pylist(relationship_rows)
62+
service.upsert_table("RELATIONSHIP", rel_table, merge=True)
63+
message = (
64+
"Upserted "
65+
f"{rel_table.num_rows} relationship rows into dataset "
66+
"'RELATIONSHIP'."
67+
)
68+
print(message)
69+
70+
71+
def _resolve_text_input(raw: str) -> str:
72+
"""Load text from a file if it exists, otherwise treat the string as content."""
73+
candidate = Path(raw)
74+
if candidate.exists():
75+
if candidate.is_dir():
76+
raise IsADirectoryError(f"Expected text file, got directory: {candidate}")
77+
return candidate.read_text(encoding="utf-8")
78+
return raw
79+
80+
81+
def _ensure_dict(item: object) -> dict:
82+
if is_dataclass(item):
83+
return asdict(item) # type: ignore[arg-type]
84+
if isinstance(item, dict):
85+
return item
86+
raise TypeError(f"Unsupported extraction item type: {type(item)!r}")
87+
88+
89+
def _result_to_dict(result: "kg_extraction.ExtractionResult") -> dict[str, list[dict]]:
90+
return {
91+
"entities": [asdict(entity) for entity in result.entities],
92+
"relationships": [asdict(rel) for rel in result.relationships],
93+
}
94+
95+
96+
def _prepare_entity_rows(
97+
entities: list[Any],
98+
*,
99+
embedding_generator: EmbeddingGenerator | None = None,
100+
) -> tuple[list[dict[str, Any]], dict[str, str]]:
101+
rows: list[dict[str, Any]] = []
102+
name_to_id: dict[str, str] = {}
103+
for entity in entities:
104+
payload = _ensure_dict(entity)
105+
name = str(payload.get("name", "")).strip()
106+
entity_type = str(
107+
payload.get("entity_type") or payload.get("type") or ""
108+
).strip()
109+
if not name:
110+
continue
111+
base = f"{name}|{entity_type}".encode("utf-8")
112+
entity_id = hashlib.md5(base).hexdigest()
113+
payload["entity_id"] = entity_id
114+
payload["entity_type"] = entity_type or "UNKNOWN"
115+
payload["name_lower"] = name.lower()
116+
rows.append(payload)
117+
name_to_id.setdefault(name.lower(), entity_id)
118+
if embedding_generator and rows:
119+
_assign_embeddings(
120+
rows,
121+
embedding_generator,
122+
_format_entity_embedding_input,
123+
)
124+
return rows, name_to_id
125+
126+
127+
def _prepare_relationship_rows(
128+
relationships: list[Any],
129+
name_to_id: dict[str, str],
130+
*,
131+
embedding_generator: EmbeddingGenerator | None = None,
132+
) -> list[dict[str, Any]]:
133+
rows: list[dict[str, Any]] = []
134+
for relation in relationships:
135+
payload = _ensure_dict(relation)
136+
source_name = str(
137+
payload.get("source_entity_name") or payload.get("source") or ""
138+
).strip()
139+
target_name = str(
140+
payload.get("target_entity_name") or payload.get("target") or ""
141+
).strip()
142+
source_id = name_to_id.get(source_name.lower())
143+
target_id = name_to_id.get(target_name.lower())
144+
if not (source_id and target_id):
145+
continue
146+
payload["source_entity_id"] = source_id
147+
payload["target_entity_id"] = target_id
148+
payload["relationship_type"] = (
149+
payload.get("relationship_type") or payload.get("type") or "RELATED_TO"
150+
)
151+
payload.setdefault("source_entity_name", source_name)
152+
payload.setdefault("target_entity_name", target_name)
153+
rows.append(payload)
154+
if embedding_generator and rows:
155+
_assign_embeddings(
156+
rows,
157+
embedding_generator,
158+
_format_relationship_embedding_input,
159+
)
160+
return rows
161+
162+
163+
def _assign_embeddings(
164+
rows: list[dict[str, Any]],
165+
embedding_generator: EmbeddingGenerator,
166+
formatter: Callable[[Mapping[str, Any]], str],
167+
) -> None:
168+
texts: list[str] = []
169+
indices: list[int] = []
170+
for idx, row in enumerate(rows):
171+
text = formatter(row)
172+
if text:
173+
texts.append(text)
174+
indices.append(idx)
175+
if not texts:
176+
return
177+
try:
178+
vectors = embedding_generator.embed(texts)
179+
except Exception as exc: # pragma: no cover - defensive logging path
180+
LOGGER.warning("Failed to generate embeddings: %s", exc)
181+
return
182+
if len(vectors) != len(indices):
183+
LOGGER.warning(
184+
"Mismatch between embedding count and row count: expected %s, got %s",
185+
len(indices),
186+
len(vectors),
187+
)
188+
return
189+
for idx, vector in zip(indices, vectors):
190+
rows[idx]["embedding"] = vector
191+
192+
193+
def _format_entity_embedding_input(row: Mapping[str, Any]) -> str:
194+
name = str(row.get("name", "")).strip()
195+
entity_type = str(row.get("entity_type", "")).strip()
196+
context = str(row.get("context", "")).strip()
197+
pieces = []
198+
if name:
199+
pieces.append(name)
200+
if entity_type:
201+
pieces.append(f"Type: {entity_type}")
202+
if context:
203+
pieces.append(f"Context: {context}")
204+
return " | ".join(pieces)
205+
206+
207+
def _format_relationship_embedding_input(row: Mapping[str, Any]) -> str:
208+
source = str(row.get("source_entity_name") or row.get("source") or "").strip()
209+
target = str(row.get("target_entity_name") or row.get("target") or "").strip()
210+
relationship_type = str(row.get("relationship_type", "")).strip()
211+
description = str(row.get("description", "")).strip()
212+
core: list[str] = []
213+
if source or target:
214+
if relationship_type:
215+
core.append(f"{source} -[{relationship_type}]-> {target}".strip())
216+
else:
217+
core.append(f"{source} -> {target}".strip())
218+
if description:
219+
core.append(f"Description: {description}")
220+
return " | ".join(part for part in core if part)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Interactive shell and CLI display helpers for the knowledge graph."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
from typing import TYPE_CHECKING
7+
8+
from ..store import LanceGraphStore
9+
10+
if TYPE_CHECKING:
11+
from ..config import KnowledgeGraphConfig
12+
from ..service import LanceKnowledgeGraph
13+
14+
if TYPE_CHECKING:
15+
import pyarrow as pa
16+
17+
18+
def list_datasets(config: "KnowledgeGraphConfig") -> None:
19+
"""List the Lance datasets available under the configured root."""
20+
store = LanceGraphStore(config)
21+
store.ensure_layout()
22+
datasets = store.list_datasets()
23+
if not datasets:
24+
print("No Lance datasets found. Load data or run extraction first.")
25+
return
26+
print("Available Lance datasets:")
27+
for name, path in sorted(datasets.items()):
28+
print(f" - {name}: {path}")
29+
30+
31+
def run_interactive(service: "LanceKnowledgeGraph") -> None:
32+
"""Enter an interactive shell for issuing Cypher queries."""
33+
print("Lance Knowledge Graph interactive shell")
34+
print("Type ':help' for commands, or 'quit' to exit.")
35+
36+
while True:
37+
try:
38+
text = input("kg> ").strip()
39+
except EOFError:
40+
print()
41+
break
42+
43+
if not text:
44+
continue
45+
lowered = text.lower()
46+
if lowered in {"quit", "exit", "q"}:
47+
break
48+
if text.startswith(":"):
49+
_handle_command(text, service)
50+
continue
51+
52+
_execute_query(service, text)
53+
54+
55+
def _handle_command(command: str, service: "LanceKnowledgeGraph") -> None:
56+
"""Handle meta-commands in the interactive shell."""
57+
cmd = command.strip()
58+
if cmd in {":help", ":h"}:
59+
print("Commands:")
60+
print(" :help Show this message")
61+
print(" :datasets List persisted Lance datasets")
62+
print(" :config Show the configured node/relationship mappings")
63+
print(" quit/exit/q Leave the shell")
64+
return
65+
if cmd in {":datasets", ":ls"}:
66+
list_datasets(service.store.config)
67+
return
68+
if cmd in {":config", ":schema"}:
69+
_print_config_summary(service)
70+
return
71+
print(f"Unknown command: {command}")
72+
73+
74+
def _print_config_summary(service: "LanceKnowledgeGraph") -> None:
75+
"""Print a brief summary of the graph configuration."""
76+
config = service.config
77+
# GraphConfig does not currently expose direct iterators; rely on repr.
78+
print("Graph configuration:")
79+
print(f" {config!r}")
80+
81+
82+
def _execute_query(service: "LanceKnowledgeGraph", statement: str) -> None:
83+
"""Execute a single Cypher statement and print results."""
84+
try:
85+
result = service.run(statement)
86+
except Exception as exc: # pragma: no cover - CLI feedback path
87+
print(f"Query failed: {exc}", file=sys.stderr)
88+
return
89+
90+
_print_table(result)
91+
92+
93+
def _print_table(table: "pa.Table") -> None:
94+
"""Render a PyArrow table in a simple textual format."""
95+
if table.num_rows == 0:
96+
print("(no rows)")
97+
return
98+
99+
column_names = table.column_names
100+
columns = [table.column(i).to_pylist() for i in range(len(column_names))]
101+
widths = []
102+
for name, values in zip(column_names, columns):
103+
str_values = ["" if value is None else str(value) for value in values]
104+
if str_values:
105+
width = max(len(name), *(len(value) for value in str_values))
106+
else:
107+
width = len(name)
108+
widths.append(width)
109+
110+
header = " | ".join(name.ljust(width) for name, width in zip(column_names, widths))
111+
separator = "-+-".join("-" * width for width in widths)
112+
print(header)
113+
print(separator)
114+
for row_values in zip(*columns):
115+
str_row = ["" if value is None else str(value) for value in row_values]
116+
line = " | ".join(value.ljust(width) for value, width in zip(str_row, widths))
117+
print(line)

0 commit comments

Comments
 (0)