Skip to content

Commit fe7aa96

Browse files
fix: Copilot review — version checks, cert validation, type safety
- Hoist protocol version validation before match in both gateway and agent control loops (single check, no per-variant boilerplate) - Validate ConnectResponse protocol version in connect_via_agent - ServerCertStatus enum for ensure_server_cert (expiry + hostname SAN) - send.finish() after proxy copy (graceful QUIC EOF) - Fix constant_time_eq doc (inaccurate timing claim) - Extract ALPN to agent_tunnel_proto::ALPN_PROTOCOL constant - Destruct EnrollResponse at parameter level for readability - ValidatedTunnelConf: make wrong state unrepresentable at type level (dto::TunnelConf for JSON, TunnelConf for runtime with non-optional fields) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 194833b commit fe7aa96

8 files changed

Lines changed: 205 additions & 52 deletions

File tree

crates/agent-tunnel-proto/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub use control::{ControlMessage, DomainAdvertisement, MAX_CONTROL_MESSAGE_SIZE}
2323
pub use error::ProtoError;
2424
pub use session::{ConnectRequest, ConnectResponse, MAX_SESSION_MESSAGE_SIZE};
2525
pub use stream::{ControlRecvStream, ControlSendStream, ControlStream, SessionStream};
26-
pub use version::{CURRENT_PROTOCOL_VERSION, MIN_SUPPORTED_VERSION, validate_protocol_version};
26+
pub use version::{ALPN_PROTOCOL, CURRENT_PROTOCOL_VERSION, MIN_SUPPORTED_VERSION, validate_protocol_version};
2727

2828
/// Current wall-clock time in milliseconds since UNIX epoch.
2929
pub fn current_time_millis() -> u64 {

crates/agent-tunnel-proto/src/version.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
/// ALPN protocol identifier for the QUIC agent tunnel, including version suffix.
2+
pub const ALPN_PROTOCOL: &[u8] = b"gw-agent-tunnel/1";
3+
14
/// Current protocol version.
25
pub const CURRENT_PROTOCOL_VERSION: u16 = 1;
36

devolutions-agent/src/config.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,84 @@ pub struct Conf {
2020
pub remote_desktop: RemoteDesktopConf,
2121
pub pedm: dto::PedmConf,
2222
pub session: dto::SessionConf,
23-
pub tunnel: dto::TunnelConf,
23+
pub tunnel: TunnelConf,
2424
pub proxy: dto::ProxyConf,
2525
pub debug: dto::DebugConf,
2626
}
2727

28+
/// Validated tunnel configuration — required fields are guaranteed present.
29+
///
30+
/// Constructed from `dto::TunnelConf` via `TryFrom`. If the tunnel is disabled
31+
/// or not yet enrolled, the `enabled` field is `false` and path fields are empty
32+
/// (but the struct is always constructible).
33+
#[derive(Debug, Clone)]
34+
pub struct TunnelConf {
35+
pub enabled: bool,
36+
pub gateway_endpoint: String,
37+
pub client_cert_path: Utf8PathBuf,
38+
pub client_key_path: Utf8PathBuf,
39+
pub gateway_ca_cert_path: Utf8PathBuf,
40+
pub advertise_subnets: Vec<String>,
41+
pub advertise_domains: Vec<String>,
42+
pub auto_detect_domain: bool,
43+
pub heartbeat_interval_secs: u64,
44+
pub route_advertise_interval_secs: u64,
45+
pub server_spki_sha256: Option<String>,
46+
}
47+
48+
impl TryFrom<dto::TunnelConf> for TunnelConf {
49+
type Error = anyhow::Error;
50+
51+
fn try_from(conf: dto::TunnelConf) -> anyhow::Result<Self> {
52+
if !conf.enabled {
53+
// Disabled tunnel — return a placeholder with defaults.
54+
return Ok(Self {
55+
enabled: false,
56+
gateway_endpoint: String::new(),
57+
client_cert_path: Utf8PathBuf::new(),
58+
client_key_path: Utf8PathBuf::new(),
59+
gateway_ca_cert_path: Utf8PathBuf::new(),
60+
advertise_subnets: Vec::new(),
61+
advertise_domains: Vec::new(),
62+
auto_detect_domain: true,
63+
heartbeat_interval_secs: 60,
64+
route_advertise_interval_secs: 30,
65+
server_spki_sha256: None,
66+
});
67+
}
68+
69+
// Enabled tunnel — all required fields must be present.
70+
let client_cert_path = conf
71+
.client_cert_path
72+
.context("tunnel enabled but client_cert_path not configured")?;
73+
let client_key_path = conf
74+
.client_key_path
75+
.context("tunnel enabled but client_key_path not configured")?;
76+
let gateway_ca_cert_path = conf
77+
.gateway_ca_cert_path
78+
.context("tunnel enabled but gateway_ca_cert_path not configured")?;
79+
80+
anyhow::ensure!(
81+
!conf.gateway_endpoint.is_empty(),
82+
"tunnel enabled but gateway_endpoint is empty"
83+
);
84+
85+
Ok(Self {
86+
enabled: true,
87+
gateway_endpoint: conf.gateway_endpoint,
88+
client_cert_path,
89+
client_key_path,
90+
gateway_ca_cert_path,
91+
advertise_subnets: conf.advertise_subnets,
92+
advertise_domains: conf.advertise_domains,
93+
auto_detect_domain: conf.auto_detect_domain,
94+
heartbeat_interval_secs: conf.heartbeat_interval_secs.unwrap_or(60),
95+
route_advertise_interval_secs: conf.route_advertise_interval_secs.unwrap_or(30),
96+
server_spki_sha256: conf.server_spki_sha256,
97+
})
98+
}
99+
}
100+
28101
impl Conf {
29102
pub fn from_conf_file(conf_file: &dto::ConfFile) -> anyhow::Result<Self> {
30103
let data_dir = get_data_dir();
@@ -49,7 +122,12 @@ impl Conf {
49122
remote_desktop,
50123
pedm: conf_file.pedm.clone().unwrap_or_default(),
51124
session: conf_file.session.clone().unwrap_or_default(),
52-
tunnel: conf_file.tunnel.clone().unwrap_or_default(),
125+
tunnel: conf_file
126+
.tunnel
127+
.clone()
128+
.unwrap_or_default()
129+
.pipe(TunnelConf::try_from)
130+
.context("invalid tunnel config")?,
53131
proxy: conf_file.proxy.clone().unwrap_or_default(),
54132
debug: conf_file.debug.clone().unwrap_or_default(),
55133
})

devolutions-agent/src/enrollment.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,13 @@ async fn request_enrollment(
118118
fn persist_enrollment_response(
119119
agent_name: &str,
120120
advertise_subnets: Vec<String>,
121-
enroll_response: EnrollResponse,
121+
EnrollResponse {
122+
agent_id,
123+
client_cert_pem,
124+
gateway_ca_cert_pem,
125+
quic_endpoint,
126+
server_spki_sha256,
127+
}: EnrollResponse,
122128
key_pem: &str,
123129
) -> anyhow::Result<PersistedEnrollment> {
124130
let config_path = config::get_conf_file_path();
@@ -132,19 +138,19 @@ fn persist_enrollment_response(
132138
std::fs::create_dir_all(&cert_dir)
133139
.with_context(|| format!("failed to create certificate directory: {}", cert_dir))?;
134140

135-
let client_cert_path = cert_dir.join(format!("{}-cert.pem", enroll_response.agent_id));
136-
let client_key_path = cert_dir.join(format!("{}-key.pem", enroll_response.agent_id));
141+
let client_cert_path = cert_dir.join(format!("{agent_id}-cert.pem"));
142+
let client_key_path = cert_dir.join(format!("{agent_id}-key.pem"));
137143
let gateway_ca_path = cert_dir.join("gateway-ca.pem");
138144

139145
// Write the locally-generated private key first (before cert/CA from the network).
140146
std::fs::write(&client_key_path, key_pem)
141-
.with_context(|| format!("failed to write client private key: {}", client_key_path))?;
147+
.with_context(|| format!("failed to write client private key: {client_key_path}"))?;
142148

143-
std::fs::write(&client_cert_path, &enroll_response.client_cert_pem)
144-
.with_context(|| format!("failed to write client certificate: {}", client_cert_path))?;
149+
std::fs::write(&client_cert_path, &client_cert_pem)
150+
.with_context(|| format!("failed to write client certificate: {client_cert_path}"))?;
145151

146-
std::fs::write(&gateway_ca_path, &enroll_response.gateway_ca_cert_pem)
147-
.with_context(|| format!("failed to write gateway CA certificate: {}", gateway_ca_path))?;
152+
std::fs::write(&gateway_ca_path, &gateway_ca_cert_pem)
153+
.with_context(|| format!("failed to write gateway CA certificate: {gateway_ca_path}"))?;
148154

149155
// Restrict permissions on cert/key files (owner-only on Unix).
150156
#[cfg(unix)]
@@ -167,7 +173,7 @@ fn persist_enrollment_response(
167173

168174
let tunnel_conf = config::dto::TunnelConf {
169175
enabled: true,
170-
gateway_endpoint: enroll_response.quic_endpoint.clone(),
176+
gateway_endpoint: quic_endpoint.clone(),
171177
client_cert_path: Some(client_cert_path.clone()),
172178
client_key_path: Some(client_key_path.clone()),
173179
gateway_ca_cert_path: Some(gateway_ca_path.clone()),
@@ -176,19 +182,19 @@ fn persist_enrollment_response(
176182
auto_detect_domain: existing_tunnel.map(|t| t.auto_detect_domain).unwrap_or(true),
177183
heartbeat_interval_secs: Some(60),
178184
route_advertise_interval_secs: Some(30),
179-
server_spki_sha256: Some(enroll_response.server_spki_sha256.clone()),
185+
server_spki_sha256: Some(server_spki_sha256),
180186
};
181187

182188
conf_file.tunnel = Some(tunnel_conf);
183189

184190
config::save_config(&conf_file)?;
185191

186192
Ok(PersistedEnrollment {
187-
agent_id: enroll_response.agent_id,
193+
agent_id,
188194
agent_name: agent_name.to_owned(),
189195
client_cert_path,
190196
client_key_path,
191197
gateway_ca_path,
192-
quic_endpoint: enroll_response.quic_endpoint,
198+
quic_endpoint,
193199
})
194200
}

devolutions-agent/src/tunnel.rs

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,9 @@ async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut S
181181
let agent_conf = conf_handle.get_conf();
182182
let tunnel_conf = &agent_conf.tunnel;
183183

184-
let cert_path = tunnel_conf
185-
.client_cert_path
186-
.as_ref()
187-
.context("client_cert_path not configured")?;
188-
let key_path = tunnel_conf
189-
.client_key_path
190-
.as_ref()
191-
.context("client_key_path not configured")?;
192-
let ca_path = tunnel_conf
193-
.gateway_ca_cert_path
194-
.as_ref()
195-
.context("gateway_ca_cert_path not configured")?;
184+
let cert_path = &tunnel_conf.client_cert_path;
185+
let key_path = &tunnel_conf.client_key_path;
186+
let ca_path = &tunnel_conf.gateway_ca_cert_path;
196187

197188
let advertise_subnets: Vec<Ipv4Network> = tunnel_conf
198189
.advertise_subnets
@@ -291,7 +282,7 @@ async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut S
291282
.with_client_auth_cert(certs, key)
292283
.context("build rustls client config with client auth")?;
293284

294-
client_crypto.alpn_protocols = vec![b"gw-agent-tunnel/1".to_vec()];
285+
client_crypto.alpn_protocols = vec![agent_tunnel_proto::ALPN_PROTOCOL.to_vec()];
295286

296287
let mut transport = quinn::TransportConfig::default();
297288
transport
@@ -356,8 +347,8 @@ async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut S
356347

357348
// -- Main loop: accept incoming session streams + periodic tasks --
358349

359-
let route_interval = tunnel_conf.route_advertise_interval_secs.unwrap_or(30);
360-
let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs.unwrap_or(60);
350+
let route_interval = tunnel_conf.route_advertise_interval_secs;
351+
let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs;
361352
let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval));
362353
let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs));
363354
// Skip the first immediate tick (we already sent the initial RouteAdvertise).
@@ -414,15 +405,16 @@ async fn run_control_reader<R: tokio::io::AsyncRead + Unpin>(mut ctrl: ControlRe
414405
loop {
415406
let message = ctrl.recv().await.context("recv control message")?;
416407

408+
let protocol_version = message.protocol_version();
409+
if agent_tunnel_proto::validate_protocol_version(protocol_version)
410+
.inspect_err(|e| warn!(%protocol_version, %e, "Ignoring control message: unsupported version"))
411+
.is_err()
412+
{
413+
continue;
414+
}
415+
417416
match message {
418-
ControlMessage::HeartbeatAck {
419-
protocol_version,
420-
timestamp_ms,
421-
} => {
422-
if let Err(e) = agent_tunnel_proto::validate_protocol_version(protocol_version) {
423-
warn!(%protocol_version, %e, "Ignoring HeartbeatAck: unsupported protocol version");
424-
continue;
425-
}
417+
ControlMessage::HeartbeatAck { timestamp_ms, .. } => {
426418
let rtt = current_time_millis().saturating_sub(timestamp_ms);
427419
debug!(rtt_ms = rtt, "Received HeartbeatAck");
428420
}
@@ -495,6 +487,9 @@ async fn run_session_proxy(advertise_subnets: Vec<Ipv4Network>, send: quinn::Sen
495487
r1.inspect_err(|e| debug!(%e, "QUIC->TCP copy ended"))?;
496488
r2.inspect_err(|e| debug!(%e, "TCP->QUIC copy ended"))?;
497489

490+
// Gracefully finish the QUIC send stream (signals EOF to peer).
491+
let _ = send.finish();
492+
498493
Ok(())
499494
}
500495
.await

devolutions-gateway/src/agent_tunnel/cert.rs

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,14 @@ impl CaManager {
179179
let cert_path = self.data_dir.join(SERVER_CERT_FILENAME);
180180
let key_path = self.data_dir.join(SERVER_KEY_FILENAME);
181181

182-
if cert_path.exists() && key_path.exists() {
183-
// TODO: check cert expiry and regenerate if near/past expiration (365-day validity).
184-
info!(%cert_path, "Using existing agent tunnel server certificate");
185-
return Ok((cert_path, key_path));
182+
match check_server_cert(&cert_path, &key_path, hostname) {
183+
ServerCertStatus::Valid => {
184+
info!(%cert_path, "Using existing agent tunnel server certificate");
185+
return Ok((cert_path, key_path));
186+
}
187+
status => {
188+
info!(%cert_path, ?status, "Generating server certificate");
189+
}
186190
}
187191

188192
info!(%hostname, "Generating agent tunnel server certificate");
@@ -301,7 +305,7 @@ impl CaManager {
301305
.with_single_cert(server_cert_chain, server_private_key)
302306
.context("build rustls ServerConfig")?;
303307

304-
tls_config.alpn_protocols = vec![b"gw-agent-tunnel/1".to_vec()];
308+
tls_config.alpn_protocols = vec![agent_tunnel_proto::ALPN_PROTOCOL.to_vec()];
305309

306310
Ok(tls_config)
307311
}
@@ -378,6 +382,65 @@ pub fn extract_agent_id_from_der(der_bytes: &[u8]) -> anyhow::Result<Uuid> {
378382
anyhow::bail!("no urn:uuid: SAN found in certificate")
379383
}
380384

385+
// ---------------------------------------------------------------------------
386+
// Server certificate validation
387+
// ---------------------------------------------------------------------------
388+
389+
/// Why an existing server certificate cannot be reused.
390+
#[derive(Debug)]
391+
enum ServerCertStatus {
392+
/// Certificate is valid and matches the configured hostname.
393+
Valid,
394+
/// Certificate or key file does not exist yet.
395+
NotFound,
396+
/// Certificate expires within 7 days.
397+
ExpiringSoon,
398+
/// Certificate's DNS SAN does not match the configured hostname.
399+
HostnameMismatch,
400+
/// Certificate file is corrupt or unparseable.
401+
Unreadable,
402+
}
403+
404+
fn check_server_cert(cert_path: &Utf8Path, key_path: &Utf8Path, hostname: &str) -> ServerCertStatus {
405+
if !cert_path.exists() || !key_path.exists() {
406+
return ServerCertStatus::NotFound;
407+
}
408+
409+
let Ok(pem_str) = std::fs::read_to_string(cert_path) else {
410+
return ServerCertStatus::Unreadable;
411+
};
412+
let Ok(parsed) = pem::parse(&pem_str) else {
413+
return ServerCertStatus::Unreadable;
414+
};
415+
let Ok((_, cert)) = x509_parser::parse_x509_certificate(parsed.contents()) else {
416+
return ServerCertStatus::Unreadable;
417+
};
418+
419+
// Expiry: reject if < 7 days remaining.
420+
let not_after = cert.validity().not_after.to_datetime();
421+
let threshold = time::OffsetDateTime::now_utc() + Duration::from_secs(7 * SECS_PER_DAY);
422+
if not_after <= threshold {
423+
return ServerCertStatus::ExpiringSoon;
424+
}
425+
426+
// Hostname: reject if DNS SAN doesn't match the configured hostname.
427+
let san_matches = cert.extensions().iter().any(|ext| {
428+
if let x509_parser::extensions::ParsedExtension::SubjectAlternativeName(san) = ext.parsed_extension() {
429+
san.general_names
430+
.iter()
431+
.any(|name| matches!(name, x509_parser::extensions::GeneralName::DNSName(h) if *h == hostname))
432+
} else {
433+
false
434+
}
435+
});
436+
437+
if !san_matches {
438+
return ServerCertStatus::HostnameMismatch;
439+
}
440+
441+
ServerCertStatus::Valid
442+
}
443+
381444
#[cfg(test)]
382445
mod tests {
383446
use super::*;

0 commit comments

Comments
 (0)