Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,15 @@ def get_function_by_name(self, name: str) -> Optional[Node]:

return res[0][0]

def prefix_search(self, prefix: str) -> str:
def prefix_search(self, prefix: str, limit: int = 10) -> str:
"""
Search for entities by prefix using a full-text search on the graph.
The search is limited to 10 nodes. Each node's name and labels are retrieved,
and the results are sorted based on their labels.
The number of results is bounded by ``limit`` (default 10). Each node's
name and labels are retrieved, and the results are sorted based on their labels.
Args:
prefix (str): The prefix string to search for in the graph database.
limit (int): Maximum number of nodes to return (default 10).
Returns:
str: A list of entity names and corresponding labels, sorted by label.
Expand All @@ -465,19 +466,19 @@ def prefix_search(self, prefix: str) -> str:
# Append a wildcard '*' to the prefix for full-text search.
search_prefix = f"{prefix}*"

# Cypher query to perform full-text search and limit the result to 10 nodes.
# Cypher query to perform full-text search, bounding the result at $limit.
# The 'CALL db.idx.fulltext.queryNodes' method searches for nodes labeled 'Searchable'
# that match the given prefix, collects the nodes, and returns the result.
query = """
CALL db.idx.fulltext.queryNodes('Searchable', $prefix)
YIELD node
WITH node
RETURN node
LIMIT 10
LIMIT $limit
"""

# Execute the query using the provided graph database connection.
result_set = self._query(query, {'prefix': search_prefix}).result_set
result_set = self._query(query, {'prefix': search_prefix, 'limit': int(limit)}).result_set

completions = [encode_node(row[0]) for row in result_set]

Expand Down Expand Up @@ -658,13 +659,16 @@ def rerun_query(self, q: str, params: dict) -> QueryResult:

return self._query(q, params)

def find_paths(self, src: int, dest: int) -> list[Path]:
def find_paths(self, src: int, dest: int, limit: Optional[int] = None) -> list[Path]:
"""
Find all paths between the source (src) and destination (dest) nodes.
Args:
src (int): The ID of the source node.
dest (int): The ID of the destination node.
limit (Optional[int]): When provided, bound the number of paths
enumerated by the database with a Cypher ``LIMIT``. When ``None``
(default) all paths are returned (legacy behavior).
Returns:
List[Optional[Path]]: A list of paths found between the src and dest nodes.
Expand All @@ -682,8 +686,13 @@ def find_paths(self, src: int, dest: int) -> list[Path]:
RETURN p
"""

params = {'src_id': src, 'dest_id': dest}
if limit is not None:
q += " LIMIT $limit\n"
params['limit'] = int(limit)

# Perform the query with the source and destination node IDs.
result_set = self._query(q, {'src_id': src, 'dest_id': dest}).result_set
result_set = self._query(q, params).result_set

paths = []

Expand Down Expand Up @@ -861,26 +870,30 @@ async def get_neighbors(self, node_ids: list[int], rel: Optional[str] = None, lb
logging.error(f"Error fetching neighbors for node {node_ids}: {e}")
return {'nodes': [], 'edges': []}

async def prefix_search(self, prefix: str) -> list:
async def prefix_search(self, prefix: str, limit: int = 10) -> list:
search_prefix = f"{prefix}*"
query = """
CALL db.idx.fulltext.queryNodes('Searchable', $prefix)
YIELD node
WITH node
RETURN node
LIMIT 10
LIMIT $limit
"""
result_set = (await self._query(query, {'prefix': search_prefix})).result_set
result_set = (await self._query(query, {'prefix': search_prefix, 'limit': int(limit)})).result_set
return [encode_node(row[0]) for row in result_set]

async def find_paths(self, src: int, dest: int) -> list:
async def find_paths(self, src: int, dest: int, limit: Optional[int] = None) -> list:
q = """MATCH (src), (dest)
WHERE ID(src) = $src_id AND ID(dest) = $dest_id
WITH src, dest
MATCH p = (src)-[:CALLS*]->(dest)
RETURN p
"""
result_set = (await self._query(q, {'src_id': src, 'dest_id': dest})).result_set
params = {'src_id': src, 'dest_id': dest}
if limit is not None:
q += " LIMIT $limit\n"
params['limit'] = int(limit)
result_set = (await self._query(q, params)).result_set
paths = []
for row in result_set:
path = []
Expand Down
275 changes: 275 additions & 0 deletions api/mcp/tools/structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import asyncio
import logging
import os
import re
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -208,3 +209,277 @@ def _payload(project) -> dict[str, Any]:
}

return await loop.run_in_executor(None, _do_index)


# ---------------------------------------------------------------------------
# T5 — get_callers / get_callees / get_dependencies
# ---------------------------------------------------------------------------


def _project_arg(project: str, branch: Optional[str]):
"""Return an :class:`AsyncGraphQuery` for ``(project, branch)``."""
from api.graph import AsyncGraphQuery

return AsyncGraphQuery(project, branch=branch)


def _node_summary(n: Any) -> dict[str, Any]:
"""Normalize a FalkorDB Node (or already-encoded dict) to a flat payload.

``encode_node`` returns ``{id, labels, properties: {...}}`` because Node
properties live on a nested attribute. Agents want a flat record, and
they also want a single ``label`` (the meaningful one — File, Class,
Function — not the fulltext-index marker ``Searchable``).
"""
if hasattr(n, "properties"):
props = dict(n.properties or {})
labels = list(n.labels or [])
node_id = getattr(n, "id", None)
else:
d = dict(n)
props = dict(d.get("properties") or {})
labels = list(d.get("labels") or [])
node_id = d.get("id")

label = next((lbl for lbl in labels if lbl != "Searchable"), None)
return {
"id": node_id,
"name": props.get("name"),
"label": label,
"file": props.get("path"),
"line": props.get("src_start"),
}


# Relationship-type names are graph labels (SCREAMING_SNAKE_CASE, e.g. CALLS,
# IMPORTS, DEFINES). FalkorDB cannot parameterize relationship types, so any
# ``rel`` interpolated into Cypher must be validated against this pattern to
# prevent Cypher injection via agent-controlled input.
_REL_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def _validate_relation(rel: str) -> str:
"""Return ``rel`` if it is a safe relationship-type name, else raise.

Guards the relationship types that are string-interpolated into Cypher
(``-[e:{rel}]->``) — parameter binding is not available for relation
types in FalkorDB.
"""
if not isinstance(rel, str) or not _REL_NAME_RE.match(rel):
raise ValueError(f"invalid relation type: {rel!r}")
return rel


def _coerce_node_id(symbol_id: Any) -> int:
"""Accept int or stringified int; raise ValueError otherwise.

The MCP wire format is JSON; agents sometimes hand back the id as a
string. Be permissive on input, strict on type after parsing.
"""
if isinstance(symbol_id, bool): # bool is an int subclass; reject loudly
raise ValueError(f"symbol_id must be an integer, got bool: {symbol_id!r}")
if isinstance(symbol_id, int):
return symbol_id
if isinstance(symbol_id, str) and symbol_id.lstrip("-").isdigit():
return int(symbol_id)
raise ValueError(f"symbol_id must be an integer id, got: {symbol_id!r}")


async def _neighbors_payload(
project: str,
branch: Optional[str],
symbol_id: Any,
rel: str,
direction: str,
limit: int,
) -> list[dict[str, Any]]:
"""Shared implementation for caller/callee/dependency tools.

``direction`` is ``IN`` (incoming edges, e.g. callers) or ``OUT``
(outgoing edges, e.g. callees). When ``IN`` we run the inverse Cypher
``(neighbor)-[:rel]->(target)``; ``AsyncGraphQuery.get_neighbors`` only
walks outgoing edges, so we inline the Cypher here for symmetry.
"""
node_id = _coerce_node_id(symbol_id)
rel = _validate_relation(rel)
g = _project_arg(project, branch)
try:
if direction == "OUT":
q = (
f"MATCH (n)-[e:{rel}]->(dest) "
f"WHERE ID(n) = $sid "
f"RETURN dest, type(e) AS rel "
f"LIMIT $limit"
)
elif direction == "IN":
q = (
f"MATCH (src)-[e:{rel}]->(n) "
f"WHERE ID(n) = $sid "
f"RETURN src AS dest, type(e) AS rel "
f"LIMIT $limit"
)
else:
raise ValueError(f"direction must be IN or OUT, got: {direction!r}")

res = await g._query(q, {"sid": node_id, "limit": int(limit)})
out: list[dict[str, Any]] = []
for row in res.result_set:
entry = _node_summary(row[0])
entry["relation"] = row[1]
entry["direction"] = direction
out.append(entry)
return out
finally:
await g.close()


@app.tool(
name="get_callers",
description=(
"Return functions that call the given symbol (incoming CALLS edges). "
"`symbol_id` is the integer node id returned by `search_code` or "
"other tools."
),
)
async def get_callers(
symbol_id: int | str,
project: str,
branch: Optional[str] = None,
limit: int = 50,
) -> list[dict[str, Any]]:
Comment thread
DvirDukhan marked this conversation as resolved.
return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit)


@app.tool(
name="get_callees",
description=(
"Return functions that the given symbol calls (outgoing CALLS edges)."
),
)
async def get_callees(
symbol_id: int | str,
project: str,
branch: Optional[str] = None,
limit: int = 50,
) -> list[dict[str, Any]]:
Comment thread
DvirDukhan marked this conversation as resolved.
return await _neighbors_payload(project, branch, symbol_id, "CALLS", "OUT", limit)


@app.tool(
name="get_dependencies",
description=(
"Return outgoing neighbors of the given symbol across any of the "
"specified relation types (default: IMPORTS, CALLS, DEFINES). "
"Useful for 'what does this depend on' queries."
),
)
async def get_dependencies(
symbol_id: int | str,
project: str,
branch: Optional[str] = None,
rels: Optional[list[str]] = None,
limit: int = 50,
) -> list[dict[str, Any]]:
Comment thread
DvirDukhan marked this conversation as resolved.
if rels is None:
rels = ["IMPORTS", "CALLS", "DEFINES"]
# Aggregate across relations; preserve ordering and dedupe by id.
seen: set[Any] = set()
out: list[dict[str, Any]] = []
for rel in rels:
# Only fetch the rows we can still accept, so total DB work is
# bounded by ``limit`` rather than ``limit * len(rels)``.
remaining = limit - len(out)
if remaining <= 0:
break
rows = await _neighbors_payload(
project, branch, symbol_id, rel, "OUT", remaining
)
for row in rows:
key = (row.get("id"), row.get("relation"))
if key in seen:
continue
seen.add(key)
out.append(row)
if len(out) >= limit:
return out
return out


# ---------------------------------------------------------------------------
# T7 — find_path
# ---------------------------------------------------------------------------


@app.tool(
name="find_path",
description=(
"Return up to `max_paths` CALLS-path sequences from `source_id` to "
"`dest_id`. Useful for 'how does A reach B' questions. Returns an "
"empty list when no path exists."
),
)
async def find_path(
source_id: int | str,
dest_id: int | str,
project: str,
branch: Optional[str] = None,
max_paths: int = 10,
) -> list[dict[str, Any]]:
Comment thread
DvirDukhan marked this conversation as resolved.
src = _coerce_node_id(source_id)
dst = _coerce_node_id(dest_id)
g = _project_arg(project, branch)
try:
# Bound DB work by ``max_paths`` so large graphs don't enumerate an
# unbounded number of paths before we slice in Python.
raw = await g.find_paths(src, dst, limit=max_paths)
finally:
await g.close()

# ``AsyncGraphQuery.find_paths`` returns each path as an alternating
# [node, edge, node, edge, ..., node] list; we strip edges and surface
# only the node sequence — that's what agents typically want.
paths: list[dict[str, Any]] = []
for entry in raw:
node_seq = [
_node_summary(x)
for x in entry
# Discriminate on ``labels``: ``encode_node`` emits a top-level
# ``labels`` key, while ``encode_edge`` does not (edges carry
# ``relation``/``src_node``/``dest_node`` instead). Filtering on
# ``properties`` would be wrong because FalkorDB's Edge also has a
# ``properties`` attribute, so edges would slip through as bogus
# all-null node entries.
if isinstance(x, dict) and "labels" in x
]
paths.append({"path": node_seq})
return paths
Comment thread
DvirDukhan marked this conversation as resolved.


# ---------------------------------------------------------------------------
# T8 — search_code
# ---------------------------------------------------------------------------


@app.tool(
name="search_code",
description=(
"Prefix-search for symbols (functions, classes, files) whose name "
"starts with `prefix`. Backed by FalkorDB's full-text index. The "
"agent typically calls this first to discover symbol ids for the "
"navigation tools (`get_callers`, `find_path`, ...)."
),
)
async def search_code(
prefix: str,
project: str,
branch: Optional[str] = None,
limit: int = 20,
) -> list[dict[str, Any]]:
g = _project_arg(project, branch)
try:
# Push the caller's ``limit`` down to the DB so it is actually honored
# (the underlying full-text query is otherwise capped at its default).
raw = await g.prefix_search(prefix, limit=limit)
finally:
await g.close()
return [_node_summary(node) for node in raw]
Loading
Loading