Skip to content

Commit 11c4432

Browse files
authored
Add jdt:// URI rewriting for documentation responses (#224)
Extends the LSP proxy to resolve jdt:// URIs embedded inside documentation responses, and applies a few performance optimizations along the way
1 parent c55c60a commit 11c4432

3 files changed

Lines changed: 152 additions & 22 deletions

File tree

proxy/src/decompile.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,80 @@ pub fn rewrite_jdt_locations(
141141
}
142142
rewritten
143143
}
144+
145+
/// A jdt:// URI in embedded markdown/text terminates at whitespace or any of these
146+
/// delimiters commonly used in markdown links and JSON strings. The URI itself only
147+
/// contains URL-encoded forms of these characters, so scanning until we hit one of
148+
/// them is safe.
149+
fn jdt_uri_end(s: &str) -> usize {
150+
s.find(|c: char| c.is_whitespace() || matches!(c, ')' | ']' | '"' | '>' | '`' | '\''))
151+
.unwrap_or(s.len())
152+
}
153+
154+
/// Extract all unique `jdt://` URIs appearing inside any string in `value`.
155+
fn collect_jdt_uris(value: &Value, out: &mut Vec<String>) {
156+
match value {
157+
Value::String(s) => {
158+
let mut rest = s.as_str();
159+
while let Some(pos) = rest.find("jdt://") {
160+
let tail = &rest[pos..];
161+
let end = jdt_uri_end(tail);
162+
let uri = tail[..end].to_string();
163+
if !out.contains(&uri) {
164+
out.push(uri);
165+
}
166+
rest = &tail[end..];
167+
}
168+
}
169+
Value::Array(arr) => arr.iter().for_each(|v| collect_jdt_uris(v, out)),
170+
Value::Object(obj) => obj.values().for_each(|v| collect_jdt_uris(v, out)),
171+
_ => {}
172+
}
173+
}
174+
175+
/// Replace all occurrences of any key in `map` with its value, inside every string
176+
/// contained in `value` (recursively).
177+
fn replace_in_strings(value: &mut Value, map: &HashMap<String, String>) {
178+
match value {
179+
Value::String(s) => {
180+
for (from, to) in map {
181+
if s.contains(from.as_str()) {
182+
*s = s.replace(from.as_str(), to);
183+
}
184+
}
185+
}
186+
Value::Array(arr) => arr.iter_mut().for_each(|v| replace_in_strings(v, map)),
187+
Value::Object(obj) => obj.values_mut().for_each(|v| replace_in_strings(v, map)),
188+
_ => {}
189+
}
190+
}
191+
192+
/// Scan a documentation response (hover, signatureHelp, completionItem/resolve, …)
193+
/// for embedded `jdt://` URIs, resolve each one to a `file://` URI backed by a temp
194+
/// file, and replace the URIs in-place in every string of `msg.result`.
195+
pub fn rewrite_jdt_in_strings(
196+
msg: &mut Value,
197+
writer: &Arc<Mutex<impl Write>>,
198+
pending: &Arc<Mutex<HashMap<Value, mpsc::Sender<Value>>>>,
199+
next_id: &mut impl FnMut() -> Value,
200+
) {
201+
let Some(result) = msg.get_mut("result") else {
202+
return;
203+
};
204+
205+
let mut uris = Vec::new();
206+
collect_jdt_uris(result, &mut uris);
207+
if uris.is_empty() {
208+
return;
209+
}
210+
211+
let mut map = HashMap::new();
212+
for uri in uris {
213+
if let Some(file_uri) = resolve_jdt_uri(&uri, writer, pending, next_id()) {
214+
map.insert(uri, file_uri);
215+
}
216+
}
217+
if !map.is_empty() {
218+
replace_in_strings(result, &map);
219+
}
220+
}

proxy/src/lsp.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ pub fn parse_lsp_content(raw: &[u8]) -> Option<serde_json::Value> {
5454
serde_json::from_slice(&raw[sep_pos + 4..]).ok()
5555
}
5656

57+
/// Cheap check for the presence of an `"id"` key in the JSON body of a raw LSP
58+
/// message. Used to skip full JSON parsing for notifications, which carry no
59+
/// `id` and therefore cannot be responses or completion results.
60+
pub fn raw_has_id(raw: &[u8]) -> bool {
61+
let Some(sep_pos) = raw.windows(4).position(|w| w == HEADER_SEP) else {
62+
return false;
63+
};
64+
let body = &raw[sep_pos + 4..];
65+
body.windows(5).any(|w| w == b"\"id\":")
66+
}
67+
5768
pub fn encode_lsp(value: &impl Serialize) -> String {
5869
let json = serde_json::to_string(value).unwrap();
5970
format!("{CONTENT_LENGTH}: {}\r\n\r\n{json}", json.len())

proxy/src/main.rs

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ mod lsp;
66
mod platform;
77

88
use completions::{should_sort_completions, sort_completions_by_param_count};
9-
use decompile::rewrite_jdt_locations;
9+
use decompile::{rewrite_jdt_in_strings, rewrite_jdt_locations};
1010
use http::handle_http;
11-
use lsp::{parse_lsp_content, write_raw, write_to_stdout, LspReader};
11+
use lsp::{parse_lsp_content, raw_has_id, write_raw, write_to_stdout, LspReader};
1212
use platform::spawn_parent_monitor;
1313
use serde_json::Value;
1414
use std::{
15-
collections::{HashMap, HashSet},
15+
collections::HashMap,
1616
env, fs,
1717
io::{self, BufReader, Write},
1818
net::TcpListener,
@@ -25,6 +25,12 @@ use std::{
2525
thread,
2626
};
2727

28+
#[derive(Clone, Copy)]
29+
enum TrackedKind {
30+
Definition,
31+
Doc,
32+
}
33+
2834
fn main() {
2935
let args: Vec<String> = env::args().skip(1).collect();
3036
if args.len() < 2 {
@@ -90,29 +96,41 @@ fn main() {
9096

9197
let id_counter = Arc::new(AtomicU64::new(1));
9298

93-
// Track definition/typeDefinition/implementation request IDs for jdt:// rewriting
94-
let definition_ids: Arc<Mutex<HashSet<Value>>> = Arc::new(Mutex::new(HashSet::new()));
99+
// Track definition/typeDefinition/implementation and documentation request IDs
100+
// so their responses can be intercepted and rewritten.
101+
let tracked_ids: Arc<Mutex<HashMap<Value, TrackedKind>>> = Arc::new(Mutex::new(HashMap::new()));
95102

96103
// --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) ---
97104
let stdin_writer = Arc::clone(&child_stdin);
98105
let alive_stdin = Arc::clone(&alive);
99-
let def_ids_in = Arc::clone(&definition_ids);
106+
let tracked_in = Arc::clone(&tracked_ids);
100107
thread::spawn(move || {
101108
let stdin = io::stdin().lock();
102-
let mut reader = LspReader::new(stdin);
109+
let mut reader = LspReader::new(BufReader::new(stdin));
103110
while alive_stdin.load(Ordering::Relaxed) {
104111
match reader.read_message() {
105112
Ok(Some(raw)) => {
106-
if let Some(msg) = parse_lsp_content(&raw) {
107-
if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
108-
if matches!(
109-
method,
110-
"textDocument/definition"
113+
// Only requests (not notifications) carry an `id`; skip the
114+
// JSON parse entirely for high-volume notifications like
115+
// textDocument/didChange.
116+
if raw_has_id(&raw) {
117+
if let Some(msg) = parse_lsp_content(&raw) {
118+
if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
119+
let kind = match method {
120+
"textDocument/definition"
111121
| "textDocument/typeDefinition"
112-
| "textDocument/implementation"
113-
) {
114-
if let Some(id) = msg.get("id").cloned() {
115-
def_ids_in.lock().unwrap().insert(id);
122+
| "textDocument/implementation" => {
123+
Some(TrackedKind::Definition)
124+
}
125+
"textDocument/hover"
126+
| "textDocument/signatureHelp"
127+
| "completionItem/resolve" => Some(TrackedKind::Doc),
128+
_ => None,
129+
};
130+
if let Some(kind) = kind {
131+
if let Some(id) = msg.get("id").cloned() {
132+
tracked_in.lock().unwrap().insert(id, kind);
133+
}
116134
}
117135
}
118136
}
@@ -131,7 +149,7 @@ fn main() {
131149
// --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending ---
132150
let pending_out = Arc::clone(&pending);
133151
let alive_out = Arc::clone(&alive);
134-
let def_ids_out = Arc::clone(&definition_ids);
152+
let tracked_out = Arc::clone(&tracked_ids);
135153
let decompile_writer = Arc::clone(&child_stdin);
136154
let decompile_pending = Arc::clone(&pending);
137155
let decompile_counter = Arc::clone(&id_counter);
@@ -141,6 +159,13 @@ fn main() {
141159
while alive_out.load(Ordering::Relaxed) {
142160
match reader.read_message() {
143161
Ok(Some(raw)) => {
162+
// Fast path: notifications (no `id`) can't be responses we
163+
// need to intercept. Forward the raw bytes without parsing.
164+
if !raw_has_id(&raw) {
165+
write_raw(&mut io::stdout().lock(), &raw);
166+
continue;
167+
}
168+
144169
let Some(mut msg) = parse_lsp_content(&raw) else {
145170
write_raw(&mut io::stdout().lock(), &raw);
146171
continue;
@@ -154,11 +179,11 @@ fn main() {
154179
}
155180
}
156181

157-
// Rewrite jdt:// URIs in definition responses
158-
// Spawns a thread so this loop stays unblocked and can
159-
// route the java/classFileContents response back via `pending`.
182+
// Rewrite jdt:// URIs in definition or documentation responses.
183+
// Spawns a thread so this loop stays unblocked and can route
184+
// the java/classFileContents response back via `pending`.
160185
if let Some(id) = msg.get("id").cloned() {
161-
if def_ids_out.lock().unwrap().remove(&id) {
186+
if let Some(kind) = tracked_out.lock().unwrap().remove(&id) {
162187
let writer = Arc::clone(&decompile_writer);
163188
let pending = Arc::clone(&decompile_pending);
164189
let pid = decompile_proxy_id.clone();
@@ -168,7 +193,24 @@ fn main() {
168193
let seq = counter.fetch_add(1, Ordering::Relaxed);
169194
Value::String(format!("{pid}-decompile-{seq}"))
170195
};
171-
rewrite_jdt_locations(&mut msg, &writer, &pending, &mut next_id);
196+
match kind {
197+
TrackedKind::Definition => {
198+
rewrite_jdt_locations(
199+
&mut msg,
200+
&writer,
201+
&pending,
202+
&mut next_id,
203+
);
204+
}
205+
TrackedKind::Doc => {
206+
rewrite_jdt_in_strings(
207+
&mut msg,
208+
&writer,
209+
&pending,
210+
&mut next_id,
211+
);
212+
}
213+
}
172214
write_to_stdout(&msg);
173215
});
174216
continue;

0 commit comments

Comments
 (0)