Skip to content

Commit 9340a82

Browse files
feat(agent): transparent routing through agent tunnel (#1741)
1 parent 9d5a662 commit 9340a82

14 files changed

Lines changed: 916 additions & 248 deletions

File tree

crates/agent-tunnel/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ publish = false
88
[lints]
99
workspace = true
1010

11+
[features]
12+
# Exposes `_for_test` helpers (e.g. `AgentPeer::set_last_seen_for_test`) so
13+
# integration tests in other crates can force specific peer states without
14+
# wall-clock sleeps. Production builds must not enable this.
15+
test-utils = []
16+
1117
[dependencies]
1218
# Internal crates
1319
agent-tunnel-proto = { path = "../agent-tunnel-proto", features = ["serde"] }

crates/agent-tunnel/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod cert;
1010
pub mod enrollment_store;
1111
pub mod listener;
1212
pub mod registry;
13+
pub mod routing;
1314
pub mod stream;
1415

1516
pub use enrollment_store::EnrollmentTokenStore;

crates/agent-tunnel/src/listener.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ impl AgentTunnelListener {
158158
let handle = AgentTunnelHandle {
159159
registry: Arc::clone(&registry),
160160
agent_connections: Arc::clone(&agent_connections),
161-
ca_manager,
161+
ca_manager: Arc::clone(&ca_manager),
162162
enrollment_token_store,
163163
};
164164

@@ -170,6 +170,11 @@ impl AgentTunnelListener {
170170

171171
Ok((listener, handle))
172172
}
173+
174+
/// Returns the local address the QUIC endpoint is bound to.
175+
pub fn local_addr(&self) -> SocketAddr {
176+
self.endpoint.local_addr().expect("endpoint has local addr")
177+
}
173178
}
174179

175180
#[async_trait]
@@ -202,9 +207,7 @@ impl devolutions_gateway_task::Task for AgentTunnelListener {
202207
let registry = Arc::clone(&self.registry);
203208
let agent_connections = Arc::clone(&self.agent_connections);
204209

205-
conn_handles.spawn(
206-
run_agent_connection(registry, agent_connections, incoming),
207-
);
210+
conn_handles.spawn(run_agent_connection(registry, agent_connections, incoming));
208211
}
209212

210213
// Reap completed connection tasks to prevent unbounded growth.
@@ -253,7 +256,7 @@ async fn run_agent_connection(
253256

254257
info!(%agent_id, %agent_name, %peer_addr, "Agent authenticated via mTLS");
255258

256-
let peer = Arc::new(AgentPeer::new(agent_id, agent_name, fingerprint));
259+
let peer = Arc::new(AgentPeer::new(agent_id, agent_name.clone(), fingerprint));
257260
registry.register(Arc::clone(&peer)).await;
258261
agent_connections.write().await.insert(agent_id, conn.clone());
259262

crates/agent-tunnel/src/registry.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use serde::Serialize;
1010
use tokio::sync::RwLock as TokioRwLock;
1111
use uuid::Uuid;
1212

13+
use crate::routing::RouteTarget;
14+
1315
/// Duration after which an agent is considered offline if no heartbeat has been received.
1416
pub const AGENT_OFFLINE_TIMEOUT: Duration = Duration::from_secs(90);
1517

@@ -32,6 +34,33 @@ pub struct RouteAdvertisementState {
3234
pub updated_at: SystemTime,
3335
}
3436

37+
impl RouteAdvertisementState {
38+
/// Match this route set against a parsed target host.
39+
///
40+
/// Returns a specificity score if matched, or `None` if no match.
41+
/// IP subnet matches return `usize::MAX` (always highest priority).
42+
/// Domain matches return the matched domain length (longer = more specific).
43+
pub fn matches_target(&self, target: &RouteTarget) -> Option<usize> {
44+
use std::net::IpAddr;
45+
46+
match target {
47+
// Only IPv4 subnets are currently tracked; only match IPv4 target IPs.
48+
RouteTarget::Ip(IpAddr::V4(ipv4)) => self
49+
.subnets
50+
.iter()
51+
.any(|subnet| subnet.contains(*ipv4))
52+
.then_some(usize::MAX),
53+
RouteTarget::Ip(IpAddr::V6(_)) => None,
54+
RouteTarget::Hostname(hostname) => self
55+
.domains
56+
.iter()
57+
.filter(|adv| adv.domain.matches_hostname(hostname.as_str()))
58+
.map(|adv| adv.domain.as_str().len())
59+
.max(),
60+
}
61+
}
62+
}
63+
3564
impl Default for RouteAdvertisementState {
3665
fn default() -> Self {
3766
let now = SystemTime::now();
@@ -79,6 +108,37 @@ impl AgentPeer {
79108
self.last_seen.store(now_ms, Ordering::Release);
80109
}
81110

111+
/// Set `last_seen` to an explicit timestamp (milliseconds since UNIX epoch).
112+
///
113+
/// Test-only API — the `_for_test` suffix is the project's signal that
114+
/// production code must not call this. Used by integration tests in other
115+
/// crates (e.g. the workspace `testsuite`) to force an agent into the
116+
/// "offline" state without waiting for the real timeout to elapse;
117+
/// production code should use [`touch`](Self::touch) instead. Gated behind
118+
/// `test-utils` (and `cfg(test)` for this crate's own unit tests) so
119+
/// production builds cannot link against it; cross-crate consumers must
120+
/// opt in via `features = ["test-utils"]` on their `agent-tunnel`
121+
/// dev-dependency.
122+
#[cfg(any(test, feature = "test-utils"))]
123+
#[doc(hidden)]
124+
pub fn set_last_seen_for_test(&self, last_seen_ms: u64) {
125+
self.last_seen.store(last_seen_ms, Ordering::Release);
126+
}
127+
128+
/// Overwrite `received_at` on the current route state.
129+
///
130+
/// Test-only API. Intended for tests that need to assert ordering by
131+
/// arrival time without relying on wall-clock `thread::sleep` — which is
132+
/// flaky on platforms with coarse timer resolution (e.g. Windows ~16 ms).
133+
/// See [`set_last_seen_for_test`](Self::set_last_seen_for_test) for the
134+
/// gating rationale.
135+
#[cfg(any(test, feature = "test-utils"))]
136+
#[doc(hidden)]
137+
pub fn set_received_at_for_test(&self, received_at: SystemTime) {
138+
let mut state = self.route_state.write();
139+
state.received_at = received_at;
140+
}
141+
82142
/// Returns the last-seen timestamp as milliseconds since UNIX epoch.
83143
pub fn last_seen_ms(&self) -> u64 {
84144
self.last_seen.load(Ordering::Acquire)
@@ -223,6 +283,39 @@ impl AgentRegistry {
223283
pub async fn agent_infos(&self) -> Vec<AgentInfo> {
224284
self.agents.read().await.values().map(AgentInfo::from).collect()
225285
}
286+
287+
/// Find all online agents that can route to the given parsed target host.
288+
///
289+
/// For IP targets: matches against advertised subnets.
290+
/// For domain targets: uses longest suffix match (more specific domain wins).
291+
///
292+
/// Results with equal specificity are sorted by `received_at` descending (most recent first).
293+
pub async fn find_agents_for(&self, target: &RouteTarget) -> Vec<Arc<AgentPeer>> {
294+
let mut best_specificity: usize = 0;
295+
let mut candidates: Vec<(SystemTime, Arc<AgentPeer>)> = Vec::new();
296+
297+
let agents = self.agents.read().await;
298+
for agent in agents.values() {
299+
if !agent.is_online(AGENT_OFFLINE_TIMEOUT) {
300+
continue;
301+
}
302+
303+
let route_state = agent.route_state();
304+
305+
if let Some(specificity) = route_state.matches_target(target) {
306+
if specificity > best_specificity {
307+
best_specificity = specificity;
308+
candidates.clear();
309+
candidates.push((route_state.received_at, Arc::clone(agent)));
310+
} else if specificity == best_specificity {
311+
candidates.push((route_state.received_at, Arc::clone(agent)));
312+
}
313+
}
314+
}
315+
316+
candidates.sort_by(|a, b| b.0.cmp(&a.0));
317+
candidates.into_iter().map(|(_, agent)| agent).collect()
318+
}
226319
}
227320

228321
impl Default for AgentRegistry {

crates/agent-tunnel/src/routing.rs

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
//! Shared routing pipeline for agent tunnel.
2+
//!
3+
//! Consumed by the upstream connection paths (forwarding, RDP clean path,
4+
//! generic client) to ensure consistent routing behavior and error messages.
5+
6+
use std::net::IpAddr;
7+
use std::sync::Arc;
8+
9+
use agent_tunnel_proto::DomainName;
10+
use anyhow::{Result, anyhow};
11+
use uuid::Uuid;
12+
13+
use super::listener::AgentTunnelHandle;
14+
use super::registry::{AgentPeer, AgentRegistry};
15+
use super::stream::TunnelStream;
16+
17+
/// A parsed target host used for route matching.
18+
///
19+
/// Routing cares only about the host identity, not the port or scheme used by
20+
/// the eventual connection attempt.
21+
#[derive(Debug, Clone, PartialEq, Eq)]
22+
pub enum RouteTarget {
23+
Ip(IpAddr),
24+
Hostname(DomainName),
25+
}
26+
27+
impl RouteTarget {
28+
pub fn ip(ip: IpAddr) -> Self {
29+
Self::Ip(ip)
30+
}
31+
32+
pub fn hostname(hostname: impl Into<String>) -> Self {
33+
Self::Hostname(DomainName::new(hostname))
34+
}
35+
}
36+
37+
impl std::fmt::Display for RouteTarget {
38+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39+
match self {
40+
Self::Ip(ip) => ip.fmt(f),
41+
Self::Hostname(hostname) => hostname.fmt(f),
42+
}
43+
}
44+
}
45+
46+
/// Result of the routing pipeline.
47+
///
48+
/// Each variant carries enough context for the caller to produce an actionable error message.
49+
#[derive(Debug)]
50+
pub enum RoutingDecision {
51+
/// Route through these agent candidates (try in order, first success wins).
52+
ViaAgent(Vec<Arc<AgentPeer>>),
53+
/// Explicit agent_id was specified but not found in registry.
54+
ExplicitAgentNotFound(Uuid),
55+
/// No agent matched — caller should attempt direct connection.
56+
Direct,
57+
}
58+
59+
/// Determines how to route a connection to the given target.
60+
///
61+
/// Pipeline (in order of priority):
62+
/// 1. Explicit agent_id (from JWT) → route to that agent
63+
/// 2. Target match (IP subnet or domain suffix) → best match wins
64+
/// 3. No match → direct connection
65+
pub async fn resolve_route(
66+
registry: &AgentRegistry,
67+
explicit_agent_id: Option<Uuid>,
68+
target: &RouteTarget,
69+
) -> RoutingDecision {
70+
// Step 1: Explicit agent ID (from JWT)
71+
if let Some(id) = explicit_agent_id {
72+
return match registry.get(&id).await {
73+
Some(agent) => RoutingDecision::ViaAgent(vec![agent]),
74+
None => RoutingDecision::ExplicitAgentNotFound(id),
75+
};
76+
}
77+
78+
// Step 2: Match target against all agents (IP subnet or domain suffix)
79+
let agents = registry.find_agents_for(target).await;
80+
81+
if agents.is_empty() {
82+
RoutingDecision::Direct
83+
} else {
84+
RoutingDecision::ViaAgent(agents)
85+
}
86+
}
87+
88+
/// Attempt to route a connection via the agent tunnel.
89+
///
90+
/// Returns `Ok(Some(stream))` if routed through an agent, `Ok(None)` if the caller
91+
/// should fall through to direct connect, or `Err` if an explicit agent was specified
92+
/// but not found (or all candidates failed).
93+
pub async fn try_route(
94+
handle: Option<&AgentTunnelHandle>,
95+
explicit_agent_id: Option<Uuid>,
96+
target: &RouteTarget,
97+
session_id: Uuid,
98+
target_addr: &str,
99+
) -> Result<Option<(TunnelStream, Arc<AgentPeer>)>> {
100+
let Some(handle) = handle else {
101+
// An explicit `jet_agent_id` claim means the token requires routing via that
102+
// specific agent; silently falling back to a direct connect would bypass the
103+
// intended network boundary. Reject instead.
104+
return match explicit_agent_id {
105+
Some(id) => Err(anyhow!(
106+
"agent {id} specified in token requires agent tunnel routing, but no tunnel handle is configured"
107+
)),
108+
None => Ok(None),
109+
};
110+
};
111+
112+
match resolve_route(handle.registry(), explicit_agent_id, target).await {
113+
RoutingDecision::ExplicitAgentNotFound(id) => {
114+
Err(anyhow!("agent {id} specified in token not found in registry"))
115+
}
116+
RoutingDecision::Direct => Ok(None),
117+
RoutingDecision::ViaAgent(candidates) => {
118+
let result = route_and_connect(handle, &candidates, session_id, target_addr).await?;
119+
Ok(Some(result))
120+
}
121+
}
122+
}
123+
124+
/// Try connecting to target through agent candidates (try-fail-retry).
125+
///
126+
/// Returns the connected `TunnelStream` and the agent that succeeded.
127+
///
128+
/// Callers must handle `RoutingDecision::ExplicitAgentNotFound` and
129+
/// `RoutingDecision::Direct` before calling this function.
130+
pub async fn route_and_connect(
131+
handle: &AgentTunnelHandle,
132+
candidates: &[Arc<AgentPeer>],
133+
session_id: Uuid,
134+
target: &str,
135+
) -> Result<(TunnelStream, Arc<AgentPeer>)> {
136+
if candidates.is_empty() {
137+
return Err(anyhow!("route_and_connect called with empty candidates"));
138+
}
139+
140+
let mut last_error = None;
141+
142+
for agent in candidates {
143+
info!(
144+
agent_id = %agent.agent_id,
145+
agent_name = %agent.name,
146+
%target,
147+
"Routing via agent tunnel"
148+
);
149+
150+
match handle.connect_via_agent(agent.agent_id, session_id, target).await {
151+
Ok(stream) => {
152+
info!(
153+
agent_id = %agent.agent_id,
154+
agent_name = %agent.name,
155+
%target,
156+
"Agent tunnel connection established"
157+
);
158+
return Ok((stream, Arc::clone(agent)));
159+
}
160+
Err(error) => {
161+
warn!(
162+
agent_id = %agent.agent_id,
163+
agent_name = %agent.name,
164+
%target,
165+
error = format!("{error:#}"),
166+
"Agent tunnel connection failed, trying next candidate"
167+
);
168+
last_error = Some(error);
169+
}
170+
}
171+
}
172+
173+
let agent_names: Vec<&str> = candidates.iter().map(|a| a.name.as_str()).collect();
174+
let last_err_msg = last_error.as_ref().map(|e| format!("{e:#}")).unwrap_or_default();
175+
176+
error!(
177+
agent_count = candidates.len(),
178+
%target,
179+
agents = ?agent_names,
180+
last_error = %last_err_msg,
181+
"All agent tunnel candidates failed"
182+
);
183+
184+
Err(last_error.unwrap_or_else(|| {
185+
anyhow!(
186+
"All {} agents matching target '{}' failed to connect. Agents tried: [{}]",
187+
candidates.len(),
188+
target,
189+
agent_names.join(", "),
190+
)
191+
}))
192+
}

0 commit comments

Comments
 (0)