Skip to content

Commit 5fc510b

Browse files
authored
Centralize stdin control semantics (#74)
1 parent 1c960f3 commit 5fc510b

9 files changed

Lines changed: 174 additions & 71 deletions

File tree

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ mod r_session;
2323
mod sandbox;
2424
mod sandbox_cli;
2525
mod server;
26+
mod stdin_payload;
2627
#[cfg(target_os = "windows")]
2728
mod windows_sandbox;
2829
mod worker;

src/server/response.rs

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rmcp::model::{
1111
use serde_json::Value;
1212
use tempfile::Builder;
1313

14+
pub(crate) use crate::stdin_payload::{TimeoutBundleReuse, timeout_bundle_reuse_for_input};
1415
use crate::worker_process::WorkerError;
1516
use crate::worker_protocol::{
1617
ContentOrigin, TextStream, WorkerContent, WorkerErrorCode, WorkerReply,
@@ -170,31 +171,6 @@ struct TimeoutReplyView<'a> {
170171
protected_bundle_id: Option<u64>,
171172
}
172173

173-
#[derive(Clone, Copy)]
174-
pub(crate) enum TimeoutBundleReuse {
175-
None,
176-
FullReply,
177-
FollowUpInput,
178-
}
179-
180-
pub(crate) fn timeout_bundle_reuse_for_input(input: &str) -> TimeoutBundleReuse {
181-
if input.is_empty() {
182-
return TimeoutBundleReuse::FullReply;
183-
}
184-
185-
let Some(first) = input.chars().next() else {
186-
return TimeoutBundleReuse::FullReply;
187-
};
188-
let tail = &input[first.len_utf8()..];
189-
190-
match first {
191-
'\u{3}' if tail.is_empty() => TimeoutBundleReuse::FullReply,
192-
'\u{3}' => TimeoutBundleReuse::FollowUpInput,
193-
'\u{4}' => TimeoutBundleReuse::None,
194-
_ => TimeoutBundleReuse::FollowUpInput,
195-
}
196-
}
197-
198174
impl ResponseState {
199175
pub(crate) fn new() -> Result<Self, WorkerError> {
200176
Ok(Self {

src/server/tests.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,22 @@ fn timeout_bundle_reuse_treats_blank_lines_as_fresh_input() {
195195
super::response::timeout_bundle_reuse_for_input("\r\n"),
196196
super::response::TimeoutBundleReuse::FollowUpInput
197197
));
198+
assert!(matches!(
199+
super::response::timeout_bundle_reuse_for_input("\u{3}"),
200+
super::response::TimeoutBundleReuse::FullReply
201+
));
202+
assert!(matches!(
203+
super::response::timeout_bundle_reuse_for_input("\u{3}\n"),
204+
super::response::TimeoutBundleReuse::FollowUpInput
205+
));
206+
assert!(matches!(
207+
super::response::timeout_bundle_reuse_for_input("\u{3}\r\n"),
208+
super::response::TimeoutBundleReuse::FollowUpInput
209+
));
210+
assert!(matches!(
211+
super::response::timeout_bundle_reuse_for_input("\u{4}"),
212+
super::response::TimeoutBundleReuse::None
213+
));
198214
}
199215

200216
#[test]

src/stdin_payload.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2+
pub(crate) enum WriteStdinControlAction {
3+
Interrupt,
4+
Restart,
5+
}
6+
7+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8+
pub(crate) enum TimeoutBundleReuse {
9+
None,
10+
FullReply,
11+
FollowUpInput,
12+
}
13+
14+
pub(crate) fn prepare_worker_stdin_payload(input: &str) -> Vec<u8> {
15+
let mut payload = input.as_bytes().to_vec();
16+
if !payload.is_empty() && !payload.ends_with(b"\n") {
17+
payload.push(b'\n');
18+
}
19+
payload
20+
}
21+
22+
pub(crate) fn split_write_stdin_control_prefix(
23+
input: &str,
24+
) -> Option<(WriteStdinControlAction, &str)> {
25+
let first = input.chars().next()?;
26+
let action = match first {
27+
'\u{3}' => WriteStdinControlAction::Interrupt,
28+
'\u{4}' => WriteStdinControlAction::Restart,
29+
_ => return None,
30+
};
31+
Some((action, &input[first.len_utf8()..]))
32+
}
33+
34+
pub(crate) fn timeout_bundle_reuse_for_input(input: &str) -> TimeoutBundleReuse {
35+
if input.is_empty() {
36+
return TimeoutBundleReuse::FullReply;
37+
}
38+
39+
match split_write_stdin_control_prefix(input) {
40+
Some((WriteStdinControlAction::Interrupt, "")) => TimeoutBundleReuse::FullReply,
41+
Some((WriteStdinControlAction::Interrupt, _)) => TimeoutBundleReuse::FollowUpInput,
42+
Some((WriteStdinControlAction::Restart, _)) => TimeoutBundleReuse::None,
43+
None => TimeoutBundleReuse::FollowUpInput,
44+
}
45+
}

src/worker_process.rs

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ use crate::sandbox_cli::{
5151
resolve_effective_sandbox_state_with_defaults, sandbox_plan_requests_inherited_state,
5252
validate_sandbox_plan_with_defaults,
5353
};
54+
use crate::stdin_payload::prepare_worker_stdin_payload;
55+
pub(crate) use crate::stdin_payload::{WriteStdinControlAction, split_write_stdin_control_prefix};
5456
use crate::worker_protocol::{
5557
ContentOrigin, TextStream, WORKER_MODE_ARG, WorkerContent, WorkerErrorCode, WorkerReply,
5658
};
@@ -276,7 +278,14 @@ fn prechecked_follow_up_requires_meta_error() -> WorkerError {
276278
}
277279

278280
trait BackendDriver: Send {
279-
fn prepare_input_payload(&self, text: &str) -> Vec<u8>;
281+
fn prepare_input_text(&self, text: String) -> String {
282+
text
283+
}
284+
285+
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
286+
prepare_worker_stdin_payload(text)
287+
}
288+
280289
fn on_input_start(
281290
&mut self,
282291
text: &str,
@@ -474,13 +483,8 @@ fn driver_refresh_worker_ready(
474483
}
475484

476485
impl BackendDriver for RBackendDriver {
477-
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
478-
let normalized = normalize_input_newlines(text);
479-
let mut payload = normalized.into_bytes();
480-
if !payload.is_empty() && !payload.ends_with(b"\n") {
481-
payload.push(b'\n');
482-
}
483-
payload
486+
fn prepare_input_text(&self, text: String) -> String {
487+
normalize_input_newlines(&text)
484488
}
485489

486490
fn on_input_start(
@@ -852,14 +856,6 @@ fn strip_one_line_ending(text: &str) -> Option<&str> {
852856

853857
#[cfg(not(target_family = "unix"))]
854858
impl BackendDriver for PythonBackendDriver {
855-
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
856-
let mut payload = text.as_bytes().to_vec();
857-
if !payload.is_empty() && !payload.ends_with(b"\n") {
858-
payload.push(b'\n');
859-
}
860-
payload
861-
}
862-
863859
fn on_input_start(
864860
&mut self,
865861
text: &str,
@@ -936,14 +932,6 @@ impl ProtocolBackendDriver {
936932
}
937933

938934
impl BackendDriver for ProtocolBackendDriver {
939-
fn prepare_input_payload(&self, text: &str) -> Vec<u8> {
940-
let mut payload = text.as_bytes().to_vec();
941-
if !payload.is_empty() && !payload.ends_with(b"\n") {
942-
payload.push(b'\n');
943-
}
944-
payload
945-
}
946-
947935
fn on_input_start(
948936
&mut self,
949937
_text: &str,
@@ -1198,12 +1186,6 @@ fn completion_info_from_ipc(
11981186

11991187
const DEFERRED_SANDBOX_UPDATE_TIMEOUT: Duration = Duration::from_secs(5);
12001188

1201-
#[derive(Clone, Copy)]
1202-
pub(crate) enum WriteStdinControlAction {
1203-
Interrupt,
1204-
Restart,
1205-
}
1206-
12071189
#[derive(Debug, Clone, Default)]
12081190
pub(crate) struct WriteStdinOptions {
12091191
pub page_bytes_override: Option<u64>,
@@ -1225,19 +1207,6 @@ impl WriteStdinOptions {
12251207
}
12261208
}
12271209

1228-
pub(crate) fn split_write_stdin_control_prefix(
1229-
input: &str,
1230-
) -> Option<(WriteStdinControlAction, &str)> {
1231-
let first = input.chars().next()?;
1232-
let action = match first {
1233-
'\u{3}' => WriteStdinControlAction::Interrupt,
1234-
'\u{4}' => WriteStdinControlAction::Restart,
1235-
_ => return None,
1236-
};
1237-
1238-
Some((action, &input[first.len_utf8()..]))
1239-
}
1240-
12411210
fn worker_context_event_payload(
12421211
worker_launch: &WorkerLaunch,
12431212
backend: Backend,
@@ -2566,6 +2535,7 @@ impl WorkerManager {
25662535
worker_timeout: Duration,
25672536
server_timeout: Duration,
25682537
) -> Result<RequestState, WorkerError> {
2538+
let text = self.driver.prepare_input_text(text);
25692539
let started_at = std::time::Instant::now();
25702540
let prompt = self.current_prompt_hint();
25712541
self.remember_prompt(prompt);

tests/common/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ impl McpTestSession {
686686
#[allow(dead_code)]
687687
pub async fn write_stdin_with(&mut self, input: impl Into<String>, timeout: Option<f64>) {
688688
let mut input = input.into();
689-
if !input.ends_with('\n') {
689+
if test_input_needs_trailing_newline(&input) {
690690
input.push('\n');
691691
}
692692
let timeout = normalized_test_timeout(timeout);
@@ -814,7 +814,7 @@ impl McpTestSession {
814814
meta: Option<Value>,
815815
) -> Result<rmcp::model::CallToolResult, ServiceError> {
816816
let mut input = input.into();
817-
if !input.is_empty() && !input.ends_with('\n') {
817+
if !input.is_empty() && test_input_needs_trailing_newline(&input) {
818818
input.push('\n');
819819
}
820820
let timeout = normalized_test_timeout(timeout);
@@ -889,6 +889,10 @@ impl McpTestSession {
889889
}
890890
}
891891

892+
fn test_input_needs_trailing_newline(input: &str) -> bool {
893+
!input.ends_with('\n') && !matches!(input, "\u{3}" | "\u{4}")
894+
}
895+
892896
pub struct McpSnapshot {
893897
sessions: Vec<(String, Vec<SnapshotStep>)>,
894898
}

tests/fixtures/zod-worker.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
7373
next_prompt: "zod> ".to_string(),
7474
shutdown_mode: ShutdownMode::Normal,
7575
previous_line_empty: false,
76+
line_number: 0,
7677
shutdown_log_path: shutdown_log_path.clone(),
7778
};
7879
let mut timeline = Timeline::default();
@@ -90,6 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
9091
send_session_end(&writer, &mut timeline, "shutdown")?;
9192
return Ok(());
9293
}
94+
command_state.line_number += 1;
9395

9496
let command = line.trim_end_matches(['\r', '\n']);
9597
let reported_input = if let Some(text) = command.strip_prefix("misreport-input ") {
@@ -166,6 +168,15 @@ fn run_command(
166168
return Ok(());
167169
}
168170

171+
if command.starts_with("raw-line-escape") {
172+
let escaped = escape_bytes(raw_line.as_bytes());
173+
writer.output_text(
174+
"stdout",
175+
format!("raw-line[{}]={escaped}\n", state.line_number).as_bytes(),
176+
)?;
177+
return Ok(());
178+
}
179+
169180
if let Some(millis) = command.strip_prefix("prompt-then-sleep ") {
170181
writer.send(&WorkerToServer::ReadlineStart {
171182
prompt: "buffered> ".to_string(),
@@ -298,6 +309,7 @@ struct CommandState {
298309
next_prompt: String,
299310
shutdown_mode: ShutdownMode,
300311
previous_line_empty: bool,
312+
line_number: u64,
301313
shutdown_log_path: Option<PathBuf>,
302314
}
303315

@@ -340,6 +352,21 @@ fn discard_buffered_stdin(reader: &mut dyn BufRead, writer: &IpcWriter) -> io::R
340352
writer.send(&WorkerToServer::ReadlineDiscard { text })
341353
}
342354

355+
fn escape_bytes(bytes: &[u8]) -> String {
356+
let mut escaped = String::new();
357+
for byte in bytes {
358+
match byte {
359+
b'\n' => escaped.push_str("\\n"),
360+
b'\r' => escaped.push_str("\\r"),
361+
b'\t' => escaped.push_str("\\t"),
362+
b'\\' => escaped.push_str("\\\\"),
363+
b' '..=b'~' => escaped.push(char::from(*byte)),
364+
_ => escaped.push_str(&format!("\\x{byte:02x}")),
365+
}
366+
}
367+
escaped
368+
}
369+
343370
fn send_readline_start(
344371
writer: &IpcWriter,
345372
timeline: &mut Timeline,

tests/snapshots/mcp_transcripts__snapshots_interrupt_handler_output.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ call:
2828
{
2929
"tool": "r_repl",
3030
"arguments": {
31-
"input": "\u0003\n",
31+
"input": "\u0003",
3232
"timeout_ms": 5000
3333
}
3434
}

tests/zod_protocol.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,70 @@ async fn zod_worker_echoes_input_and_returns_worker_prompt() -> TestResult<()> {
179179
Ok(())
180180
}
181181

182+
#[tokio::test(flavor = "multi_thread")]
183+
async fn zod_worker_raw_line_escape_preserves_stdin_bytes() -> TestResult<()> {
184+
let session = spawn_zod_server().await?;
185+
186+
let result = session
187+
.call_tool_raw(
188+
"repl",
189+
json!({
190+
"input": "raw-line-escape crlf\r\nraw-line-escape bare\rcoda",
191+
"timeout_ms": 10_000
192+
}),
193+
)
194+
.await?;
195+
let text = result_text(&result);
196+
197+
assert!(
198+
text.contains("raw-line[1]=raw-line-escape crlf\\r\\n\n"),
199+
"expected Zod to receive existing CRLF bytes, got: {text:?}"
200+
);
201+
assert!(
202+
text.contains("raw-line[2]=raw-line-escape bare\\rcoda\\n\n"),
203+
"expected Zod to receive bare CR plus one appended LF, got: {text:?}"
204+
);
205+
206+
session.cancel().await?;
207+
Ok(())
208+
}
209+
210+
#[tokio::test(flavor = "multi_thread")]
211+
async fn zod_worker_restart_control_prefix_preserves_newline_tail() -> TestResult<()> {
212+
let session = spawn_zod_server().await?;
213+
214+
let result = session
215+
.call_tool_raw(
216+
"repl",
217+
json!({
218+
"input": "\u{4}\nraw-line-escape after",
219+
"timeout_ms": 10_000
220+
}),
221+
)
222+
.await?;
223+
let text = result_text(&result);
224+
225+
let poll = session
226+
.call_tool_raw(
227+
"repl",
228+
json!({
229+
"input": "",
230+
"timeout_ms": 10_000
231+
}),
232+
)
233+
.await?;
234+
let combined_text = format!("{text}{}", result_text(&poll));
235+
236+
assert!(
237+
combined_text
238+
.contains("[repl] new session started\nraw-line[2]=raw-line-escape after\\n\n"),
239+
"expected Ctrl-D tail to preserve the immediate newline before follow-up input, got: {text:?}"
240+
);
241+
242+
session.cancel().await?;
243+
Ok(())
244+
}
245+
182246
#[tokio::test(flavor = "multi_thread")]
183247
async fn zod_worker_pipe_launch_records_transport_and_starts_sideband() -> TestResult<()> {
184248
let tempdir = tempfile::tempdir()?;

0 commit comments

Comments
 (0)