forked from zed-extensions/java
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.rs
More file actions
231 lines (203 loc) · 8.44 KB
/
main.rs
File metadata and controls
231 lines (203 loc) · 8.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
mod completions;
mod decompile;
mod http;
mod log;
mod lsp;
mod platform;
use completions::{should_sort_completions, sort_completions_by_param_count};
use decompile::rewrite_jdt_locations;
use http::handle_http;
use lsp::{parse_lsp_content, write_raw, write_to_stdout, LspReader};
use platform::spawn_parent_monitor;
use serde_json::Value;
use std::{
collections::{HashMap, HashSet},
env, fs,
io::{self, BufReader, Write},
net::TcpListener,
path::Path,
process::{self, Command, Stdio},
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
mpsc, Arc, Mutex,
},
thread,
};
fn main() {
let args: Vec<String> = env::args().skip(1).collect();
if args.len() < 2 {
eprintln!("Usage: java-lsp-proxy <workdir> <bin> [args...]");
lsp_error!("Usage: java-lsp-proxy <workdir> <bin> [args...]");
process::exit(1);
}
let workdir = &args[0];
let bin = &args[1];
let child_args = &args[2..];
lsp_info!("java-lsp-proxy starting: bin={bin}, workdir={workdir}");
let proxy_id = hex_encode(
env::current_dir()
.unwrap()
.to_string_lossy()
.trim_end_matches('/'),
);
// Spawn JDTLS (use shell on Windows for .bat files)
let mut cmd = Command::new(bin);
cmd.args(child_args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
#[cfg(windows)]
if bin.ends_with(".bat") || bin.ends_with(".cmd") {
cmd = Command::new("cmd");
cmd.arg("/C")
.arg(bin)
.args(child_args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
}
let mut child = cmd.spawn().unwrap_or_else(|e| {
eprintln!("Failed to spawn {bin}: {e}");
lsp_error!("Failed to spawn {bin}: {e}");
process::exit(1);
});
lsp_info!("JDTLS process spawned (pid={})", child.id());
let child_stdin = Arc::new(Mutex::new(child.stdin.take().unwrap()));
let child_stdout = child.stdout.take().unwrap();
let alive = Arc::new(AtomicBool::new(true));
let pending: Arc<Mutex<HashMap<Value, mpsc::Sender<Value>>>> =
Arc::new(Mutex::new(HashMap::new()));
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let port_file = Path::new(workdir).join("proxy").join(&proxy_id);
fs::create_dir_all(port_file.parent().unwrap()).unwrap();
fs::write(&port_file, port.to_string()).unwrap();
lsp_info!("HTTP server listening on 127.0.0.1:{port}");
let id_counter = Arc::new(AtomicU64::new(1));
// Track definition/typeDefinition/implementation request IDs for jdt:// rewriting
let definition_ids: Arc<Mutex<HashSet<Value>>> = Arc::new(Mutex::new(HashSet::new()));
// --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) ---
let stdin_writer = Arc::clone(&child_stdin);
let alive_stdin = Arc::clone(&alive);
let def_ids_in = Arc::clone(&definition_ids);
thread::spawn(move || {
let stdin = io::stdin().lock();
let mut reader = LspReader::new(stdin);
while alive_stdin.load(Ordering::Relaxed) {
match reader.read_message() {
Ok(Some(raw)) => {
if let Some(msg) = parse_lsp_content(&raw) {
if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
if matches!(
method,
"textDocument/definition"
| "textDocument/typeDefinition"
| "textDocument/implementation"
) {
if let Some(id) = msg.get("id").cloned() {
def_ids_in.lock().unwrap().insert(id);
}
}
}
}
let mut w = stdin_writer.lock().unwrap();
if w.write_all(&raw).is_err() || w.flush().is_err() {
break;
}
}
Ok(None) | Err(_) => break,
}
}
alive_stdin.store(false, Ordering::Relaxed);
});
// --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending ---
let pending_out = Arc::clone(&pending);
let alive_out = Arc::clone(&alive);
let def_ids_out = Arc::clone(&definition_ids);
let decompile_writer = Arc::clone(&child_stdin);
let decompile_pending = Arc::clone(&pending);
let decompile_counter = Arc::clone(&id_counter);
let decompile_proxy_id = proxy_id.clone();
thread::spawn(move || {
let mut reader = LspReader::new(BufReader::new(child_stdout));
while alive_out.load(Ordering::Relaxed) {
match reader.read_message() {
Ok(Some(raw)) => {
let Some(mut msg) = parse_lsp_content(&raw) else {
write_raw(&mut io::stdout().lock(), &raw);
continue;
};
// Route responses to pending HTTP requests
if let Some(id) = msg.get("id") {
if let Some(tx) = pending_out.lock().unwrap().remove(id) {
let _ = tx.send(msg);
continue;
}
}
// Rewrite jdt:// URIs in definition responses
// Spawns a thread so this loop stays unblocked and can
// route the java/classFileContents response back via `pending`.
if let Some(id) = msg.get("id").cloned() {
if def_ids_out.lock().unwrap().remove(&id) {
let writer = Arc::clone(&decompile_writer);
let pending = Arc::clone(&decompile_pending);
let pid = decompile_proxy_id.clone();
let counter = Arc::clone(&decompile_counter);
thread::spawn(move || {
let mut next_id = move || {
let seq = counter.fetch_add(1, Ordering::Relaxed);
Value::String(format!("{pid}-decompile-{seq}"))
};
rewrite_jdt_locations(&mut msg, &writer, &pending, &mut next_id);
write_to_stdout(&msg);
});
continue;
}
}
// Sort completion responses by param count
if should_sort_completions(&msg) {
sort_completions_by_param_count(&mut msg);
write_to_stdout(&msg);
continue;
}
// Passthrough
write_raw(&mut io::stdout().lock(), &raw);
}
Ok(None) | Err(_) => break,
}
}
alive_out.store(false, Ordering::Relaxed);
});
// --- Thread 3: HTTP server for extension requests ---
let http_writer = Arc::clone(&child_stdin);
let http_pending = Arc::clone(&pending);
let http_alive = Arc::clone(&alive);
let http_id_counter = Arc::clone(&id_counter);
let http_proxy_id = proxy_id.clone();
thread::spawn(move || {
for stream in listener.incoming() {
if !http_alive.load(Ordering::Relaxed) {
break;
}
let Ok(stream) = stream else { continue };
let writer = Arc::clone(&http_writer);
let pend = Arc::clone(&http_pending);
let counter = Arc::clone(&http_id_counter);
let pid = http_proxy_id.clone();
thread::spawn(move || {
handle_http(stream, writer, pend, counter, &pid);
});
}
});
// --- Thread 4: Parent process monitor ---
spawn_parent_monitor(Arc::clone(&alive), child.id());
// Wait for child to exit
let status = child.wait();
lsp_info!("JDTLS process exited: {status:?}");
alive.store(false, Ordering::Relaxed);
let _ = fs::remove_file(&port_file);
}
// --- Utilities ---
fn hex_encode(s: &str) -> String {
s.as_bytes().iter().map(|b| format!("{b:02x}")).collect()
}