@@ -6,13 +6,13 @@ mod lsp;
66mod platform;
77
88use completions:: { should_sort_completions, sort_completions_by_param_count} ;
9- use decompile:: rewrite_jdt_locations;
9+ use decompile:: { rewrite_jdt_in_strings , rewrite_jdt_locations} ;
1010use http:: handle_http;
11- use lsp:: { parse_lsp_content, write_raw, write_to_stdout, LspReader } ;
11+ use lsp:: { parse_lsp_content, raw_has_id , write_raw, write_to_stdout, LspReader } ;
1212use platform:: spawn_parent_monitor;
1313use serde_json:: Value ;
1414use std:: {
15- collections:: { HashMap , HashSet } ,
15+ collections:: HashMap ,
1616 env, fs,
1717 io:: { self , BufReader , Write } ,
1818 net:: TcpListener ,
@@ -25,6 +25,12 @@ use std::{
2525 thread,
2626} ;
2727
28+ #[ derive( Clone , Copy ) ]
29+ enum TrackedKind {
30+ Definition ,
31+ Doc ,
32+ }
33+
2834fn main ( ) {
2935 let args: Vec < String > = env:: args ( ) . skip ( 1 ) . collect ( ) ;
3036 if args. len ( ) < 2 {
@@ -90,29 +96,41 @@ fn main() {
9096
9197 let id_counter = Arc :: new ( AtomicU64 :: new ( 1 ) ) ;
9298
93- // Track definition/typeDefinition/implementation request IDs for jdt:// rewriting
94- let definition_ids: Arc < Mutex < HashSet < Value > > > = Arc :: new ( Mutex :: new ( HashSet :: new ( ) ) ) ;
99+ // Track definition/typeDefinition/implementation and documentation request IDs
100+ // so their responses can be intercepted and rewritten.
101+ let tracked_ids: Arc < Mutex < HashMap < Value , TrackedKind > > > = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
95102
96103 // --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) ---
97104 let stdin_writer = Arc :: clone ( & child_stdin) ;
98105 let alive_stdin = Arc :: clone ( & alive) ;
99- let def_ids_in = Arc :: clone ( & definition_ids ) ;
106+ let tracked_in = Arc :: clone ( & tracked_ids ) ;
100107 thread:: spawn ( move || {
101108 let stdin = io:: stdin ( ) . lock ( ) ;
102- let mut reader = LspReader :: new ( stdin) ;
109+ let mut reader = LspReader :: new ( BufReader :: new ( stdin) ) ;
103110 while alive_stdin. load ( Ordering :: Relaxed ) {
104111 match reader. read_message ( ) {
105112 Ok ( Some ( raw) ) => {
106- if let Some ( msg) = parse_lsp_content ( & raw ) {
107- if let Some ( method) = msg. get ( "method" ) . and_then ( |m| m. as_str ( ) ) {
108- if matches ! (
109- method,
110- "textDocument/definition"
113+ // Only requests (not notifications) carry an `id`; skip the
114+ // JSON parse entirely for high-volume notifications like
115+ // textDocument/didChange.
116+ if raw_has_id ( & raw ) {
117+ if let Some ( msg) = parse_lsp_content ( & raw ) {
118+ if let Some ( method) = msg. get ( "method" ) . and_then ( |m| m. as_str ( ) ) {
119+ let kind = match method {
120+ "textDocument/definition"
111121 | "textDocument/typeDefinition"
112- | "textDocument/implementation"
113- ) {
114- if let Some ( id) = msg. get ( "id" ) . cloned ( ) {
115- def_ids_in. lock ( ) . unwrap ( ) . insert ( id) ;
122+ | "textDocument/implementation" => {
123+ Some ( TrackedKind :: Definition )
124+ }
125+ "textDocument/hover"
126+ | "textDocument/signatureHelp"
127+ | "completionItem/resolve" => Some ( TrackedKind :: Doc ) ,
128+ _ => None ,
129+ } ;
130+ if let Some ( kind) = kind {
131+ if let Some ( id) = msg. get ( "id" ) . cloned ( ) {
132+ tracked_in. lock ( ) . unwrap ( ) . insert ( id, kind) ;
133+ }
116134 }
117135 }
118136 }
@@ -131,7 +149,7 @@ fn main() {
131149 // --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending ---
132150 let pending_out = Arc :: clone ( & pending) ;
133151 let alive_out = Arc :: clone ( & alive) ;
134- let def_ids_out = Arc :: clone ( & definition_ids ) ;
152+ let tracked_out = Arc :: clone ( & tracked_ids ) ;
135153 let decompile_writer = Arc :: clone ( & child_stdin) ;
136154 let decompile_pending = Arc :: clone ( & pending) ;
137155 let decompile_counter = Arc :: clone ( & id_counter) ;
@@ -141,6 +159,13 @@ fn main() {
141159 while alive_out. load ( Ordering :: Relaxed ) {
142160 match reader. read_message ( ) {
143161 Ok ( Some ( raw) ) => {
162+ // Fast path: notifications (no `id`) can't be responses we
163+ // need to intercept. Forward the raw bytes without parsing.
164+ if !raw_has_id ( & raw ) {
165+ write_raw ( & mut io:: stdout ( ) . lock ( ) , & raw ) ;
166+ continue ;
167+ }
168+
144169 let Some ( mut msg) = parse_lsp_content ( & raw ) else {
145170 write_raw ( & mut io:: stdout ( ) . lock ( ) , & raw ) ;
146171 continue ;
@@ -154,11 +179,11 @@ fn main() {
154179 }
155180 }
156181
157- // Rewrite jdt:// URIs in definition responses
158- // Spawns a thread so this loop stays unblocked and can
159- // route the java/classFileContents response back via `pending`.
182+ // Rewrite jdt:// URIs in definition or documentation responses.
183+ // Spawns a thread so this loop stays unblocked and can route
184+ // the java/classFileContents response back via `pending`.
160185 if let Some ( id) = msg. get ( "id" ) . cloned ( ) {
161- if def_ids_out . lock ( ) . unwrap ( ) . remove ( & id) {
186+ if let Some ( kind ) = tracked_out . lock ( ) . unwrap ( ) . remove ( & id) {
162187 let writer = Arc :: clone ( & decompile_writer) ;
163188 let pending = Arc :: clone ( & decompile_pending) ;
164189 let pid = decompile_proxy_id. clone ( ) ;
@@ -168,7 +193,24 @@ fn main() {
168193 let seq = counter. fetch_add ( 1 , Ordering :: Relaxed ) ;
169194 Value :: String ( format ! ( "{pid}-decompile-{seq}" ) )
170195 } ;
171- rewrite_jdt_locations ( & mut msg, & writer, & pending, & mut next_id) ;
196+ match kind {
197+ TrackedKind :: Definition => {
198+ rewrite_jdt_locations (
199+ & mut msg,
200+ & writer,
201+ & pending,
202+ & mut next_id,
203+ ) ;
204+ }
205+ TrackedKind :: Doc => {
206+ rewrite_jdt_in_strings (
207+ & mut msg,
208+ & writer,
209+ & pending,
210+ & mut next_id,
211+ ) ;
212+ }
213+ }
172214 write_to_stdout ( & msg) ;
173215 } ) ;
174216 continue ;
0 commit comments