Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/shim-protos/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ required-features = ["async"]
[dependencies]
async-trait = { workspace = true, optional = true }
protobuf = { version = "3.7", default-features = false }
ttrpc = { version = "0.8", default-features = false, features = ["sync"] }
ttrpc = { version = "0.9", default-features = false, features = ["sync"] }

[build-dependencies]
ttrpc-codegen = "0.6.0"
Expand Down
4 changes: 3 additions & 1 deletion crates/shim-protos/examples/connect-async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async fn main() {
let pid = args.get(2).map(|str| str.to_owned()).unwrap_or_default();

println!("Connecting to {}...", socket_path);
let client = Client::connect(socket_path).expect("Failed to connect to shim");
let client = Client::connect(socket_path)
.await
.expect("Failed to connect to shim");

let task_client = TaskClient::new(client);

Expand Down
4 changes: 3 additions & 1 deletion crates/shim-protos/examples/ttrpc-client-async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ fn default_ctx() -> Context {

#[tokio::main]
async fn main() {
let c = Client::connect("unix:///tmp/shim-proto-ttrpc-001").unwrap();
let c = Client::connect("unix:///tmp/shim-proto-ttrpc-001")
.await
.unwrap();
let task = TaskClient::new(c);
let now = std::time::Instant::now();

Expand Down
9 changes: 4 additions & 5 deletions crates/shim/src/asynchronous/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,10 @@ pub async fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) ->

#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))]
async fn create_server(flags: &args::Flags) -> Result<Server> {
use std::os::fd::IntoRawFd;
use containerd_shim_protos::ttrpc::r#async::transport::Listener;
let listener = start_listener(&flags.socket).await?;
let mut server = Server::new();
server = server.add_listener(listener.into_raw_fd())?;
server = server.set_domain_unix();
let listener = Listener::try_from(listener).map_err(io_error!(e, "creating ttrpc listener"))?;
let server = Server::new().add_listener(listener);
Ok(server)
}

Expand Down Expand Up @@ -543,7 +542,7 @@ async fn start_listener(address: &str) -> Result<UnixListener> {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
async fn wait_socket_working(address: &str, interval_in_ms: u64, count: u32) -> Result<()> {
for _i in 0..count {
match Client::connect(address) {
match Client::connect(address).await {
Ok(_) => {
return Ok(());
}
Expand Down
18 changes: 7 additions & 11 deletions crates/shim/src/asynchronous/publisher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ impl RemotePublisher {
})
.await?;

// Client::new() takes ownership of the RawFd.
Ok(Client::new(fd))
// Safety: `fd` is a unix socket returned by `connect()`.
// `from_raw_unix_socket_fd` takes ownership of the RawFd.
Ok(unsafe { Client::from_raw_unix_socket_fd(fd) })
}

/// Publish a new event.
Expand Down Expand Up @@ -195,17 +196,14 @@ impl Events for RemotePublisher {

#[cfg(test)]
mod tests {
use std::{
os::unix::{io::AsRawFd, net::UnixListener},
sync::Arc,
};
use std::{os::unix::net::UnixListener, sync::Arc};

use async_trait::async_trait;
use containerd_shim_protos::{
api::{Empty, ForwardRequest},
events::task::TaskOOM,
shim_async::{create_events, Events},
ttrpc::asynchronous::Server,
ttrpc::asynchronous::{transport::Listener, Server},
};
use tokio::sync::{
mpsc::{channel, Sender},
Expand Down Expand Up @@ -247,13 +245,11 @@ mod tests {
let barrier2 = barrier.clone();
let server_thread = tokio::spawn(async move {
let listener = UnixListener::bind(&path1).unwrap();
let listener = Listener::try_from(listener).unwrap();
let service = create_events(Arc::new(server));
let mut server = Server::new()
.set_domain_unix()
.add_listener(listener.as_raw_fd())
.unwrap()
.add_listener(listener)
.register_service(service);
std::mem::forget(listener);
server.start().await.unwrap();
barrier2.wait().await;

Expand Down
Loading