6060JSONRPC_DOWNSTREAM_TIMEOUT = - 32014
6161DEFAULT_DOWNSTREAM_RESPONSE_TIMEOUT_SECONDS = 30.0
6262MAX_DOWNSTREAM_MESSAGE_BYTES = 1 * 1024 * 1024
63+ MAX_CLIENT_MESSAGE_BYTES = 1 * 1024 * 1024
64+ MAX_PENDING_RESPONSES = 1000
65+ DEFAULT_TIMED_OUT_ID_RETENTION_SECONDS = 600.0
6366SAFE_ENV_KEYS = (
6467 "PATH" ,
6568 "HOME" ,
@@ -287,6 +290,46 @@ def jsonrpc_error(
287290 }
288291
289292
293+ def _read_bounded_line (client_in : TextIO , max_bytes : int ) -> tuple [str | None , bool ]:
294+ read = getattr (client_in , "read" , None )
295+ if not callable (read ):
296+ try :
297+ raw_line = next (client_in ) # type: ignore[arg-type]
298+ except StopIteration :
299+ return None , False
300+ raw_bytes = raw_line .encode ("utf-8" , errors = "replace" )
301+ if not raw_line .endswith ("\n " ) or len (raw_bytes .rstrip (b"\n " )) > max_bytes :
302+ return "" , True
303+ return raw_line , False
304+
305+ chunks : list [str ] = []
306+ byte_count = 0
307+ while True :
308+ char = read (1 )
309+ if char == "" :
310+ if chunks :
311+ return "" , True
312+ return None , False
313+ if char == "\n " :
314+ return "" .join (chunks ) + "\n " , False
315+ char_size = len (char .encode ("utf-8" , errors = "replace" ))
316+ if byte_count + char_size > max_bytes :
317+ _discard_line_remainder (client_in )
318+ return "" , True
319+ chunks .append (char )
320+ byte_count += char_size
321+
322+
323+ def _discard_line_remainder (client_in : TextIO ) -> None :
324+ read = getattr (client_in , "read" , None )
325+ if not callable (read ):
326+ return
327+ while True :
328+ char = read (1 )
329+ if char in {"" , "\n " }:
330+ return
331+
332+
290333def _blocked_error (
291334 request_id : Any ,
292335 message : str ,
@@ -356,8 +399,11 @@ def __init__(
356399 self ._runtime_gate_startup_error : Exception | None = None
357400 self ._runtime_gate_errors = 0
358401 self ._downstream_timeouts = 0
402+ self ._client_oversized_messages = 0
403+ self ._unsolicited_downstream_responses = 0
359404 self ._security_events : Deque [Mapping [str , Any ]] = deque (maxlen = 1000 )
360- self ._timed_out_response_ids : set [str ] = set ()
405+ self ._inflight_ids : set [str ] = set ()
406+ self ._timed_out_response_ids : dict [str , float ] = {}
361407 self ._windows_job : _WindowsJobObject | None = None
362408
363409 @property
@@ -384,6 +430,18 @@ def downstream_timeouts(self) -> int:
384430
385431 return self ._downstream_timeouts
386432
433+ @property
434+ def client_oversized_messages (self ) -> int :
435+ """Number of oversized or unterminated client messages rejected."""
436+
437+ return self ._client_oversized_messages
438+
439+ @property
440+ def unsolicited_downstream_responses (self ) -> int :
441+ """Number of downstream responses dropped for unknown client request IDs."""
442+
443+ return self ._unsolicited_downstream_responses
444+
387445 @property
388446 def security_events (self ) -> tuple [Mapping [str , Any ], ...]:
389447 """Sanitized in-memory security events for P5 failure handling."""
@@ -496,7 +554,21 @@ def run_stdio(self, client_in: TextIO, client_out: TextIO) -> int:
496554 self ._notification_writer = lambda message : self ._write_client (client_out , message )
497555 self .start ()
498556 try :
499- for raw_line in client_in :
557+ while True :
558+ raw_line , rejected = _read_bounded_line (client_in , MAX_CLIENT_MESSAGE_BYTES )
559+ if rejected :
560+ self ._increment_client_oversized_messages ()
561+ self ._write_client (
562+ client_out ,
563+ jsonrpc_error (
564+ None ,
565+ JSONRPC_INVALID_REQUEST ,
566+ "client request exceeds maximum size" ,
567+ ),
568+ )
569+ continue
570+ if raw_line is None :
571+ break
500572 if not raw_line .strip ():
501573 continue
502574 responses = self .handle_client_line (raw_line )
@@ -528,12 +600,17 @@ def handle_client_line(self, raw_line: str) -> list[dict[str, Any]]:
528600 policy_error , approval_outcome = self ._policy_error_response (classification , request_id )
529601 if policy_error is not None :
530602 return [policy_error ] if has_id else []
531- self ._send_downstream (message )
532- if not has_id :
533- return []
534- response = self ._wait_downstream_response (request_id )
535- self ._record_approval_result (approval_outcome , response )
536- return [response ]
603+ response_key = self ._register_inflight_id (request_id ) if has_id else None
604+ try :
605+ self ._send_downstream (message )
606+ if not has_id :
607+ return []
608+ response = self ._wait_downstream_response (request_id )
609+ self ._record_approval_result (approval_outcome , response )
610+ return [response ]
611+ finally :
612+ if response_key is not None :
613+ self ._unregister_inflight_id (response_key )
537614 except DownstreamTimeoutError :
538615 self ._increment_downstream_timeouts ()
539616 self ._record_approval_error (approval_outcome , "downstream_response_timeout" )
@@ -801,11 +878,31 @@ def _increment_downstream_timeouts(self) -> None:
801878 with self ._counters_lock :
802879 self ._downstream_timeouts += 1
803880
881+ def _increment_client_oversized_messages (self ) -> None :
882+ with self ._counters_lock :
883+ self ._client_oversized_messages += 1
884+
885+ def _increment_unsolicited_downstream_responses (self ) -> None :
886+ with self ._counters_lock :
887+ self ._unsolicited_downstream_responses += 1
888+
889+ def _register_inflight_id (self , request_id : Any ) -> str :
890+ response_key = self ._id_key (request_id )
891+ with self ._stdout_condition :
892+ self ._inflight_ids .add (response_key )
893+ return response_key
894+
895+ def _unregister_inflight_id (self , response_key : str ) -> None :
896+ with self ._stdout_condition :
897+ self ._inflight_ids .discard (response_key )
898+ self ._prune_pending_responses_locked ()
899+
804900 def _wait_downstream_response (self , expected_id : Any ) -> dict [str , Any ]:
805901 response_key = self ._id_key (expected_id )
806902 deadline = time .monotonic () + self .downstream .response_timeout_seconds
807903 with self ._stdout_condition :
808904 while True :
905+ self ._prune_timed_out_ids_locked ()
809906 queued = self ._responses .get (response_key )
810907 if queued :
811908 response = queued .pop (0 )
@@ -816,7 +913,9 @@ def _wait_downstream_response(self, expected_id: Any) -> dict[str, Any]:
816913 raise self ._downstream_error
817914 remaining = deadline - time .monotonic ()
818915 if remaining <= 0 :
819- self ._timed_out_response_ids .add (response_key )
916+ self ._timed_out_response_ids [
917+ response_key
918+ ] = time .monotonic () + DEFAULT_TIMED_OUT_ID_RETENTION_SECONDS
820919 raise DownstreamTimeoutError ("downstream response timed out" )
821920 self ._stdout_condition .wait (timeout = remaining )
822921
@@ -911,13 +1010,46 @@ def _handle_downstream_message(self, response: Any) -> None:
9111010 return
9121011 if "id" in response :
9131012 with self ._stdout_condition :
1013+ self ._prune_timed_out_ids_locked ()
9141014 response_key = self ._id_key (response .get ("id" ))
9151015 if response_key in self ._timed_out_response_ids :
916- self ._timed_out_response_ids .remove (response_key )
1016+ self ._timed_out_response_ids .pop (response_key , None )
1017+ return
1018+ if response_key not in self ._inflight_ids :
1019+ self ._increment_unsolicited_downstream_responses ()
9171020 return
9181021 self ._responses .setdefault (response_key , []).append (response )
1022+ self ._prune_pending_responses_locked ()
9191023 self ._stdout_condition .notify_all ()
9201024
1025+ def _prune_timed_out_ids_locked (self , now : float | None = None ) -> None :
1026+ now = time .monotonic () if now is None else now
1027+ expired = [
1028+ response_key
1029+ for response_key , expires_at in self ._timed_out_response_ids .items ()
1030+ if expires_at <= now
1031+ ]
1032+ for response_key in expired :
1033+ self ._timed_out_response_ids .pop (response_key , None )
1034+
1035+ def _prune_pending_responses_locked (self ) -> None :
1036+ pending_count = sum (len (responses ) for responses in self ._responses .values ())
1037+ while pending_count > MAX_PENDING_RESPONSES :
1038+ dropped = False
1039+ for response_key , responses in list (self ._responses .items ()):
1040+ if response_key in self ._inflight_ids :
1041+ continue
1042+ if responses :
1043+ responses .pop (0 )
1044+ pending_count -= 1
1045+ dropped = True
1046+ if not responses :
1047+ self ._responses .pop (response_key , None )
1048+ if pending_count <= MAX_PENDING_RESPONSES :
1049+ return
1050+ if not dropped :
1051+ return
1052+
9211053 def _downstream_buffer_too_large (self , buffer : str ) -> bool :
9221054 return len (buffer .encode ("utf-8" , errors = "replace" )) > MAX_DOWNSTREAM_MESSAGE_BYTES
9231055
0 commit comments