Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 93 additions & 3 deletions host/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,32 @@ pub struct HttpRequestMatcher {
}

/// Allow-list requests.
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone)]
pub struct AllowCertainHttpRequests {
/// Set of all matchers.
///
/// If ANY of them matches, the request will be allowed.
matchers: HashSet<HttpRequestMatcher>,

/// Set of all allowed paths.
///
/// The request path must start with one of these paths to be allowed.
allowed_paths: AllowHttpRequestPath,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the paths should be bound to the matcher since different hosts may have different path filters, i.e. it's likely matchers: HasMap<HttpRequestMatcher, AllowedHttpPath>, although that interface becomes a bit of a mess. I suggest the following: the public interface is based on HttpRequestMatcher and you add the add the list of allowed prefixes to that struct. For fast internal filtering I think you have two options, since you now need to replace the HashSet in AllowCertainHttpRequests:

  • hash map: Model the lookup table as HashMap<MatcherWithoutPrefixes, Prefixes>
  • binary search: Encode each prefix in each matcher as something like method\0host\0port\0prefixl and perform a binary search on that data
  • tree: Same as the binary search but as some kind of search tree.

}

impl Default for AllowCertainHttpRequests {
fn default() -> Self {
Self::new(vec!["/".to_string()])
}
}

impl AllowCertainHttpRequests {
/// Create new, empty request matcher.
pub fn new() -> Self {
Self::default()
pub fn new(allowed_paths: Vec<String>) -> Self {
Self {
matchers: HashSet::new(),
allowed_paths: AllowHttpRequestPath::new(allowed_paths),
}
}

/// Allow given request.
Expand All @@ -90,6 +104,10 @@ impl HttpRequestValidator for AllowCertainHttpRequests {
request: &hyper::Request<HyperOutgoingBody>,
use_tls: bool,
) -> Result<(), HttpRequestRejected> {
if !self.allowed_paths.is_allowed(request.uri().path()) {
return Err(HttpRequestRejected);
}

let matcher = HttpRequestMatcher {
method: request.method().clone(),
host: request
Expand All @@ -112,6 +130,43 @@ impl HttpRequestValidator for AllowCertainHttpRequests {
}
}

/// Restrict HTTP request paths.
#[derive(Debug, Clone, Default)]
pub(crate) struct AllowHttpRequestPath {
/// Allowed paths.
pub allowed_paths: Vec<String>,
}

impl AllowHttpRequestPath {
/// Create new path allow-list.
pub(crate) fn new(allowed_paths: Vec<String>) -> Self {
Self { allowed_paths }
}

/// Check if given path is allowed.
pub(crate) fn is_allowed(&self, path: &str) -> bool {
self.allowed_paths
.iter()
.any(|allowed| path.starts_with(allowed))
}
}

impl HttpRequestValidator for AllowHttpRequestPath {
fn validate(
&self,
request: &hyper::Request<HyperOutgoingBody>,
_use_tls: bool,
) -> Result<(), HttpRequestRejected> {
let path = request.uri().path_and_query().map_or("", |pq| pq.as_str());

if self.is_allowed(path) {
Ok(())
} else {
Err(HttpRequestRejected)
}
}
}

/// Reject HTTP request.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct HttpRequestRejected;
Expand Down Expand Up @@ -335,4 +390,39 @@ mod test {
);
}
}

#[test]
fn restrict_paths() {
let policy =
AllowHttpRequestPath::new(vec!["/allowed".to_string(), "/also/allowed".to_string()]);

struct Case {
path: &'static str,
result: Result<(), HttpRequestRejected>,
}

let cases = [
Case {
path: "/allowed",
result: Ok(()),
},
Case {
path: "/also/allowed",
result: Ok(()),
},
Case {
path: "/not/allowed",
result: Err(HttpRequestRejected),
},
];

for case in cases {
let request = hyper::Request::builder()
.uri(case.path)
.body(Default::default())
.unwrap();
let result = policy.validate(&request, false);
assert_eq!(result, case.result);
}
}
}
154 changes: 152 additions & 2 deletions host/tests/integration_tests/python/runtime/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def perform_request(url: str) -> str:
.mount(&server)
.await;

let mut permissions = AllowCertainHttpRequests::new();
let mut permissions = AllowCertainHttpRequests::default();
permissions.allow(HttpRequestMatcher {
method: http::Method::GET,
host: server.address().ip().to_string().into(),
Expand Down Expand Up @@ -638,7 +638,7 @@ def perform_request(url: str) -> str:

// deliberately use a runtime what we are going to throw away later to prevent tricks like `Handle::current`
let udf = rt_tmp.block_on(async {
let mut permissions = AllowCertainHttpRequests::new();
let mut permissions = AllowCertainHttpRequests::default();
permissions.allow(HttpRequestMatcher {
method: http::Method::GET,
host: server.address().ip().to_string().into(),
Expand Down Expand Up @@ -677,3 +677,153 @@ def perform_request(url: str) -> str:
&StringArray::from_iter([Some("hello world!".to_owned()),]) as &dyn Array,
);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After you've changed it to a "prefixes are matcher-specific", you should probably also harded test_integration to only accept the paths that are part of the specific test cases.

#[tokio::test]
async fn test_allowed_http_request_path() {
const CODE: &str = r#"
import requests

def perform_request(url: str) -> str:
return requests.get(url).text
"#;

let server = MockServer::start().await;
Mock::given(matchers::any())
.respond_with(ResponseTemplate::new(200).set_body_string("hello world!"))
.expect(1)
.mount(&server)
.await;

let allowed_paths = vec!["/allowed".to_string()];

let mut permissions = AllowCertainHttpRequests::new(allowed_paths);
permissions.allow(HttpRequestMatcher {
method: http::Method::GET,
host: server.address().ip().to_string().into(),
port: server.address().port(),
});
let udf = python_udf_with_permissions(CODE, permissions).await;

let array = udf
.invoke_async_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!(
"{}/allowed",
server.uri()
))))],
arg_fields: vec![Arc::new(Field::new("uri", DataType::Utf8, true))],
number_rows: 1,
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
config_options: Arc::new(ConfigOptions::default()),
})
.await
.unwrap()
.unwrap_array();

assert_eq!(
array.as_ref(),
&StringArray::from_iter([Some("hello world!".to_owned()),]) as &dyn Array,
);

let err = udf
.invoke_async_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!(
"{}/not_allowed",
server.uri()
))))],
arg_fields: vec![Arc::new(Field::new("uri", DataType::Utf8, true))],
number_rows: 1,
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
config_options: Arc::new(ConfigOptions::default()),
})
.await
.unwrap_err();

insta::assert_snapshot!(
err.to_string(),
@r#"
cannot call function
caused by
Execution error: Traceback (most recent call last):
File "/lib/python3.14/site-packages/urllib3/connectionpool.py", line 787, in urlopen
response = self._make_request(
conn,
...<10 lines>...
**response_kw,
)
File "/lib/python3.14/site-packages/urllib3/connectionpool.py", line 493, in _make_request
conn.request(
~~~~~~~~~~~~^
method,
^^^^^^^
...<6 lines>...
enforce_content_length=enforce_content_length,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/lib/python3.14/site-packages/urllib3/contrib/wasi/connection.py", line 124, in request
self._response = wasi.send_request(request)
~~~~~~~~~~~~~~~~~^^^^^^^^^
File "/lib/python3.14/site-packages/urllib3/contrib/wasi/wasi.py", line 79, in send_request
raise errors.WasiErrorCode(str(response.value.value))
urllib3.contrib.wasi.errors.WasiErrorCode: Request failed with wasi http error ErrorCode_HttpRequestDenied

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/lib/python3.14/site-packages/requests/adapters.py", line 644, in send
resp = conn.urlopen(
method=request.method,
...<9 lines>...
chunked=chunked,
)
File "/lib/python3.14/site-packages/urllib3/connectionpool.py", line 841, in urlopen
retries = retries.increment(
method, url, error=new_e, _pool=self, _stacktrace=sys.exc_info()[2]
)
File "/lib/python3.14/site-packages/urllib3/util/retry.py", line 474, in increment
raise reraise(type(error), error, _stacktrace)
~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.14/site-packages/urllib3/util/util.py", line 38, in reraise
raise value.with_traceback(tb)
File "/lib/python3.14/site-packages/urllib3/connectionpool.py", line 787, in urlopen
response = self._make_request(
conn,
...<10 lines>...
**response_kw,
)
File "/lib/python3.14/site-packages/urllib3/connectionpool.py", line 493, in _make_request
conn.request(
~~~~~~~~~~~~^
method,
^^^^^^^
...<6 lines>...
enforce_content_length=enforce_content_length,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/lib/python3.14/site-packages/urllib3/contrib/wasi/connection.py", line 124, in request
self._response = wasi.send_request(request)
~~~~~~~~~~~~~~~~~^^^^^^^^^
File "/lib/python3.14/site-packages/urllib3/contrib/wasi/wasi.py", line 79, in send_request
raise errors.WasiErrorCode(str(response.value.value))
urllib3.exceptions.ProtocolError: ('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "<string>", line 5, in perform_request
File "/lib/python3.14/site-packages/requests/api.py", line 73, in get
return request("get", url, params=params, **kwargs)
File "/lib/python3.14/site-packages/requests/api.py", line 59, in request
return session.request(method=method, url=url, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.14/site-packages/requests/sessions.py", line 589, in request
resp = self.send(prep, **send_kwargs)
File "/lib/python3.14/site-packages/requests/sessions.py", line 703, in send
r = adapter.send(request, **kwargs)
File "/lib/python3.14/site-packages/requests/adapters.py", line 659, in send
raise ConnectionError(err, request=request)
requests.exceptions.ConnectionError: ('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied'))
"#,
);
}