Skip to content

Commit c55c60a

Browse files
authored
Add decompile module (#221)
Implement a workaround to get decompiled files working within Zed. The proxy intercepts definitions requests being returned from JDTLS with a URI starting with `jdt://` and creates a temp java file which Zed is then able to read and further decompilation can happen. To work, JDTLS needs to receive additional client capabilities within the initialization options. If the user has not set such the extension will enable those by default.
1 parent 584465a commit c55c60a

4 files changed

Lines changed: 230 additions & 20 deletions

File tree

proxy/src/decompile.rs

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use serde_json::{json, Value};
2+
use std::{
3+
collections::{hash_map::DefaultHasher, HashMap},
4+
env, fs,
5+
hash::{Hash, Hasher},
6+
io::Write,
7+
path::{Path, PathBuf},
8+
sync::{mpsc, Arc, Mutex},
9+
};
10+
11+
use crate::{lsp::encode_lsp, lsp_error, lsp_warn};
12+
13+
const DECOMPILED_DIR: &str = "jdtls-decompiled";
14+
15+
/// Convert a `PathBuf` to a proper `file://` URI.
16+
///
17+
/// On Unix the path already starts with `/`, so `file://` + path gives us
18+
/// the correct `file:///…` form with no extra work.
19+
///
20+
/// On Windows we must replace `\` with `/` and prepend `file:///` before the
21+
/// drive letter so that we get `file:///C:/…` instead of `file://C:\…`.
22+
#[cfg(unix)]
23+
fn path_to_file_uri(path: &Path) -> String {
24+
format!("file://{}", path.display())
25+
}
26+
27+
#[cfg(windows)]
28+
fn path_to_file_uri(path: &Path) -> String {
29+
let s = path.display().to_string().replace('\\', "/");
30+
format!("file:///{s}")
31+
}
32+
33+
fn cache_dir() -> PathBuf {
34+
env::temp_dir().join(DECOMPILED_DIR)
35+
}
36+
37+
fn cache_path(uri: &str) -> PathBuf {
38+
let mut hasher = DefaultHasher::new();
39+
uri.hash(&mut hasher);
40+
let hex = format!("{:016x}", hasher.finish());
41+
42+
// jdt://contents/java.base/java.util/ArrayList.java?=.../%3Cjava.util%28ArrayList.class
43+
// The class name is between the last %28 (URL-encoded '(') and .class at the end
44+
let name = uri
45+
.rsplit_once("%28")
46+
.and_then(|(_, rest)| rest.strip_suffix(".class"))
47+
.or_else(|| {
48+
uri.split('?')
49+
.next()
50+
.and_then(|path| path.rsplit('/').next())
51+
.and_then(|seg| seg.strip_suffix(".java").or(seg.strip_suffix(".class")))
52+
})
53+
.unwrap_or("Decompiled");
54+
55+
cache_dir().join(format!("{name}-{hex}.java"))
56+
}
57+
58+
/// Send `java/classFileContents` to JDTLS and wait for the response.
59+
fn fetch_class_contents(
60+
uri: &str,
61+
writer: &Arc<Mutex<impl Write>>,
62+
pending: &Arc<Mutex<HashMap<Value, mpsc::Sender<Value>>>>,
63+
request_id: Value,
64+
) -> Option<String> {
65+
let (tx, rx) = mpsc::channel();
66+
pending.lock().unwrap().insert(request_id.clone(), tx);
67+
68+
let req = encode_lsp(&json!({
69+
"jsonrpc": "2.0",
70+
"id": request_id,
71+
"method": "java/classFileContents",
72+
"params": { "uri": uri }
73+
}));
74+
{
75+
let mut w = writer.lock().unwrap();
76+
let _ = w.write_all(req.as_bytes());
77+
let _ = w.flush();
78+
}
79+
80+
match rx.recv_timeout(std::time::Duration::from_secs(10)) {
81+
Ok(resp) => {
82+
let content = resp.get("result")?.as_str()?;
83+
Some(content.to_string())
84+
}
85+
Err(_) => {
86+
lsp_warn!("[decompile] Timed out fetching class contents for {uri}");
87+
None
88+
}
89+
}
90+
}
91+
92+
fn resolve_jdt_uri(
93+
uri: &str,
94+
writer: &Arc<Mutex<impl Write>>,
95+
pending: &Arc<Mutex<HashMap<Value, mpsc::Sender<Value>>>>,
96+
request_id: Value,
97+
) -> Option<String> {
98+
let path = cache_path(uri);
99+
if path.exists() {
100+
return Some(path_to_file_uri(&path));
101+
}
102+
103+
let content = fetch_class_contents(uri, writer, pending, request_id)?;
104+
let _ = fs::create_dir_all(cache_dir());
105+
match fs::write(&path, &content) {
106+
Ok(_) => Some(path_to_file_uri(&path)),
107+
Err(e) => {
108+
lsp_error!("[decompile] Failed to write {}: {e}", path.display());
109+
None
110+
}
111+
}
112+
}
113+
114+
/// Rewrite any `jdt://` URIs in a definition/typeDefinition/implementation response.
115+
/// Returns `true` if any URI was rewritten.
116+
pub fn rewrite_jdt_locations(
117+
msg: &mut Value,
118+
writer: &Arc<Mutex<impl Write>>,
119+
pending: &Arc<Mutex<HashMap<Value, mpsc::Sender<Value>>>>,
120+
next_id: &mut impl FnMut() -> Value,
121+
) -> bool {
122+
let results = match msg.get_mut("result") {
123+
Some(Value::Array(arr)) => arr.iter_mut().collect::<Vec<_>>(),
124+
Some(obj @ Value::Object(_)) => vec![obj],
125+
_ => return false,
126+
};
127+
128+
let mut rewritten = false;
129+
for loc in results {
130+
for key in &["uri", "targetUri"] {
131+
if let Some(Value::String(uri)) = loc.get(key) {
132+
if uri.starts_with("jdt://") {
133+
let jdt_uri = uri.clone();
134+
if let Some(file_uri) = resolve_jdt_uri(&jdt_uri, writer, pending, next_id()) {
135+
loc[*key] = Value::String(file_uri);
136+
rewritten = true;
137+
}
138+
}
139+
}
140+
}
141+
}
142+
rewritten
143+
}

proxy/src/lsp.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use serde::Serialize;
2-
use std::io::{self, Read};
2+
use std::io::{self, Read, Write};
33

44
pub const CONTENT_LENGTH: &str = "Content-Length";
55
pub const HEADER_SEP: &[u8] = b"\r\n\r\n";
@@ -58,3 +58,17 @@ pub fn encode_lsp(value: &impl Serialize) -> String {
5858
let json = serde_json::to_string(value).unwrap();
5959
format!("{CONTENT_LENGTH}: {}\r\n\r\n{json}", json.len())
6060
}
61+
62+
/// Write raw LSP bytes to a writer, flushing afterward.
63+
pub fn write_raw(w: &mut impl Write, raw: &[u8]) {
64+
let _ = w.write_all(raw);
65+
let _ = w.flush();
66+
}
67+
68+
/// Encode a value as an LSP message and write it to stdout.
69+
pub fn write_to_stdout(value: &impl Serialize) {
70+
let out = encode_lsp(value);
71+
let mut w = io::stdout().lock();
72+
let _ = w.write_all(out.as_bytes());
73+
let _ = w.flush();
74+
}

proxy/src/main.rs

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
mod completions;
2+
mod decompile;
23
mod http;
34
mod log;
45
mod lsp;
56
mod platform;
67

78
use completions::{should_sort_completions, sort_completions_by_param_count};
9+
use decompile::rewrite_jdt_locations;
810
use http::handle_http;
9-
use lsp::{encode_lsp, parse_lsp_content, LspReader};
11+
use lsp::{parse_lsp_content, write_raw, write_to_stdout, LspReader};
1012
use platform::spawn_parent_monitor;
1113
use serde_json::Value;
1214
use std::{
13-
collections::HashMap,
15+
collections::{HashMap, HashSet},
1416
env, fs,
1517
io::{self, BufReader, Write},
1618
net::TcpListener,
@@ -88,15 +90,33 @@ fn main() {
8890

8991
let id_counter = Arc::new(AtomicU64::new(1));
9092

91-
// --- Thread 1: Zed stdin -> JDTLS stdin (passthrough) ---
93+
// Track definition/typeDefinition/implementation request IDs for jdt:// rewriting
94+
let definition_ids: Arc<Mutex<HashSet<Value>>> = Arc::new(Mutex::new(HashSet::new()));
95+
96+
// --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) ---
9297
let stdin_writer = Arc::clone(&child_stdin);
9398
let alive_stdin = Arc::clone(&alive);
99+
let def_ids_in = Arc::clone(&definition_ids);
94100
thread::spawn(move || {
95101
let stdin = io::stdin().lock();
96102
let mut reader = LspReader::new(stdin);
97103
while alive_stdin.load(Ordering::Relaxed) {
98104
match reader.read_message() {
99105
Ok(Some(raw)) => {
106+
if let Some(msg) = parse_lsp_content(&raw) {
107+
if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
108+
if matches!(
109+
method,
110+
"textDocument/definition"
111+
| "textDocument/typeDefinition"
112+
| "textDocument/implementation"
113+
) {
114+
if let Some(id) = msg.get("id").cloned() {
115+
def_ids_in.lock().unwrap().insert(id);
116+
}
117+
}
118+
}
119+
}
100120
let mut w = stdin_writer.lock().unwrap();
101121
if w.write_all(&raw).is_err() || w.flush().is_err() {
102122
break;
@@ -108,19 +128,21 @@ fn main() {
108128
alive_stdin.store(false, Ordering::Relaxed);
109129
});
110130

111-
// --- Thread 2: JDTLS stdout -> modify completions -> Zed stdout / resolve pending ---
131+
// --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending ---
112132
let pending_out = Arc::clone(&pending);
113133
let alive_out = Arc::clone(&alive);
134+
let def_ids_out = Arc::clone(&definition_ids);
135+
let decompile_writer = Arc::clone(&child_stdin);
136+
let decompile_pending = Arc::clone(&pending);
137+
let decompile_counter = Arc::clone(&id_counter);
138+
let decompile_proxy_id = proxy_id.clone();
114139
thread::spawn(move || {
115140
let mut reader = LspReader::new(BufReader::new(child_stdout));
116-
let stdout = io::stdout();
117141
while alive_out.load(Ordering::Relaxed) {
118142
match reader.read_message() {
119143
Ok(Some(raw)) => {
120144
let Some(mut msg) = parse_lsp_content(&raw) else {
121-
let mut w = stdout.lock();
122-
let _ = w.write_all(&raw);
123-
let _ = w.flush();
145+
write_raw(&mut io::stdout().lock(), &raw);
124146
continue;
125147
};
126148

@@ -132,20 +154,36 @@ fn main() {
132154
}
133155
}
134156

157+
// Rewrite jdt:// URIs in definition responses
158+
// Spawns a thread so this loop stays unblocked and can
159+
// route the java/classFileContents response back via `pending`.
160+
if let Some(id) = msg.get("id").cloned() {
161+
if def_ids_out.lock().unwrap().remove(&id) {
162+
let writer = Arc::clone(&decompile_writer);
163+
let pending = Arc::clone(&decompile_pending);
164+
let pid = decompile_proxy_id.clone();
165+
let counter = Arc::clone(&decompile_counter);
166+
thread::spawn(move || {
167+
let mut next_id = move || {
168+
let seq = counter.fetch_add(1, Ordering::Relaxed);
169+
Value::String(format!("{pid}-decompile-{seq}"))
170+
};
171+
rewrite_jdt_locations(&mut msg, &writer, &pending, &mut next_id);
172+
write_to_stdout(&msg);
173+
});
174+
continue;
175+
}
176+
}
177+
135178
// Sort completion responses by param count
136179
if should_sort_completions(&msg) {
137180
sort_completions_by_param_count(&mut msg);
138-
let out = encode_lsp(&msg);
139-
let mut w = stdout.lock();
140-
let _ = w.write_all(out.as_bytes());
141-
let _ = w.flush();
181+
write_to_stdout(&msg);
142182
continue;
143183
}
144184

145185
// Passthrough
146-
let mut w = stdout.lock();
147-
let _ = w.write_all(&raw);
148-
let _ = w.flush();
186+
write_raw(&mut io::stdout().lock(), &raw);
149187
}
150188
Ok(None) | Err(_) => break,
151189
}

src/java.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,21 +380,36 @@ impl Extension for Java {
380380
})?;
381381
}
382382

383-
let options = LspSettings::for_worktree(language_server_id.as_ref(), worktree)
383+
let mut options = LspSettings::for_worktree(language_server_id.as_ref(), worktree)
384384
.map(|lsp_settings| lsp_settings.initialization_options)
385-
.map_err(|err| format!("Failed to get LSP settings for worktree: {err}"))?;
385+
.map_err(|err| format!("Failed to get LSP settings for worktree: {err}"))?
386+
.unwrap_or_else(|| json!({}));
387+
388+
// Inject extendedClientCapabilities defaults if not already set by the user
389+
let caps = options
390+
.as_object_mut()
391+
.unwrap()
392+
.entry("extendedClientCapabilities")
393+
.or_insert_with(|| json!({}));
394+
let caps_obj = caps.as_object_mut().unwrap();
395+
caps_obj
396+
.entry("classFileContentsSupport")
397+
.or_insert(json!(true));
398+
caps_obj
399+
.entry("resolveAdditionalTextEditsSupport")
400+
.or_insert(json!(true));
386401

387402
if self.debugger().is_ok_and(|v| v.loaded()) {
388403
return Ok(Some(
389404
self.debugger()?
390-
.inject_plugin_into_options(options)
405+
.inject_plugin_into_options(Some(options))
391406
.map_err(|err| {
392407
format!("Failed to inject debugger plugin into options: {err}")
393408
})?,
394409
));
395410
}
396411

397-
Ok(options)
412+
Ok(Some(options))
398413
}
399414

400415
fn language_server_workspace_configuration(

0 commit comments

Comments
 (0)