|
20 | 20 | import asyncio |
21 | 21 | import logging |
22 | 22 | import os |
| 23 | +import re |
23 | 24 | from pathlib import Path |
24 | 25 | from typing import Any, Optional |
25 | 26 |
|
@@ -208,3 +209,277 @@ def _payload(project) -> dict[str, Any]: |
208 | 209 | } |
209 | 210 |
|
210 | 211 | return await loop.run_in_executor(None, _do_index) |
| 212 | + |
| 213 | + |
| 214 | +# --------------------------------------------------------------------------- |
| 215 | +# T5 — get_callers / get_callees / get_dependencies |
| 216 | +# --------------------------------------------------------------------------- |
| 217 | + |
| 218 | + |
| 219 | +def _project_arg(project: str, branch: Optional[str]): |
| 220 | + """Return an :class:`AsyncGraphQuery` for ``(project, branch)``.""" |
| 221 | + from api.graph import AsyncGraphQuery |
| 222 | + |
| 223 | + return AsyncGraphQuery(project, branch=branch) |
| 224 | + |
| 225 | + |
| 226 | +def _node_summary(n: Any) -> dict[str, Any]: |
| 227 | + """Normalize a FalkorDB Node (or already-encoded dict) to a flat payload. |
| 228 | +
|
| 229 | + ``encode_node`` returns ``{id, labels, properties: {...}}`` because Node |
| 230 | + properties live on a nested attribute. Agents want a flat record, and |
| 231 | + they also want a single ``label`` (the meaningful one — File, Class, |
| 232 | + Function — not the fulltext-index marker ``Searchable``). |
| 233 | + """ |
| 234 | + if hasattr(n, "properties"): |
| 235 | + props = dict(n.properties or {}) |
| 236 | + labels = list(n.labels or []) |
| 237 | + node_id = getattr(n, "id", None) |
| 238 | + else: |
| 239 | + d = dict(n) |
| 240 | + props = dict(d.get("properties") or {}) |
| 241 | + labels = list(d.get("labels") or []) |
| 242 | + node_id = d.get("id") |
| 243 | + |
| 244 | + label = next((lbl for lbl in labels if lbl != "Searchable"), None) |
| 245 | + return { |
| 246 | + "id": node_id, |
| 247 | + "name": props.get("name"), |
| 248 | + "label": label, |
| 249 | + "file": props.get("path"), |
| 250 | + "line": props.get("src_start"), |
| 251 | + } |
| 252 | + |
| 253 | + |
| 254 | +# Relationship-type names are graph labels (SCREAMING_SNAKE_CASE, e.g. CALLS, |
| 255 | +# IMPORTS, DEFINES). FalkorDB cannot parameterize relationship types, so any |
| 256 | +# ``rel`` interpolated into Cypher must be validated against this pattern to |
| 257 | +# prevent Cypher injection via agent-controlled input. |
| 258 | +_REL_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") |
| 259 | + |
| 260 | + |
| 261 | +def _validate_relation(rel: str) -> str: |
| 262 | + """Return ``rel`` if it is a safe relationship-type name, else raise. |
| 263 | +
|
| 264 | + Guards the relationship types that are string-interpolated into Cypher |
| 265 | + (``-[e:{rel}]->``) — parameter binding is not available for relation |
| 266 | + types in FalkorDB. |
| 267 | + """ |
| 268 | + if not isinstance(rel, str) or not _REL_NAME_RE.match(rel): |
| 269 | + raise ValueError(f"invalid relation type: {rel!r}") |
| 270 | + return rel |
| 271 | + |
| 272 | + |
| 273 | +def _coerce_node_id(symbol_id: Any) -> int: |
| 274 | + """Accept int or stringified int; raise ValueError otherwise. |
| 275 | +
|
| 276 | + The MCP wire format is JSON; agents sometimes hand back the id as a |
| 277 | + string. Be permissive on input, strict on type after parsing. |
| 278 | + """ |
| 279 | + if isinstance(symbol_id, bool): # bool is an int subclass; reject loudly |
| 280 | + raise ValueError(f"symbol_id must be an integer, got bool: {symbol_id!r}") |
| 281 | + if isinstance(symbol_id, int): |
| 282 | + return symbol_id |
| 283 | + if isinstance(symbol_id, str) and symbol_id.lstrip("-").isdigit(): |
| 284 | + return int(symbol_id) |
| 285 | + raise ValueError(f"symbol_id must be an integer id, got: {symbol_id!r}") |
| 286 | + |
| 287 | + |
| 288 | +async def _neighbors_payload( |
| 289 | + project: str, |
| 290 | + branch: Optional[str], |
| 291 | + symbol_id: Any, |
| 292 | + rel: str, |
| 293 | + direction: str, |
| 294 | + limit: int, |
| 295 | +) -> list[dict[str, Any]]: |
| 296 | + """Shared implementation for caller/callee/dependency tools. |
| 297 | +
|
| 298 | + ``direction`` is ``IN`` (incoming edges, e.g. callers) or ``OUT`` |
| 299 | + (outgoing edges, e.g. callees). When ``IN`` we run the inverse Cypher |
| 300 | + ``(neighbor)-[:rel]->(target)``; ``AsyncGraphQuery.get_neighbors`` only |
| 301 | + walks outgoing edges, so we inline the Cypher here for symmetry. |
| 302 | + """ |
| 303 | + node_id = _coerce_node_id(symbol_id) |
| 304 | + rel = _validate_relation(rel) |
| 305 | + g = _project_arg(project, branch) |
| 306 | + try: |
| 307 | + if direction == "OUT": |
| 308 | + q = ( |
| 309 | + f"MATCH (n)-[e:{rel}]->(dest) " |
| 310 | + f"WHERE ID(n) = $sid " |
| 311 | + f"RETURN dest, type(e) AS rel " |
| 312 | + f"LIMIT $limit" |
| 313 | + ) |
| 314 | + elif direction == "IN": |
| 315 | + q = ( |
| 316 | + f"MATCH (src)-[e:{rel}]->(n) " |
| 317 | + f"WHERE ID(n) = $sid " |
| 318 | + f"RETURN src AS dest, type(e) AS rel " |
| 319 | + f"LIMIT $limit" |
| 320 | + ) |
| 321 | + else: |
| 322 | + raise ValueError(f"direction must be IN or OUT, got: {direction!r}") |
| 323 | + |
| 324 | + res = await g._query(q, {"sid": node_id, "limit": int(limit)}) |
| 325 | + out: list[dict[str, Any]] = [] |
| 326 | + for row in res.result_set: |
| 327 | + entry = _node_summary(row[0]) |
| 328 | + entry["relation"] = row[1] |
| 329 | + entry["direction"] = direction |
| 330 | + out.append(entry) |
| 331 | + return out |
| 332 | + finally: |
| 333 | + await g.close() |
| 334 | + |
| 335 | + |
| 336 | +@app.tool( |
| 337 | + name="get_callers", |
| 338 | + description=( |
| 339 | + "Return functions that call the given symbol (incoming CALLS edges). " |
| 340 | + "`symbol_id` is the integer node id returned by `search_code` or " |
| 341 | + "other tools." |
| 342 | + ), |
| 343 | +) |
| 344 | +async def get_callers( |
| 345 | + symbol_id: int | str, |
| 346 | + project: str, |
| 347 | + branch: Optional[str] = None, |
| 348 | + limit: int = 50, |
| 349 | +) -> list[dict[str, Any]]: |
| 350 | + return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit) |
| 351 | + |
| 352 | + |
| 353 | +@app.tool( |
| 354 | + name="get_callees", |
| 355 | + description=( |
| 356 | + "Return functions that the given symbol calls (outgoing CALLS edges)." |
| 357 | + ), |
| 358 | +) |
| 359 | +async def get_callees( |
| 360 | + symbol_id: int | str, |
| 361 | + project: str, |
| 362 | + branch: Optional[str] = None, |
| 363 | + limit: int = 50, |
| 364 | +) -> list[dict[str, Any]]: |
| 365 | + return await _neighbors_payload(project, branch, symbol_id, "CALLS", "OUT", limit) |
| 366 | + |
| 367 | + |
| 368 | +@app.tool( |
| 369 | + name="get_dependencies", |
| 370 | + description=( |
| 371 | + "Return outgoing neighbors of the given symbol across any of the " |
| 372 | + "specified relation types (default: IMPORTS, CALLS, DEFINES). " |
| 373 | + "Useful for 'what does this depend on' queries." |
| 374 | + ), |
| 375 | +) |
| 376 | +async def get_dependencies( |
| 377 | + symbol_id: int | str, |
| 378 | + project: str, |
| 379 | + branch: Optional[str] = None, |
| 380 | + rels: Optional[list[str]] = None, |
| 381 | + limit: int = 50, |
| 382 | +) -> list[dict[str, Any]]: |
| 383 | + if rels is None: |
| 384 | + rels = ["IMPORTS", "CALLS", "DEFINES"] |
| 385 | + # Aggregate across relations; preserve ordering and dedupe by id. |
| 386 | + seen: set[Any] = set() |
| 387 | + out: list[dict[str, Any]] = [] |
| 388 | + for rel in rels: |
| 389 | + # Only fetch the rows we can still accept, so total DB work is |
| 390 | + # bounded by ``limit`` rather than ``limit * len(rels)``. |
| 391 | + remaining = limit - len(out) |
| 392 | + if remaining <= 0: |
| 393 | + break |
| 394 | + rows = await _neighbors_payload( |
| 395 | + project, branch, symbol_id, rel, "OUT", remaining |
| 396 | + ) |
| 397 | + for row in rows: |
| 398 | + key = (row.get("id"), row.get("relation")) |
| 399 | + if key in seen: |
| 400 | + continue |
| 401 | + seen.add(key) |
| 402 | + out.append(row) |
| 403 | + if len(out) >= limit: |
| 404 | + return out |
| 405 | + return out |
| 406 | + |
| 407 | + |
| 408 | +# --------------------------------------------------------------------------- |
| 409 | +# T7 — find_path |
| 410 | +# --------------------------------------------------------------------------- |
| 411 | + |
| 412 | + |
| 413 | +@app.tool( |
| 414 | + name="find_path", |
| 415 | + description=( |
| 416 | + "Return up to `max_paths` CALLS-path sequences from `source_id` to " |
| 417 | + "`dest_id`. Useful for 'how does A reach B' questions. Returns an " |
| 418 | + "empty list when no path exists." |
| 419 | + ), |
| 420 | +) |
| 421 | +async def find_path( |
| 422 | + source_id: int | str, |
| 423 | + dest_id: int | str, |
| 424 | + project: str, |
| 425 | + branch: Optional[str] = None, |
| 426 | + max_paths: int = 10, |
| 427 | +) -> list[dict[str, Any]]: |
| 428 | + src = _coerce_node_id(source_id) |
| 429 | + dst = _coerce_node_id(dest_id) |
| 430 | + g = _project_arg(project, branch) |
| 431 | + try: |
| 432 | + # Bound DB work by ``max_paths`` so large graphs don't enumerate an |
| 433 | + # unbounded number of paths before we slice in Python. |
| 434 | + raw = await g.find_paths(src, dst, limit=max_paths) |
| 435 | + finally: |
| 436 | + await g.close() |
| 437 | + |
| 438 | + # ``AsyncGraphQuery.find_paths`` returns each path as an alternating |
| 439 | + # [node, edge, node, edge, ..., node] list; we strip edges and surface |
| 440 | + # only the node sequence — that's what agents typically want. |
| 441 | + paths: list[dict[str, Any]] = [] |
| 442 | + for entry in raw: |
| 443 | + node_seq = [ |
| 444 | + _node_summary(x) |
| 445 | + for x in entry |
| 446 | + # Discriminate on ``labels``: ``encode_node`` emits a top-level |
| 447 | + # ``labels`` key, while ``encode_edge`` does not (edges carry |
| 448 | + # ``relation``/``src_node``/``dest_node`` instead). Filtering on |
| 449 | + # ``properties`` would be wrong because FalkorDB's Edge also has a |
| 450 | + # ``properties`` attribute, so edges would slip through as bogus |
| 451 | + # all-null node entries. |
| 452 | + if isinstance(x, dict) and "labels" in x |
| 453 | + ] |
| 454 | + paths.append({"path": node_seq}) |
| 455 | + return paths |
| 456 | + |
| 457 | + |
| 458 | +# --------------------------------------------------------------------------- |
| 459 | +# T8 — search_code |
| 460 | +# --------------------------------------------------------------------------- |
| 461 | + |
| 462 | + |
| 463 | +@app.tool( |
| 464 | + name="search_code", |
| 465 | + description=( |
| 466 | + "Prefix-search for symbols (functions, classes, files) whose name " |
| 467 | + "starts with `prefix`. Backed by FalkorDB's full-text index. The " |
| 468 | + "agent typically calls this first to discover symbol ids for the " |
| 469 | + "navigation tools (`get_callers`, `find_path`, ...)." |
| 470 | + ), |
| 471 | +) |
| 472 | +async def search_code( |
| 473 | + prefix: str, |
| 474 | + project: str, |
| 475 | + branch: Optional[str] = None, |
| 476 | + limit: int = 20, |
| 477 | +) -> list[dict[str, Any]]: |
| 478 | + g = _project_arg(project, branch) |
| 479 | + try: |
| 480 | + # Push the caller's ``limit`` down to the DB so it is actually honored |
| 481 | + # (the underlying full-text query is otherwise capped at its default). |
| 482 | + raw = await g.prefix_search(prefix, limit=limit) |
| 483 | + finally: |
| 484 | + await g.close() |
| 485 | + return [_node_summary(node) for node in raw] |
0 commit comments