Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions crates/amalthea/src/fixtures/dummy_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ use crate::registration_file::RegistrationFile;
use crate::session::Session;
use crate::socket::Socket;
use crate::wire::comm_msg::CommWireMsg;
use crate::wire::debug_request::DebugRequest;
use crate::wire::execute_input::ExecuteInput;
use crate::wire::execute_request::ExecuteRequest;
use crate::wire::execute_request::ExecuteRequestPositron;
use crate::wire::handshake_reply::HandshakeReply;
use crate::wire::input_reply::InputReply;
use crate::wire::interrupt_request::InterruptRequest;
use crate::wire::jupyter_message::JupyterMessage;
use crate::wire::jupyter_message::Message;
use crate::wire::jupyter_message::ProtocolMessage;
Expand Down Expand Up @@ -235,6 +237,10 @@ impl DummyFrontend {
self.send_control(ShutdownRequest { restart })
}

pub fn send_interrupt_request(&self) -> String {
self.send_control(InterruptRequest {})
}

pub fn send_execute_request(&self, code: &str, options: ExecuteRequestOptions) -> String {
self.send_shell(ExecuteRequest {
code: String::from(code),
Expand All @@ -247,6 +253,56 @@ impl DummyFrontend {
})
}

/// Send an execute request with custom metadata (e.g. `cellId` for notebook debugging).
pub fn send_execute_request_with_metadata(
&self,
code: &str,
options: ExecuteRequestOptions,
metadata: serde_json::Value,
) -> String {
Self::send_with_metadata(
&self.shell_socket,
&self.session,
ExecuteRequest {
code: String::from(code),
silent: false,
store_history: true,
user_expressions: serde_json::Value::Null,
allow_stdin: options.allow_stdin,
stop_on_error: false,
positron: options.positron,
},
metadata,
)
}

/// Send a DAP request wrapped in a Jupyter `debug_request` on the control channel.
pub fn send_debug_request(&self, dap_request: serde_json::Value) -> String {
self.send_control(DebugRequest {
content: dap_request,
})
}

/// Receive a `debug_reply` from the control channel.
#[track_caller]
pub fn recv_debug_reply(&self) -> serde_json::Value {
let msg = Self::recv(&self.control_socket);
match msg {
Message::DebugReply(msg) => msg.content.content,
other => panic!("Expected DebugReply, got {other:?}"),
}
}

/// Receive a `debug_event` from the IOPub channel.
#[track_caller]
pub fn recv_iopub_debug_event(&self) -> serde_json::Value {
let msg = Self::recv(&self.iopub_socket);
match msg {
Message::DebugEvent(msg) => msg.content.content,
other => panic!("Expected DebugEvent, got {other:?}"),
}
}

/// Sends a Jupyter message on the Stdin socket
pub fn send_stdin<T: ProtocolMessage>(&self, msg: T) {
Self::send(&self.stdin_socket, &self.session, msg);
Expand All @@ -259,6 +315,19 @@ impl DummyFrontend {
id
}

fn send_with_metadata<T: ProtocolMessage>(
socket: &Socket,
session: &Session,
msg: T,
metadata: serde_json::Value,
) -> String {
let mut message = JupyterMessage::create(msg, None, session);
message.metadata = metadata;
let id = message.header.msg_id.clone();
message.send(socket).unwrap();
id
}

#[track_caller]
pub fn recv(socket: &Socket) -> Message {
// It's important to wait with a timeout because the kernel thread might have
Expand Down Expand Up @@ -312,6 +381,15 @@ impl DummyFrontend {
})
}

/// Receive from Control and assert `InterruptReply` message.
#[track_caller]
pub fn recv_control_interrupt_reply(&self) {
let message = self.recv_control();
assert_matches!(message, Message::InterruptReply(message) => {
assert_eq!(message.content.status, Status::Ok);
});
}

/// Receive from Shell and assert `ExecuteReply` message.
/// Returns `execution_count`.
#[track_caller]
Expand Down
72 changes: 52 additions & 20 deletions crates/amalthea/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub enum StreamBehavior {
/// Handler implementations provided by the language runtime.
pub struct Handlers {
pub shell_handler: Box<dyn ShellHandler>,
pub control_handler: Arc<Mutex<dyn ControlHandler>>,
pub control_handler: Box<dyn ControlHandler>,
pub server_handlers: HashMap<String, Arc<Mutex<dyn ServerHandler>>>,
}

Expand Down Expand Up @@ -379,7 +379,7 @@ pub fn read_connection(connection_file: &str) -> (ConnectionFile, Option<Registr
fn control_thread(
socket: Socket,
iopub_tx: Sender<IOPubMessage>,
handler: Arc<Mutex<dyn ControlHandler>>,
handler: Box<dyn ControlHandler>,
stdin_interrupt_tx: Sender<bool>,
) {
let control = Control::new(socket, iopub_tx, handler, stdin_interrupt_tx);
Expand Down Expand Up @@ -531,37 +531,69 @@ fn socket_bridge_thread(
};

loop {
let n = unwrap!(
zmq::poll(&mut poll_items, -1),
Err(err) => {
debug_panic!("While polling 0MQ items: {err:?}");
0
}
);

for _ in 0..n {
if consume_outbound_notification() {
forward_outbound();
continue;
}

// On Windows ARM, zmq::poll with a non-zero timeout blocks forever
// and inproc notification sockets may not wake the poll at all.
// Use a fully separate polling path that doesn't rely on ZMQ
// readability reporting: non-blocking poll + unconditional drain
// of all sources + short sleep when idle.
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
{
// Drain outbound messages (IOPub, StdIn) unconditionally
consume_outbound_notification();
forward_outbound();

// Check inbound sockets with non-blocking poll
if has_inbound(&stdin_socket) {
unwrap!(
forward_inbound(&stdin_socket, &stdin_inbound_tx),
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
);
continue;
}

if has_inbound(&iopub_socket) {
unwrap!(
forward_inbound_subscription(&iopub_socket, &iopub_inbound_tx),
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
);
continue;
}

debug_panic!("Could not find readable message");
std::thread::sleep(std::time::Duration::from_millis(1));
continue;
}

#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
{
let n = unwrap!(
zmq::poll(&mut poll_items, -1),
Err(err) => {
debug_panic!("While polling 0MQ items: {err:?}");
0
}
);

for _ in 0..n {
if consume_outbound_notification() {
forward_outbound();
continue;
}

if has_inbound(&stdin_socket) {
unwrap!(
forward_inbound(&stdin_socket, &stdin_inbound_tx),
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
);
continue;
}

if has_inbound(&iopub_socket) {
unwrap!(
forward_inbound_subscription(&iopub_socket, &iopub_inbound_tx),
Err(err) => debug_panic!("While forwarding inbound message: {err:?}")
);
continue;
}

debug_panic!("Could not find readable message");
}
}
}
}
Expand Down
18 changes: 10 additions & 8 deletions crates/amalthea/src/language/control_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,29 @@
*
*/

use async_trait::async_trait;

use crate::wire::debug_reply::DebugReply;
use crate::wire::debug_request::DebugRequest;
use crate::wire::exception::Exception;
use crate::wire::interrupt_reply::InterruptReply;
use crate::wire::shutdown_reply::ShutdownReply;
use crate::wire::shutdown_request::ShutdownRequest;

#[async_trait]
pub trait ControlHandler: Send {
/// Handles a request to shut down the kernel. This message is forwarded
/// from the Control socket.
///
/// https://jupyter-client.readthedocs.io/en/stable/messaging.html#kernel-shutdown
async fn handle_shutdown_request(
&self,
msg: &ShutdownRequest,
) -> Result<ShutdownReply, Exception>;
fn handle_shutdown_request(&self, msg: &ShutdownRequest) -> Result<ShutdownReply, Exception>;

/// Handles a request to interrupt the kernel. This message is forwarded
/// from the Control socket.
///
/// https://jupyter-client.readthedocs.io/en/stable/messaging.html#kernel-interrupt
async fn handle_interrupt_request(&self) -> Result<InterruptReply, Exception>;
fn handle_interrupt_request(&self) -> Result<InterruptReply, Exception>;

/// Handles a debug request forwarded from the Control socket.
/// The request and reply contents are opaque DAP messages.
///
/// https://jupyter-client.readthedocs.io/en/latest/messaging.html#debug-request
fn handle_debug_request(&self, msg: &DebugRequest) -> Result<DebugReply, Exception>;
}
28 changes: 28 additions & 0 deletions crates/amalthea/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,38 @@ impl Socket {
}
}

#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
pub fn poll_incoming(&self, timeout_ms: i64) -> zmq::Result<bool> {
Ok(self.socket.poll(zmq::PollEvents::POLLIN, timeout_ms)? != 0)
}

/// On Windows ARM, ZMQ poll with a non-zero timeout blocks forever
/// instead of respecting the timeout. Use non-blocking poll with
/// manual timing.
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
pub fn poll_incoming(&self, timeout_ms: i64) -> zmq::Result<bool> {
if timeout_ms == 0 {
return Ok(self.socket.poll(zmq::PollEvents::POLLIN, 0)? != 0);
}

let start = std::time::Instant::now();
let timeout = if timeout_ms < 0 {
std::time::Duration::from_secs(u64::MAX / 2)
} else {
std::time::Duration::from_millis(timeout_ms as u64)
};

loop {
if self.socket.poll(zmq::PollEvents::POLLIN, 0)? != 0 {
return Ok(true);
}
if start.elapsed() >= timeout {
return Ok(false);
}
std::thread::sleep(std::time::Duration::from_millis(1));
}
}

pub fn has_incoming_data(&self) -> zmq::Result<bool> {
self.poll_incoming(0)
}
Expand Down
35 changes: 21 additions & 14 deletions crates/amalthea/src/socket/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@
*
*/

use std::sync::Arc;
use std::sync::Mutex;

use crossbeam::channel::SendError;
use crossbeam::channel::Sender;
use futures::executor::block_on;
use log::error;
use log::info;
use log::trace;
use log::warn;
use stdext::result::ResultExt;
use stdext::unwrap;

use crate::error::Error;
use crate::language::control_handler::ControlHandler;
use crate::socket::iopub::IOPubContextChannel;
use crate::socket::iopub::IOPubMessage;
use crate::socket::Socket;
use crate::wire::debug_request::DebugRequest;
use crate::wire::interrupt_request::InterruptRequest;
use crate::wire::jupyter_message::JupyterMessage;
use crate::wire::jupyter_message::Message;
Expand All @@ -33,15 +31,15 @@ use crate::wire::status::KernelStatus;
pub struct Control {
socket: Socket,
iopub_tx: Sender<IOPubMessage>,
handler: Arc<Mutex<dyn ControlHandler>>,
handler: Box<dyn ControlHandler>,
stdin_interrupt_tx: Sender<bool>,
}

impl Control {
pub fn new(
socket: Socket,
iopub_tx: Sender<IOPubMessage>,
handler: Arc<Mutex<dyn ControlHandler>>,
handler: Box<dyn ControlHandler>,
stdin_interrupt_tx: Sender<bool>,
) -> Self {
Self {
Expand Down Expand Up @@ -73,6 +71,9 @@ impl Control {

fn process_message(&self, message: Message) -> Result<(), Error> {
match message {
Message::DebugRequest(req) => {
self.handle_request(req, |r| self.handle_debug_request(r))
},
Message::ShutdownRequest(req) => {
self.handle_request(req, |r| self.handle_shutdown_request(r))
},
Expand Down Expand Up @@ -130,11 +131,8 @@ impl Control {
fn handle_shutdown_request(&self, req: JupyterMessage<ShutdownRequest>) -> Result<(), Error> {
info!("Received shutdown request, shutting down kernel: {:?}", req);

// Lock the control handler object on this thread
let control_handler = self.handler.lock().unwrap();

let reply = unwrap!(
block_on(control_handler.handle_shutdown_request(&req.content)),
self.handler.handle_shutdown_request(&req.content),
Err(err) => {
log::error!("Failed to handle shutdown request: {err:?}");
return Ok(())
Expand All @@ -156,6 +154,18 @@ impl Control {
Ok(())
}

fn handle_debug_request(&self, req: JupyterMessage<DebugRequest>) -> Result<(), Error> {
log::trace!("Received debug request: {:?}", req);

let Some(reply) = self.handler.handle_debug_request(&req.content).log_err() else {
return Ok(());
};

req.send_reply(reply, &self.socket).log_err();

Ok(())
Comment on lines +157 to +166
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting how we never actually return an Error from these methods...

}

fn handle_interrupt_request(&self, req: JupyterMessage<InterruptRequest>) -> Result<(), Error> {
info!(
"Received interrupt request, asking kernel to stop: {:?}",
Expand All @@ -169,11 +179,8 @@ impl Control {
error!("Failed to send interrupt request: {:?}", err);
}

// Lock the control handler object on this thread
let control_handler = self.handler.lock().unwrap();

let reply = unwrap!(
block_on(control_handler.handle_interrupt_request()),
self.handler.handle_interrupt_request(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened to allow us to remove all of this and the async ness of the methods? Have we never needed it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've never needed amalthea handler methods to be async, and should move towards sync over time. I made that clean up for Control as I was in the neighborhood.

Err(err) => {
log::error!("Failed to handle interrupt request: {err:?}");
return Ok(())
Expand Down
Loading
Loading