Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod r_session;
mod sandbox;
mod sandbox_cli;
mod server;
mod stdin_payload;
#[cfg(target_os = "windows")]
mod windows_sandbox;
mod worker;
Expand Down
26 changes: 1 addition & 25 deletions src/server/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rmcp::model::{
use serde_json::Value;
use tempfile::Builder;

pub(crate) use crate::stdin_payload::{TimeoutBundleReuse, timeout_bundle_reuse_for_input};
use crate::worker_process::WorkerError;
use crate::worker_protocol::{
ContentOrigin, TextStream, WorkerContent, WorkerErrorCode, WorkerReply,
Expand Down Expand Up @@ -170,31 +171,6 @@ struct TimeoutReplyView<'a> {
protected_bundle_id: Option<u64>,
}

#[derive(Clone, Copy)]
pub(crate) enum TimeoutBundleReuse {
None,
FullReply,
FollowUpInput,
}

pub(crate) fn timeout_bundle_reuse_for_input(input: &str) -> TimeoutBundleReuse {
if input.is_empty() {
return TimeoutBundleReuse::FullReply;
}

let Some(first) = input.chars().next() else {
return TimeoutBundleReuse::FullReply;
};
let tail = &input[first.len_utf8()..];

match first {
'\u{3}' if tail.is_empty() => TimeoutBundleReuse::FullReply,
'\u{3}' => TimeoutBundleReuse::FollowUpInput,
'\u{4}' => TimeoutBundleReuse::None,
_ => TimeoutBundleReuse::FollowUpInput,
}
}

impl ResponseState {
pub(crate) fn new() -> Result<Self, WorkerError> {
Ok(Self {
Expand Down
16 changes: 16 additions & 0 deletions src/server/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,22 @@ fn timeout_bundle_reuse_treats_blank_lines_as_fresh_input() {
super::response::timeout_bundle_reuse_for_input("\r\n"),
super::response::TimeoutBundleReuse::FollowUpInput
));
assert!(matches!(
super::response::timeout_bundle_reuse_for_input("\u{3}"),
super::response::TimeoutBundleReuse::FullReply
));
assert!(matches!(
super::response::timeout_bundle_reuse_for_input("\u{3}\n"),
super::response::TimeoutBundleReuse::FollowUpInput
));
assert!(matches!(
super::response::timeout_bundle_reuse_for_input("\u{3}\r\n"),
super::response::TimeoutBundleReuse::FollowUpInput
));
assert!(matches!(
super::response::timeout_bundle_reuse_for_input("\u{4}"),
super::response::TimeoutBundleReuse::None
));
}

#[test]
Expand Down
45 changes: 45 additions & 0 deletions src/stdin_payload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WriteStdinControlAction {
Interrupt,
Restart,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TimeoutBundleReuse {
None,
FullReply,
FollowUpInput,
}

pub(crate) fn prepare_worker_stdin_payload(input: &str) -> Vec<u8> {
let mut payload = input.as_bytes().to_vec();
if !payload.is_empty() && !payload.ends_with(b"\n") {
payload.push(b'\n');
}
payload
}

pub(crate) fn split_write_stdin_control_prefix(
input: &str,
) -> Option<(WriteStdinControlAction, &str)> {
let first = input.chars().next()?;
let action = match first {
'\u{3}' => WriteStdinControlAction::Interrupt,
'\u{4}' => WriteStdinControlAction::Restart,
_ => return None,
};
Some((action, &input[first.len_utf8()..]))
}

pub(crate) fn timeout_bundle_reuse_for_input(input: &str) -> TimeoutBundleReuse {
if input.is_empty() {
return TimeoutBundleReuse::FullReply;
}

match split_write_stdin_control_prefix(input) {
Some((WriteStdinControlAction::Interrupt, "")) => TimeoutBundleReuse::FullReply,
Some((WriteStdinControlAction::Interrupt, _)) => TimeoutBundleReuse::FollowUpInput,
Some((WriteStdinControlAction::Restart, _)) => TimeoutBundleReuse::None,
None => TimeoutBundleReuse::FollowUpInput,
}
}
56 changes: 13 additions & 43 deletions src/worker_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ use crate::sandbox_cli::{
resolve_effective_sandbox_state_with_defaults, sandbox_plan_requests_inherited_state,
validate_sandbox_plan_with_defaults,
};
use crate::stdin_payload::prepare_worker_stdin_payload;
pub(crate) use crate::stdin_payload::{WriteStdinControlAction, split_write_stdin_control_prefix};
use crate::worker_protocol::{
ContentOrigin, TextStream, WORKER_MODE_ARG, WorkerContent, WorkerErrorCode, WorkerReply,
};
Expand Down Expand Up @@ -276,7 +278,14 @@ fn prechecked_follow_up_requires_meta_error() -> WorkerError {
}

trait BackendDriver: Send {
fn prepare_input_payload(&self, text: &str) -> Vec<u8>;
fn prepare_input_text(&self, text: String) -> String {
text
}

fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
prepare_worker_stdin_payload(text)
}

fn on_input_start(
&mut self,
text: &str,
Expand Down Expand Up @@ -474,13 +483,8 @@ fn driver_refresh_worker_ready(
}

impl BackendDriver for RBackendDriver {
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
let normalized = normalize_input_newlines(text);
let mut payload = normalized.into_bytes();
if !payload.is_empty() && !payload.ends_with(b"\n") {
payload.push(b'\n');
}
payload
fn prepare_input_text(&self, text: String) -> String {
normalize_input_newlines(&text)
}

fn on_input_start(
Expand Down Expand Up @@ -852,14 +856,6 @@ fn strip_one_line_ending(text: &str) -> Option<&str> {

#[cfg(not(target_family = "unix"))]
impl BackendDriver for PythonBackendDriver {
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
let mut payload = text.as_bytes().to_vec();
if !payload.is_empty() && !payload.ends_with(b"\n") {
payload.push(b'\n');
}
payload
}

fn on_input_start(
&mut self,
text: &str,
Expand Down Expand Up @@ -936,14 +932,6 @@ impl ProtocolBackendDriver {
}

impl BackendDriver for ProtocolBackendDriver {
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
let mut payload = text.as_bytes().to_vec();
if !payload.is_empty() && !payload.ends_with(b"\n") {
payload.push(b'\n');
}
payload
}

fn on_input_start(
&mut self,
_text: &str,
Expand Down Expand Up @@ -1198,12 +1186,6 @@ fn completion_info_from_ipc(

const DEFERRED_SANDBOX_UPDATE_TIMEOUT: Duration = Duration::from_secs(5);

#[derive(Clone, Copy)]
pub(crate) enum WriteStdinControlAction {
Interrupt,
Restart,
}

#[derive(Debug, Clone, Default)]
pub(crate) struct WriteStdinOptions {
pub page_bytes_override: Option<u64>,
Expand All @@ -1225,19 +1207,6 @@ impl WriteStdinOptions {
}
}

pub(crate) fn split_write_stdin_control_prefix(
input: &str,
) -> Option<(WriteStdinControlAction, &str)> {
let first = input.chars().next()?;
let action = match first {
'\u{3}' => WriteStdinControlAction::Interrupt,
'\u{4}' => WriteStdinControlAction::Restart,
_ => return None,
};

Some((action, &input[first.len_utf8()..]))
}

fn worker_context_event_payload(
worker_launch: &WorkerLaunch,
backend: Backend,
Expand Down Expand Up @@ -2566,6 +2535,7 @@ impl WorkerManager {
worker_timeout: Duration,
server_timeout: Duration,
) -> Result<RequestState, WorkerError> {
let text = self.driver.prepare_input_text(text);
let started_at = std::time::Instant::now();
let prompt = self.current_prompt_hint();
self.remember_prompt(prompt);
Expand Down
8 changes: 6 additions & 2 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ impl McpTestSession {
#[allow(dead_code)]
pub async fn write_stdin_with(&mut self, input: impl Into<String>, timeout: Option<f64>) {
let mut input = input.into();
if !input.ends_with('\n') {
if test_input_needs_trailing_newline(&input) {
input.push('\n');
}
let timeout = normalized_test_timeout(timeout);
Expand Down Expand Up @@ -814,7 +814,7 @@ impl McpTestSession {
meta: Option<Value>,
) -> Result<rmcp::model::CallToolResult, ServiceError> {
let mut input = input.into();
if !input.is_empty() && !input.ends_with('\n') {
if !input.is_empty() && test_input_needs_trailing_newline(&input) {
input.push('\n');
}
let timeout = normalized_test_timeout(timeout);
Expand Down Expand Up @@ -889,6 +889,10 @@ impl McpTestSession {
}
}

fn test_input_needs_trailing_newline(input: &str) -> bool {
!input.ends_with('\n') && !matches!(input, "\u{3}" | "\u{4}")
}

pub struct McpSnapshot {
sessions: Vec<(String, Vec<SnapshotStep>)>,
}
Expand Down
27 changes: 27 additions & 0 deletions tests/fixtures/zod-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
next_prompt: "zod> ".to_string(),
shutdown_mode: ShutdownMode::Normal,
previous_line_empty: false,
line_number: 0,
shutdown_log_path: shutdown_log_path.clone(),
};
let mut timeline = Timeline::default();
Expand All @@ -90,6 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
send_session_end(&writer, &mut timeline, "shutdown")?;
return Ok(());
}
command_state.line_number += 1;

let command = line.trim_end_matches(['\r', '\n']);
let reported_input = if let Some(text) = command.strip_prefix("misreport-input ") {
Expand Down Expand Up @@ -166,6 +168,15 @@ fn run_command(
return Ok(());
}

if command.starts_with("raw-line-escape") {
let escaped = escape_bytes(raw_line.as_bytes());
writer.output_text(
"stdout",
format!("raw-line[{}]={escaped}\n", state.line_number).as_bytes(),
)?;
return Ok(());
}

if let Some(millis) = command.strip_prefix("prompt-then-sleep ") {
writer.send(&WorkerToServer::ReadlineStart {
prompt: "buffered> ".to_string(),
Expand Down Expand Up @@ -298,6 +309,7 @@ struct CommandState {
next_prompt: String,
shutdown_mode: ShutdownMode,
previous_line_empty: bool,
line_number: u64,
shutdown_log_path: Option<PathBuf>,
}

Expand Down Expand Up @@ -340,6 +352,21 @@ fn discard_buffered_stdin(reader: &mut dyn BufRead, writer: &IpcWriter) -> io::R
writer.send(&WorkerToServer::ReadlineDiscard { text })
}

fn escape_bytes(bytes: &[u8]) -> String {
let mut escaped = String::new();
for byte in bytes {
match byte {
b'\n' => escaped.push_str("\\n"),
b'\r' => escaped.push_str("\\r"),
b'\t' => escaped.push_str("\\t"),
b'\\' => escaped.push_str("\\\\"),
b' '..=b'~' => escaped.push(char::from(*byte)),
_ => escaped.push_str(&format!("\\x{byte:02x}")),
}
}
escaped
}

fn send_readline_start(
writer: &IpcWriter,
timeline: &mut Timeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ call:
{
"tool": "r_repl",
"arguments": {
"input": "\u0003\n",
"input": "\u0003",
"timeout_ms": 5000
}
}
Expand Down
64 changes: 64 additions & 0 deletions tests/zod_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,70 @@ async fn zod_worker_echoes_input_and_returns_worker_prompt() -> TestResult<()> {
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn zod_worker_raw_line_escape_preserves_stdin_bytes() -> TestResult<()> {
let session = spawn_zod_server().await?;

let result = session
.call_tool_raw(
"repl",
json!({
"input": "raw-line-escape crlf\r\nraw-line-escape bare\rcoda",
"timeout_ms": 10_000
}),
)
.await?;
let text = result_text(&result);

assert!(
text.contains("raw-line[1]=raw-line-escape crlf\\r\\n\n"),
"expected Zod to receive existing CRLF bytes, got: {text:?}"
);
assert!(
text.contains("raw-line[2]=raw-line-escape bare\\rcoda\\n\n"),
"expected Zod to receive bare CR plus one appended LF, got: {text:?}"
);

session.cancel().await?;
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn zod_worker_restart_control_prefix_preserves_newline_tail() -> TestResult<()> {
let session = spawn_zod_server().await?;

let result = session
.call_tool_raw(
"repl",
json!({
"input": "\u{4}\nraw-line-escape after",
"timeout_ms": 10_000
}),
)
.await?;
let text = result_text(&result);

let poll = session
.call_tool_raw(
"repl",
json!({
"input": "",
"timeout_ms": 10_000
}),
)
.await?;
let combined_text = format!("{text}{}", result_text(&poll));

assert!(
combined_text
.contains("[repl] new session started\nraw-line[2]=raw-line-escape after\\n\n"),
"expected Ctrl-D tail to preserve the immediate newline before follow-up input, got: {text:?}"
);

session.cancel().await?;
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn zod_worker_pipe_launch_records_transport_and_starts_sideband() -> TestResult<()> {
let tempdir = tempfile::tempdir()?;
Expand Down
Loading