Skip to content

Commit eaf8a98

Browse files
authored
Add: get_callers & get_callees tools (#58)
- `get_callers`: Takes one or more function names/addresses, returns JSON {"results":[{identifier,function,callers,caller_sites}], "errors":[...]} where each entry includes normalized function metadata, every caller, and HLIL/IL snippets for each call site. - `get_callees`: Accepts the same identifier inputs, returns {"results":[{identifier,function,callees,call_sites}], "errors":[...]} listing every outgoing callee plus per-site IL context, falling back to raw addresses when no symbol exists.
1 parent 410d3dd commit eaf8a98

5 files changed

Lines changed: 317 additions & 0 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ The following table lists the available MCP functions for use:
189189
| `get_xrefs_to_union(union_name)` | Get xrefs/usages related to a union (members, globals, code refs). |
190190
| `get_stack_frame_vars(function_identifier)` | Get stack frame variable information for a function (names, offsets, sizes, types). |
191191
| `get_type_info(type_name)` | Resolve a type and return declaration, kind, and members. |
192+
| `get_callers(identifiers)` | List callers plus call sites for one or more function identifiers. |
193+
| `get_callees(identifiers)` | List callees plus call sites for one or more function identifiers. |
192194
| `make_function_at(address, platform)` | Create a function at an address. `platform` optional; use `default` to pick the BinaryView/platform default. |
193195
| `list_platforms()` | List all available platform names. |
194196
| `list_binaries()` | List managed/open binaries with ids and active flag. |
@@ -236,6 +238,8 @@ These are the list of HTTP endpoints that can be called:
236238
- `/getTypeInfo?name=<type>`: Resolve a type and return declaration and details.
237239
- `/getXrefsToUnion?name=<union>`: Union xrefs/usages (members, globals, refs).
238240
- `/getStackFrameVars?name=<function>|address=<addr>`: Get stack frame variable information for a function.
241+
- `/getCallers?identifiers=<name|addr>[,...]`: Return caller summaries (functions, call sites, HLIL/IL snippets) for one or more identifiers. Accepts `identifiers`, `identifier`, `names`, or `addresses` query params.
242+
- `/getCallees?identifiers=<name|addr>[,...]`: Return callee summaries with the same schema as `/getCallers`, detailing every outgoing call target per request identifier.
239243
- `/localTypes?offset=<n>&limit=<m>`: List local types.
240244
- `/strings?offset=<n>&limit=<m>`: Paginated strings.
241245
- `/strings/filter?offset=<n>&limit=<m>&filter=<substr>`: Filtered strings.

bridge/binja_mcp_bridge.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,53 @@ def get_type_info(type_name: str) -> str:
805805
return _json.dumps(data, indent=2, ensure_ascii=False)
806806

807807

808+
def _normalize_identifier_input(value: str | list[str]) -> list[str]:
809+
tokens: list[str] = []
810+
if isinstance(value, str):
811+
raw = value.replace(";", ",").split(",")
812+
tokens.extend([tok.strip() for tok in raw if tok.strip()])
813+
elif isinstance(value, (list, tuple, set)):
814+
for item in value:
815+
if item is None:
816+
continue
817+
tokens.extend(_normalize_identifier_input(str(item)))
818+
return tokens
819+
820+
821+
@mcp.tool()
822+
def get_callers(identifiers: str) -> str:
823+
"""
824+
List callers and caller sites for one or more function identifiers (name or address).
825+
Provide comma-separated identifiers like "sub_401000,main".
826+
"""
827+
items = _normalize_identifier_input(identifiers)
828+
if not items:
829+
return "Error: provide at least one identifier"
830+
data = get_json("getCallers", {"identifiers": ",".join(items)}, timeout=None)
831+
if not data:
832+
return "Error: no response"
833+
import json as _json
834+
835+
return _json.dumps(data, indent=2, ensure_ascii=False)
836+
837+
838+
@mcp.tool()
839+
def get_callees(identifiers: str) -> str:
840+
"""
841+
List callees and call sites for one or more function identifiers (name or address).
842+
Provide comma-separated identifiers like "sub_401000,main".
843+
"""
844+
items = _normalize_identifier_input(identifiers)
845+
if not items:
846+
return "Error: provide at least one identifier"
847+
data = get_json("getCallees", {"identifiers": ",".join(items)}, timeout=None)
848+
if not data:
849+
return "Error: no response"
850+
import json as _json
851+
852+
return _json.dumps(data, indent=2, ensure_ascii=False)
853+
854+
808855
@mcp.tool()
809856
def set_function_prototype(name_or_address: str, prototype: str) -> str:
810857
"""

plugin/api/endpoints.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def get_function_info(self, identifier: str) -> dict[str, Any] | None:
104104
bn.log_error(f"Error getting function info: {e}")
105105
return None
106106

107+
def get_callers(self, identifiers: list[str]) -> dict[str, Any]:
108+
"""Proxy BinaryOperations.get_callers for endpoint reuse."""
109+
return self.binary_ops.get_callers(identifiers)
110+
111+
def get_callees(self, identifiers: list[str]) -> dict[str, Any]:
112+
"""Proxy BinaryOperations.get_callees for endpoint reuse."""
113+
return self.binary_ops.get_callees(identifiers)
114+
107115
def get_imports(self, offset: int = 0, limit: int = 100) -> list[dict[str, Any]]:
108116
"""Get list of imported functions"""
109117
if not self.binary_ops.current_view:

plugin/core/binary_operations.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,204 @@ def get_function_by_name_or_address(self, identifier: str | int) -> bn.Function
350350
bn.log_error(f"Could not find function: {identifier}")
351351
return None
352352

353+
def _normalize_identifier_list(self, identifiers: Any) -> list[Any]:
354+
"""Normalize comma-delimited strings or iterables into a list of identifiers."""
355+
if identifiers is None:
356+
return []
357+
if isinstance(identifiers, (list, tuple, set)):
358+
raw_items = list(identifiers)
359+
else:
360+
raw_items = [identifiers]
361+
normalized: list[Any] = []
362+
for item in raw_items:
363+
if item is None:
364+
continue
365+
if isinstance(item, str):
366+
# Allow comma or semicolon separation for convenience
367+
tokens = [tok.strip() for tok in item.replace(";", ",").split(",")]
368+
normalized.extend([tok for tok in tokens if tok])
369+
else:
370+
normalized.append(item)
371+
return normalized
372+
373+
def _format_function_reference(self, func: bn.Function | None) -> dict[str, Any] | None:
374+
if not func:
375+
return None
376+
try:
377+
return {
378+
"name": getattr(func, "name", None),
379+
"address": hex(int(func.start)) if hasattr(func, "start") else None,
380+
}
381+
except Exception:
382+
return {
383+
"name": getattr(func, "name", None),
384+
"address": None,
385+
}
386+
387+
def _collect_related_functions(
388+
self, func: bn.Function, relation_attr: str
389+
) -> list[dict[str, Any]]:
390+
related: list[dict[str, Any]] = []
391+
seen: set[int] = set()
392+
try:
393+
rel_iter = getattr(func, relation_attr, None)
394+
except Exception:
395+
rel_iter = None
396+
if rel_iter is None:
397+
return related
398+
try:
399+
for rel_func in list(rel_iter):
400+
if not rel_func:
401+
continue
402+
addr = None
403+
try:
404+
addr = int(rel_func.start)
405+
except Exception:
406+
addr = None
407+
if addr is not None and addr in seen:
408+
continue
409+
if addr is not None:
410+
seen.add(addr)
411+
ref = self._format_function_reference(rel_func)
412+
if ref:
413+
related.append(ref)
414+
except Exception:
415+
pass
416+
return related
417+
418+
def _summarize_call_sites(self, func: bn.Function, relation: str) -> list[dict[str, Any]]:
419+
entries: list[dict[str, Any]] = []
420+
attr = "caller_sites" if relation == "callers" else "call_sites"
421+
try:
422+
sites = getattr(func, attr, None)
423+
except Exception:
424+
sites = None
425+
if not sites:
426+
return entries
427+
428+
def _extract_function(site: Any, names: tuple[str, ...]) -> bn.Function | None:
429+
for name in names:
430+
try:
431+
value = getattr(site, name, None)
432+
except Exception:
433+
value = None
434+
if value:
435+
return value
436+
return None
437+
438+
for site in list(sites):
439+
try:
440+
entry: dict[str, Any] = {}
441+
addr = getattr(site, "address", None)
442+
if isinstance(addr, int):
443+
entry["address"] = hex(addr)
444+
elif isinstance(addr, str) and addr:
445+
entry["address"] = addr
446+
447+
if relation == "callers":
448+
caller_func = _extract_function(site, ("function", "source_function", "caller"))
449+
ref = self._format_function_reference(caller_func)
450+
if ref:
451+
entry["caller"] = ref
452+
else:
453+
callee_func = _extract_function(
454+
site, ("callee", "dest_function", "target_function")
455+
)
456+
ref = self._format_function_reference(callee_func)
457+
if ref:
458+
entry["callee"] = ref
459+
else:
460+
# Fall back to raw destination address when available
461+
dest = None
462+
for attr_name in ("dest", "target", "constant"):
463+
try:
464+
dest = getattr(site, attr_name)
465+
except Exception:
466+
dest = None
467+
if dest is not None:
468+
break
469+
if isinstance(dest, int):
470+
entry["callee"] = {"name": None, "address": hex(dest)}
471+
472+
# Attach textual representation for quick context
473+
summary_text = None
474+
for attr_name in ("hlil", "il"):
475+
try:
476+
val = getattr(site, attr_name, None)
477+
except Exception:
478+
val = None
479+
if val is not None:
480+
summary_text = str(val)
481+
break
482+
if summary_text is None:
483+
summary_text = str(site)
484+
entry["il"] = summary_text
485+
486+
entries.append(entry)
487+
except Exception:
488+
continue
489+
return entries
490+
491+
def get_callers(self, identifiers: Any) -> dict[str, Any]:
492+
"""Collect caller information for the given function identifiers."""
493+
if not self._current_view:
494+
raise RuntimeError("No binary loaded")
495+
496+
items = self._normalize_identifier_list(identifiers)
497+
if not items:
498+
raise ValueError("No function identifiers provided")
499+
500+
results: list[dict[str, Any]] = []
501+
errors: list[str] = []
502+
for ident in items:
503+
try:
504+
func = self.get_function_by_name_or_address(ident)
505+
except Exception as exc:
506+
func = None
507+
errors.append(f"{ident}: {exc}")
508+
if not func:
509+
errors.append(f"Function not found: {ident}")
510+
continue
511+
entry = {
512+
"identifier": str(ident),
513+
"function": self._format_function_reference(func),
514+
"callers": self._collect_related_functions(func, "callers"),
515+
"caller_sites": self._summarize_call_sites(func, "callers"),
516+
}
517+
results.append(entry)
518+
519+
return {"results": results, "errors": errors}
520+
521+
def get_callees(self, identifiers: Any) -> dict[str, Any]:
522+
"""Collect callee information for the given function identifiers."""
523+
if not self._current_view:
524+
raise RuntimeError("No binary loaded")
525+
526+
items = self._normalize_identifier_list(identifiers)
527+
if not items:
528+
raise ValueError("No function identifiers provided")
529+
530+
results: list[dict[str, Any]] = []
531+
errors: list[str] = []
532+
for ident in items:
533+
try:
534+
func = self.get_function_by_name_or_address(ident)
535+
except Exception as exc:
536+
func = None
537+
errors.append(f"{ident}: {exc}")
538+
if not func:
539+
errors.append(f"Function not found: {ident}")
540+
continue
541+
entry = {
542+
"identifier": str(ident),
543+
"function": self._format_function_reference(func),
544+
"callees": self._collect_related_functions(func, "callees"),
545+
"call_sites": self._summarize_call_sites(func, "callees"),
546+
}
547+
results.append(entry)
548+
549+
return {"results": results, "errors": errors}
550+
353551
def get_function_names(self, offset: int = 0, limit: int = 100) -> list[dict[str, str]]:
354552
"""Get list of function names with addresses"""
355553
if not self._current_view:

plugin/server/http_server.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,28 @@ def _parse_query_params(self) -> dict[str, str]:
7878
parsed_path = urllib.parse.urlparse(self.path)
7979
return dict(urllib.parse.parse_qsl(parsed_path.query))
8080

81+
def _extract_identifiers(self, params: dict[str, str]) -> list[str]:
82+
"""Normalize identifier-bearing query params into a list."""
83+
candidates: list[str] = []
84+
for key in (
85+
"identifiers",
86+
"identifier",
87+
"functions",
88+
"function",
89+
"names",
90+
"name",
91+
"addresses",
92+
"address",
93+
):
94+
value = params.get(key)
95+
if value:
96+
candidates.append(value)
97+
identifiers: list[str] = []
98+
for raw in candidates:
99+
tokens = [tok.strip() for tok in raw.replace(";", ",").split(",")]
100+
identifiers.extend([tok for tok in tokens if tok])
101+
return identifiers
102+
81103
def _parse_post_params(self) -> dict[str, Any]:
82104
"""Parse POST request parameters from various formats.
83105
@@ -745,6 +767,44 @@ def _printable(b: int) -> str:
745767
matches = self.endpoints.search_functions(search_term, offset, limit)
746768
self._send_json_response({"matches": matches})
747769

770+
elif path == "/getCallers":
771+
identifiers = self._extract_identifiers(params)
772+
if not identifiers:
773+
self._send_json_response(
774+
{
775+
"error": "Missing identifier parameter",
776+
"help": "Provide ?identifier=<name|address> or comma-separated ?identifiers=a,b",
777+
},
778+
400,
779+
)
780+
return
781+
try:
782+
payload = self.binary_ops.get_callers(identifiers)
783+
except Exception as e:
784+
bn.log_error(f"Error handling getCallers: {e}")
785+
self._send_json_response({"error": str(e)}, 500)
786+
else:
787+
self._send_json_response(payload)
788+
789+
elif path == "/getCallees":
790+
identifiers = self._extract_identifiers(params)
791+
if not identifiers:
792+
self._send_json_response(
793+
{
794+
"error": "Missing identifier parameter",
795+
"help": "Provide ?identifier=<name|address> or comma-separated ?identifiers=a,b",
796+
},
797+
400,
798+
)
799+
return
800+
try:
801+
payload = self.binary_ops.get_callees(identifiers)
802+
except Exception as e:
803+
bn.log_error(f"Error handling getCallees: {e}")
804+
self._send_json_response({"error": str(e)}, 500)
805+
else:
806+
self._send_json_response(payload)
807+
748808
elif path == "/decompile":
749809
function_name = params.get("name") or params.get("functionName")
750810
if not function_name:

0 commit comments

Comments
 (0)