|
| 1 | +use anyhow::{Context as _, Result, anyhow}; |
| 2 | +use core::future::Future; |
| 3 | +use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses; |
| 4 | +use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket}; |
| 5 | +use test_programs::p3::wasi::tls::client::Connector; |
| 6 | +use test_programs::p3::wit_stream; |
| 7 | + |
| 8 | +struct Component; |
| 9 | + |
| 10 | +test_programs::p3::export!(Component); |
| 11 | + |
| 12 | +const PORT: u16 = 443; |
| 13 | + |
| 14 | +async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { |
| 15 | + let request = format!( |
| 16 | + "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n" |
| 17 | + ); |
| 18 | + |
| 19 | + let sock = TcpSocket::create(ip.family()).unwrap(); |
| 20 | + sock.connect(IpSocketAddress::new(ip, PORT)) |
| 21 | + .await |
| 22 | + .context("tcp connect failed")?; |
| 23 | + |
| 24 | + let conn = Connector::new(); |
| 25 | + |
| 26 | + let (sock_rx, sock_rx_fut) = sock.receive(); |
| 27 | + let (tls_rx, tls_rx_fut) = conn.receive(sock_rx); |
| 28 | + |
| 29 | + let (mut data_tx, data_rx) = wit_stream::new(); |
| 30 | + let (tls_tx, tls_tx_err_fut) = conn.send(data_rx); |
| 31 | + let sock_tx_fut = sock.send(tls_tx); |
| 32 | + |
| 33 | + Connector::connect(conn, domain.into()) |
| 34 | + .await |
| 35 | + .context("tls handshake failed")?; |
| 36 | + let buf = data_tx.write_all(request.into()).await; |
| 37 | + assert!(buf.is_empty()); |
| 38 | + |
| 39 | + let response = tls_rx.collect().await; |
| 40 | + let response = String::from_utf8(response)?; |
| 41 | + if !response.contains("HTTP/1.1 200 OK") { |
| 42 | + return Err(anyhow!("server did not respond with 200 OK: {response}")); |
| 43 | + } |
| 44 | + drop(data_tx); |
| 45 | + sock_rx_fut.await.context("tcp recv")?; |
| 46 | + sock_tx_fut.await.context("tcp send")?; |
| 47 | + tls_rx_fut.await.context("tls recv")?; |
| 48 | + tls_tx_err_fut.await.context("tls send")?; |
| 49 | + |
| 50 | + Ok(()) |
| 51 | +} |
| 52 | + |
| 53 | +/// This test sets up a TCP connection using one domain, and then attempts to |
| 54 | +/// perform a TLS handshake using another unrelated domain. This should result |
| 55 | +/// in a handshake error. |
| 56 | +async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { |
| 57 | + const BAD_DOMAIN: &str = "wrongdomain.localhost"; |
| 58 | + |
| 59 | + let sock = TcpSocket::create(ip.family()).unwrap(); |
| 60 | + sock.connect(IpSocketAddress::new(ip, PORT)) |
| 61 | + .await |
| 62 | + .context("tcp connect failed")?; |
| 63 | + |
| 64 | + let (_, data_rx) = wit_stream::new(); |
| 65 | + let conn = Connector::new(); |
| 66 | + |
| 67 | + conn.receive(sock.receive().0); |
| 68 | + sock.send(conn.send(data_rx).0); |
| 69 | + |
| 70 | + match Connector::connect(conn, BAD_DOMAIN.into()).await { |
| 71 | + Err(e) => { |
| 72 | + let debug_string = e.to_debug_string(); |
| 73 | + // We're expecting an error regarding certificates in some form or |
| 74 | + // another. When we add more TLS backends this naive check will |
| 75 | + // likely need to be revisited/expanded: |
| 76 | + if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") { |
| 77 | + return Ok(()); |
| 78 | + } |
| 79 | + Err(anyhow!(debug_string)) |
| 80 | + } |
| 81 | + Ok(_) => panic!("expecting server name mismatch"), |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) |
| 86 | +where |
| 87 | + Fut: Future<Output = Result<()>> + 'a, |
| 88 | +{ |
| 89 | + // since this is testing remote endpoints to ensure system cert store works |
| 90 | + // the test uses a couple different endpoints to reduce the number of flakes |
| 91 | + const DOMAINS: &[&str] = &[ |
| 92 | + "example.com", |
| 93 | + "api.github.com", |
| 94 | + "docs.wasmtime.dev", |
| 95 | + "bytecodealliance.org", |
| 96 | + "www.rust-lang.org", |
| 97 | + ]; |
| 98 | + |
| 99 | + for &domain in DOMAINS { |
| 100 | + let result = (|| async { |
| 101 | + let ip = resolve_addresses(domain.into()) |
| 102 | + .await? |
| 103 | + .first() |
| 104 | + .map(|a| a.to_owned()) |
| 105 | + .ok_or_else(|| anyhow!("DNS lookup failed."))?; |
| 106 | + test(domain, ip).await |
| 107 | + })(); |
| 108 | + |
| 109 | + match result.await { |
| 110 | + Ok(()) => return, |
| 111 | + Err(e) => { |
| 112 | + eprintln!("test for {domain} failed: {e:#}"); |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + panic!("all tests failed"); |
| 118 | +} |
| 119 | + |
| 120 | +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { |
| 121 | + async fn run() -> Result<(), ()> { |
| 122 | + println!("sample app"); |
| 123 | + try_live_endpoints(test_tls_sample_application).await; |
| 124 | + println!("invalid cert"); |
| 125 | + try_live_endpoints(test_tls_invalid_certificate).await; |
| 126 | + Ok(()) |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +fn main() {} |
0 commit comments