|
7 | 7 | //! pre-decided by the caller. |
8 | 8 |
|
9 | 9 | use std::io; |
| 10 | +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; |
10 | 11 | use std::sync::Arc; |
11 | 12 |
|
12 | 13 | use axum::http::StatusCode; |
@@ -128,24 +129,30 @@ impl KdcConnector { |
128 | 129 | .err(), |
129 | 130 | )?) |
130 | 131 | } else { |
| 132 | + let udp_payload = message.get(4..).ok_or_else(|| { |
| 133 | + HttpError::bad_request().msg("KDC UDP message is too short to contain a length prefix") |
| 134 | + })?; |
| 135 | + |
| 136 | + let destination_addr = resolve_udp_destination(kdc_addr).await?; |
| 137 | + let bind_addr = udp_bind_addr_for(destination_addr); |
| 138 | + |
131 | 139 | // We assume that ticket length is not bigger than 2048 bytes. |
132 | 140 | let mut buf = [0; 2048]; |
133 | 141 |
|
134 | | - let udp_socket = UdpSocket::bind("127.0.0.1:0") |
| 142 | + let udp_socket = UdpSocket::bind(bind_addr) |
135 | 143 | .await |
136 | 144 | .map_err(HttpError::internal().with_msg("unable to bind UDP socket").err())?; |
137 | 145 |
|
138 | | - let port = udp_socket |
| 146 | + let local_addr = udp_socket |
139 | 147 | .local_addr() |
140 | | - .map_err(HttpError::internal().with_msg("unable to get UDP socket address").err())? |
141 | | - .port(); |
| 148 | + .map_err(HttpError::internal().with_msg("unable to get UDP socket address").err())?; |
142 | 149 |
|
143 | | - trace!("Binded UDP listener to 127.0.0.1:{port}, forwarding KDC message..."); |
| 150 | + trace!(%local_addr, %destination_addr, "Bound UDP listener, forwarding KDC message"); |
144 | 151 |
|
145 | 152 | // First 4 bytes contains message length. We don't need it for UDP. |
146 | 153 | #[allow(clippy::redundant_closure)] // We get a better caller location for the error by using a closure. |
147 | 154 | udp_socket |
148 | | - .send_to(&message[4..], kdc_addr.as_addr()) |
| 155 | + .send_to(udp_payload, destination_addr) |
149 | 156 | .await |
150 | 157 | .map_err(|e| unable_to_reach_kdc_server_err(e))?; |
151 | 158 |
|
@@ -187,6 +194,24 @@ impl KdcConnector { |
187 | 194 | } |
188 | 195 | } |
189 | 196 |
|
| 197 | +async fn resolve_udp_destination(kdc_addr: &TargetAddr) -> Result<SocketAddr, HttpError> { |
| 198 | + let mut addrs = tokio::net::lookup_host(kdc_addr.as_addr()) |
| 199 | + .await |
| 200 | + .map_err(unable_to_reach_kdc_server_err)?; |
| 201 | + |
| 202 | + addrs.next().ok_or_else(|| { |
| 203 | + unable_to_reach_kdc_server_err(io::Error::new(io::ErrorKind::NotFound, "KDC address resolved empty")) |
| 204 | + }) |
| 205 | +} |
| 206 | + |
| 207 | +fn udp_bind_addr_for(destination_addr: SocketAddr) -> SocketAddr { |
| 208 | + if destination_addr.is_ipv4() { |
| 209 | + SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) |
| 210 | + } else { |
| 211 | + SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) |
| 212 | + } |
| 213 | +} |
| 214 | + |
190 | 215 | /// Hard ceiling on the announced length of a TCP-framed KDC reply. |
191 | 216 | /// |
192 | 217 | /// The KDC TCP transport prefixes its message with a 4-byte big-endian length. |
@@ -257,6 +282,10 @@ mod tests { |
257 | 282 | TargetAddr::parse("tcp://127.0.0.1:1", Some(88)).expect("static target addr is valid") |
258 | 283 | } |
259 | 284 |
|
| 285 | + fn udp_kdc_addr() -> TargetAddr { |
| 286 | + TargetAddr::parse("udp://127.0.0.1:88", Some(88)).expect("static target addr is valid") |
| 287 | + } |
| 288 | + |
260 | 289 | /// No tunnel handle + explicit agent pin → must error. |
261 | 290 | /// |
262 | 291 | /// `jet_agent_id` declares a routing requirement; with no agent tunnel listener |
@@ -288,4 +317,26 @@ mod tests { |
288 | 317 | "should have reached the direct-connect branch, got: {err}", |
289 | 318 | ); |
290 | 319 | } |
| 320 | + |
| 321 | + #[tokio::test] |
| 322 | + async fn udp_message_shorter_than_length_prefix_errors() { |
| 323 | + let connector = KdcConnector::new(Uuid::new_v4(), None, None); |
| 324 | + let result = connector.send(&udp_kdc_addr(), b"\x00\x01\x02").await; |
| 325 | + let err = result.expect_err("UDP message shorter than the TCP length prefix must be rejected"); |
| 326 | + assert!( |
| 327 | + format!("{err}").contains("too short"), |
| 328 | + "error message should explain the malformed UDP payload, got: {err}", |
| 329 | + ); |
| 330 | + } |
| 331 | + |
| 332 | + #[test] |
| 333 | + fn udp_bind_addr_matches_destination_family() { |
| 334 | + let v4_bind = udp_bind_addr_for(SocketAddr::from((Ipv4Addr::LOCALHOST, 88))); |
| 335 | + assert!(v4_bind.is_ipv4()); |
| 336 | + assert_eq!(v4_bind.port(), 0); |
| 337 | + |
| 338 | + let v6_bind = udp_bind_addr_for(SocketAddr::from((Ipv6Addr::LOCALHOST, 88))); |
| 339 | + assert!(v6_bind.is_ipv6()); |
| 340 | + assert_eq!(v6_bind.port(), 0); |
| 341 | + } |
291 | 342 | } |
0 commit comments