|
1 | 1 | use core::net::SocketAddr; |
| 2 | +use core::time::Duration; |
2 | 3 | use std::rc::Rc; |
3 | 4 | use std::sync::Arc; |
4 | 5 |
|
@@ -42,6 +43,50 @@ use crate::{SoundServerFactory, builder, capabilities}; |
42 | 43 | /// TCP listen backlog size for the RDP server socket. |
43 | 44 | const LISTENER_BACKLOG: u32 = 1024; |
44 | 45 |
|
| 46 | +/// Action to take after a client disconnects. |
| 47 | +/// |
| 48 | +/// Returned by [`ConnectionHandler::on_disconnected`] to control whether |
| 49 | +/// the server continues accepting new connections or shuts down. |
| 50 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 51 | +pub enum PostConnectionAction { |
| 52 | + /// Continue accepting new connections. |
| 53 | + Continue, |
| 54 | + /// Stop the accept loop and return from [`RdpServer::run`]. |
| 55 | + Stop, |
| 56 | +} |
| 57 | + |
| 58 | +/// Hooks for connection lifecycle events in [`RdpServer::run`]. |
| 59 | +/// |
| 60 | +/// Implement this trait to add pre-accept filtering (rate limiting, |
| 61 | +/// IP allowlists) and post-disconnect logic (cleanup, session validity |
| 62 | +/// checks, metrics). |
| 63 | +/// |
| 64 | +/// All methods have default implementations that accept all connections |
| 65 | +/// and continue unconditionally. |
| 66 | +pub trait ConnectionHandler: Send { |
| 67 | + /// Called after `accept()` returns but before `run_connection()`. |
| 68 | + /// |
| 69 | + /// Return `false` to reject the connection (the TCP stream is dropped). |
| 70 | + fn on_accept(&mut self, peer: SocketAddr) -> bool { |
| 71 | + let _ = peer; |
| 72 | + true |
| 73 | + } |
| 74 | + |
| 75 | + /// Called after `run_connection()` completes (successfully or with error). |
| 76 | + /// |
| 77 | + /// `duration` is the wall-clock time the connection was active. |
| 78 | + /// `error` is `Some` if the connection ended with an error. |
| 79 | + fn on_disconnected( |
| 80 | + &mut self, |
| 81 | + peer: SocketAddr, |
| 82 | + duration: Duration, |
| 83 | + error: Option<&anyhow::Error>, |
| 84 | + ) -> PostConnectionAction { |
| 85 | + let _ = (peer, duration, error); |
| 86 | + PostConnectionAction::Continue |
| 87 | + } |
| 88 | +} |
| 89 | + |
45 | 90 | #[derive(Clone)] |
46 | 91 | pub struct RdpServerOptions { |
47 | 92 | pub addr: SocketAddr, |
@@ -245,6 +290,7 @@ pub struct RdpServer { |
245 | 290 | creds: Option<Credentials>, |
246 | 291 | local_addr: Option<SocketAddr>, |
247 | 292 | autodetect: Option<AutoDetectManager>, |
| 293 | + connection_handler: Option<Box<dyn ConnectionHandler>>, |
248 | 294 | } |
249 | 295 |
|
250 | 296 | #[derive(Debug)] |
@@ -285,6 +331,7 @@ impl RdpServer { |
285 | 331 | display: Box<dyn RdpServerDisplay>, |
286 | 332 | mut sound_factory: Option<Box<dyn SoundServerFactory>>, |
287 | 333 | mut cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>, |
| 334 | + connection_handler: Option<Box<dyn ConnectionHandler>>, |
288 | 335 | #[cfg(feature = "egfx")] mut gfx_factory: Option<Box<dyn GfxServerFactory>>, |
289 | 336 | ) -> Self { |
290 | 337 | let (ev_sender, ev_receiver) = ServerEvent::create_channel(); |
@@ -315,6 +362,7 @@ impl RdpServer { |
315 | 362 | creds: None, |
316 | 363 | local_addr: None, |
317 | 364 | autodetect: None, |
| 365 | + connection_handler, |
318 | 366 | } |
319 | 367 | } |
320 | 368 |
|
@@ -531,10 +579,37 @@ impl RdpServer { |
531 | 579 | Ok((stream, peer)) = listener.accept() => { |
532 | 580 | debug!(?peer, "Received connection"); |
533 | 581 | drop(ev_receiver); |
534 | | - if let Err(error) = self.run_connection(stream).await { |
535 | | - error!(?error, "Connection error"); |
| 582 | + |
| 583 | + let accepted = self.connection_handler |
| 584 | + .as_mut() |
| 585 | + .is_none_or(|h| h.on_accept(peer)); |
| 586 | + |
| 587 | + if !accepted { |
| 588 | + debug!(?peer, "Connection rejected by handler"); |
| 589 | + drop(stream); |
| 590 | + } else { |
| 591 | + let started = tokio::time::Instant::now(); |
| 592 | + let result = self.run_connection(stream).await; |
| 593 | + let duration = started.elapsed(); |
| 594 | + |
| 595 | + if let Err(ref error) = result { |
| 596 | + error!(?error, "Connection error"); |
| 597 | + } |
| 598 | + |
| 599 | + self.static_channels = StaticChannelSet::new(); |
| 600 | + |
| 601 | + if let Some(ref mut handler) = self.connection_handler { |
| 602 | + let action = handler.on_disconnected( |
| 603 | + peer, |
| 604 | + duration, |
| 605 | + result.as_ref().err(), |
| 606 | + ); |
| 607 | + if action == PostConnectionAction::Stop { |
| 608 | + debug!(?peer, "Handler requested stop after disconnect"); |
| 609 | + break; |
| 610 | + } |
| 611 | + } |
536 | 612 | } |
537 | | - self.static_channels = StaticChannelSet::new(); |
538 | 613 | } |
539 | 614 | else => break, |
540 | 615 | } |
|
0 commit comments