@@ -11,21 +11,19 @@ use bytes::Bytes;
1111use devolutions_gateway_task:: ShutdownSignal ;
1212use tap:: Pipe as _;
1313use tokio:: io:: { AsyncRead , AsyncWrite } ;
14- use tokio:: net:: TcpStream ;
15- use tokio_rustls:: client:: TlsStream ;
1614use tracing:: { Instrument as _, field} ;
1715use typed_builder:: TypedBuilder ;
1816use uuid:: Uuid ;
1917
18+ use crate :: DgwState ;
2019use crate :: config:: Conf ;
2120use crate :: extract:: { AssociationToken , BridgeToken } ;
2221use crate :: http:: HttpError ;
2322use crate :: proxy:: Proxy ;
2423use crate :: session:: { ConnectionModeDetails , DisconnectInterest , SessionInfo , SessionMessageSender } ;
2524use crate :: subscriber:: SubscriberSender ;
26- use crate :: target_addr:: TargetAddr ;
2725use crate :: token:: { ApplicationProtocol , AssociationTokenClaims , ConnectionMode , Protocol , RecordingPolicy } ;
28- use crate :: { DgwState , utils } ;
26+ use crate :: upstream :: { self , PreparedUpstream , UpstreamMode } ;
2927
3028pub fn make_router < S > ( state : DgwState ) -> Router < S > {
3129 use axum:: routing:: { self , MethodFilter , get} ;
@@ -161,7 +159,7 @@ async fn handle_fwd(
161159 . claims ( claims)
162160 . sessions ( sessions)
163161 . subscriber_tx ( subscriber_tx)
164- . mode ( if with_tls { ForwardMode :: Tls } else { ForwardMode :: Tcp } )
162+ . mode ( if with_tls { UpstreamMode :: Tls } else { UpstreamMode :: Tcp } )
165163 . agent_tunnel_handle ( agent_tunnel_handle)
166164 . build ( )
167165 . run ( )
@@ -192,240 +190,11 @@ struct Forward<S> {
192190 client_addr : SocketAddr ,
193191 sessions : SessionMessageSender ,
194192 subscriber_tx : SubscriberSender ,
195- mode : ForwardMode ,
193+ mode : UpstreamMode ,
196194 #[ builder( default ) ]
197195 agent_tunnel_handle : Option < Arc < agent_tunnel:: AgentTunnelHandle > > ,
198196}
199197
200- #[ derive( Debug , Clone , Copy ) ]
201- enum ForwardMode {
202- Tcp ,
203- Tls ,
204- }
205-
206- enum UpstreamLeg {
207- Tcp ( TcpStream ) ,
208- Tunnel ( agent_tunnel:: stream:: TunnelStream ) ,
209- }
210-
211- impl AsyncRead for UpstreamLeg {
212- fn poll_read (
213- self : std:: pin:: Pin < & mut Self > ,
214- cx : & mut std:: task:: Context < ' _ > ,
215- buf : & mut tokio:: io:: ReadBuf < ' _ > ,
216- ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
217- match self . get_mut ( ) {
218- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_read ( cx, buf) ,
219- Self :: Tunnel ( stream) => std:: pin:: Pin :: new ( stream) . poll_read ( cx, buf) ,
220- }
221- }
222- }
223-
224- impl AsyncWrite for UpstreamLeg {
225- fn poll_write (
226- self : std:: pin:: Pin < & mut Self > ,
227- cx : & mut std:: task:: Context < ' _ > ,
228- buf : & [ u8 ] ,
229- ) -> std:: task:: Poll < std:: io:: Result < usize > > {
230- match self . get_mut ( ) {
231- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_write ( cx, buf) ,
232- Self :: Tunnel ( stream) => std:: pin:: Pin :: new ( stream) . poll_write ( cx, buf) ,
233- }
234- }
235-
236- fn poll_flush (
237- self : std:: pin:: Pin < & mut Self > ,
238- cx : & mut std:: task:: Context < ' _ > ,
239- ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
240- match self . get_mut ( ) {
241- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_flush ( cx) ,
242- Self :: Tunnel ( stream) => std:: pin:: Pin :: new ( stream) . poll_flush ( cx) ,
243- }
244- }
245-
246- fn poll_shutdown (
247- self : std:: pin:: Pin < & mut Self > ,
248- cx : & mut std:: task:: Context < ' _ > ,
249- ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
250- match self . get_mut ( ) {
251- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_shutdown ( cx) ,
252- Self :: Tunnel ( stream) => std:: pin:: Pin :: new ( stream) . poll_shutdown ( cx) ,
253- }
254- }
255- }
256-
257- enum UpstreamSession {
258- Tcp ( UpstreamLeg ) ,
259- Tls ( Box < TlsStream < UpstreamLeg > > ) ,
260- }
261-
262- impl AsyncRead for UpstreamSession {
263- fn poll_read (
264- self : std:: pin:: Pin < & mut Self > ,
265- cx : & mut std:: task:: Context < ' _ > ,
266- buf : & mut tokio:: io:: ReadBuf < ' _ > ,
267- ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
268- match self . get_mut ( ) {
269- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_read ( cx, buf) ,
270- Self :: Tls ( stream) => std:: pin:: Pin :: new ( stream. as_mut ( ) ) . poll_read ( cx, buf) ,
271- }
272- }
273- }
274-
275- impl AsyncWrite for UpstreamSession {
276- fn poll_write (
277- self : std:: pin:: Pin < & mut Self > ,
278- cx : & mut std:: task:: Context < ' _ > ,
279- buf : & [ u8 ] ,
280- ) -> std:: task:: Poll < std:: io:: Result < usize > > {
281- match self . get_mut ( ) {
282- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_write ( cx, buf) ,
283- Self :: Tls ( stream) => std:: pin:: Pin :: new ( stream. as_mut ( ) ) . poll_write ( cx, buf) ,
284- }
285- }
286-
287- fn poll_flush (
288- self : std:: pin:: Pin < & mut Self > ,
289- cx : & mut std:: task:: Context < ' _ > ,
290- ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
291- match self . get_mut ( ) {
292- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_flush ( cx) ,
293- Self :: Tls ( stream) => std:: pin:: Pin :: new ( stream. as_mut ( ) ) . poll_flush ( cx) ,
294- }
295- }
296-
297- fn poll_shutdown (
298- self : std:: pin:: Pin < & mut Self > ,
299- cx : & mut std:: task:: Context < ' _ > ,
300- ) -> std:: task:: Poll < std:: io:: Result < ( ) > > {
301- match self . get_mut ( ) {
302- Self :: Tcp ( stream) => std:: pin:: Pin :: new ( stream) . poll_shutdown ( cx) ,
303- Self :: Tls ( stream) => std:: pin:: Pin :: new ( stream. as_mut ( ) ) . poll_shutdown ( cx) ,
304- }
305- }
306- }
307-
308- enum RoutePlan < ' a > {
309- Direct ( & ' a TargetAddr ) ,
310- ViaAgent {
311- target : & ' a TargetAddr ,
312- candidates : Vec < Arc < agent_tunnel:: registry:: AgentPeer > > ,
313- } ,
314- }
315-
316- impl < ' a > RoutePlan < ' a > {
317- async fn resolve (
318- agent_tunnel_handle : Option < & agent_tunnel:: AgentTunnelHandle > ,
319- explicit_agent_id : Option < Uuid > ,
320- target : & ' a TargetAddr ,
321- ) -> Result < Self , ForwardError > {
322- if let Some ( agent_id) = explicit_agent_id {
323- let handle = agent_tunnel_handle. ok_or_else ( || {
324- ForwardError :: BadGateway ( anyhow:: anyhow!(
325- "agent {agent_id} specified in token requires agent tunnel routing, but no tunnel handle is configured"
326- ) )
327- } ) ?;
328-
329- let agent = handle. registry ( ) . get ( & agent_id) . await . ok_or_else ( || {
330- ForwardError :: BadGateway ( anyhow:: anyhow!(
331- "agent {agent_id} specified in token not found in registry"
332- ) )
333- } ) ?;
334-
335- return Ok ( Self :: ViaAgent {
336- target,
337- candidates : vec ! [ agent] ,
338- } ) ;
339- }
340-
341- let Some ( handle) = agent_tunnel_handle else {
342- return Ok ( Self :: Direct ( target) ) ;
343- } ;
344-
345- match agent_tunnel:: routing:: resolve_route ( handle. registry ( ) , None , target. host ( ) ) . await {
346- agent_tunnel:: routing:: RoutingDecision :: ViaAgent ( candidates) => Ok ( Self :: ViaAgent { target, candidates } ) ,
347- agent_tunnel:: routing:: RoutingDecision :: Direct => Ok ( Self :: Direct ( target) ) ,
348- agent_tunnel:: routing:: RoutingDecision :: ExplicitAgentNotFound ( _) => {
349- unreachable ! ( "explicit agent IDs are handled before route resolution" )
350- }
351- }
352- }
353-
354- async fn execute (
355- self ,
356- agent_tunnel_handle : Option < & agent_tunnel:: AgentTunnelHandle > ,
357- session_id : Uuid ,
358- ) -> anyhow:: Result < ConnectedTarget > {
359- match self {
360- Self :: Direct ( target) => {
361- trace ! ( %target, "Select and connect to target" ) ;
362-
363- let ( stream, server_addr) = utils:: tcp_connect ( target) . await ?;
364-
365- trace ! ( %target, "Connected" ) ;
366-
367- Ok ( ConnectedTarget {
368- leg : UpstreamLeg :: Tcp ( stream) ,
369- server_addr,
370- selected_target : target. clone ( ) ,
371- } )
372- }
373- Self :: ViaAgent { target, candidates } => {
374- let handle = agent_tunnel_handle. expect ( "route plan requires configured agent tunnel" ) ;
375- let mut last_error = None ;
376-
377- for agent in & candidates {
378- info ! (
379- agent_id = %agent. agent_id,
380- agent_name = %agent. name,
381- target = %target. as_addr( ) ,
382- "Routing via agent tunnel"
383- ) ;
384-
385- match handle
386- . connect_via_agent ( agent. agent_id , session_id, target. as_addr ( ) )
387- . await
388- {
389- Ok ( stream) => {
390- let server_addr: SocketAddr = "0.0.0.0:0" . parse ( ) . expect ( "valid placeholder" ) ;
391-
392- return Ok ( ConnectedTarget {
393- leg : UpstreamLeg :: Tunnel ( stream) ,
394- server_addr,
395- selected_target : target. clone ( ) ,
396- } ) ;
397- }
398- Err ( error) => {
399- warn ! (
400- agent_id = %agent. agent_id,
401- agent_name = %agent. name,
402- target = %target. as_addr( ) ,
403- error = format!( "{error:#}" ) ,
404- "Agent tunnel candidate failed"
405- ) ;
406- last_error = Some ( error) ;
407- }
408- }
409- }
410-
411- Err ( last_error. unwrap_or_else ( || anyhow:: anyhow!( "all agent tunnel candidates failed" ) ) )
412- }
413- }
414- }
415- }
416-
417- struct ConnectedTarget {
418- leg : UpstreamLeg ,
419- server_addr : SocketAddr ,
420- selected_target : TargetAddr ,
421- }
422-
423- struct PreparedTarget {
424- session : UpstreamSession ,
425- server_addr : SocketAddr ,
426- selected_target : TargetAddr ,
427- }
428-
429198#[ derive( Debug , thiserror:: Error ) ]
430199pub enum ForwardError {
431200 #[ error( "bad gateway" ) ]
@@ -464,19 +233,22 @@ where
464233 None
465234 } ;
466235
467- let PreparedTarget {
468- session,
469- server_addr,
470- selected_target,
471- } = connect_target (
236+ let connected = upstream:: connect_upstream (
472237 targets,
473238 claims. jet_agent_id ,
474239 claims. jet_aid ,
475- mode,
476- claims. cert_thumb256 ,
477240 agent_tunnel_handle. as_deref ( ) ,
478241 )
479- . await ?;
242+ . await
243+ . map_err ( ForwardError :: BadGateway ) ?;
244+
245+ let PreparedUpstream {
246+ session,
247+ server_addr,
248+ selected_target,
249+ } = upstream:: prepare_upstream ( connected, mode, claims. cert_thumb256 )
250+ . await
251+ . map_err ( ForwardError :: BadGateway ) ?;
480252
481253 tracing:: Span :: current ( ) . record ( "target" , selected_target. to_string ( ) ) ;
482254
@@ -493,8 +265,8 @@ where
493265
494266 info ! (
495267 mode = match mode {
496- ForwardMode :: Tcp => "tcp" ,
497- ForwardMode :: Tls => "tls" ,
268+ UpstreamMode :: Tcp => "tcp" ,
269+ UpstreamMode :: Tls => "tls" ,
498270 } ,
499271 "WebSocket forwarding"
500272 ) ;
@@ -535,65 +307,6 @@ fn validate_forward_request(claims: &AssociationTokenClaims) -> Result<(), Forwa
535307 Ok ( ( ) )
536308}
537309
538- async fn connect_target (
539- targets : & nonempty:: NonEmpty < TargetAddr > ,
540- explicit_agent_id : Option < Uuid > ,
541- session_id : Uuid ,
542- mode : ForwardMode ,
543- cert_thumb256 : Option < crate :: tls:: thumbprint:: Sha256Thumbprint > ,
544- agent_tunnel_handle : Option < & agent_tunnel:: AgentTunnelHandle > ,
545- ) -> Result < PreparedTarget , ForwardError > {
546- let mut last_error = None ;
547-
548- for target in targets {
549- match RoutePlan :: resolve ( agent_tunnel_handle, explicit_agent_id, target)
550- . await ?
551- . execute ( agent_tunnel_handle, session_id)
552- . await
553- {
554- Err ( error) => {
555- last_error = Some ( error) ;
556- }
557- Ok ( connected_upstream) => return prepare_target ( mode, cert_thumb256, connected_upstream) . await ,
558- }
559- }
560-
561- Err ( ForwardError :: BadGateway (
562- last_error. unwrap_or_else ( || anyhow:: anyhow!( "no target candidates available" ) ) ,
563- ) )
564- }
565- async fn prepare_target (
566- mode : ForwardMode ,
567- cert_thumb256 : Option < crate :: tls:: thumbprint:: Sha256Thumbprint > ,
568- connected_upstream : ConnectedTarget ,
569- ) -> Result < PreparedTarget , ForwardError > {
570- let ConnectedTarget {
571- leg,
572- server_addr,
573- selected_target,
574- } = connected_upstream;
575-
576- let session = match mode {
577- ForwardMode :: Tcp => UpstreamSession :: Tcp ( leg) ,
578- ForwardMode :: Tls => {
579- trace ! ( target = %selected_target, "Establishing TLS connection with server" ) ;
580-
581- let tls_stream = crate :: tls:: safe_connect ( selected_target. host ( ) . to_owned ( ) , leg, cert_thumb256)
582- . await
583- . context ( "TLS connect" )
584- . map_err ( ForwardError :: BadGateway ) ?;
585-
586- UpstreamSession :: Tls ( Box :: new ( tls_stream) )
587- }
588- } ;
589-
590- Ok ( PreparedTarget {
591- session,
592- server_addr,
593- selected_target,
594- } )
595- }
596-
597310async fn fwd_http (
598311 State ( state) : State < DgwState > ,
599312 BridgeToken ( claims) : BridgeToken ,
0 commit comments