Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions guests/python/src/python_modules/mod.rs
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some improvements to the error mapping here. This is on the guest side, mapping from WASI to Python.

I was debugging #464 and only got error messages a la Err { value: () }" because some of the respective WASI methods return Result<..., ()>. While you can get the Python backtrace to find out which method was called, it still seemed rather bad as a debugging experience. So I've improved these error messages.

Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ mod wit_world {
}

fn stream(&self) -> PyResult<InputStream> {
let stream = self.inner()?.stream().to_pyres()?;
let stream = self.inner()?.stream().map_err(|()| "stream").to_pyres()?;
Ok(InputStream {
inner: Some(stream),
})
Expand All @@ -1170,7 +1170,7 @@ mod wit_world {
#[pymethods]
impl IncomingResponse {
fn consume(&self) -> PyResult<IncomingBody> {
let body = self.inner.consume().to_pyres()?;
let body = self.inner.consume().map_err(|()| "consume").to_pyres()?;
Ok(IncomingBody { inner: Some(body) })
}

Expand Down Expand Up @@ -1402,7 +1402,7 @@ mod wit_world {
#[pymethods]
impl OutgoingBody {
fn write(&self) -> PyResult<OutputStream> {
let stream = self.inner()?.write().to_pyres()?;
let stream = self.inner()?.write().map_err(|()| "write").to_pyres()?;
Ok(OutputStream {
inner: Some(stream),
})
Expand Down Expand Up @@ -1442,32 +1442,40 @@ mod wit_world {
}

fn body(&self) -> PyResult<OutgoingBody> {
let body = self.inner()?.body().to_pyres()?;
let body = self.inner()?.body().map_err(|()| "body").to_pyres()?;
Ok(OutgoingBody { inner: Some(body) })
}

fn set_authority(&self, authority: Option<String>) -> PyResult<()> {
self.inner()?
.set_authority(authority.as_deref())
.map_err(|()| format!("set_authority: {authority:?}"))
.to_pyres()?;
Ok(())
}

fn set_method(&self, method: Method) -> PyResult<()> {
self.inner()?.set_method(&method.into()).to_pyres()?;
let method = method.into();
self.inner()?
.set_method(&method)
.map_err(|()| format!("set_method: {method:?}"))
.to_pyres()?;
Ok(())
}

fn set_path_with_query(&self, path_with_query: Option<String>) -> PyResult<()> {
self.inner()?
.set_path_with_query(path_with_query.as_deref())
.map_err(|()| format!("set_path_with_query: {path_with_query:?}"))
.to_pyres()?;
Ok(())
}

fn set_scheme(&self, scheme: Option<Scheme>) -> PyResult<()> {
let scheme = scheme.map(|s| s.into());
self.inner()?
.set_scheme(scheme.map(|s| s.into()).as_ref())
.set_scheme(scheme.as_ref())
.map_err(|()| format!("set_scheme: {scheme:?}"))
.to_pyres()?;
Ok(())
}
Expand All @@ -1494,7 +1502,10 @@ mod wit_world {
}

fn set_connect_timeout(&self, duration: Option<u64>) -> PyResult<()> {
self.inner()?.set_connect_timeout(duration).to_pyres()?;
self.inner()?
.set_connect_timeout(duration)
.map_err(|()| format!("set_connect_timeout: {duration:?}"))
.to_pyres()?;
Ok(())
}
}
Expand Down
54 changes: 53 additions & 1 deletion host/src/http/mod.rs
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual bug fix: try to peak into the reqwests::Error to map the error to a proper WASI error code.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Interfaces for HTTP interactions of the guest.

use std::sync::Arc;
use std::{io::ErrorKind, sync::Arc};

use datafusion_common::{DataFusionError, error::Result as DataFusionResult};
use http::HeaderName;
Expand Down Expand Up @@ -235,5 +235,57 @@ fn assemble_response(

/// Map [`reqwest::Error`] to [`HttpErrorCode`].
fn map_reqwest_err(e: reqwest::Error) -> HttpErrorCode {
// try to find an IO error first, since this is potentially the most low-level information
if let Some(e) = extract_error_type::<std::io::Error>(&e) {
match e.kind() {
ErrorKind::ConnectionRefused => {
return HttpErrorCode::ConnectionRefused;
}
ErrorKind::ConnectionReset => {
return HttpErrorCode::ConnectionTerminated;
}
ErrorKind::TimedOut => {
return HttpErrorCode::ConnectionTimeout;
}
_ => {}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you CANNOT match all kinds because that enum is non-exhaustive.

}
}

// hyper might have some hints for us
if let Some(e) = extract_error_type::<hyper::Error>(&e) {
if e.is_incomplete_message() {
return HttpErrorCode::HttpResponseIncomplete;
} else if e.is_parse() {
return HttpErrorCode::HttpProtocolError;
} else if e.is_timeout() {
return HttpErrorCode::ConnectionTimeout;
}
}

// cannot really extract anything meaningful, fall back to "internal error" ("internal" as in "in our stack", not
// as in "internal server error")
HttpErrorCode::InternalError(Some(e.to_string()))
}

/// Extract concrete error type from error chain.
fn extract_error_type<'a, E>(e: &'a (dyn std::error::Error + 'static)) -> Option<&'a E>
where
E: std::error::Error + 'static,
{
let mut current = e;

loop {
if let Some(concrete) = current.downcast_ref::<E>() {
return Some(concrete);
}

match current.source() {
Some(next) => {
current = next;
}
None => {
return None;
}
}
}
}
198 changes: 122 additions & 76 deletions host/tests/integration_tests/python/runtime/http/mock_server.rs
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change here looks bigger than it is due to a changed indentation. It really just introduces:

  • IPv6 support
  • two low-level failure modes that we can emulate

Both is now possible due to #453, showing that this "use home-grown test harness" approach pays off.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::{net::TcpListener, task::JoinSet};
use wasmtime_wasi_http::DEFAULT_FORBIDDEN_HEADERS;

const LISTEN_ADDR: &str = "127.0.0.1:0";
const LISTEN_ADDR_IPV4: &str = "127.0.0.1:0";
const LISTEN_ADDR_IPV6: &str = "[::1]:0";

pub(crate) type Request = http::Request<Bytes>;
pub(crate) type Response = http::Response<BoxBody<Bytes, Infallible>>;
Expand Down Expand Up @@ -182,6 +183,28 @@ impl State {
}
}

#[derive(Debug, Default)]
pub(crate) enum IpVersion {
#[default]
V4,
V6,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Failure {
/// Reject all TCP connections.
RejectConnections,

/// Close connection w/o answering.
CloseWithoutAnswer,
}

#[derive(Debug, Default)]
pub(crate) struct MockServerOptions {
pub(crate) ip_version: IpVersion,
pub(crate) failure: Option<Failure>,
}

type SharedState = Arc<Mutex<State>>;

#[derive(Debug)]
Expand All @@ -193,89 +216,112 @@ pub(crate) struct MockServer {

impl MockServer {
pub(crate) async fn start() -> Self {
let tcp_listener = TcpListener::bind(LISTEN_ADDR).await.expect("bind");
Self::with_options(MockServerOptions::default()).await
}

pub(crate) async fn with_options(options: MockServerOptions) -> Self {
let MockServerOptions {
ip_version,
failure,
} = options;
let tcp_listener = TcpListener::bind(match ip_version {
IpVersion::V4 => LISTEN_ADDR_IPV4,
IpVersion::V6 => LISTEN_ADDR_IPV6,
})
.await
.expect("bind");
let addr = tcp_listener.local_addr().unwrap();

let state = SharedState::default();

let mut task = JoinSet::new();
let state_captured = Arc::clone(&state);
task.spawn(async move {
let mut connections = JoinSet::new();

loop {
let (stream, accept_addr) = match tcp_listener.accept().await {
Ok(x) => x,
Err(e) => {
eprintln!("failed to accept connection: {e}");
if failure != Some(Failure::RejectConnections) {
let state_captured = Arc::clone(&state);
task.spawn(async move {
let mut connections = JoinSet::new();

loop {
let (stream, accept_addr) = match tcp_listener.accept().await {
Ok(x) => x,
Err(e) => {
eprintln!("failed to accept connection: {e}");
continue;
}
};

if failure == Some(Failure::CloseWithoutAnswer) {
continue;
}
};

let state = Arc::clone(&state_captured);
let serve_connection = async move {
let result = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection(
TokioIo::new(stream),
service_fn(move |req: http::Request<Incoming>| {
let state = Arc::clone(&state);

async move {
// hydrate entire body so we can process it easier
let (parts, body) = req.into_parts();
let body = body.collect().await.unwrap().to_bytes();
let req = http::Request::from_parts(parts, body);

let mut state = state.lock().unwrap();

if has_forbidden_headers(&req, addr) {
return Ok(
state.fail(format!("Forbidden headers:\n{req:#?}"))
);
}

let Some((mock, count)) = state
.mocks
.iter_mut()
.find(|(mock, _hits)| mock.matcher.matches(&req))
else {
return Ok(state.fail(format!("Not mocked:\n{req:#?}")));
};

*count += 1;

let mut resp = mock.response.resp(&req);

// combine repeated headers, but we should probably not do that
// See https://github.com/influxdata/datafusion-udf-wasm/issues/452
let headers = resp.headers_mut();
*headers = headers
.keys()
.map(|k| {
let vals = headers
.get_all(k)
.iter()
.map(|v| v.to_str().unwrap())
.collect::<Vec<_>>();
let v = HeaderValue::from_str(&vals.join(",")).unwrap();
(k.clone(), v)
})
.collect();

Result::<_, Infallible>::Ok(resp)
}
}),
)
.await;

if let Err(e) = result {
eprintln!("error serving {accept_addr}: {e}");
}
};

connections.spawn(serve_connection);
}
});
let state = Arc::clone(&state_captured);
let serve_connection = async move {
let result =
hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection(
TokioIo::new(stream),
service_fn(move |req: http::Request<Incoming>| {
let state = Arc::clone(&state);

async move {
// hydrate entire body so we can process it easier
let (parts, body) = req.into_parts();
let body = body.collect().await.unwrap().to_bytes();
let req = http::Request::from_parts(parts, body);

let mut state = state.lock().unwrap();

if has_forbidden_headers(&req, addr) {
return Ok(state.fail(format!(
"Forbidden headers:\n{req:#?}"
)));
}

let Some((mock, count)) = state
.mocks
.iter_mut()
.find(|(mock, _hits)| mock.matcher.matches(&req))
else {
return Ok(
state.fail(format!("Not mocked:\n{req:#?}"))
);
};

*count += 1;

let mut resp = mock.response.resp(&req);

// combine repeated headers, but we should probably not do that
// See https://github.com/influxdata/datafusion-udf-wasm/issues/452
let headers = resp.headers_mut();
*headers = headers
.keys()
.map(|k| {
let vals = headers
.get_all(k)
.iter()
.map(|v| v.to_str().unwrap())
.collect::<Vec<_>>();
let v = HeaderValue::from_str(&vals.join(","))
.unwrap();
(k.clone(), v)
})
.collect();

Result::<_, Infallible>::Ok(resp)
}
}),
)
.await;

if let Err(e) = result {
eprintln!("error serving {accept_addr}: {e}");
}
};

connections.spawn(serve_connection);
}
});
}

Self {
_task: task,
Expand Down
Loading