Skip to content

Commit c8659a3

Browse files
committed
fix(pegboard-runner): clear terminal tunnel routes
1 parent 412eec8 commit c8659a3

2 files changed

Lines changed: 222 additions & 10 deletions

File tree

engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,14 @@ async fn handle_tunnel_message_mk2(
860860
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
861861
msg: protocol::mk2::ToServerTunnelMessage,
862862
) -> Result<()> {
863+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
864+
let clear_route = matches!(
865+
msg.message_kind,
866+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(_)
867+
| protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort
868+
| protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(_)
869+
);
870+
863871
// Extract inner data length before consuming msg
864872
let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind);
865873

@@ -868,10 +876,7 @@ async fn handle_tunnel_message_mk2(
868876
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
869877
}
870878

871-
if !authorized_tunnel_routes
872-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
873-
.await
874-
{
879+
if !authorized_tunnel_routes.contains_async(&route).await {
875880
return Err(
876881
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
877882
);
@@ -899,6 +904,10 @@ async fn handle_tunnel_message_mk2(
899904
)
900905
})?;
901906

907+
if clear_route {
908+
authorized_tunnel_routes.remove_async(&route).await;
909+
}
910+
902911
Ok(())
903912
}
904913

@@ -909,6 +918,14 @@ async fn handle_tunnel_message_mk1(
909918
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
910919
msg: protocol::ToServerTunnelMessage,
911920
) -> Result<()> {
921+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
922+
let clear_route = matches!(
923+
msg.message_kind,
924+
protocol::ToServerTunnelMessageKind::ToServerResponseStart(_)
925+
| protocol::ToServerTunnelMessageKind::ToServerResponseAbort
926+
| protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_)
927+
);
928+
912929
// Ignore DeprecatedTunnelAck messages (used only for backwards compatibility)
913930
if matches!(
914931
msg.message_kind,
@@ -925,10 +942,7 @@ async fn handle_tunnel_message_mk1(
925942
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
926943
}
927944

928-
if !authorized_tunnel_routes
929-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
930-
.await
931-
{
945+
if !authorized_tunnel_routes.contains_async(&route).await {
932946
return Err(
933947
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
934948
);
@@ -950,6 +964,10 @@ async fn handle_tunnel_message_mk1(
950964
)
951965
})?;
952966

967+
if clear_route {
968+
authorized_tunnel_routes.remove_async(&route).await;
969+
}
970+
953971
Ok(())
954972
}
955973

engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs

Lines changed: 196 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,46 @@ fn response_abort_message_mk2(
2525
}
2626
}
2727

28+
fn response_start_message_mk2(
29+
gateway_id: protocol::mk2::GatewayId,
30+
request_id: protocol::mk2::RequestId,
31+
) -> protocol::mk2::ToServerTunnelMessage {
32+
protocol::mk2::ToServerTunnelMessage {
33+
message_id: protocol::mk2::MessageId {
34+
gateway_id,
35+
request_id,
36+
message_index: 0,
37+
},
38+
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(
39+
protocol::mk2::ToServerResponseStart {
40+
status: 200,
41+
headers: Default::default(),
42+
body: None,
43+
stream: false,
44+
},
45+
),
46+
}
47+
}
48+
49+
fn websocket_message_mk2(
50+
gateway_id: protocol::mk2::GatewayId,
51+
request_id: protocol::mk2::RequestId,
52+
) -> protocol::mk2::ToServerTunnelMessage {
53+
protocol::mk2::ToServerTunnelMessage {
54+
message_id: protocol::mk2::MessageId {
55+
gateway_id,
56+
request_id,
57+
message_index: 0,
58+
},
59+
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
60+
protocol::mk2::ToServerWebSocketMessage {
61+
data: b"ping".to_vec(),
62+
binary: false,
63+
},
64+
),
65+
}
66+
}
67+
2868
fn response_abort_message_mk1(
2969
gateway_id: protocol::mk2::GatewayId,
3070
request_id: protocol::mk2::RequestId,
@@ -39,6 +79,46 @@ fn response_abort_message_mk1(
3979
}
4080
}
4181

82+
fn websocket_message_mk1(
83+
gateway_id: protocol::mk2::GatewayId,
84+
request_id: protocol::mk2::RequestId,
85+
) -> protocol::ToServerTunnelMessage {
86+
protocol::ToServerTunnelMessage {
87+
message_id: protocol::MessageId {
88+
gateway_id,
89+
request_id,
90+
message_index: 0,
91+
},
92+
message_kind: protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(
93+
protocol::ToServerWebSocketMessage {
94+
data: b"ping".to_vec(),
95+
binary: false,
96+
},
97+
),
98+
}
99+
}
100+
101+
fn response_start_message_mk1(
102+
gateway_id: protocol::mk2::GatewayId,
103+
request_id: protocol::mk2::RequestId,
104+
) -> protocol::ToServerTunnelMessage {
105+
protocol::ToServerTunnelMessage {
106+
message_id: protocol::MessageId {
107+
gateway_id,
108+
request_id,
109+
message_index: 0,
110+
},
111+
message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseStart(
112+
protocol::ToServerResponseStart {
113+
status: 200,
114+
headers: Default::default(),
115+
body: None,
116+
stream: false,
117+
},
118+
),
119+
}
120+
}
121+
42122
#[tokio::test]
43123
async fn rejects_unissued_mk2_tunnel_message_pairs() {
44124
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-reject-mk2");
@@ -82,7 +162,7 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
82162
&pubsub,
83163
1024,
84164
&authorized_tunnel_routes,
85-
response_abort_message_mk2(gateway_id, request_id),
165+
websocket_message_mk2(gateway_id, request_id),
86166
)
87167
.await
88168
.unwrap();
@@ -92,6 +172,11 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
92172
.unwrap()
93173
.unwrap();
94174
assert!(matches!(msg, NextOutput::Message(_)));
175+
assert!(
176+
authorized_tunnel_routes
177+
.contains_async(&(gateway_id, request_id))
178+
.await
179+
);
95180
}
96181

97182
#[tokio::test]
@@ -137,7 +222,7 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
137222
&pubsub,
138223
1024,
139224
&authorized_tunnel_routes,
140-
response_abort_message_mk1(gateway_id, request_id),
225+
websocket_message_mk1(gateway_id, request_id),
141226
)
142227
.await
143228
.unwrap();
@@ -147,4 +232,113 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
147232
.unwrap()
148233
.unwrap();
149234
assert!(matches!(msg, NextOutput::Message(_)));
235+
assert!(
236+
authorized_tunnel_routes
237+
.contains_async(&(gateway_id, request_id))
238+
.await
239+
);
240+
}
241+
242+
#[tokio::test]
243+
async fn removes_terminal_mk2_tunnel_message_pairs() {
244+
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-remove-mk2");
245+
let gateway_id = [33, 34, 35, 36];
246+
let request_id = [37, 38, 39, 40];
247+
let authorized_tunnel_routes = HashMap::new();
248+
let _ = authorized_tunnel_routes
249+
.insert_async((gateway_id, request_id), ())
250+
.await;
251+
252+
handle_tunnel_message_mk2(
253+
&pubsub,
254+
1024,
255+
&authorized_tunnel_routes,
256+
response_abort_message_mk2(gateway_id, request_id),
257+
)
258+
.await
259+
.unwrap();
260+
261+
assert!(
262+
!authorized_tunnel_routes
263+
.contains_async(&(gateway_id, request_id))
264+
.await
265+
);
266+
}
267+
268+
#[tokio::test]
269+
async fn removes_response_start_mk2_tunnel_message_pairs() {
270+
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-remove-response-start-mk2");
271+
let gateway_id = [49, 50, 51, 52];
272+
let request_id = [53, 54, 55, 56];
273+
let authorized_tunnel_routes = HashMap::new();
274+
let _ = authorized_tunnel_routes
275+
.insert_async((gateway_id, request_id), ())
276+
.await;
277+
278+
handle_tunnel_message_mk2(
279+
&pubsub,
280+
1024,
281+
&authorized_tunnel_routes,
282+
response_start_message_mk2(gateway_id, request_id),
283+
)
284+
.await
285+
.unwrap();
286+
287+
assert!(
288+
!authorized_tunnel_routes
289+
.contains_async(&(gateway_id, request_id))
290+
.await
291+
);
292+
}
293+
294+
#[tokio::test]
295+
async fn removes_terminal_mk1_tunnel_message_pairs() {
296+
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-remove-mk1");
297+
let gateway_id = [41, 42, 43, 44];
298+
let request_id = [45, 46, 47, 48];
299+
let authorized_tunnel_routes = HashMap::new();
300+
let _ = authorized_tunnel_routes
301+
.insert_async((gateway_id, request_id), ())
302+
.await;
303+
304+
handle_tunnel_message_mk1(
305+
&pubsub,
306+
1024,
307+
&authorized_tunnel_routes,
308+
response_abort_message_mk1(gateway_id, request_id),
309+
)
310+
.await
311+
.unwrap();
312+
313+
assert!(
314+
!authorized_tunnel_routes
315+
.contains_async(&(gateway_id, request_id))
316+
.await
317+
);
318+
}
319+
320+
#[tokio::test]
321+
async fn removes_response_start_mk1_tunnel_message_pairs() {
322+
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-remove-response-start-mk1");
323+
let gateway_id = [57, 58, 59, 60];
324+
let request_id = [61, 62, 63, 64];
325+
let authorized_tunnel_routes = HashMap::new();
326+
let _ = authorized_tunnel_routes
327+
.insert_async((gateway_id, request_id), ())
328+
.await;
329+
330+
handle_tunnel_message_mk1(
331+
&pubsub,
332+
1024,
333+
&authorized_tunnel_routes,
334+
response_start_message_mk1(gateway_id, request_id),
335+
)
336+
.await
337+
.unwrap();
338+
339+
assert!(
340+
!authorized_tunnel_routes
341+
.contains_async(&(gateway_id, request_id))
342+
.await
343+
);
150344
}

0 commit comments

Comments
 (0)