Skip to content

Commit d96b76e

Browse files
committed
fix(pegboard-runner): clear terminal tunnel routes
1 parent 412eec8 commit d96b76e

2 files changed

Lines changed: 86 additions & 8 deletions

File tree

engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,13 @@ async fn handle_tunnel_message_mk2(
860860
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
861861
msg: protocol::mk2::ToServerTunnelMessage,
862862
) -> Result<()> {
863+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
864+
let clear_route = matches!(
865+
msg.message_kind,
866+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort
867+
| protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(_)
868+
);
869+
863870
// Extract inner data length before consuming msg
864871
let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind);
865872

@@ -868,10 +875,7 @@ async fn handle_tunnel_message_mk2(
868875
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
869876
}
870877

871-
if !authorized_tunnel_routes
872-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
873-
.await
874-
{
878+
if !authorized_tunnel_routes.contains_async(&route).await {
875879
return Err(
876880
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
877881
);
@@ -899,6 +903,10 @@ async fn handle_tunnel_message_mk2(
899903
)
900904
})?;
901905

906+
if clear_route {
907+
authorized_tunnel_routes.remove_async(&route).await;
908+
}
909+
902910
Ok(())
903911
}
904912

@@ -909,6 +917,13 @@ async fn handle_tunnel_message_mk1(
909917
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
910918
msg: protocol::ToServerTunnelMessage,
911919
) -> Result<()> {
920+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
921+
let clear_route = matches!(
922+
msg.message_kind,
923+
protocol::ToServerTunnelMessageKind::ToServerResponseAbort
924+
| protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_)
925+
);
926+
912927
// Ignore DeprecatedTunnelAck messages (used only for backwards compatibility)
913928
if matches!(
914929
msg.message_kind,
@@ -925,10 +940,7 @@ async fn handle_tunnel_message_mk1(
925940
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
926941
}
927942

928-
if !authorized_tunnel_routes
929-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
930-
.await
931-
{
943+
if !authorized_tunnel_routes.contains_async(&route).await {
932944
return Err(
933945
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
934946
);
@@ -950,6 +962,10 @@ async fn handle_tunnel_message_mk1(
950962
)
951963
})?;
952964

965+
if clear_route {
966+
authorized_tunnel_routes.remove_async(&route).await;
967+
}
968+
953969
Ok(())
954970
}
955971

engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
9292
.unwrap()
9393
.unwrap();
9494
assert!(matches!(msg, NextOutput::Message(_)));
95+
assert!(
96+
authorized_tunnel_routes
97+
.contains_async(&(gateway_id, request_id))
98+
.await
99+
);
95100
}
96101

97102
#[tokio::test]
@@ -147,4 +152,61 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
147152
.unwrap()
148153
.unwrap();
149154
assert!(matches!(msg, NextOutput::Message(_)));
155+
assert!(
156+
authorized_tunnel_routes
157+
.contains_async(&(gateway_id, request_id))
158+
.await
159+
);
160+
}
161+
162+
#[tokio::test]
163+
async fn removes_terminal_mk2_tunnel_message_pairs() {
164+
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-remove-mk2");
165+
let gateway_id = [33, 34, 35, 36];
166+
let request_id = [37, 38, 39, 40];
167+
let authorized_tunnel_routes = HashMap::new();
168+
let _ = authorized_tunnel_routes
169+
.insert_async((gateway_id, request_id), ())
170+
.await;
171+
172+
handle_tunnel_message_mk2(
173+
&pubsub,
174+
1024,
175+
&authorized_tunnel_routes,
176+
response_abort_message_mk2(gateway_id, request_id),
177+
)
178+
.await
179+
.unwrap();
180+
181+
assert!(
182+
!authorized_tunnel_routes
183+
.contains_async(&(gateway_id, request_id))
184+
.await
185+
);
186+
}
187+
188+
#[tokio::test]
189+
async fn removes_terminal_mk1_tunnel_message_pairs() {
190+
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-remove-mk1");
191+
let gateway_id = [41, 42, 43, 44];
192+
let request_id = [45, 46, 47, 48];
193+
let authorized_tunnel_routes = HashMap::new();
194+
let _ = authorized_tunnel_routes
195+
.insert_async((gateway_id, request_id), ())
196+
.await;
197+
198+
handle_tunnel_message_mk1(
199+
&pubsub,
200+
1024,
201+
&authorized_tunnel_routes,
202+
response_abort_message_mk1(gateway_id, request_id),
203+
)
204+
.await
205+
.unwrap();
206+
207+
assert!(
208+
!authorized_tunnel_routes
209+
.contains_async(&(gateway_id, request_id))
210+
.await
211+
);
150212
}

0 commit comments

Comments
 (0)