Skip to content

Commit fcd765f

Browse files
authored
Merge pull request #46 from sourcehold/feature/tool-decompilation-cache
Feature/tool decompilation cache
2 parents 4f8b4eb + 42ebc14 commit fcd765f

3 files changed

Lines changed: 273 additions & 13 deletions

File tree

tools/mcp/decomphelper.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
import mcp.server.stdio
4343

44-
PATH_CMAKE_OPENSHC_SOURCES = Path("cmake/openshc-sources.txt")
44+
PATH_CMAKE_OPENSHC_SOURCES = Path("cmake/openshc-sources.txt.local")
4545
if not PATH_CMAKE_OPENSHC_SOURCES.exists():
4646
raise Exception(f"could not find cmake core sources txt file: {str(PATH_CMAKE_OPENSHC_SOURCES)}")
4747

@@ -102,7 +102,7 @@ def extract_function_assembly_diff(function_name: str) -> tuple[bool, Any, str,
102102
all_data = diff['data']
103103
data = [entry for entry in all_data if entry['name'] == function_name]
104104
if len(data) == 0:
105-
return False, "", "", f"no function with name '{function_name}' in diff.json"
105+
return False, "", "", f"no function with name '{function_name}' in .pdb file. Cannot execute diff"
106106
data = data[0]
107107
return True, data, "", ""
108108

@@ -133,7 +133,7 @@ def compile_cpp_code_for_function(function_name: str, contents: str) -> tuple[bo
133133
return rstate, "", f"could not resolve function name to file path: {rerr}"
134134
path = Path(rresult)
135135
if not path.exists():
136-
return False, "", f"cpp file path does not exist: {str(path)}"
136+
path.parent.mkdir(parents = True, exist_ok=True)
137137
path.write_text(contents)
138138

139139
# Ensure the cpp file is included in the build
@@ -149,6 +149,18 @@ def compile_cpp_code_for_function(function_name: str, contents: str) -> tuple[bo
149149
# Compile the project and return the resulting state
150150
return compile_project()
151151

152+
def read_function(function_name: str, base_path: Path = Path("src")):
153+
rstate, rresult, rerr = function_name_to_cpp_path(function_name=function_name, base_path=base_path)
154+
if not rstate:
155+
return rstate, "", f"could not resolve function name to file path: {rerr}"
156+
path = Path(rresult)
157+
if not path.exists():
158+
return False, "", f"cpp file path does not exist: {str(path)}"
159+
try:
160+
return True, path.read_text(), ""
161+
except Exception as e:
162+
return False, "", f"{e}"
163+
152164
@mcp.tool()
153165
def read_cpp_code_for_function(function_name: str) -> tuple[bool, str, str]:
154166
"""
@@ -160,16 +172,7 @@ def read_cpp_code_for_function(function_name: str) -> tuple[bool, str, str]:
160172
Returns:
161173
Tuple of (success, contents, stderr)
162174
"""
163-
rstate, rresult, rerr = function_name_to_cpp_path(function_name=function_name)
164-
if not rstate:
165-
return rstate, "", f"could not resolve function name to file path: {rerr}"
166-
path = Path(rresult)
167-
if not path.exists():
168-
return False, "", f"cpp file path does not exist: {str(path)}"
169-
try:
170-
return True, path.read_text(), ""
171-
except Exception as e:
172-
return False, "", f"{e}"
175+
return read_function(function_name=function_name)
173176

174177
@mcp.tool()
175178
def read_source_file(relative_path: str) -> tuple[bool, str, str]:
@@ -215,5 +218,19 @@ def fetch_ghidra_function_decompilation(function_name: str) -> tuple[bool, str,
215218
return False, "", f"{e}"
216219

217220

221+
@mcp.tool()
222+
def fetch_cached_ghidra_function_decompilation(function_name: str) -> tuple[bool, str, str]:
223+
"""
224+
Fetches cached decompilation of a function (json with additional information) containing ghidra special functions.
225+
This should be used if 'fetch_ghidra_function_decompilation' isn't available or fails.
226+
227+
Args:
228+
function_name: Name of the function to extract, fully namespaced using '::'
229+
230+
Returns:
231+
Tuple of (success, contents, stderr)
232+
"""
233+
return read_function(function_name=function_name, base_path=Path("tools") / "mcp" / "ghidra_scripts" / "decompilation")
234+
218235
if __name__ == "__main__":
219236
mcp.run(transport="stdio")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Ignore cached decompiled functions from 'functionexporter.py'
2+
decompilation/OpenSHC
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Export Decompiled Code for _HoldStrong Namespace (Recursive, one file per function)
2+
# Ghidra 11.4.1 - Python 3 (Jython bridge via GhidraScript)
3+
#
4+
# Output layout:
5+
# Given function _HoldStrong::A::B(), the decompiled code is written to:
6+
# <OUTPUT_ROOT>/_HoldStrong/A/B.cpp
7+
#
8+
# Usage:
9+
# 1. Open your binary in Ghidra and run auto-analysis.
10+
# 2. Go to Window > Script Manager > Run Script, select this file.
11+
# 3. Files are written under OUTPUT_ROOT (default: a folder named after the
12+
# program, placed next to this script).
13+
#
14+
# @author Claude
15+
# @category _OPENSHC.TOOLS.DECOMPILATION
16+
# @keybinding
17+
# @menupath
18+
# @toolbar
19+
# @runtime PyGhidra
20+
21+
import typing
22+
if typing.TYPE_CHECKING:
23+
from ghidra.ghidra_builtins import * # type: ignore
24+
25+
import os
26+
import re
27+
from ghidra.app.decompiler import DecompInterface, DecompileOptions
28+
from ghidra.util.task import ConsoleTaskMonitor
29+
from ghidra.program.model.symbol import SymbolType
30+
31+
32+
33+
TARGET_NAMESPACE = "_HoldStrong" # Root namespace to search (case-sensitive)
34+
OUTPUT_ROOT = None # None <script_dir>/<program_name>_decompiled/
35+
TIMEOUT_SECONDS = 60 # Per-function decompile timeout
36+
INCLUDE_SIGNATURE = True # Prepend the function signature as a comment
37+
RECURSIVE = True # Descend into child namespaces
38+
39+
currentProgram = getCurrentProgram()
40+
41+
42+
def get_output_root():
43+
"""Return (and create if needed) the root output directory."""
44+
if OUTPUT_ROOT:
45+
root = OUTPUT_ROOT
46+
else:
47+
#prog_name = re.sub(r'[^\w\-.]', '_', currentProgram.getName())
48+
script_dir = os.path.dirname(os.path.abspath(
49+
getScriptName() if hasattr(__builtins__, 'getScriptName') else __file__))
50+
root = os.path.join(script_dir, "decompilation")
51+
if not os.path.isdir(root):
52+
os.makedirs(root)
53+
return root
54+
55+
56+
def sanitize_path_component(name):
57+
"""
58+
Make a single namespace / function name safe for use as a file-system
59+
component. Replaces characters that are illegal on Windows/Linux/macOS.
60+
Operator names like 'operator+=' become 'operator_plus_eq_' etc.
61+
"""
62+
replacements = [
63+
('::', '__'),
64+
('<', '_lt_'), ('>', '_gt_'),
65+
('*', '_star_'), ('/', '_div_'), ('%', '_mod_'),
66+
('+', '_plus_'), ('-', '_minus_'),
67+
('&', '_amp_'), ('|', '_pipe_'), ('^', '_xor_'), ('~', '_tilde_'),
68+
('!', '_not_'), ('=', '_eq_'),
69+
('(', '_'), (')', '_'), (',', '_'), (' ', '_'),
70+
('[', '_'), (']', '_'),
71+
]
72+
for old, new in replacements:
73+
name = name.replace(old, new)
74+
name = re.sub(r'[^\w\-.]', '_', name)
75+
name = re.sub(r'_+', '_', name).strip('_')
76+
return name or '_unnamed_'
77+
78+
79+
def ns_path_to_file_path(output_root, ns_qualified_name, func_name):
80+
"""
81+
Convert a fully-qualified namespace path and function name into a file path.
82+
83+
Example:
84+
ns_qualified_name = "_HoldStrong::A"
85+
func_name = "B"
86+
<output_root>/_HoldStrong/A/B.cpp
87+
"""
88+
# Split on '::' to get each namespace component
89+
parts = [p for p in ns_qualified_name.split('::') if p]
90+
if parts[0] == "_HoldStrong":
91+
parts[0] = "OpenSHC"
92+
parts.append(func_name) # function name becomes the filename
93+
safe_parts = [sanitize_path_component(p) for p in parts]
94+
dir_parts = safe_parts[:-1]
95+
file_part = safe_parts[-1] + '.cpp'
96+
out_dir = os.path.join(output_root, *dir_parts) if dir_parts else output_root
97+
if not os.path.isdir(out_dir):
98+
os.makedirs(out_dir)
99+
return os.path.join(out_dir, file_part)
100+
101+
102+
def init_decompiler():
103+
"""Initialise and open a DecompInterface for the current program."""
104+
iface = DecompInterface()
105+
opts = DecompileOptions()
106+
iface.setOptions(opts)
107+
iface.openProgram(currentProgram)
108+
return iface
109+
110+
111+
def decompile_function(func, iface, monitor):
112+
"""Return (code_str, None) on success or (None, error_str) on failure."""
113+
result = iface.decompileFunction(func, TIMEOUT_SECONDS, monitor)
114+
if result is None or not result.decompileCompleted():
115+
err = result.getErrorMessage() if result else "unknown error"
116+
return None, "/* ERROR decompiling {}: {} */\n".format(func.getName(), err)
117+
markup = result.getDecompiledFunction().getC() # .getCCodeMarkup()
118+
if markup is None:
119+
return None, "/* ERROR: no C markup for {} */\n".format(func.getName())
120+
return str(markup).replace("_HoldStrong", "OpenSHC"), None
121+
122+
123+
def collect_functions_in_namespace(ns):
124+
"""
125+
Recursively collect every Function symbol inside *ns* and its descendants.
126+
127+
Returns a list of (ns_qualified_name, Function) tuples where
128+
ns_qualified_name is the namespace the function belongs to (NOT including
129+
the function name itself), e.g. "_HoldStrong::A".
130+
"""
131+
results = []
132+
sym_table = currentProgram.getSymbolTable()
133+
ns_name = ns.getName(True) # fully-qualified, e.g. "_HoldStrong::A"
134+
135+
children = sym_table.getChildren(ns.getSymbol())
136+
while children.hasNext():
137+
sym = children.next()
138+
sym_type = sym.getSymbolType()
139+
140+
if sym_type == SymbolType.FUNCTION:
141+
func = getFunctionAt(sym.getAddress())
142+
if func is not None:
143+
results.append((ns_name, func))
144+
145+
elif RECURSIVE and sym_type in (SymbolType.NAMESPACE, SymbolType.CLASS):
146+
child_ns = sym.getObject()
147+
if child_ns is not None:
148+
results.extend(collect_functions_in_namespace(child_ns))
149+
150+
return results
151+
152+
153+
def find_root_namespace(name):
154+
"""
155+
Find the first namespace (or class) whose simple name matches *name*.
156+
Returns the Namespace object, or None if not found.
157+
"""
158+
sym_table = currentProgram.getSymbolTable()
159+
it = sym_table.getSymbols(name)
160+
while it.hasNext():
161+
sym = it.next()
162+
if sym.getSymbolType() in (SymbolType.NAMESPACE, SymbolType.CLASS):
163+
return sym.getObject()
164+
return None
165+
166+
167+
168+
def run():
169+
monitor = ConsoleTaskMonitor()
170+
171+
print("[*] Searching for namespace: {}".format(TARGET_NAMESPACE))
172+
root_ns = find_root_namespace(TARGET_NAMESPACE)
173+
if root_ns is None:
174+
msg = "Namespace '{}' not found in program.".format(TARGET_NAMESPACE)
175+
print("[!] " + msg)
176+
popup(msg)
177+
return
178+
179+
print("[*] Found namespace: {}".format(root_ns.getName(True)))
180+
print("[*] Collecting functions{}...".format(
181+
" recursively" if RECURSIVE else ""))
182+
183+
entries = collect_functions_in_namespace(root_ns)
184+
if not entries:
185+
msg = "No functions found under '{}'.".format(TARGET_NAMESPACE)
186+
print("[!] " + msg)
187+
popup(msg)
188+
return
189+
190+
print("[*] Found {} function(s). Decompiling...".format(len(entries)))
191+
192+
iface = init_decompiler()
193+
output_root = get_output_root()
194+
error_count = 0
195+
written = []
196+
197+
for idx, (ns_qualified, func) in enumerate(entries, 1):
198+
addr = func.getEntryPoint()
199+
func_name = func.getName()
200+
out_path = ns_path_to_file_path(output_root, ns_qualified, func_name)
201+
202+
print(" [{}/{}] {}::{} @{}".format(
203+
idx, len(entries), ns_qualified, func_name, addr))
204+
print(" -> {}".format(out_path))
205+
206+
code, err = decompile_function(func, iface, monitor)
207+
208+
with open(out_path, 'w', encoding='UTF-8') as fh:
209+
fh.write("// {}\n".format("=" * 76))
210+
fh.write("// Program : {}\n".format(currentProgram.getName()))
211+
fh.write("// Namespace : {}\n".format(ns_qualified))
212+
fh.write("// Function : {}\n".format(func_name))
213+
fh.write("// Address : {}\n".format(addr))
214+
if INCLUDE_SIGNATURE:
215+
fh.write("// Signature : {}\n".format(func.getSignature()))
216+
fh.write("// {}\n\n".format("=" * 76))
217+
218+
if err is not None:
219+
fh.write(err + "\n")
220+
error_count += 1
221+
print(" [!] " + err.strip())
222+
else:
223+
if not code:
224+
raise Exception("impossible situation")
225+
fh.write(code)
226+
fh.write("\n")
227+
228+
written.append(out_path)
229+
230+
iface.dispose()
231+
232+
summary = (
233+
"[+] Done. {} file(s) written, {} error(s).\n"
234+
" Output root: {}"
235+
).format(len(written), error_count, output_root)
236+
print(summary)
237+
popup(summary)
238+
239+
if __name__ == "__main__":
240+
# Ghidra calls run() automatically when the script is executed.
241+
run()

0 commit comments

Comments
 (0)