Skip to content

Commit 0726567

Browse files
committed
Fix timeout on Windows ARM
1 parent 11444ad commit 0726567

4 files changed

Lines changed: 120 additions & 21 deletions

File tree

crates/amalthea/src/fixtures/dummy_frontend.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ impl DummyFrontend {
316316
}
317317

318318
#[track_caller]
319+
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
319320
pub fn recv(socket: &Socket) -> Message {
320321
// It's important to wait with a timeout because the kernel thread might have
321322
// panicked, preventing it from sending the expected message. The tests would then
@@ -335,6 +336,28 @@ impl DummyFrontend {
335336
panic!("Timeout while expecting message on socket {}", socket.name);
336337
}
337338

339+
/// On Windows ARM, ZMQ poll with a non-zero timeout blocks forever
340+
/// instead of respecting the timeout. Use non-blocking poll with
341+
/// manual timing, matching the workaround in `recv_iopub_with_timeout`.
342+
#[track_caller]
343+
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
344+
pub fn recv(socket: &Socket) -> Message {
345+
let start = std::time::Instant::now();
346+
let timeout = std::time::Duration::from_secs(30);
347+
348+
loop {
349+
if start.elapsed() >= timeout {
350+
panic!("Timeout while expecting message on socket {}", socket.name);
351+
}
352+
353+
match socket.poll_incoming(0) {
354+
Ok(true) => return Message::read_from_socket(socket).unwrap(),
355+
Ok(false) => std::thread::sleep(std::time::Duration::from_millis(10)),
356+
Err(e) => panic!("Error polling socket {}: {e}", socket.name),
357+
}
358+
}
359+
}
360+
338361
/// Receives a Jupyter message from the Control socket
339362
#[track_caller]
340363
pub fn recv_control(&self) -> Message {

crates/amalthea/src/kernel.rs

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -531,37 +531,69 @@ fn socket_bridge_thread(
531531
};
532532

533533
loop {
534-
let n = unwrap!(
535-
zmq::poll(&mut poll_items, -1),
536-
Err(err) => {
537-
debug_panic!("While polling 0MQ items: {err:?}");
538-
0
539-
}
540-
);
541-
542-
for _ in 0..n {
543-
if consume_outbound_notification() {
544-
forward_outbound();
545-
continue;
546-
}
547-
534+
// On Windows ARM, zmq::poll with a non-zero timeout blocks forever
535+
// and inproc notification sockets may not wake the poll at all.
536+
// Use a fully separate polling path that doesn't rely on ZMQ
537+
// readability reporting: non-blocking poll + unconditional drain
538+
// of all sources + short sleep when idle.
539+
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
540+
{
541+
// Drain outbound messages (IOPub, StdIn) unconditionally
542+
consume_outbound_notification();
543+
forward_outbound();
544+
545+
// Check inbound sockets with non-blocking poll
548546
if has_inbound(&stdin_socket) {
549547
unwrap!(
550548
forward_inbound(&stdin_socket, &stdin_inbound_tx),
551549
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
552550
);
553-
continue;
554551
}
555-
556552
if has_inbound(&iopub_socket) {
557553
unwrap!(
558554
forward_inbound_subscription(&iopub_socket, &iopub_inbound_tx),
559555
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
560556
);
561-
continue;
562557
}
563558

564-
debug_panic!("Could not find readable message");
559+
std::thread::sleep(std::time::Duration::from_millis(1));
560+
continue;
561+
}
562+
563+
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
564+
{
565+
let n = unwrap!(
566+
zmq::poll(&mut poll_items, -1),
567+
Err(err) => {
568+
debug_panic!("While polling 0MQ items: {err:?}");
569+
0
570+
}
571+
);
572+
573+
for _ in 0..n {
574+
if consume_outbound_notification() {
575+
forward_outbound();
576+
continue;
577+
}
578+
579+
if has_inbound(&stdin_socket) {
580+
unwrap!(
581+
forward_inbound(&stdin_socket, &stdin_inbound_tx),
582+
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
583+
);
584+
continue;
585+
}
586+
587+
if has_inbound(&iopub_socket) {
588+
unwrap!(
589+
forward_inbound_subscription(&iopub_socket, &iopub_inbound_tx),
590+
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
591+
);
592+
continue;
593+
}
594+
595+
debug_panic!("Could not find readable message");
596+
}
565597
}
566598
}
567599
}

crates/amalthea/src/socket.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,38 @@ impl Socket {
214214
}
215215
}
216216

217+
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
217218
pub fn poll_incoming(&self, timeout_ms: i64) -> zmq::Result<bool> {
218219
Ok(self.socket.poll(zmq::PollEvents::POLLIN, timeout_ms)? != 0)
219220
}
220221

222+
/// On Windows ARM, ZMQ poll with a non-zero timeout blocks forever
223+
/// instead of respecting the timeout. Use non-blocking poll with
224+
/// manual timing.
225+
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
226+
pub fn poll_incoming(&self, timeout_ms: i64) -> zmq::Result<bool> {
227+
if timeout_ms == 0 {
228+
return Ok(self.socket.poll(zmq::PollEvents::POLLIN, 0)? != 0);
229+
}
230+
231+
let start = std::time::Instant::now();
232+
let timeout = if timeout_ms < 0 {
233+
std::time::Duration::from_secs(u64::MAX / 2)
234+
} else {
235+
std::time::Duration::from_millis(timeout_ms as u64)
236+
};
237+
238+
loop {
239+
if self.socket.poll(zmq::PollEvents::POLLIN, 0)? != 0 {
240+
return Ok(true);
241+
}
242+
if start.elapsed() >= timeout {
243+
return Ok(false);
244+
}
245+
std::thread::sleep(std::time::Duration::from_millis(1));
246+
}
247+
}
248+
221249
pub fn has_incoming_data(&self) -> zmq::Result<bool> {
222250
self.poll_incoming(0)
223251
}

crates/amalthea/src/socket/shell.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,25 @@ impl Shell {
123123
self.comm_notif_socket.socket.as_poll_item(zmq::POLLIN),
124124
];
125125

126-
// -1 means block indefinitely
127-
match zmq::poll(&mut poll_items, -1) {
128-
Ok(0) => continue,
126+
// On Windows ARM, zmq::poll with a non-zero timeout blocks
127+
// forever instead of respecting the timeout.
128+
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
129+
let poll_timeout_ms: i64 = 0;
130+
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
131+
let poll_timeout_ms: i64 = -1;
132+
133+
match zmq::poll(&mut poll_items, poll_timeout_ms) {
134+
Ok(0) => {
135+
// On Windows ARM, check for comm events even when
136+
// poll reports nothing ready since inproc
137+
// notifications may not wake the poll.
138+
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
139+
{
140+
self.process_comm_notification();
141+
std::thread::sleep(std::time::Duration::from_millis(1));
142+
}
143+
continue;
144+
},
129145
Ok(_) => (poll_items[0].is_readable(), poll_items[1].is_readable()),
130146
Err(err) => {
131147
log::warn!("Could not poll shell sockets: {err:?}");

0 commit comments

Comments
 (0)