Skip to content

Commit 1bbc849

Browse files
authored
feat(python): discover actual entity and relationship types from the … (#65)
* feat(python): discover actual entity and relationship types from the table * format code
1 parent f0aafbe commit 1bbc849

1 file changed

Lines changed: 69 additions & 14 deletions

File tree

  • python/python/knowledge_graph/llm

python/python/knowledge_graph/llm/qa.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,47 @@ def ask_question(
6464

6565
schema_summary = summarize_schema(service)
6666
type_hints = service.store.config.type_hints()
67-
allowed_relationship_types = tuple(
68-
str(t) for t in (type_hints.get("relationship_types") or ())
69-
)
70-
if not allowed_relationship_types:
71-
discovered = _discover_relationship_types(service)
72-
allowed_relationship_types = tuple(discovered)
73-
if discovered:
74-
LOGGER.debug(
75-
"Discovered relationship_type values from dataset: %s",
76-
", ".join(discovered),
77-
)
78-
type_hint_lines = build_type_hint_lines(type_hints)
67+
68+
# Discover actual relationship types from data
69+
discovered_rel_types = _discover_relationship_types(service)
70+
if discovered_rel_types:
71+
allowed_relationship_types = tuple(discovered_rel_types)
72+
LOGGER.debug(
73+
"Discovered relationship_type values from dataset: %s",
74+
", ".join(discovered_rel_types),
75+
)
76+
else:
77+
# Fall back to config types if discovery fails
78+
allowed_relationship_types = tuple(
79+
str(t) for t in (type_hints.get("relationship_types") or ())
80+
)
81+
82+
# Discover actual entity types from data
83+
discovered_entity_types = _discover_entity_types(service)
84+
if discovered_entity_types:
85+
LOGGER.debug(
86+
"Discovered entity_type values from dataset: %s",
87+
", ".join(discovered_entity_types),
88+
)
89+
else:
90+
# Fall back to config types if discovery fails
91+
discovered_entity_types = list(
92+
str(t) for t in (type_hints.get("entity_types") or ())
93+
)
94+
95+
# Use discovered types in the prompt instead of config types
96+
actual_type_hints = dict(type_hints)
97+
if discovered_rel_types:
98+
actual_type_hints["relationship_types"] = tuple(discovered_rel_types)
99+
if discovered_entity_types:
100+
actual_type_hints["entity_types"] = tuple(discovered_entity_types)
101+
102+
type_hint_lines = build_type_hint_lines(actual_type_hints)
79103
query_prompt = build_query_prompt(
80104
question,
81105
schema_summary,
82106
type_hint_lines,
83-
type_hints,
107+
actual_type_hints,
84108
seed_entities,
85109
seed_neighbors,
86110
)
@@ -291,13 +315,44 @@ def replace_in(match: re.Match[str]) -> str:
291315

292316

293317
def _discover_relationship_types(service: LanceKnowledgeGraph) -> list[str]:
294-
"""Discover distinct relationship_type values from the dataset as a fallback."""
318+
"""Discover distinct relationship_type values from the dataset.
319+
320+
Results are cached on the service object to avoid repeated table loads.
321+
"""
322+
if hasattr(service, "_cached_rel_types"):
323+
return service._cached_rel_types
324+
295325
try:
296326
table = service.load_table("RELATIONSHIP")
297327
if "relationship_type" in table.column_names:
298328
values = table.column("relationship_type").to_pylist()
299329
distinct = sorted({str(v) for v in values if v is not None and str(v)})
330+
service._cached_rel_types = distinct
300331
return distinct
301332
except Exception as exc: # pragma: no cover - defensive
302333
LOGGER.debug("Unable to discover relationship types: %s", exc)
334+
335+
service._cached_rel_types = []
336+
return []
337+
338+
339+
def _discover_entity_types(service: LanceKnowledgeGraph) -> list[str]:
340+
"""Discover distinct entity_type values from the dataset.
341+
342+
Results are cached on the service object to avoid repeated table loads.
343+
"""
344+
if hasattr(service, "_cached_entity_types"):
345+
return service._cached_entity_types
346+
347+
try:
348+
table = service.load_table("Entity")
349+
if "entity_type" in table.column_names:
350+
values = table.column("entity_type").to_pylist()
351+
distinct = sorted({str(v) for v in values if v is not None and str(v)})
352+
service._cached_entity_types = distinct
353+
return distinct
354+
except Exception as exc: # pragma: no cover - defensive
355+
LOGGER.debug("Unable to discover entity types: %s", exc)
356+
357+
service._cached_entity_types = []
303358
return []

0 commit comments

Comments
 (0)