diff --git a/guests/python/src/python_modules/mod.rs b/guests/python/src/python_modules/mod.rs index 41ceac06..bb1eeced 100644 --- a/guests/python/src/python_modules/mod.rs +++ b/guests/python/src/python_modules/mod.rs @@ -1154,7 +1154,7 @@ mod wit_world { } fn stream(&self) -> PyResult { - let stream = self.inner()?.stream().to_pyres()?; + let stream = self.inner()?.stream().map_err(|()| "stream").to_pyres()?; Ok(InputStream { inner: Some(stream), }) @@ -1170,7 +1170,7 @@ mod wit_world { #[pymethods] impl IncomingResponse { fn consume(&self) -> PyResult { - let body = self.inner.consume().to_pyres()?; + let body = self.inner.consume().map_err(|()| "consume").to_pyres()?; Ok(IncomingBody { inner: Some(body) }) } @@ -1402,7 +1402,7 @@ mod wit_world { #[pymethods] impl OutgoingBody { fn write(&self) -> PyResult { - let stream = self.inner()?.write().to_pyres()?; + let stream = self.inner()?.write().map_err(|()| "write").to_pyres()?; Ok(OutputStream { inner: Some(stream), }) @@ -1442,32 +1442,40 @@ mod wit_world { } fn body(&self) -> PyResult { - let body = self.inner()?.body().to_pyres()?; + let body = self.inner()?.body().map_err(|()| "body").to_pyres()?; Ok(OutgoingBody { inner: Some(body) }) } fn set_authority(&self, authority: Option) -> PyResult<()> { self.inner()? .set_authority(authority.as_deref()) + .map_err(|()| format!("set_authority: {authority:?}")) .to_pyres()?; Ok(()) } fn set_method(&self, method: Method) -> PyResult<()> { - self.inner()?.set_method(&method.into()).to_pyres()?; + let method = method.into(); + self.inner()? + .set_method(&method) + .map_err(|()| format!("set_method: {method:?}")) + .to_pyres()?; Ok(()) } fn set_path_with_query(&self, path_with_query: Option) -> PyResult<()> { self.inner()? .set_path_with_query(path_with_query.as_deref()) + .map_err(|()| format!("set_path_with_query: {path_with_query:?}")) .to_pyres()?; Ok(()) } fn set_scheme(&self, scheme: Option) -> PyResult<()> { + let scheme = scheme.map(|s| s.into()); self.inner()? - .set_scheme(scheme.map(|s| s.into()).as_ref()) + .set_scheme(scheme.as_ref()) + .map_err(|()| format!("set_scheme: {scheme:?}")) .to_pyres()?; Ok(()) } @@ -1494,7 +1502,10 @@ mod wit_world { } fn set_connect_timeout(&self, duration: Option) -> PyResult<()> { - self.inner()?.set_connect_timeout(duration).to_pyres()?; + self.inner()? + .set_connect_timeout(duration) + .map_err(|()| format!("set_connect_timeout: {duration:?}")) + .to_pyres()?; Ok(()) } } diff --git a/host/src/http/mod.rs b/host/src/http/mod.rs index b18f4e19..c913fbf8 100644 --- a/host/src/http/mod.rs +++ b/host/src/http/mod.rs @@ -1,6 +1,6 @@ //! Interfaces for HTTP interactions of the guest. -use std::sync::Arc; +use std::{io::ErrorKind, sync::Arc}; use datafusion_common::{DataFusionError, error::Result as DataFusionResult}; use http::HeaderName; @@ -235,5 +235,57 @@ fn assemble_response( /// Map [`reqwest::Error`] to [`HttpErrorCode`]. fn map_reqwest_err(e: reqwest::Error) -> HttpErrorCode { + // try to find an IO error first, since this is potentially the most low-level information + if let Some(e) = extract_error_type::(&e) { + match e.kind() { + ErrorKind::ConnectionRefused => { + return HttpErrorCode::ConnectionRefused; + } + ErrorKind::ConnectionReset => { + return HttpErrorCode::ConnectionTerminated; + } + ErrorKind::TimedOut => { + return HttpErrorCode::ConnectionTimeout; + } + _ => {} + } + } + + // hyper might have some hints for us + if let Some(e) = extract_error_type::(&e) { + if e.is_incomplete_message() { + return HttpErrorCode::HttpResponseIncomplete; + } else if e.is_parse() { + return HttpErrorCode::HttpProtocolError; + } else if e.is_timeout() { + return HttpErrorCode::ConnectionTimeout; + } + } + + // cannot really extract anything meaningful, fall back to "internal error" ("internal" as in "in our stack", not + // as in "internal server error") HttpErrorCode::InternalError(Some(e.to_string())) } + +/// Extract concrete error type from error chain. +fn extract_error_type<'a, E>(e: &'a (dyn std::error::Error + 'static)) -> Option<&'a E> +where + E: std::error::Error + 'static, +{ + let mut current = e; + + loop { + if let Some(concrete) = current.downcast_ref::() { + return Some(concrete); + } + + match current.source() { + Some(next) => { + current = next; + } + None => { + return None; + } + } + } +} diff --git a/host/tests/integration_tests/python/runtime/http/mock_server.rs b/host/tests/integration_tests/python/runtime/http/mock_server.rs index 235cb108..cc125c2a 100644 --- a/host/tests/integration_tests/python/runtime/http/mock_server.rs +++ b/host/tests/integration_tests/python/runtime/http/mock_server.rs @@ -13,7 +13,8 @@ use hyper_util::rt::{TokioExecutor, TokioIo}; use tokio::{net::TcpListener, task::JoinSet}; use wasmtime_wasi_http::DEFAULT_FORBIDDEN_HEADERS; -const LISTEN_ADDR: &str = "127.0.0.1:0"; +const LISTEN_ADDR_IPV4: &str = "127.0.0.1:0"; +const LISTEN_ADDR_IPV6: &str = "[::1]:0"; pub(crate) type Request = http::Request; pub(crate) type Response = http::Response>; @@ -182,6 +183,28 @@ impl State { } } +#[derive(Debug, Default)] +pub(crate) enum IpVersion { + #[default] + V4, + V6, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Failure { + /// Reject all TCP connections. + RejectConnections, + + /// Close connection w/o answering. + CloseWithoutAnswer, +} + +#[derive(Debug, Default)] +pub(crate) struct MockServerOptions { + pub(crate) ip_version: IpVersion, + pub(crate) failure: Option, +} + type SharedState = Arc>; #[derive(Debug)] @@ -193,89 +216,112 @@ pub(crate) struct MockServer { impl MockServer { pub(crate) async fn start() -> Self { - let tcp_listener = TcpListener::bind(LISTEN_ADDR).await.expect("bind"); + Self::with_options(MockServerOptions::default()).await + } + + pub(crate) async fn with_options(options: MockServerOptions) -> Self { + let MockServerOptions { + ip_version, + failure, + } = options; + let tcp_listener = TcpListener::bind(match ip_version { + IpVersion::V4 => LISTEN_ADDR_IPV4, + IpVersion::V6 => LISTEN_ADDR_IPV6, + }) + .await + .expect("bind"); let addr = tcp_listener.local_addr().unwrap(); let state = SharedState::default(); let mut task = JoinSet::new(); - let state_captured = Arc::clone(&state); - task.spawn(async move { - let mut connections = JoinSet::new(); - - loop { - let (stream, accept_addr) = match tcp_listener.accept().await { - Ok(x) => x, - Err(e) => { - eprintln!("failed to accept connection: {e}"); + if failure != Some(Failure::RejectConnections) { + let state_captured = Arc::clone(&state); + task.spawn(async move { + let mut connections = JoinSet::new(); + + loop { + let (stream, accept_addr) = match tcp_listener.accept().await { + Ok(x) => x, + Err(e) => { + eprintln!("failed to accept connection: {e}"); + continue; + } + }; + + if failure == Some(Failure::CloseWithoutAnswer) { continue; } - }; - - let state = Arc::clone(&state_captured); - let serve_connection = async move { - let result = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection( - TokioIo::new(stream), - service_fn(move |req: http::Request| { - let state = Arc::clone(&state); - - async move { - // hydrate entire body so we can process it easier - let (parts, body) = req.into_parts(); - let body = body.collect().await.unwrap().to_bytes(); - let req = http::Request::from_parts(parts, body); - - let mut state = state.lock().unwrap(); - - if has_forbidden_headers(&req, addr) { - return Ok( - state.fail(format!("Forbidden headers:\n{req:#?}")) - ); - } - - let Some((mock, count)) = state - .mocks - .iter_mut() - .find(|(mock, _hits)| mock.matcher.matches(&req)) - else { - return Ok(state.fail(format!("Not mocked:\n{req:#?}"))); - }; - - *count += 1; - - let mut resp = mock.response.resp(&req); - - // combine repeated headers, but we should probably not do that - // See https://github.com/influxdata/datafusion-udf-wasm/issues/452 - let headers = resp.headers_mut(); - *headers = headers - .keys() - .map(|k| { - let vals = headers - .get_all(k) - .iter() - .map(|v| v.to_str().unwrap()) - .collect::>(); - let v = HeaderValue::from_str(&vals.join(",")).unwrap(); - (k.clone(), v) - }) - .collect(); - - Result::<_, Infallible>::Ok(resp) - } - }), - ) - .await; - - if let Err(e) = result { - eprintln!("error serving {accept_addr}: {e}"); - } - }; - connections.spawn(serve_connection); - } - }); + let state = Arc::clone(&state_captured); + let serve_connection = async move { + let result = + hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection( + TokioIo::new(stream), + service_fn(move |req: http::Request| { + let state = Arc::clone(&state); + + async move { + // hydrate entire body so we can process it easier + let (parts, body) = req.into_parts(); + let body = body.collect().await.unwrap().to_bytes(); + let req = http::Request::from_parts(parts, body); + + let mut state = state.lock().unwrap(); + + if has_forbidden_headers(&req, addr) { + return Ok(state.fail(format!( + "Forbidden headers:\n{req:#?}" + ))); + } + + let Some((mock, count)) = state + .mocks + .iter_mut() + .find(|(mock, _hits)| mock.matcher.matches(&req)) + else { + return Ok( + state.fail(format!("Not mocked:\n{req:#?}")) + ); + }; + + *count += 1; + + let mut resp = mock.response.resp(&req); + + // combine repeated headers, but we should probably not do that + // See https://github.com/influxdata/datafusion-udf-wasm/issues/452 + let headers = resp.headers_mut(); + *headers = headers + .keys() + .map(|k| { + let vals = headers + .get_all(k) + .iter() + .map(|v| v.to_str().unwrap()) + .collect::>(); + let v = HeaderValue::from_str(&vals.join(",")) + .unwrap(); + (k.clone(), v) + }) + .collect(); + + Result::<_, Infallible>::Ok(resp) + } + }), + ) + .await; + + if let Err(e) = result { + eprintln!("error serving {accept_addr}: {e}"); + } + }; + + connections.spawn(serve_connection); + } + }); + } Self { _task: task, diff --git a/host/tests/integration_tests/python/runtime/http/mod.rs b/host/tests/integration_tests/python/runtime/http/mod.rs index a3fcef6e..0502d871 100644 --- a/host/tests/integration_tests/python/runtime/http/mod.rs +++ b/host/tests/integration_tests/python/runtime/http/mod.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashSet, + collections::{BTreeMap, HashSet}, io::Write, sync::{Arc, LazyLock}, time::Duration, @@ -31,7 +31,8 @@ use wasmtime_wasi_http::DEFAULT_FORBIDDEN_HEADERS; use crate::integration_tests::{ python::{ runtime::http::mock_server::{ - Matcher, MockServer, Request, ResponseGenFn, ServerMock, SimpleResponseGen, + Matcher, MockServer, MockServerOptions, Request, ResponseGenFn, ServerMock, + SimpleResponseGen, }, test_utils::{python_component, python_scalar_udf}, }, @@ -263,6 +264,41 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - "#; const NUMBER_OF_IMPLEMENTATIONS: usize = 2; + let servers = BTreeMap::from([ + ( + "ipv4", + MockServer::with_options(MockServerOptions { + ip_version: mock_server::IpVersion::V4, + ..Default::default() + }) + .await, + ), + ( + "ipv6", + MockServer::with_options(MockServerOptions { + ip_version: mock_server::IpVersion::V6, + ..Default::default() + }) + .await, + ), + ( + "reject_all", + MockServer::with_options(MockServerOptions { + failure: Some(mock_server::Failure::RejectConnections), + ..Default::default() + }) + .await, + ), + ( + "no_answer", + MockServer::with_options(MockServerOptions { + failure: Some(mock_server::Failure::CloseWithoutAnswer), + ..Default::default() + }) + .await, + ), + ]); + let mut cases = vec![ TestCase { resp: Ok(TestResponse { @@ -352,6 +388,26 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - }), ..Default::default() }, + // https://github.com/influxdata/datafusion-udf-wasm/issues/464 + TestCase { + server: "ipv6", + resp: Err(format!( + "Err {{ value: set_authority: Some(\"{ip}:{port}\") }}", + ip=servers.get("ipv6").unwrap().address().ip(), + port=servers.get("ipv6").unwrap().address().port(), + )), + ..Default::default() + }, + TestCase { + server: "reject_all", + resp: Err("('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_ConnectionRefused'))".to_owned()), + ..Default::default() + }, + TestCase { + server: "no_answer", + resp: Err("('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpResponseIncomplete'))".to_owned()), + ..Default::default() + }, ]; cases.extend( DEFAULT_FORBIDDEN_HEADERS @@ -365,8 +421,6 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - ..Default::default() }), ); - - let server = MockServer::start().await; let mut permissions = AllowCertainHttpRequests::default(); let mut builder_method = StringBuilder::new(); @@ -376,9 +430,8 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - let mut builder_result = StringBuilder::new(); for case in &cases { - case.allow(&server, &mut permissions); - let TestCase { + server, base, method, path, @@ -387,6 +440,10 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - resp, } = case; + case.allow(&servers, &mut permissions); + + let server = servers.get(server).unwrap(); + builder_method.append_value(method); builder_url.append_value(format!( "{}{}", @@ -438,10 +495,11 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - assert_eq!(udfs.len(), NUMBER_OF_IMPLEMENTATIONS); for udf in udfs { + println!("============================================================"); println!("{}", udf.name()); for case in &cases { - case.mock(&server); + case.mock(&servers); } let actual = udf @@ -475,7 +533,10 @@ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) - panic!("FAIL:\n\n{s}"); } - server.clear_mocks(); + for (name, server) in &servers { + println!("clean up server `{name}`..."); + server.clear_mocks(); + } } } @@ -497,6 +558,7 @@ impl Default for TestResponse { } struct TestCase { + server: &'static str, base: Option<&'static str>, method: &'static str, path: String, @@ -508,6 +570,7 @@ struct TestCase { impl Default for TestCase { fn default() -> Self { Self { + server: "ipv4", base: None, method: "GET", path: "/".to_owned(), @@ -519,7 +582,12 @@ impl Default for TestCase { } impl TestCase { - fn allow(&self, server: &MockServer, permissions: &mut AllowCertainHttpRequests) { + fn allow( + &self, + servers: &BTreeMap<&'static str, MockServer>, + permissions: &mut AllowCertainHttpRequests, + ) { + let server = servers.get(self.server).unwrap(); let endpoint = permissions .allow_host(server.address().ip().to_string()) .allow_port(HttpPort::new(server.address().port()).unwrap()); @@ -527,8 +595,9 @@ impl TestCase { endpoint.allow_method(self.method.try_into().unwrap()); } - fn mock(&self, server: &MockServer) { + fn mock(&self, servers: &BTreeMap<&'static str, MockServer>) { let Self { + server, base, method, path, @@ -536,6 +605,7 @@ impl TestCase { requ_body, resp, } = self; + let server = servers.get(server).unwrap(); if base.is_some() { return; }