Skip to content

Commit 4135def

Browse files
authored
Merge pull request #679 from FalkorDB/dvirdukhan/mcp-t5-t7-t8-query-tools
feat(mcp): query tools — get_callers/callees/deps, find_path, search_code (T5/T7/T8)
2 parents cabe9a1 + fbbb366 commit 4135def

3 files changed

Lines changed: 574 additions & 13 deletions

File tree

api/graph.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -448,14 +448,15 @@ def get_function_by_name(self, name: str) -> Optional[Node]:
448448

449449
return res[0][0]
450450

451-
def prefix_search(self, prefix: str) -> str:
451+
def prefix_search(self, prefix: str, limit: int = 10) -> str:
452452
"""
453453
Search for entities by prefix using a full-text search on the graph.
454-
The search is limited to 10 nodes. Each node's name and labels are retrieved,
455-
and the results are sorted based on their labels.
454+
The number of results is bounded by ``limit`` (default 10). Each node's
455+
name and labels are retrieved, and the results are sorted based on their labels.
456456
457457
Args:
458458
prefix (str): The prefix string to search for in the graph database.
459+
limit (int): Maximum number of nodes to return (default 10).
459460
460461
Returns:
461462
str: A list of entity names and corresponding labels, sorted by label.
@@ -465,19 +466,19 @@ def prefix_search(self, prefix: str) -> str:
465466
# Append a wildcard '*' to the prefix for full-text search.
466467
search_prefix = f"{prefix}*"
467468

468-
# Cypher query to perform full-text search and limit the result to 10 nodes.
469+
# Cypher query to perform full-text search, bounding the result at $limit.
469470
# The 'CALL db.idx.fulltext.queryNodes' method searches for nodes labeled 'Searchable'
470471
# that match the given prefix, collects the nodes, and returns the result.
471472
query = """
472473
CALL db.idx.fulltext.queryNodes('Searchable', $prefix)
473474
YIELD node
474475
WITH node
475476
RETURN node
476-
LIMIT 10
477+
LIMIT $limit
477478
"""
478479

479480
# Execute the query using the provided graph database connection.
480-
result_set = self._query(query, {'prefix': search_prefix}).result_set
481+
result_set = self._query(query, {'prefix': search_prefix, 'limit': int(limit)}).result_set
481482

482483
completions = [encode_node(row[0]) for row in result_set]
483484

@@ -658,13 +659,16 @@ def rerun_query(self, q: str, params: dict) -> QueryResult:
658659

659660
return self._query(q, params)
660661

661-
def find_paths(self, src: int, dest: int) -> list[Path]:
662+
def find_paths(self, src: int, dest: int, limit: Optional[int] = None) -> list[Path]:
662663
"""
663664
Find all paths between the source (src) and destination (dest) nodes.
664665
665666
Args:
666667
src (int): The ID of the source node.
667668
dest (int): The ID of the destination node.
669+
limit (Optional[int]): When provided, bound the number of paths
670+
enumerated by the database with a Cypher ``LIMIT``. When ``None``
671+
(default) all paths are returned (legacy behavior).
668672
669673
Returns:
670674
List[Optional[Path]]: A list of paths found between the src and dest nodes.
@@ -682,8 +686,13 @@ def find_paths(self, src: int, dest: int) -> list[Path]:
682686
RETURN p
683687
"""
684688

689+
params = {'src_id': src, 'dest_id': dest}
690+
if limit is not None:
691+
q += " LIMIT $limit\n"
692+
params['limit'] = int(limit)
693+
685694
# Perform the query with the source and destination node IDs.
686-
result_set = self._query(q, {'src_id': src, 'dest_id': dest}).result_set
695+
result_set = self._query(q, params).result_set
687696

688697
paths = []
689698

@@ -861,26 +870,30 @@ async def get_neighbors(self, node_ids: list[int], rel: Optional[str] = None, lb
861870
logging.error(f"Error fetching neighbors for node {node_ids}: {e}")
862871
return {'nodes': [], 'edges': []}
863872

864-
async def prefix_search(self, prefix: str) -> list:
873+
async def prefix_search(self, prefix: str, limit: int = 10) -> list:
865874
search_prefix = f"{prefix}*"
866875
query = """
867876
CALL db.idx.fulltext.queryNodes('Searchable', $prefix)
868877
YIELD node
869878
WITH node
870879
RETURN node
871-
LIMIT 10
880+
LIMIT $limit
872881
"""
873-
result_set = (await self._query(query, {'prefix': search_prefix})).result_set
882+
result_set = (await self._query(query, {'prefix': search_prefix, 'limit': int(limit)})).result_set
874883
return [encode_node(row[0]) for row in result_set]
875884

876-
async def find_paths(self, src: int, dest: int) -> list:
885+
async def find_paths(self, src: int, dest: int, limit: Optional[int] = None) -> list:
877886
q = """MATCH (src), (dest)
878887
WHERE ID(src) = $src_id AND ID(dest) = $dest_id
879888
WITH src, dest
880889
MATCH p = (src)-[:CALLS*]->(dest)
881890
RETURN p
882891
"""
883-
result_set = (await self._query(q, {'src_id': src, 'dest_id': dest})).result_set
892+
params = {'src_id': src, 'dest_id': dest}
893+
if limit is not None:
894+
q += " LIMIT $limit\n"
895+
params['limit'] = int(limit)
896+
result_set = (await self._query(q, params)).result_set
884897
paths = []
885898
for row in result_set:
886899
path = []

api/mcp/tools/structural.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import asyncio
2121
import logging
2222
import os
23+
import re
2324
from pathlib import Path
2425
from typing import Any, Optional
2526

@@ -208,3 +209,277 @@ def _payload(project) -> dict[str, Any]:
208209
}
209210

210211
return await loop.run_in_executor(None, _do_index)
212+
213+
214+
# ---------------------------------------------------------------------------
215+
# T5 — get_callers / get_callees / get_dependencies
216+
# ---------------------------------------------------------------------------
217+
218+
219+
def _project_arg(project: str, branch: Optional[str]):
220+
"""Return an :class:`AsyncGraphQuery` for ``(project, branch)``."""
221+
from api.graph import AsyncGraphQuery
222+
223+
return AsyncGraphQuery(project, branch=branch)
224+
225+
226+
def _node_summary(n: Any) -> dict[str, Any]:
227+
"""Normalize a FalkorDB Node (or already-encoded dict) to a flat payload.
228+
229+
``encode_node`` returns ``{id, labels, properties: {...}}`` because Node
230+
properties live on a nested attribute. Agents want a flat record, and
231+
they also want a single ``label`` (the meaningful one — File, Class,
232+
Function — not the fulltext-index marker ``Searchable``).
233+
"""
234+
if hasattr(n, "properties"):
235+
props = dict(n.properties or {})
236+
labels = list(n.labels or [])
237+
node_id = getattr(n, "id", None)
238+
else:
239+
d = dict(n)
240+
props = dict(d.get("properties") or {})
241+
labels = list(d.get("labels") or [])
242+
node_id = d.get("id")
243+
244+
label = next((lbl for lbl in labels if lbl != "Searchable"), None)
245+
return {
246+
"id": node_id,
247+
"name": props.get("name"),
248+
"label": label,
249+
"file": props.get("path"),
250+
"line": props.get("src_start"),
251+
}
252+
253+
254+
# Relationship-type names are graph labels (SCREAMING_SNAKE_CASE, e.g. CALLS,
255+
# IMPORTS, DEFINES). FalkorDB cannot parameterize relationship types, so any
256+
# ``rel`` interpolated into Cypher must be validated against this pattern to
257+
# prevent Cypher injection via agent-controlled input.
258+
_REL_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
259+
260+
261+
def _validate_relation(rel: str) -> str:
262+
"""Return ``rel`` if it is a safe relationship-type name, else raise.
263+
264+
Guards the relationship types that are string-interpolated into Cypher
265+
(``-[e:{rel}]->``) — parameter binding is not available for relation
266+
types in FalkorDB.
267+
"""
268+
if not isinstance(rel, str) or not _REL_NAME_RE.match(rel):
269+
raise ValueError(f"invalid relation type: {rel!r}")
270+
return rel
271+
272+
273+
def _coerce_node_id(symbol_id: Any) -> int:
274+
"""Accept int or stringified int; raise ValueError otherwise.
275+
276+
The MCP wire format is JSON; agents sometimes hand back the id as a
277+
string. Be permissive on input, strict on type after parsing.
278+
"""
279+
if isinstance(symbol_id, bool): # bool is an int subclass; reject loudly
280+
raise ValueError(f"symbol_id must be an integer, got bool: {symbol_id!r}")
281+
if isinstance(symbol_id, int):
282+
return symbol_id
283+
if isinstance(symbol_id, str) and symbol_id.lstrip("-").isdigit():
284+
return int(symbol_id)
285+
raise ValueError(f"symbol_id must be an integer id, got: {symbol_id!r}")
286+
287+
288+
async def _neighbors_payload(
289+
project: str,
290+
branch: Optional[str],
291+
symbol_id: Any,
292+
rel: str,
293+
direction: str,
294+
limit: int,
295+
) -> list[dict[str, Any]]:
296+
"""Shared implementation for caller/callee/dependency tools.
297+
298+
``direction`` is ``IN`` (incoming edges, e.g. callers) or ``OUT``
299+
(outgoing edges, e.g. callees). When ``IN`` we run the inverse Cypher
300+
``(neighbor)-[:rel]->(target)``; ``AsyncGraphQuery.get_neighbors`` only
301+
walks outgoing edges, so we inline the Cypher here for symmetry.
302+
"""
303+
node_id = _coerce_node_id(symbol_id)
304+
rel = _validate_relation(rel)
305+
g = _project_arg(project, branch)
306+
try:
307+
if direction == "OUT":
308+
q = (
309+
f"MATCH (n)-[e:{rel}]->(dest) "
310+
f"WHERE ID(n) = $sid "
311+
f"RETURN dest, type(e) AS rel "
312+
f"LIMIT $limit"
313+
)
314+
elif direction == "IN":
315+
q = (
316+
f"MATCH (src)-[e:{rel}]->(n) "
317+
f"WHERE ID(n) = $sid "
318+
f"RETURN src AS dest, type(e) AS rel "
319+
f"LIMIT $limit"
320+
)
321+
else:
322+
raise ValueError(f"direction must be IN or OUT, got: {direction!r}")
323+
324+
res = await g._query(q, {"sid": node_id, "limit": int(limit)})
325+
out: list[dict[str, Any]] = []
326+
for row in res.result_set:
327+
entry = _node_summary(row[0])
328+
entry["relation"] = row[1]
329+
entry["direction"] = direction
330+
out.append(entry)
331+
return out
332+
finally:
333+
await g.close()
334+
335+
336+
@app.tool(
337+
name="get_callers",
338+
description=(
339+
"Return functions that call the given symbol (incoming CALLS edges). "
340+
"`symbol_id` is the integer node id returned by `search_code` or "
341+
"other tools."
342+
),
343+
)
344+
async def get_callers(
345+
symbol_id: int | str,
346+
project: str,
347+
branch: Optional[str] = None,
348+
limit: int = 50,
349+
) -> list[dict[str, Any]]:
350+
return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit)
351+
352+
353+
@app.tool(
354+
name="get_callees",
355+
description=(
356+
"Return functions that the given symbol calls (outgoing CALLS edges)."
357+
),
358+
)
359+
async def get_callees(
360+
symbol_id: int | str,
361+
project: str,
362+
branch: Optional[str] = None,
363+
limit: int = 50,
364+
) -> list[dict[str, Any]]:
365+
return await _neighbors_payload(project, branch, symbol_id, "CALLS", "OUT", limit)
366+
367+
368+
@app.tool(
369+
name="get_dependencies",
370+
description=(
371+
"Return outgoing neighbors of the given symbol across any of the "
372+
"specified relation types (default: IMPORTS, CALLS, DEFINES). "
373+
"Useful for 'what does this depend on' queries."
374+
),
375+
)
376+
async def get_dependencies(
377+
symbol_id: int | str,
378+
project: str,
379+
branch: Optional[str] = None,
380+
rels: Optional[list[str]] = None,
381+
limit: int = 50,
382+
) -> list[dict[str, Any]]:
383+
if rels is None:
384+
rels = ["IMPORTS", "CALLS", "DEFINES"]
385+
# Aggregate across relations; preserve ordering and dedupe by id.
386+
seen: set[Any] = set()
387+
out: list[dict[str, Any]] = []
388+
for rel in rels:
389+
# Only fetch the rows we can still accept, so total DB work is
390+
# bounded by ``limit`` rather than ``limit * len(rels)``.
391+
remaining = limit - len(out)
392+
if remaining <= 0:
393+
break
394+
rows = await _neighbors_payload(
395+
project, branch, symbol_id, rel, "OUT", remaining
396+
)
397+
for row in rows:
398+
key = (row.get("id"), row.get("relation"))
399+
if key in seen:
400+
continue
401+
seen.add(key)
402+
out.append(row)
403+
if len(out) >= limit:
404+
return out
405+
return out
406+
407+
408+
# ---------------------------------------------------------------------------
409+
# T7 — find_path
410+
# ---------------------------------------------------------------------------
411+
412+
413+
@app.tool(
414+
name="find_path",
415+
description=(
416+
"Return up to `max_paths` CALLS-path sequences from `source_id` to "
417+
"`dest_id`. Useful for 'how does A reach B' questions. Returns an "
418+
"empty list when no path exists."
419+
),
420+
)
421+
async def find_path(
422+
source_id: int | str,
423+
dest_id: int | str,
424+
project: str,
425+
branch: Optional[str] = None,
426+
max_paths: int = 10,
427+
) -> list[dict[str, Any]]:
428+
src = _coerce_node_id(source_id)
429+
dst = _coerce_node_id(dest_id)
430+
g = _project_arg(project, branch)
431+
try:
432+
# Bound DB work by ``max_paths`` so large graphs don't enumerate an
433+
# unbounded number of paths before we slice in Python.
434+
raw = await g.find_paths(src, dst, limit=max_paths)
435+
finally:
436+
await g.close()
437+
438+
# ``AsyncGraphQuery.find_paths`` returns each path as an alternating
439+
# [node, edge, node, edge, ..., node] list; we strip edges and surface
440+
# only the node sequence — that's what agents typically want.
441+
paths: list[dict[str, Any]] = []
442+
for entry in raw:
443+
node_seq = [
444+
_node_summary(x)
445+
for x in entry
446+
# Discriminate on ``labels``: ``encode_node`` emits a top-level
447+
# ``labels`` key, while ``encode_edge`` does not (edges carry
448+
# ``relation``/``src_node``/``dest_node`` instead). Filtering on
449+
# ``properties`` would be wrong because FalkorDB's Edge also has a
450+
# ``properties`` attribute, so edges would slip through as bogus
451+
# all-null node entries.
452+
if isinstance(x, dict) and "labels" in x
453+
]
454+
paths.append({"path": node_seq})
455+
return paths
456+
457+
458+
# ---------------------------------------------------------------------------
459+
# T8 — search_code
460+
# ---------------------------------------------------------------------------
461+
462+
463+
@app.tool(
464+
name="search_code",
465+
description=(
466+
"Prefix-search for symbols (functions, classes, files) whose name "
467+
"starts with `prefix`. Backed by FalkorDB's full-text index. The "
468+
"agent typically calls this first to discover symbol ids for the "
469+
"navigation tools (`get_callers`, `find_path`, ...)."
470+
),
471+
)
472+
async def search_code(
473+
prefix: str,
474+
project: str,
475+
branch: Optional[str] = None,
476+
limit: int = 20,
477+
) -> list[dict[str, Any]]:
478+
g = _project_arg(project, branch)
479+
try:
480+
# Push the caller's ``limit`` down to the DB so it is actually honored
481+
# (the underlying full-text query is otherwise capped at its default).
482+
raw = await g.prefix_search(prefix, limit=limit)
483+
finally:
484+
await g.close()
485+
return [_node_summary(node) for node in raw]

0 commit comments

Comments
 (0)