Skip to content

Commit d7a61ec

Browse files
refactor(dgw): extract upstream routing into shared module
Moves the duplicated upstream connection machinery (RoutePlan, UpstreamLeg, UpstreamSession, connect_upstream, prepare_upstream) out of api/fwd.rs into a new crate::upstream module, and consumes it from fwd.rs, generic_client.rs, and rd_clean_path.rs. Net effect: - Eliminates three copies of the same `resolve route → connect → optional TLS wrap` sequence. All three call sites now share one implementation, so future fixes (e.g. alternate-target iteration, IPv6 bracketing, TLS over agent tunnel) happen in one place. - Fixes TLS-over-agent-tunnel silently not working in fwd.rs: the new `UpstreamSession::Tls(Box<TlsStream<UpstreamLeg>>)` wraps either TCP or tunnel legs, where previously the TLS path only handled TCP. - RDP credential injection in generic_client.rs now works over agent tunnel too (RdpProxy<_, S> is generic over S; UpstreamLeg satisfies its bounds). rd_clean_path.rs's local `ServerTransport` enum is removed in favour of the shared `UpstreamLeg`; the comment explaining why this must be an enum (not Box<dyn>) moves to the shared module. devolutions-agent/src/main.rs: nightly rustfmt reflow of a long `.context(...)` chain, no behavioural change. - fwd.rs: 942 → 650 LOC (−292) - generic_client.rs: 240 → 205 LOC (−35) - rd_clean_path.rs: ~900 → ~860 LOC (−40) - upstream.rs: +344 new (shared) Verified: cargo check --workspace --all-targets clean (video-streamer bench failure is pre-existing). cargo clippy -p agent-tunnel -p agent-tunnel-proto -p devolutions-gateway -p devolutions-agent --all-targets: 0 warnings. cargo +nightly fmt applied. Tests: 52 lib + 23 gateway integration (routing/registry) all green.
1 parent c3dd7b9 commit d7a61ec

6 files changed

Lines changed: 403 additions & 447 deletions

File tree

devolutions-agent/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ fn parse_up_command_args(args: &[String]) -> Result<UpCommand> {
194194

195195
// CLI flag wins over JWT claim. At least one must be provided — the gateway does
196196
// not self-report a QUIC endpoint (see `EnrollmentJwtClaims::jet_quic_endpoint`).
197-
let quic_endpoint = cli_quic_endpoint.or(jwt_quic_endpoint).context(
198-
"missing QUIC endpoint: pass --quic-endpoint or include `jet_quic_endpoint` in the enrollment JWT",
199-
)?;
197+
let quic_endpoint = cli_quic_endpoint
198+
.or(jwt_quic_endpoint)
199+
.context("missing QUIC endpoint: pass --quic-endpoint or include `jet_quic_endpoint` in the enrollment JWT")?;
200200

201201
Ok(UpCommand {
202202
gateway_url: gateway_url.context("missing required --gateway")?,

devolutions-gateway/src/api/fwd.rs

Lines changed: 17 additions & 304 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,19 @@ use bytes::Bytes;
1111
use devolutions_gateway_task::ShutdownSignal;
1212
use tap::Pipe as _;
1313
use tokio::io::{AsyncRead, AsyncWrite};
14-
use tokio::net::TcpStream;
15-
use tokio_rustls::client::TlsStream;
1614
use tracing::{Instrument as _, field};
1715
use typed_builder::TypedBuilder;
1816
use uuid::Uuid;
1917

18+
use crate::DgwState;
2019
use crate::config::Conf;
2120
use crate::extract::{AssociationToken, BridgeToken};
2221
use crate::http::HttpError;
2322
use crate::proxy::Proxy;
2423
use crate::session::{ConnectionModeDetails, DisconnectInterest, SessionInfo, SessionMessageSender};
2524
use crate::subscriber::SubscriberSender;
26-
use crate::target_addr::TargetAddr;
2725
use crate::token::{ApplicationProtocol, AssociationTokenClaims, ConnectionMode, Protocol, RecordingPolicy};
28-
use crate::{DgwState, utils};
26+
use crate::upstream::{self, PreparedUpstream, UpstreamMode};
2927

3028
pub fn make_router<S>(state: DgwState) -> Router<S> {
3129
use axum::routing::{self, MethodFilter, get};
@@ -161,7 +159,7 @@ async fn handle_fwd(
161159
.claims(claims)
162160
.sessions(sessions)
163161
.subscriber_tx(subscriber_tx)
164-
.mode(if with_tls { ForwardMode::Tls } else { ForwardMode::Tcp })
162+
.mode(if with_tls { UpstreamMode::Tls } else { UpstreamMode::Tcp })
165163
.agent_tunnel_handle(agent_tunnel_handle)
166164
.build()
167165
.run()
@@ -192,240 +190,11 @@ struct Forward<S> {
192190
client_addr: SocketAddr,
193191
sessions: SessionMessageSender,
194192
subscriber_tx: SubscriberSender,
195-
mode: ForwardMode,
193+
mode: UpstreamMode,
196194
#[builder(default)]
197195
agent_tunnel_handle: Option<Arc<agent_tunnel::AgentTunnelHandle>>,
198196
}
199197

200-
#[derive(Debug, Clone, Copy)]
201-
enum ForwardMode {
202-
Tcp,
203-
Tls,
204-
}
205-
206-
enum UpstreamLeg {
207-
Tcp(TcpStream),
208-
Tunnel(agent_tunnel::stream::TunnelStream),
209-
}
210-
211-
impl AsyncRead for UpstreamLeg {
212-
fn poll_read(
213-
self: std::pin::Pin<&mut Self>,
214-
cx: &mut std::task::Context<'_>,
215-
buf: &mut tokio::io::ReadBuf<'_>,
216-
) -> std::task::Poll<std::io::Result<()>> {
217-
match self.get_mut() {
218-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
219-
Self::Tunnel(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
220-
}
221-
}
222-
}
223-
224-
impl AsyncWrite for UpstreamLeg {
225-
fn poll_write(
226-
self: std::pin::Pin<&mut Self>,
227-
cx: &mut std::task::Context<'_>,
228-
buf: &[u8],
229-
) -> std::task::Poll<std::io::Result<usize>> {
230-
match self.get_mut() {
231-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
232-
Self::Tunnel(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
233-
}
234-
}
235-
236-
fn poll_flush(
237-
self: std::pin::Pin<&mut Self>,
238-
cx: &mut std::task::Context<'_>,
239-
) -> std::task::Poll<std::io::Result<()>> {
240-
match self.get_mut() {
241-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
242-
Self::Tunnel(stream) => std::pin::Pin::new(stream).poll_flush(cx),
243-
}
244-
}
245-
246-
fn poll_shutdown(
247-
self: std::pin::Pin<&mut Self>,
248-
cx: &mut std::task::Context<'_>,
249-
) -> std::task::Poll<std::io::Result<()>> {
250-
match self.get_mut() {
251-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
252-
Self::Tunnel(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
253-
}
254-
}
255-
}
256-
257-
enum UpstreamSession {
258-
Tcp(UpstreamLeg),
259-
Tls(Box<TlsStream<UpstreamLeg>>),
260-
}
261-
262-
impl AsyncRead for UpstreamSession {
263-
fn poll_read(
264-
self: std::pin::Pin<&mut Self>,
265-
cx: &mut std::task::Context<'_>,
266-
buf: &mut tokio::io::ReadBuf<'_>,
267-
) -> std::task::Poll<std::io::Result<()>> {
268-
match self.get_mut() {
269-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
270-
Self::Tls(stream) => std::pin::Pin::new(stream.as_mut()).poll_read(cx, buf),
271-
}
272-
}
273-
}
274-
275-
impl AsyncWrite for UpstreamSession {
276-
fn poll_write(
277-
self: std::pin::Pin<&mut Self>,
278-
cx: &mut std::task::Context<'_>,
279-
buf: &[u8],
280-
) -> std::task::Poll<std::io::Result<usize>> {
281-
match self.get_mut() {
282-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
283-
Self::Tls(stream) => std::pin::Pin::new(stream.as_mut()).poll_write(cx, buf),
284-
}
285-
}
286-
287-
fn poll_flush(
288-
self: std::pin::Pin<&mut Self>,
289-
cx: &mut std::task::Context<'_>,
290-
) -> std::task::Poll<std::io::Result<()>> {
291-
match self.get_mut() {
292-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
293-
Self::Tls(stream) => std::pin::Pin::new(stream.as_mut()).poll_flush(cx),
294-
}
295-
}
296-
297-
fn poll_shutdown(
298-
self: std::pin::Pin<&mut Self>,
299-
cx: &mut std::task::Context<'_>,
300-
) -> std::task::Poll<std::io::Result<()>> {
301-
match self.get_mut() {
302-
Self::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
303-
Self::Tls(stream) => std::pin::Pin::new(stream.as_mut()).poll_shutdown(cx),
304-
}
305-
}
306-
}
307-
308-
enum RoutePlan<'a> {
309-
Direct(&'a TargetAddr),
310-
ViaAgent {
311-
target: &'a TargetAddr,
312-
candidates: Vec<Arc<agent_tunnel::registry::AgentPeer>>,
313-
},
314-
}
315-
316-
impl<'a> RoutePlan<'a> {
317-
async fn resolve(
318-
agent_tunnel_handle: Option<&agent_tunnel::AgentTunnelHandle>,
319-
explicit_agent_id: Option<Uuid>,
320-
target: &'a TargetAddr,
321-
) -> Result<Self, ForwardError> {
322-
if let Some(agent_id) = explicit_agent_id {
323-
let handle = agent_tunnel_handle.ok_or_else(|| {
324-
ForwardError::BadGateway(anyhow::anyhow!(
325-
"agent {agent_id} specified in token requires agent tunnel routing, but no tunnel handle is configured"
326-
))
327-
})?;
328-
329-
let agent = handle.registry().get(&agent_id).await.ok_or_else(|| {
330-
ForwardError::BadGateway(anyhow::anyhow!(
331-
"agent {agent_id} specified in token not found in registry"
332-
))
333-
})?;
334-
335-
return Ok(Self::ViaAgent {
336-
target,
337-
candidates: vec![agent],
338-
});
339-
}
340-
341-
let Some(handle) = agent_tunnel_handle else {
342-
return Ok(Self::Direct(target));
343-
};
344-
345-
match agent_tunnel::routing::resolve_route(handle.registry(), None, target.host()).await {
346-
agent_tunnel::routing::RoutingDecision::ViaAgent(candidates) => Ok(Self::ViaAgent { target, candidates }),
347-
agent_tunnel::routing::RoutingDecision::Direct => Ok(Self::Direct(target)),
348-
agent_tunnel::routing::RoutingDecision::ExplicitAgentNotFound(_) => {
349-
unreachable!("explicit agent IDs are handled before route resolution")
350-
}
351-
}
352-
}
353-
354-
async fn execute(
355-
self,
356-
agent_tunnel_handle: Option<&agent_tunnel::AgentTunnelHandle>,
357-
session_id: Uuid,
358-
) -> anyhow::Result<ConnectedTarget> {
359-
match self {
360-
Self::Direct(target) => {
361-
trace!(%target, "Select and connect to target");
362-
363-
let (stream, server_addr) = utils::tcp_connect(target).await?;
364-
365-
trace!(%target, "Connected");
366-
367-
Ok(ConnectedTarget {
368-
leg: UpstreamLeg::Tcp(stream),
369-
server_addr,
370-
selected_target: target.clone(),
371-
})
372-
}
373-
Self::ViaAgent { target, candidates } => {
374-
let handle = agent_tunnel_handle.expect("route plan requires configured agent tunnel");
375-
let mut last_error = None;
376-
377-
for agent in &candidates {
378-
info!(
379-
agent_id = %agent.agent_id,
380-
agent_name = %agent.name,
381-
target = %target.as_addr(),
382-
"Routing via agent tunnel"
383-
);
384-
385-
match handle
386-
.connect_via_agent(agent.agent_id, session_id, target.as_addr())
387-
.await
388-
{
389-
Ok(stream) => {
390-
let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder");
391-
392-
return Ok(ConnectedTarget {
393-
leg: UpstreamLeg::Tunnel(stream),
394-
server_addr,
395-
selected_target: target.clone(),
396-
});
397-
}
398-
Err(error) => {
399-
warn!(
400-
agent_id = %agent.agent_id,
401-
agent_name = %agent.name,
402-
target = %target.as_addr(),
403-
error = format!("{error:#}"),
404-
"Agent tunnel candidate failed"
405-
);
406-
last_error = Some(error);
407-
}
408-
}
409-
}
410-
411-
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("all agent tunnel candidates failed")))
412-
}
413-
}
414-
}
415-
}
416-
417-
struct ConnectedTarget {
418-
leg: UpstreamLeg,
419-
server_addr: SocketAddr,
420-
selected_target: TargetAddr,
421-
}
422-
423-
struct PreparedTarget {
424-
session: UpstreamSession,
425-
server_addr: SocketAddr,
426-
selected_target: TargetAddr,
427-
}
428-
429198
#[derive(Debug, thiserror::Error)]
430199
pub enum ForwardError {
431200
#[error("bad gateway")]
@@ -464,19 +233,22 @@ where
464233
None
465234
};
466235

467-
let PreparedTarget {
468-
session,
469-
server_addr,
470-
selected_target,
471-
} = connect_target(
236+
let connected = upstream::connect_upstream(
472237
targets,
473238
claims.jet_agent_id,
474239
claims.jet_aid,
475-
mode,
476-
claims.cert_thumb256,
477240
agent_tunnel_handle.as_deref(),
478241
)
479-
.await?;
242+
.await
243+
.map_err(ForwardError::BadGateway)?;
244+
245+
let PreparedUpstream {
246+
session,
247+
server_addr,
248+
selected_target,
249+
} = upstream::prepare_upstream(connected, mode, claims.cert_thumb256)
250+
.await
251+
.map_err(ForwardError::BadGateway)?;
480252

481253
tracing::Span::current().record("target", selected_target.to_string());
482254

@@ -493,8 +265,8 @@ where
493265

494266
info!(
495267
mode = match mode {
496-
ForwardMode::Tcp => "tcp",
497-
ForwardMode::Tls => "tls",
268+
UpstreamMode::Tcp => "tcp",
269+
UpstreamMode::Tls => "tls",
498270
},
499271
"WebSocket forwarding"
500272
);
@@ -535,65 +307,6 @@ fn validate_forward_request(claims: &AssociationTokenClaims) -> Result<(), Forwa
535307
Ok(())
536308
}
537309

538-
async fn connect_target(
539-
targets: &nonempty::NonEmpty<TargetAddr>,
540-
explicit_agent_id: Option<Uuid>,
541-
session_id: Uuid,
542-
mode: ForwardMode,
543-
cert_thumb256: Option<crate::tls::thumbprint::Sha256Thumbprint>,
544-
agent_tunnel_handle: Option<&agent_tunnel::AgentTunnelHandle>,
545-
) -> Result<PreparedTarget, ForwardError> {
546-
let mut last_error = None;
547-
548-
for target in targets {
549-
match RoutePlan::resolve(agent_tunnel_handle, explicit_agent_id, target)
550-
.await?
551-
.execute(agent_tunnel_handle, session_id)
552-
.await
553-
{
554-
Err(error) => {
555-
last_error = Some(error);
556-
}
557-
Ok(connected_upstream) => return prepare_target(mode, cert_thumb256, connected_upstream).await,
558-
}
559-
}
560-
561-
Err(ForwardError::BadGateway(
562-
last_error.unwrap_or_else(|| anyhow::anyhow!("no target candidates available")),
563-
))
564-
}
565-
async fn prepare_target(
566-
mode: ForwardMode,
567-
cert_thumb256: Option<crate::tls::thumbprint::Sha256Thumbprint>,
568-
connected_upstream: ConnectedTarget,
569-
) -> Result<PreparedTarget, ForwardError> {
570-
let ConnectedTarget {
571-
leg,
572-
server_addr,
573-
selected_target,
574-
} = connected_upstream;
575-
576-
let session = match mode {
577-
ForwardMode::Tcp => UpstreamSession::Tcp(leg),
578-
ForwardMode::Tls => {
579-
trace!(target = %selected_target, "Establishing TLS connection with server");
580-
581-
let tls_stream = crate::tls::safe_connect(selected_target.host().to_owned(), leg, cert_thumb256)
582-
.await
583-
.context("TLS connect")
584-
.map_err(ForwardError::BadGateway)?;
585-
586-
UpstreamSession::Tls(Box::new(tls_stream))
587-
}
588-
};
589-
590-
Ok(PreparedTarget {
591-
session,
592-
server_addr,
593-
selected_target,
594-
})
595-
}
596-
597310
async fn fwd_http(
598311
State(state): State<DgwState>,
599312
BridgeToken(claims): BridgeToken,

0 commit comments

Comments
 (0)