diff --git a/python/pyproject.toml b/python/pyproject.toml index d04800ab..ee735bdb 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -41,7 +41,7 @@ requires = ["maturin>=1.4"] build-backend = "maturin" [project.optional-dependencies] -tests = ["pytest", "pyarrow>=14", "pandas"] +tests = ["pytest", "pyarrow>=14", "pandas", "ruff"] dev = ["ruff", "pyright"] llm = ["openai>=1.52.0"] lance-storage = ["lance>=0.17.0"] diff --git a/python/python/knowledge_graph/embeddings.py b/python/python/knowledge_graph/embeddings.py new file mode 100644 index 00000000..20a771a5 --- /dev/null +++ b/python/python/knowledge_graph/embeddings.py @@ -0,0 +1,108 @@ +"""Embedding utilities backed by the OpenAI client.""" + +from __future__ import annotations + +import logging +import math +from typing import Any, Mapping, MutableMapping, Sequence + +LOGGER = logging.getLogger(__name__) + +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" + + +class EmbeddingGenerator: + """Generate embeddings using OpenAI's embeddings API.""" + + def __init__( + self, + *, + model: str = DEFAULT_EMBEDDING_MODEL, + client_options: Mapping[str, Any] | None = None, + ) -> None: + self._model = model + self._client_options: MutableMapping[str, Any] = dict(client_options or {}) + self._client = None + + def embed(self, texts: Sequence[str]) -> list[list[float]]: + """Return embeddings for the provided texts.""" + sanitized = [text for text in texts if text] + if not sanitized: + return [] + + client = self._ensure_client() + response = client.embeddings.create(model=self._model, input=sanitized) + data = getattr(response, "data", None) + if not data: + raise RuntimeError("OpenAI embedding response missing 'data' entries.") + + embeddings: list[list[float]] = [] + for item in data: + vector = getattr(item, "embedding", None) + if vector is None and isinstance(item, Mapping): + vector = item.get("embedding") + if vector is None: + raise RuntimeError("OpenAI embedding response missing 'embedding'.") + try: + embeddings.append([float(value) for value in vector]) + except (TypeError, ValueError) as exc: + raise RuntimeError( + "Embedding vector contains non-numeric values." + ) from exc + return embeddings + + def embed_one(self, text: str) -> list[float] | None: + """Return a single embedding for convenience.""" + vectors = self.embed([text]) + return vectors[0] if vectors else None + + def _ensure_client(self): + if self._client is None: + try: + from openai import OpenAI # type: ignore[import-not-found] + except ImportError as exc: + raise RuntimeError( + "The `openai` package is required for embeddings. " + "Install it or supply a custom client." + ) from exc + + sanitized_opts = _sanitize_options(self._client_options) + LOGGER.debug( + "Initializing OpenAI embeddings client", + extra={ + "lance_graph": { + "openai_model": self._model, + "openai_options": sanitized_opts, + } + }, + ) + self._client = OpenAI(**self._client_options) + return self._client + + +def cosine_similarity(lhs: Sequence[float], rhs: Sequence[float]) -> float: + """Return cosine similarity between two vectors.""" + if len(lhs) != len(rhs): + LOGGER.debug( + "Unable to compute cosine similarity due to mismatched lengths: %s vs %s", + len(lhs), + len(rhs), + ) + return 0.0 + dot = sum(x * y for x, y in zip(lhs, rhs)) + lhs_norm = math.sqrt(sum(x * x for x in lhs)) + rhs_norm = math.sqrt(sum(y * y for y in rhs)) + if lhs_norm == 0 or rhs_norm == 0: + return 0.0 + return dot / (lhs_norm * rhs_norm) + + +def _sanitize_options(options: Mapping[str, Any]) -> dict[str, Any]: + """Strip sensitive values for logging.""" + sanitized: dict[str, Any] = {} + for key, value in options.items(): + if key.lower() in {"api_key", "api-key", "authorization"}: + sanitized[key] = "***" + else: + sanitized[key] = value + return sanitized diff --git a/python/python/knowledge_graph/main.py b/python/python/knowledge_graph/main.py index addc4665..570fadd2 100644 --- a/python/python/knowledge_graph/main.py +++ b/python/python/knowledge_graph/main.py @@ -9,12 +9,17 @@ import sys from dataclasses import asdict, is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Sequence import yaml from . import extraction as kg_extraction from .config import KnowledgeGraphConfig +from .embeddings import ( + DEFAULT_EMBEDDING_MODEL, + EmbeddingGenerator, + cosine_similarity, +) from .service import LanceKnowledgeGraph from .store import LanceGraphStore @@ -24,6 +29,12 @@ from .extractors import ExtractionResult +LOGGER = logging.getLogger(__name__) + +DEFAULT_SEED_COUNT = 5 +DEFAULT_SEED_NEIGHBOR_LIMIT = 50 + + def init_graph(config: KnowledgeGraphConfig) -> None: """Initialize the on-disk storage and scaffold the schema file.""" config.ensure_directories() @@ -172,13 +183,17 @@ def extract_and_add( source: str, service: LanceKnowledgeGraph, extractor: kg_extraction.BaseExtractor, + *, + embedding_generator: EmbeddingGenerator | None = None, ) -> None: """Extract knowledge and append it to the backing graph.""" import pyarrow as pa text = _resolve_text_input(source) result = kg_extraction.preview_extraction(text, extractor=extractor) - entity_rows, name_to_id = _prepare_entity_rows(result.entities) + entity_rows, name_to_id = _prepare_entity_rows( + result.entities, embedding_generator=embedding_generator + ) relationships = result.relationships if not entity_rows and not relationships: @@ -191,7 +206,11 @@ def extract_and_add( message = f"Upserted {entity_table.num_rows} entity rows into dataset 'Entity'." print(message) - relationship_rows = _prepare_relationship_rows(relationships, name_to_id) + relationship_rows = _prepare_relationship_rows( + relationships, + name_to_id, + embedding_generator=embedding_generator, + ) if relationship_rows: rel_table = pa.Table.from_pylist(relationship_rows) service.upsert_table("RELATIONSHIP", rel_table, merge=True) @@ -209,7 +228,29 @@ def ask_question( args: argparse.Namespace, ) -> None: """Answer a natural-language question using the graph via LLM-assisted Cypher.""" - llm_client = _create_llm_client(args) + client_options = _load_llm_options(args.llm_config) + llm_client = _create_llm_client(args, options=client_options) + embedding_generator = _resolve_embedding_generator(args, options=client_options) + seed_limit = getattr(args, "seed_count", DEFAULT_SEED_COUNT) + try: + seed_limit = int(seed_limit) + except (TypeError, ValueError): + seed_limit = DEFAULT_SEED_COUNT + if seed_limit < 0: + seed_limit = 0 + + seed_entities = _find_seed_entities( + question, + service, + embedding_generator, + limit=seed_limit, + ) + seed_neighbors = _collect_seed_neighbors( + service, + seed_entities, + limit=DEFAULT_SEED_NEIGHBOR_LIMIT, + ) + schema_summary = _summarize_schema(service) type_hints = service.store.config.type_hints() type_hint_lines = _build_type_hint_lines(type_hints) @@ -218,17 +259,49 @@ def ask_question( schema_summary, type_hint_lines, type_hints, + seed_entities, + seed_neighbors, ) raw_plan = llm_client.complete(query_prompt) plan_payload = kg_extraction.parse_llm_json(raw_plan) query_plan = _extract_query_plan(plan_payload) - if not query_plan: + if not query_plan and not seed_entities: print("Unable to generate Cypher queries for the question.") return - execution_results = _execute_queries(service, query_plan) + execution_results: list[dict[str, Any]] = [] + if seed_entities: + execution_results.append( + { + "cypher": "(semantic search)", + "description": ( + "Top seed entities retrieved via embedding similarity search." + ), + "rows": seed_entities, + "truncated": False, + } + ) + if seed_neighbors: + execution_results.append( + { + "cypher": "(seed expansion)", + "description": ("Neighboring entities connected to the seed entities."), + "rows": seed_neighbors, + "truncated": bool( + DEFAULT_SEED_NEIGHBOR_LIMIT + and len(seed_neighbors) >= DEFAULT_SEED_NEIGHBOR_LIMIT + ), + } + ) + if query_plan: + execution_results.extend(_execute_queries(service, query_plan)) + + if not execution_results: + print("Unable to gather context for the question.") + return + answer_prompt = _build_answer_prompt(question, schema_summary, execution_results) raw_answer = llm_client.complete(answer_prompt) @@ -262,6 +335,8 @@ def _result_to_dict(result: "ExtractionResult") -> dict[str, list[dict]]: def _prepare_entity_rows( entities: list[Any], + *, + embedding_generator: EmbeddingGenerator | None = None, ) -> tuple[list[dict[str, Any]], dict[str, str]]: rows: list[dict[str, Any]] = [] name_to_id: dict[str, str] = {} @@ -280,12 +355,20 @@ def _prepare_entity_rows( payload["name_lower"] = name.lower() rows.append(payload) name_to_id.setdefault(name.lower(), entity_id) + if embedding_generator and rows: + _assign_embeddings( + rows, + embedding_generator, + _format_entity_embedding_input, + ) return rows, name_to_id def _prepare_relationship_rows( relationships: list[Any], name_to_id: dict[str, str], + *, + embedding_generator: EmbeddingGenerator | None = None, ) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for relation in relationships: @@ -308,9 +391,212 @@ def _prepare_relationship_rows( payload.setdefault("source_entity_name", source_name) payload.setdefault("target_entity_name", target_name) rows.append(payload) + if embedding_generator and rows: + _assign_embeddings( + rows, + embedding_generator, + _format_relationship_embedding_input, + ) return rows +def _assign_embeddings( + rows: list[dict[str, Any]], + embedding_generator: EmbeddingGenerator, + formatter: Callable[[Mapping[str, Any]], str], +) -> None: + texts: list[str] = [] + indices: list[int] = [] + for idx, row in enumerate(rows): + text = formatter(row) + if text: + texts.append(text) + indices.append(idx) + if not texts: + return + try: + vectors = embedding_generator.embed(texts) + except Exception as exc: # pragma: no cover - defensive logging path + LOGGER.warning("Failed to generate embeddings: %s", exc) + return + if len(vectors) != len(indices): + LOGGER.warning( + "Mismatch between embedding count and row count: expected %s, got %s", + len(indices), + len(vectors), + ) + return + for idx, vector in zip(indices, vectors): + rows[idx]["embedding"] = vector + + +def _format_entity_embedding_input(row: Mapping[str, Any]) -> str: + name = str(row.get("name", "")).strip() + entity_type = str(row.get("entity_type", "")).strip() + context = str(row.get("context", "")).strip() + pieces = [] + if name: + pieces.append(name) + if entity_type: + pieces.append(f"Type: {entity_type}") + if context: + pieces.append(f"Context: {context}") + return " | ".join(pieces) + + +def _format_relationship_embedding_input(row: Mapping[str, Any]) -> str: + source = str(row.get("source_entity_name") or row.get("source") or "").strip() + target = str(row.get("target_entity_name") or row.get("target") or "").strip() + relationship_type = str(row.get("relationship_type", "")).strip() + description = str(row.get("description", "")).strip() + core: list[str] = [] + if source or target: + if relationship_type: + core.append(f"{source} -[{relationship_type}]-> {target}".strip()) + else: + core.append(f"{source} -> {target}".strip()) + if description: + core.append(f"Description: {description}") + return " | ".join(part for part in core if part) + + +def _find_seed_entities( + question: str, + service: LanceKnowledgeGraph, + embedding_generator: EmbeddingGenerator | None, + *, + limit: int = DEFAULT_SEED_COUNT, +) -> list[dict[str, Any]]: + if not embedding_generator: + return [] + prepared_question = question.strip() + if not prepared_question: + return [] + if limit <= 0: + return [] + if not service.has_dataset("Entity"): + return [] + try: + question_vector = embedding_generator.embed_one(prepared_question) + except Exception as exc: # pragma: no cover - defensive logging path + LOGGER.warning("Failed to embed question for semantic search: %s", exc) + return [] + if not question_vector: + return [] + try: + question_vector = [float(value) for value in question_vector] + except (TypeError, ValueError): + LOGGER.warning("Question embedding returned non-numeric values.") + return [] + try: + entity_table = service.load_table("Entity") + except Exception as exc: + LOGGER.warning("Unable to load Entity dataset for semantic search: %s", exc) + return [] + seeds: list[dict[str, Any]] = [] + for row in entity_table.to_pylist(): + embedding = row.get("embedding") + if not isinstance(embedding, (list, tuple)): + continue + try: + vector = [float(value) for value in embedding] + except (TypeError, ValueError): + continue + try: + similarity = float(cosine_similarity(question_vector, vector)) + except Exception: + similarity = 0.0 + entity_id = row.get("entity_id") + if not entity_id: + continue + seeds.append( + { + "entity_id": entity_id, + "name": row.get("name"), + "entity_type": row.get("entity_type"), + "similarity": similarity, + "context": row.get("context"), + } + ) + seeds.sort(key=lambda item: item.get("similarity", 0.0), reverse=True) + if limit and len(seeds) > limit: + seeds = seeds[:limit] + return seeds + + +def _collect_seed_neighbors( + service: LanceKnowledgeGraph, + seed_entities: Sequence[Mapping[str, Any]], + *, + limit: int = DEFAULT_SEED_NEIGHBOR_LIMIT, +) -> list[dict[str, Any]]: + if not seed_entities: + return [] + if not (service.has_dataset("Entity") and service.has_dataset("RELATIONSHIP")): + return [] + try: + entity_rows = service.load_table("Entity").to_pylist() + relationship_rows = service.load_table("RELATIONSHIP").to_pylist() + except Exception as exc: + LOGGER.warning("Unable to load datasets for neighbor expansion: %s", exc) + return [] + + id_to_entity: dict[str, Mapping[str, Any]] = {} + for entity in entity_rows: + entity_id = entity.get("entity_id") + if entity_id: + id_to_entity[str(entity_id)] = entity + + seed_ids = { + str(seed.get("entity_id")) for seed in seed_entities if seed.get("entity_id") + } + if not seed_ids: + return [] + + neighbors: list[dict[str, Any]] = [] + for relation in relationship_rows: + source_id = relation.get("source_entity_id") + target_id = relation.get("target_entity_id") + if source_id in seed_ids or target_id in seed_ids: + if source_id in seed_ids: + direction = "outgoing" + seed_id = str(source_id) + neighbor_id = str(target_id) if target_id else "" + else: + direction = "incoming" + seed_id = str(target_id) + neighbor_id = str(source_id) if source_id else "" + if not neighbor_id: + continue + seed_entity = id_to_entity.get(seed_id, {}) + neighbor_entity = id_to_entity.get(neighbor_id, {}) + neighbors.append( + { + "seed_entity_id": seed_id, + "seed_name": seed_entity.get("name"), + "seed_entity_type": seed_entity.get("entity_type"), + "neighbor_entity_id": neighbor_id, + "neighbor_name": neighbor_entity.get("name"), + "neighbor_entity_type": neighbor_entity.get("entity_type"), + "relationship_type": relation.get("relationship_type"), + "relationship_description": relation.get("description"), + "direction": direction, + } + ) + if not neighbors: + return [] + neighbors.sort( + key=lambda item: ( + str(item.get("seed_name") or ""), + str(item.get("neighbor_name") or ""), + str(item.get("relationship_type") or ""), + ) + ) + if limit and len(neighbors) > limit: + return neighbors[:limit] + return neighbors + + def _summarize_schema( service: LanceKnowledgeGraph, max_columns: int = 20, @@ -388,91 +674,113 @@ def _build_query_prompt( schema_summary: str, type_hint_lines: list[str], type_hints: Mapping[str, tuple[str, ...]], + seed_entities: Sequence[Mapping[str, Any]] | None = None, + seed_neighbors: Sequence[Mapping[str, Any]] | None = None, ) -> str: example_rel_type = _select_example_relationship_type(type_hints) - instructions = "\n".join( - [ - "You translate questions into Cypher for Lance graph datasets.", - ( - "Use the schema summary to craft queries that directly answer the " - "question." - ), - ( - " • Use the schema summary and allowed relationship_type values to " - "identify candidate relationship directions and types." - ), - ( - " • When the schema lists relationship_type values and the question " - "does not narrow them down, treat the list as exhaustive and include " - "every value in your filter using OR clauses or " - "WHERE rel.relationship_type IN [...]." - ), - ( - "Always specify node labels and relationship types in MATCH patterns " - "that introduce aliases." - ), - "Supported constructs include:", - (" • MATCH (e:Entity) to scan entity rows (name, name_lower, entity_id)."), - ( - " • MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) to traverse " - "relationships (relationship_type column); `src` aligns with " - "`source_entity_id` and `dst` with `target_entity_id`." - ), - ( - " • Decide which node should be `src` versus `dst` based on the " - "relationship meaning in the question and schema hints." - ), - ( - " • Map natural language roles (team, person, product, etc.) to the " - "`entity_type` column so queries filter to the expected entities." - ), - " • Use WHERE e.column = 'value' for node-level filters.", - ( - " • Filter relationships with WHERE rel.relationship_type = 'VALUE' " - "or by comparing rel.source_entity_id / rel.target_entity_id; when the " - "question does not name a specific relationship type, include every " - "relevant value from the schema summary using OR clauses or " - "WHERE rel.relationship_type IN [...], explicitly note which values " - "you considered, and avoid emitting only a single guessed type." - ), - ( - " • Select columns using the aliases you define, such as e.name or " - "rel.relationship_type." - ), - ( - " • Avoid inventing relationship datasets; match RELATIONSHIP and " - "filter rel.relationship_type instead of [:TYPE]." - ), - ( - "Example: MATCH (part:Entity)-[rel:RELATIONSHIP]->(whole:Entity) " - f"WHERE rel.relationship_type = '{example_rel_type}' " - "RETURN part.name, whole.name." - ), - ( - "Example: MATCH (a:Entity)-[rel:RELATIONSHIP]->(b:Entity) WHERE " - "rel.relationship_type = 'TYPE_A' OR rel.relationship_type = 'TYPE_B' " - "RETURN a.name, b.name." - ), - ( - "Example: MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) WHERE " - "rel.relationship_type IN ['TYPE_A', 'TYPE_B', 'TYPE_C'] " - "RETURN src.name, dst.name." - ), - ( - "Example: MATCH (dst:Entity) WHERE dst.name_lower = 'acme corp' " - "RETURN dst.name, dst.entity_id." - ), - ( - f"Do not use relationship patterns like [:{example_rel_type}]; rely on " - "rel.relationship_type filters instead." - ), - ( - "Always emit at least one query when relevant data exists; only " - "return [] when it is impossible to answer." - ), - "Return ONLY a JSON array where each item has `cypher` and `description`.", - ] - ) + instruction_lines = [ + "You translate questions into Cypher for Lance graph datasets.", + ("Use the schema summary to craft queries that directly answer the question."), + ( + " • Use the schema summary and allowed relationship_type values to " + "identify candidate relationship directions and types." + ), + ( + " • When the schema lists relationship_type values and the question " + "does not narrow them down, treat the list as exhaustive and include " + "every value in your filter using OR clauses or " + "WHERE rel.relationship_type IN [...]." + ), + ( + "Always specify node labels and relationship types in MATCH patterns " + "that introduce aliases." + ), + "Supported constructs include:", + (" • MATCH (e:Entity) to scan entity rows (name, name_lower, entity_id)."), + ( + " • MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) to traverse " + "relationships (relationship_type column); `src` aligns with " + "`source_entity_id` and `dst` with `target_entity_id`." + ), + ( + " • Decide which node should be `src` versus `dst` based on the " + "relationship meaning in the question and schema hints." + ), + ( + " • Map natural language roles (team, person, product, etc.) to the " + "`entity_type` column so queries filter to the expected entities." + ), + " • Use WHERE e.column = 'value' for node-level filters.", + ( + " • Filter relationships with WHERE rel.relationship_type = 'VALUE' " + "or by comparing rel.source_entity_id / rel.target_entity_id; when the " + "question does not name a specific relationship type, include every " + "relevant value from the schema summary using OR clauses or " + "WHERE rel.relationship_type IN [...], explicitly note which values " + "you considered, and avoid emitting only a single guessed type." + ), + ( + " • Select columns using the aliases you define, such as e.name or " + "rel.relationship_type." + ), + ( + " • Avoid inventing relationship datasets; match RELATIONSHIP and " + "filter rel.relationship_type instead of [:TYPE]." + ), + ( + "Example: MATCH (part:Entity)-[rel:RELATIONSHIP]->(whole:Entity) " + f"WHERE rel.relationship_type = '{example_rel_type}' " + "RETURN part.name, whole.name." + ), + ( + "Example: MATCH (a:Entity)-[rel:RELATIONSHIP]->(b:Entity) WHERE " + "rel.relationship_type = 'TYPE_A' OR rel.relationship_type = 'TYPE_B' " + "RETURN a.name, b.name." + ), + ( + "Example: MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) WHERE " + "rel.relationship_type IN ['TYPE_A', 'TYPE_B', 'TYPE_C'] " + "RETURN src.name, dst.name." + ), + ( + "Example: MATCH (dst:Entity) WHERE dst.name_lower = 'acme corp' " + "RETURN dst.name, dst.entity_id." + ), + ( + f"Do not use relationship patterns like [:{example_rel_type}]; rely on " + "rel.relationship_type filters instead." + ), + ( + "Always emit at least one query when relevant data exists; only " + "return [] when it is impossible to answer." + ), + "Return ONLY a JSON array where each item has `cypher` and `description`.", + ] + if seed_entities: + instruction_lines.append( + "Prefer queries that start from the provided seed entities by referencing " + "their entity_id values before exploring related nodes." + ) + if seed_neighbors: + instruction_lines.extend( + [ + ( + "Use the provided seed neighbor relationships to decide " + "relationship direction." + ), + ( + " • Each neighbor entry includes a `direction` field: 'outgoing' " + "means the seed entity is the relationship source; 'incoming' " + "means the seed entity is the target." + ), + ( + " • Build MATCH patterns accordingly, e.g., outgoing -> " + "(seed)-[rel:RELATIONSHIP]->(neighbor); incoming -> " + "(neighbor)-[rel:RELATIONSHIP]->(seed)." + ), + ] + ) + instructions = "\n".join(instruction_lines) if type_hint_lines: hint_block = "\n".join(f" • {line}" for line in type_hint_lines) @@ -487,9 +795,65 @@ def _build_query_prompt( prompt_parts = [ instructions, f"Schema summary:\n{schema_summary}", - f"Question:\n{question}", - "JSON:", ] + if seed_entities: + seed_lines = [] + for item in seed_entities: + similarity = item.get("similarity") + if isinstance(similarity, (int, float)): + score = f"{similarity:.3f}" + else: + score = "n/a" + display_name = str(item.get("name") or "(unknown)") + seed_lines.append( + ( + f"- {display_name} " + f"(entity_id={item.get('entity_id')}, " + f"entity_type={item.get('entity_type')}, similarity={score})" + ) + ) + prompt_parts.append( + "Seed entities discovered via embedding similarity:\n" + + "\n".join(seed_lines) + ) + if seed_neighbors: + neighbor_lines: list[str] = [] + for entry in seed_neighbors: + direction = str(entry.get("direction") or "outgoing") + seed_name = str( + entry.get("seed_name") or entry.get("seed_entity_id") or "(seed)" + ) + neighbor_name = str( + entry.get("neighbor_name") + or entry.get("neighbor_entity_id") + or "(neighbor)" + ) + rel_type = entry.get("relationship_type") or "RELATIONSHIP" + description = entry.get("relationship_description") or "" + seed_id = entry.get("seed_entity_id") + neighbor_id = entry.get("neighbor_entity_id") + if direction.lower() == "incoming": + arrow = f"{neighbor_name} -[{rel_type}]-> {seed_name}" + else: + arrow = f"{seed_name} -[{rel_type}]-> {neighbor_name}" + line = ( + f"- {arrow} (seed_entity_id={seed_id}, " + f"neighbor_entity_id={neighbor_id}, direction={direction}" + ) + if description: + line += f", description={description}" + line += ")" + neighbor_lines.append(line) + prompt_parts.append( + "Seed neighbor relationships (match patterns to respect direction):\n" + + "\n".join(neighbor_lines) + ) + prompt_parts.extend( + [ + f"Question:\n{question}", + "JSON:", + ] + ) return "\n\n".join(prompt_parts) @@ -635,6 +999,23 @@ def _resolve_extractor(args: argparse.Namespace) -> kg_extraction.BaseExtractor: ) +def _resolve_embedding_generator( + args: argparse.Namespace, + *, + options: Optional[Mapping[str, Any]] = None, +) -> EmbeddingGenerator | None: + model = getattr(args, "embedding_model", None) + model_name = (model or "").strip() + if not model_name or model_name.lower() == "none": + return None + client_options = dict(options or _load_llm_options(args.llm_config)) + try: + return EmbeddingGenerator(model=model_name, client_options=client_options) + except RuntimeError as exc: + LOGGER.warning("Embeddings disabled: %s", exc) + return None + + def _load_llm_options(path: Optional[Path]) -> dict: if not path: return {} @@ -652,12 +1033,16 @@ def _load_llm_options(path: Optional[Path]) -> dict: return data -def _create_llm_client(args: argparse.Namespace) -> kg_extraction.LLMClient: - options = _load_llm_options(args.llm_config) +def _create_llm_client( + args: argparse.Namespace, + *, + options: Optional[Mapping[str, Any]] = None, +) -> kg_extraction.LLMClient: + resolved_options = dict(options or _load_llm_options(args.llm_config)) return kg_extraction.get_llm_client( llm_model=args.llm_model, llm_temperature=args.llm_temperature, - llm_options=options, + llm_options=resolved_options, ) @@ -714,6 +1099,20 @@ def _build_parser() -> argparse.ArgumentParser: "headers, etc)." ), ) + parser.add_argument( + "--embedding-model", + default=DEFAULT_EMBEDDING_MODEL, + help=("OpenAI embedding model for semantic search (set to 'none' to disable)."), + ) + parser.add_argument( + "--seed-count", + type=int, + default=DEFAULT_SEED_COUNT, + help=( + "Maximum number of seed entities to surface from similarity search " + f"(default: {DEFAULT_SEED_COUNT})." + ), + ) parser.add_argument( "--log-level", default="WARNING", @@ -790,7 +1189,13 @@ def main(argv: Optional[Sequence[str]] = None) -> int: if args.extract_and_add: extractor = _resolve_extractor(args) - extract_and_add(args.extract_and_add, service, extractor) + embedding_generator = _resolve_embedding_generator(args) + extract_and_add( + args.extract_and_add, + service, + extractor, + embedding_generator=embedding_generator, + ) return 0 if args.ask: ask_question(args.ask, service, args) diff --git a/python/uv.lock b/python/uv.lock index d2a07722..1c0fd833 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -316,6 +316,7 @@ tests = [ { name = "pandas" }, { name = "pyarrow" }, { name = "pytest" }, + { name = "ruff" }, ] [package.metadata] @@ -331,6 +332,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'tests'" }, { name = "pyyaml", specifier = ">=6.0" }, { name = "ruff", marker = "extra == 'dev'" }, + { name = "ruff", marker = "extra == 'tests'" }, { name = "uvicorn", specifier = ">=0.24.0" }, ] provides-extras = ["tests", "dev", "llm", "lance-storage"]