Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions engine/packages/pegboard-envoy/src/sqlite_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,9 @@ mod tests {
use universaldb::driver::RocksDbDatabaseDriver;

use super::{
FILE_TAG_JOURNAL, FILE_TAG_MAIN, FILE_TAG_SHM, FILE_TAG_WAL,
SQLITE_V1_CHUNK_SIZE, SQLITE_V1_MAX_MIGRATION_BYTES, SQLITE_V1_MIGRATION_LEASE_MS,
maybe_migrate_v1_to_v2, read_v1_file, sqlite_subspace, v1_chunk_key, v1_meta_key,
FILE_TAG_JOURNAL, FILE_TAG_MAIN, FILE_TAG_SHM, FILE_TAG_WAL, SQLITE_V1_CHUNK_SIZE,
SQLITE_V1_MAX_MIGRATION_BYTES, SQLITE_V1_MIGRATION_LEASE_MS, maybe_migrate_v1_to_v2,
read_v1_file, sqlite_subspace, v1_chunk_key, v1_meta_key,
};

fn recipient(actor_id: Id) -> Recipient {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"code": "incoming_too_long",
"group": "message",
"message": "Incoming message too long."
"message": "Incoming message too long"
}
315 changes: 68 additions & 247 deletions rivetkit-rust/packages/rivetkit-core/src/registry/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ impl RegistryDispatcher {
return self.handle_metrics_fetch(&instance, &request);
}

let original_path = request.path.clone();
let request = build_http_request(request).await?;
let framework_route = framework_http_route(request.uri().path())?;
let route = RegistryHttpRoute::from_paths(&original_path, request.uri().path())?;
let instance = match self.active_actor(actor_id).await {
Ok(instance) => instance,
Err(error) => {
if framework_route.is_some() {
if matches!(route, RegistryHttpRoute::Framework(_)) {
return message_boundary_error_response(
request_encoding(request.headers()),
framework_anyhow_status(&error),
Expand All @@ -36,20 +37,24 @@ impl RegistryDispatcher {

instance.ctx.cancel_sleep_timer();

let rearm_sleep_after_request = |ctx: ActorContext| {
let sleep_ctx = ctx.clone();
ctx.wait_until(async move {
sleep_ctx.wait_for_http_requests_idle().await;
sleep_ctx.reset_sleep_timer();
});
let response = match route {
RegistryHttpRoute::Framework(route) => {
let response = self.handle_framework_fetch(&instance, request, route).await;
rearm_sleep_after_request(instance.ctx.clone());
response
}
RegistryHttpRoute::UserRawRequest => {
self.handle_user_request_fetch(&instance, request).await
}
};
response
}

if let Some(route) = framework_route {
let response = self.handle_framework_fetch(&instance, request, route).await;
rearm_sleep_after_request(instance.ctx.clone());
return response;
}

async fn handle_user_request_fetch(
&self,
instance: &ActorTaskHandle,
request: Request,
) -> Result<HttpResponse> {
let (reply_tx, reply_rx) = oneshot::channel();
try_send_dispatch_command(
&instance.dispatch,
Expand All @@ -72,7 +77,11 @@ impl RegistryDispatcher {
build_envoy_response(response)
}
Err(error) => {
tracing::error!(actor_id, ?error, "actor request callback failed");
tracing::error!(
actor_id = instance.actor_id,
?error,
"actor request callback failed"
);
rearm_sleep_after_request(instance.ctx.clone());
Ok(inspector_anyhow_response(error))
}
Expand Down Expand Up @@ -373,6 +382,40 @@ impl RegistryDispatcher {
}
}

enum RegistryHttpRoute {
Framework(FrameworkHttpRoute),
UserRawRequest,
}

impl RegistryHttpRoute {
fn from_paths(original_path: &str, normalized_path: &str) -> Result<Self> {
if let Some(stripped) = original_path.strip_prefix("/request") {
if stripped.is_empty() || matches!(stripped.as_bytes().first(), Some(b'/') | Some(b'?'))
{
return Ok(Self::UserRawRequest);
}
}

if let Some(segment) = single_path_segment(normalized_path, "/action/") {
return Ok(Self::Framework(FrameworkHttpRoute::Action(
percent_decode_path_segment(segment)?,
)));
}
if let Some(segment) = single_path_segment(normalized_path, "/queue/") {
return Ok(Self::Framework(FrameworkHttpRoute::Queue(
percent_decode_path_segment(segment)?,
)));
}

match normalized_path {
"/metadata" => Ok(Self::Framework(FrameworkHttpRoute::Metadata)),
"/health" => Ok(Self::Framework(FrameworkHttpRoute::Health)),
"/" => Ok(Self::Framework(FrameworkHttpRoute::Root)),
_ => Ok(Self::UserRawRequest),
}
}
}

pub(super) enum FrameworkHttpRoute {
Action(String),
Queue(String),
Expand All @@ -387,25 +430,6 @@ pub(super) struct DecodedHttpQueueRequest {
timeout: Option<u64>,
}

pub(super) fn framework_http_route(path: &str) -> Result<Option<FrameworkHttpRoute>> {
if let Some(segment) = single_path_segment(path, "/action/") {
return Ok(Some(FrameworkHttpRoute::Action(
percent_decode_path_segment(segment)?,
)));
}
if let Some(segment) = single_path_segment(path, "/queue/") {
return Ok(Some(FrameworkHttpRoute::Queue(
percent_decode_path_segment(segment)?,
)));
}
match path {
"/metadata" => Ok(Some(FrameworkHttpRoute::Metadata)),
"/health" => Ok(Some(FrameworkHttpRoute::Health)),
"/" => Ok(Some(FrameworkHttpRoute::Root)),
_ => Ok(None),
}
}

fn handle_metadata_fetch(request: &Request) -> Result<HttpResponse> {
if request.method() != http::Method::GET {
return method_not_allowed_response(request);
Expand Down Expand Up @@ -469,6 +493,14 @@ fn method_not_allowed_response(request: &Request) -> Result<HttpResponse> {
)
}

fn rearm_sleep_after_request(ctx: ActorContext) {
let sleep_ctx = ctx.clone();
ctx.wait_until(async move {
sleep_ctx.wait_for_http_requests_idle().await;
sleep_ctx.reset_sleep_timer();
});
}

pub(super) fn single_path_segment<'a>(path: &'a str, prefix: &str) -> Option<&'a str> {
let segment = path.strip_prefix(prefix)?;
(!segment.is_empty() && !segment.contains('/')).then_some(segment)
Expand Down Expand Up @@ -899,216 +931,5 @@ fn bearer_token_from_authorization(value: &str) -> Option<&str> {
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::time::Duration;

use super::{
HttpRequest, HttpResponseEncoding, authorization_bearer_token,
authorization_bearer_token_map, framework_action_error_response,
message_boundary_error_response, request_encoding, request_has_bearer_token,
workflow_dispatch_result,
};
use crate::actor::action::ActionDispatchError;
use crate::error::ActorLifecycle as ActorLifecycleError;
use http::StatusCode;
use rivet_error::RivetError;
use serde_json::json;
use vbare::OwnedVersionedData;

#[derive(RivetError)]
#[error("message", "incoming_too_long", "Incoming message too long")]
struct IncomingMessageTooLong;

#[derive(RivetError)]
#[error("message", "outgoing_too_long", "Outgoing message too long")]
struct OutgoingMessageTooLong;

#[test]
fn workflow_dispatch_result_marks_handled_workflow_as_enabled() {
assert_eq!(
workflow_dispatch_result(Ok(Some(vec![1, 2, 3])))
.expect("workflow dispatch should succeed"),
(true, Some(vec![1, 2, 3])),
);
assert_eq!(
workflow_dispatch_result(Ok(None)).expect("workflow dispatch should succeed"),
(true, None),
);
}

#[test]
fn workflow_dispatch_result_treats_dropped_reply_as_disabled() {
assert_eq!(
workflow_dispatch_result(Err(ActorLifecycleError::DroppedReply.build()))
.expect("dropped reply should map to workflow disabled"),
(false, None),
);
}

#[test]
fn workflow_dispatch_result_preserves_non_dropped_reply_errors() {
let error = workflow_dispatch_result(Err(ActorLifecycleError::Destroying.build()))
.expect_err("non-dropped reply errors should be preserved");
let error = rivet_error::RivetError::extract(&error);
assert_eq!(error.group(), "actor");
assert_eq!(error.code(), "destroying");
}

#[test]
fn inspector_error_status_maps_action_timeout_to_408() {
assert_eq!(
super::inspector_error_status("actor", "action_timed_out"),
StatusCode::REQUEST_TIMEOUT,
);
}

#[test]
fn authorization_bearer_token_accepts_case_insensitive_scheme_and_whitespace() {
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::AUTHORIZATION,
"bearer test-token".parse().unwrap(),
);

assert_eq!(authorization_bearer_token(&headers), Some("test-token"));

let map = HashMap::from([(
http::header::AUTHORIZATION.as_str().to_owned(),
"BEARER\ttest-token".to_owned(),
)]);
assert_eq!(authorization_bearer_token_map(&map), Some("test-token"));
}

#[test]
fn request_has_bearer_token_uses_same_authorization_parser() {
let request = HttpRequest {
method: "GET".to_owned(),
path: "/metrics".to_owned(),
headers: HashMap::from([(
http::header::AUTHORIZATION.as_str().to_owned(),
"Bearer configured".to_owned(),
)]),
body: Some(Vec::new()),
body_stream: None,
};

assert!(request_has_bearer_token(&request, Some("configured")));
assert!(!request_has_bearer_token(&request, Some("other")));
}

#[tokio::test]
async fn action_dispatch_timeout_returns_structured_error() {
let error = super::with_action_dispatch_timeout(Duration::from_millis(1), async {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<Vec<u8>, ActionDispatchError>(Vec::new())
})
.await
.expect_err("timeout should return an action dispatch error");

assert_eq!(error.group, "actor");
assert_eq!(error.code, "action_timed_out");
assert_eq!(error.message, "Action timed out");
}

#[tokio::test]
async fn framework_action_timeout_returns_structured_error() {
let error = super::with_framework_action_timeout(Duration::from_millis(1), async {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<(), anyhow::Error>(())
})
.await
.expect_err("timeout should return a framework error");
let error = RivetError::extract(&error);

assert_eq!(error.group(), "actor");
assert_eq!(error.code(), "action_timed_out");
assert_eq!(error.message(), "Action timed out");
}

#[test]
fn framework_action_error_response_maps_timeout_to_408() {
let response = framework_action_error_response(
HttpResponseEncoding::Json,
ActionDispatchError {
group: "actor".to_owned(),
code: "action_timed_out".to_owned(),
message: "Action timed out".to_owned(),
metadata: None,
},
)
.expect("timeout error response should serialize");

assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT.as_u16());
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "actor",
"code": "action_timed_out",
"message": "Action timed out",
}))
.expect("json body should encode")
)
);
}

#[test]
fn message_boundary_error_response_defaults_to_json() {
let response = message_boundary_error_response(
HttpResponseEncoding::Json,
StatusCode::BAD_REQUEST,
IncomingMessageTooLong.build(),
)
.expect("json response should serialize");

assert_eq!(response.status, StatusCode::BAD_REQUEST.as_u16());
assert_eq!(
response.headers.get(http::header::CONTENT_TYPE.as_str()),
Some(&"application/json".to_owned())
);
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "message",
"code": "incoming_too_long",
"message": "Incoming message too long",
}))
.expect("json body should encode")
)
);
}

#[test]
fn request_encoding_reads_cbor_header() {
let mut headers = http::HeaderMap::new();
headers.insert("x-rivet-encoding", "cbor".parse().unwrap());

assert_eq!(request_encoding(&headers), HttpResponseEncoding::Cbor);
}

#[test]
fn message_boundary_error_response_serializes_bare_v3() {
let response = message_boundary_error_response(
HttpResponseEncoding::Bare,
StatusCode::BAD_REQUEST,
OutgoingMessageTooLong.build(),
)
.expect("bare response should serialize");

assert_eq!(
response.headers.get(http::header::CONTENT_TYPE.as_str()),
Some(&"application/octet-stream".to_owned())
);

let body = response.body.expect("bare response should include body");
let decoded =
<rivetkit_client_protocol::versioned::HttpResponseError as OwnedVersionedData>::deserialize_with_embedded_version(&body)
.expect("bare error should decode");
assert_eq!(decoded.group, "message");
assert_eq!(decoded.code, "outgoing_too_long");
assert_eq!(decoded.message, "Outgoing message too long");
assert_eq!(decoded.metadata, None);
}
}
#[path = "../../tests/modules/registry_http.rs"]
mod tests;
Loading
Loading