diff --git a/crates/shim-protos/Cargo.toml b/crates/shim-protos/Cargo.toml index ed2c5764..bf0b8456 100644 --- a/crates/shim-protos/Cargo.toml +++ b/crates/shim-protos/Cargo.toml @@ -50,7 +50,7 @@ required-features = ["async"] [dependencies] async-trait = { workspace = true, optional = true } protobuf = { version = "3.7", default-features = false } -ttrpc = { version = "0.8", default-features = false, features = ["sync"] } +ttrpc = { version = "0.9", default-features = false, features = ["sync"] } [build-dependencies] ttrpc-codegen = "0.6.0" diff --git a/crates/shim-protos/examples/connect-async.rs b/crates/shim-protos/examples/connect-async.rs index 665db8e6..734b955f 100644 --- a/crates/shim-protos/examples/connect-async.rs +++ b/crates/shim-protos/examples/connect-async.rs @@ -32,7 +32,9 @@ async fn main() { let pid = args.get(2).map(|str| str.to_owned()).unwrap_or_default(); println!("Connecting to {}...", socket_path); - let client = Client::connect(socket_path).expect("Failed to connect to shim"); + let client = Client::connect(socket_path) + .await + .expect("Failed to connect to shim"); let task_client = TaskClient::new(client); diff --git a/crates/shim-protos/examples/ttrpc-client-async.rs b/crates/shim-protos/examples/ttrpc-client-async.rs index e7fa6574..c052a754 100644 --- a/crates/shim-protos/examples/ttrpc-client-async.rs +++ b/crates/shim-protos/examples/ttrpc-client-async.rs @@ -30,7 +30,9 @@ fn default_ctx() -> Context { #[tokio::main] async fn main() { - let c = Client::connect("unix:///tmp/shim-proto-ttrpc-001").unwrap(); + let c = Client::connect("unix:///tmp/shim-proto-ttrpc-001") + .await + .unwrap(); let task = TaskClient::new(c); let now = std::time::Instant::now(); diff --git a/crates/shim/src/asynchronous/mod.rs b/crates/shim/src/asynchronous/mod.rs index c05839d2..114677ee 100644 --- a/crates/shim/src/asynchronous/mod.rs +++ b/crates/shim/src/asynchronous/mod.rs @@ -349,11 +349,10 @@ pub async fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) -> #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))] async fn create_server(flags: &args::Flags) -> Result { - use std::os::fd::IntoRawFd; + use containerd_shim_protos::ttrpc::r#async::transport::Listener; let listener = start_listener(&flags.socket).await?; - let mut server = Server::new(); - server = server.add_listener(listener.into_raw_fd())?; - server = server.set_domain_unix(); + let listener = Listener::try_from(listener).map_err(io_error!(e, "creating ttrpc listener"))?; + let server = Server::new().add_listener(listener); Ok(server) } @@ -543,7 +542,7 @@ async fn start_listener(address: &str) -> Result { #[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))] async fn wait_socket_working(address: &str, interval_in_ms: u64, count: u32) -> Result<()> { for _i in 0..count { - match Client::connect(address) { + match Client::connect(address).await { Ok(_) => { return Ok(()); } diff --git a/crates/shim/src/asynchronous/publisher.rs b/crates/shim/src/asynchronous/publisher.rs index e1a72dcf..da366826 100644 --- a/crates/shim/src/asynchronous/publisher.rs +++ b/crates/shim/src/asynchronous/publisher.rs @@ -134,8 +134,9 @@ impl RemotePublisher { }) .await?; - // Client::new() takes ownership of the RawFd. - Ok(Client::new(fd)) + // Safety: `fd` is a unix socket returned by `connect()`. + // `from_raw_unix_socket_fd` takes ownership of the RawFd. + Ok(unsafe { Client::from_raw_unix_socket_fd(fd) }) } /// Publish a new event. @@ -195,17 +196,14 @@ impl Events for RemotePublisher { #[cfg(test)] mod tests { - use std::{ - os::unix::{io::AsRawFd, net::UnixListener}, - sync::Arc, - }; + use std::{os::unix::net::UnixListener, sync::Arc}; use async_trait::async_trait; use containerd_shim_protos::{ api::{Empty, ForwardRequest}, events::task::TaskOOM, shim_async::{create_events, Events}, - ttrpc::asynchronous::Server, + ttrpc::asynchronous::{transport::Listener, Server}, }; use tokio::sync::{ mpsc::{channel, Sender}, @@ -247,13 +245,11 @@ mod tests { let barrier2 = barrier.clone(); let server_thread = tokio::spawn(async move { let listener = UnixListener::bind(&path1).unwrap(); + let listener = Listener::try_from(listener).unwrap(); let service = create_events(Arc::new(server)); let mut server = Server::new() - .set_domain_unix() - .add_listener(listener.as_raw_fd()) - .unwrap() + .add_listener(listener) .register_service(service); - std::mem::forget(listener); server.start().await.unwrap(); barrier2.wait().await;