diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index 19543e4a..8661be0e 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -107,8 +107,8 @@ async fn establish_connection( completion_tx: watch::Sender, ) { loop { - if session_ref.await_active_session_time().await.is_err() { - warn!("session task terminated when checking active session time"); + if session_ref.await_in_schedule().await.is_err() { + warn!("session task terminated when checking schedule"); break; } diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 3926bc74..be2013b4 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -45,7 +45,9 @@ pub(crate) use crate::session::session_ref::InternalSessionRef; pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; -use crate::session::state::{AwaitingResendTransitionOutcome, TestRequestId}; +use crate::session::state::{ + AwaitingLogonState, AwaitingLogoutState, AwaitingResendTransitionOutcome, TestRequestId, +}; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; @@ -200,7 +202,7 @@ where } } - if let SessionState::AwaitingLogon { .. } = &mut self.state { + if let SessionState::AwaitingLogon(_) = &mut self.state { // TODO: should this (and all inbound message processing) logic be pushed into the state? if message_type != Logon::MSG_TYPE { self.state.disconnect_writer().await; @@ -332,11 +334,11 @@ where } async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { - self.state = SessionState::AwaitingLogon { + self.state = SessionState::AwaitingLogon(AwaitingLogonState { writer, logon_sent: false, logon_timeout: Instant::now() + Duration::from_secs(self.config.logon_timeout), - }; + }); self.reset_peer_timer(None); self.send_logon().await?; @@ -345,23 +347,23 @@ where async fn on_disconnect(&mut self, reason: String) { match self.state { - SessionState::Active { .. } - | SessionState::AwaitingLogon { .. } + SessionState::Active(_) + | SessionState::AwaitingLogon(_) | SessionState::AwaitingResend(_) => { self.state.disconnect_writer().await; self.state = SessionState::new_disconnected(true, &reason); } - SessionState::Disconnected { .. } => { + SessionState::Disconnected(_) => { warn!("disconnect message was received, but the session is already disconnected") } - SessionState::AwaitingLogout { reconnect, .. } => { + SessionState::AwaitingLogout(AwaitingLogoutState { reconnect, .. }) => { self.state = SessionState::new_disconnected(reconnect, &reason); } } } async fn on_logon(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let SessionState::AwaitingLogon { writer, .. } = &self.state { + if let SessionState::AwaitingLogon(AwaitingLogonState { writer, .. }) = &self.state { match self.verify_message(message, true, true) { Ok(_) => { // happy logon flow, the session is now active @@ -395,7 +397,7 @@ where // if the session is already disconnected, we have nothing else to do SessionState::Disconnected(..) => {} // if we initiated the logout, preserve the reconnect flag - SessionState::AwaitingLogout { reconnect, .. } => { + SessionState::AwaitingLogout(AwaitingLogoutState { reconnect, .. }) => { self.state.disconnect_writer().await; self.state = SessionState::new_disconnected(reconnect, "logout completed"); } @@ -1039,8 +1041,8 @@ where warn!("tried to respond to ShouldReconnect query but the receiver is gone"); } } - SessionEvent::AwaitingActiveSession(responder) => { - self.state.register_session_awaiter(responder); + SessionEvent::AwaitSchedule(responder) => { + self.state.register_schedule_awaiter(responder); } } } @@ -1117,7 +1119,7 @@ where let is_active = self.schedule.is_active_at(&now); if is_active { - self.state.notify_session_awaiter(); + self.state.notify_schedule_awaiter(); match self .schedule .is_same_session_period(&self.store.creation_time(), &now) diff --git a/crates/hotfix/src/session/event.rs b/crates/hotfix/src/session/event.rs index c144fd85..323cbeef 100644 --- a/crates/hotfix/src/session/event.rs +++ b/crates/hotfix/src/session/event.rs @@ -13,19 +13,19 @@ pub enum SessionEvent { Connected(WriterRef), /// Ask the session whether we should attempt to reconnect. ShouldReconnect(oneshot::Sender), - /// Ask the session to notify us when the session is active. - AwaitingActiveSession(oneshot::Sender), + /// Ask the session to notify us when the schedule indicates we should connect. + AwaitSchedule(oneshot::Sender), } -/// The response sent by the session to AwaitingActiveSession messages. +/// The response sent by the session to AwaitSchedule messages. /// -/// This doesn't include an Inactive variant, as the session won't respond until -/// it's active or in a state that indicates it should just be shut down due to an -/// unrecoverable error. +/// This doesn't include an out-of-schedule variant, as the session won't respond +/// until the schedule indicates we should connect or the session is in a state that +/// indicates it should just be shut down due to an unrecoverable error. #[derive(Debug, Clone, Copy)] -pub enum AwaitingActiveSessionResponse { - /// The session is now active and ready to connect. - Active, +pub enum ScheduleResponse { + /// The schedule indicates we should connect. + InSchedule, /// The session should be shut down due to an unrecoverable error. Shutdown, } diff --git a/crates/hotfix/src/session/session_ref.rs b/crates/hotfix/src/session/session_ref.rs index e04ba2db..15aa3957 100644 --- a/crates/hotfix/src/session/session_ref.rs +++ b/crates/hotfix/src/session/session_ref.rs @@ -7,7 +7,7 @@ use crate::message::{OutboundMessage, RawFixMessage}; use crate::session::Session; use crate::session::admin_request::AdminRequest; use crate::session::error::{SendError, SendOutcome, SessionCreationError}; -use crate::session::event::{AwaitingActiveSessionResponse, SessionEvent}; +use crate::session::event::{ScheduleResponse, SessionEvent}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; use crate::{Application, session}; @@ -82,15 +82,15 @@ impl InternalSessionRef { Ok(receiver.await?) } - pub async fn await_active_session_time(&self) -> Result<(), SessionGone> { - debug!("awaiting active session time"); - let (sender, receiver) = oneshot::channel::(); + pub async fn await_in_schedule(&self) -> Result<(), SessionGone> { + debug!("awaiting in-schedule time"); + let (sender, receiver) = oneshot::channel::(); self.event_sender - .send(SessionEvent::AwaitingActiveSession(sender)) + .send(SessionEvent::AwaitSchedule(sender)) .await?; receiver.await?; - debug!("resuming connection as session is active"); + debug!("resuming connection as schedule is active"); Ok(()) } } diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index fa84472d..e4ecc44d 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -1,36 +1,37 @@ +mod active; +mod awaiting_logon; +mod awaiting_logout; +mod awaiting_resend; +mod disconnected; + +pub(crate) use active::{ActiveState, calculate_peer_interval}; +pub(crate) use awaiting_logon::AwaitingLogonState; +pub(crate) use awaiting_logout::AwaitingLogoutState; +pub(crate) use awaiting_resend::{AwaitingResendState, AwaitingResendTransitionOutcome}; +pub(crate) use disconnected::DisconnectedState; + use crate::message::logon::Logon; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; -use crate::session::event::AwaitingActiveSessionResponse; +use crate::session::event::ScheduleResponse; use crate::session::info::Status as SessionInfoStatus; use crate::transport::writer::WriterRef; -use hotfix_message::message::Message; -use std::collections::VecDeque; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; use tracing::{debug, error}; const TEST_REQUEST_THRESHOLD: f64 = 1.2; -const MAX_RESEND_ATTEMPTS: usize = 3; pub(crate) type TestRequestId = String; pub enum SessionState { /// We have established a connection, sent a logon message and await a response. - AwaitingLogon { - writer: WriterRef, - logon_sent: bool, - logon_timeout: Instant, - }, + AwaitingLogon(AwaitingLogonState), /// We are awaiting the target to resend the gap we have. AwaitingResend(AwaitingResendState), /// We are in the process of gracefully logging out - AwaitingLogout { - writer: WriterRef, // we need the writer so we can disconnect it on successful logout - logout_timeout: Instant, - reconnect: bool, // we carry this forward for the subsequent disconnected state - }, + AwaitingLogout(AwaitingLogoutState), /// The session is active, we have connected and mutually logged on. Active(ActiveState), /// The TCP connection has been dropped. @@ -73,9 +74,9 @@ impl SessionState { writer.send_raw_message(message).await } } - Self::AwaitingLogon { + Self::AwaitingLogon(AwaitingLogonState { writer, logon_sent, .. - } => match message_type { + }) => match message_type { Logon::MSG_TYPE => { if *logon_sent { error!("trying to send logon twice"); @@ -89,7 +90,7 @@ impl SessionState { } _ => error!("invalid outgoing message for AwaitingLogon state"), }, - Self::AwaitingLogout { writer, .. } => { + Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) => { // Logout messages are allowed because we first transition into AwaitingLogout // and only then send the logout message if message_type == Logout::MSG_TYPE { @@ -103,8 +104,8 @@ impl SessionState { pub async fn disconnect_writer(&self) { match self { Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon { writer, .. } - | Self::AwaitingLogout { writer, .. } + | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) + | Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) | Self::AwaitingResend(AwaitingResendState { writer, .. }) => writer.disconnect().await, _ => debug!("disconnecting an already disconnected session"), } @@ -113,8 +114,8 @@ impl SessionState { fn get_writer(&self) -> Option<&WriterRef> { match self { Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon { writer, .. } - | Self::AwaitingLogout { writer, .. } + | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) + | Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) | Self::AwaitingResend(AwaitingResendState { writer, .. }) => Some(writer), _ => None, } @@ -125,17 +126,17 @@ impl SessionState { logout_timeout: Duration, reconnect: bool, ) -> bool { - if matches!(self, SessionState::AwaitingLogout { .. }) { + if matches!(self, SessionState::AwaitingLogout(_)) { debug!("already in awaiting logout state"); return false; } if let Some(writer) = self.get_writer() { - *self = SessionState::AwaitingLogout { + *self = SessionState::AwaitingLogout(AwaitingLogoutState { writer: writer.clone(), logout_timeout: Instant::now() + logout_timeout, reconnect, - }; + }); true } else { error!("trying to transition to awaiting logout without an established connection"); @@ -149,14 +150,14 @@ impl SessionState { end: u64, ) -> AwaitingResendTransitionOutcome { match self { - SessionState::AwaitingLogon { writer, .. } + SessionState::AwaitingLogon(AwaitingLogonState { writer, .. }) | SessionState::Active(ActiveState { writer, .. }) => { let awaiting_resend = AwaitingResendState::new(writer.to_owned(), begin, end); *self = SessionState::AwaitingResend(awaiting_resend); AwaitingResendTransitionOutcome::Success } SessionState::AwaitingResend(state) => state.update(begin, end), - SessionState::AwaitingLogout { .. } => AwaitingResendTransitionOutcome::InvalidState( + SessionState::AwaitingLogout(_) => AwaitingResendTransitionOutcome::InvalidState( "trying to request a resend while we are already logging out".to_string(), ), SessionState::Disconnected(_) => AwaitingResendTransitionOutcome::InvalidState( @@ -166,42 +167,39 @@ impl SessionState { } } - pub fn register_session_awaiter( - &mut self, - responder: oneshot::Sender, - ) { + pub fn register_schedule_awaiter(&mut self, responder: oneshot::Sender) { match self { SessionState::Disconnected(state) => { - if state.has_session_awaiter() { + if state.has_schedule_awaiter() { let reason = &state.reason; error!( - "session awaiter already registered on state disconnected due to: {reason}" + "schedule awaiter already registered on state disconnected due to: {reason}" ); - if let Err(err) = responder.send(AwaitingActiveSessionResponse::Shutdown) { - error!("failed to send session awaiter response: {err:?}"); + if let Err(err) = responder.send(ScheduleResponse::Shutdown) { + error!("failed to send schedule awaiter response: {err:?}"); } } else { - state.set_session_awaiter(responder); - debug!("registered session awaiter"); + state.set_schedule_awaiter(responder); + debug!("registered schedule awaiter"); } } _ => { - error!("session awaiter can only be registered on disconnected sessions"); - if let Err(err) = responder.send(AwaitingActiveSessionResponse::Shutdown) { - error!("failed to send session awaiter response: {err:?}"); + error!("schedule awaiter can only be registered on disconnected sessions"); + if let Err(err) = responder.send(ScheduleResponse::Shutdown) { + error!("failed to send schedule awaiter response: {err:?}"); } } } } - pub fn notify_session_awaiter(&mut self) { + pub fn notify_schedule_awaiter(&mut self) { if let SessionState::Disconnected(state) = self - && let Some(awaiter) = state.take_session_awaiter() + && let Some(awaiter) = state.take_schedule_awaiter() { - if let Err(err) = awaiter.send(AwaitingActiveSessionResponse::Active) { - error!("failed to send session awaiter response: {err:?}"); + if let Err(err) = awaiter.send(ScheduleResponse::InSchedule) { + error!("failed to send schedule awaiter response: {err:?}"); } else { - debug!("notified session awaiter"); + debug!("notified schedule awaiter"); } } } @@ -227,8 +225,10 @@ impl SessionState { pub fn peer_deadline(&self) -> Option<&Instant> { match self { Self::Active(ActiveState { peer_deadline, .. }) => Some(peer_deadline), - Self::AwaitingLogon { logon_timeout, .. } => Some(logon_timeout), - Self::AwaitingLogout { logout_timeout, .. } => Some(logout_timeout), + Self::AwaitingLogon(AwaitingLogonState { logon_timeout, .. }) => Some(logon_timeout), + Self::AwaitingLogout(AwaitingLogoutState { logout_timeout, .. }) => { + Some(logout_timeout) + } _ => None, } } @@ -274,16 +274,16 @@ impl SessionState { } pub fn is_awaiting_logon(&self) -> bool { - matches!(self, SessionState::AwaitingLogon { .. }) + matches!(self, SessionState::AwaitingLogon(_)) } pub fn is_awaiting_logout(&self) -> bool { - matches!(self, SessionState::AwaitingLogout { .. }) + matches!(self, SessionState::AwaitingLogout(_)) } pub fn as_status(&self) -> SessionInfoStatus { match self { - SessionState::AwaitingLogon { .. } => SessionInfoStatus::AwaitingLogon, + SessionState::AwaitingLogon(_) => SessionInfoStatus::AwaitingLogon, SessionState::AwaitingResend(AwaitingResendState { begin_seq_number, end_seq_number, @@ -294,165 +294,9 @@ impl SessionState { end: *end_seq_number, attempts: *resend_attempts, }, - SessionState::AwaitingLogout { .. } => SessionInfoStatus::AwaitingLogout, + SessionState::AwaitingLogout(_) => SessionInfoStatus::AwaitingLogout, SessionState::Active(_) => SessionInfoStatus::Active, SessionState::Disconnected(_) => SessionInfoStatus::Disconnected, } } } - -#[inline] -fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { - (heartbeat_interval as f64 * TEST_REQUEST_THRESHOLD).round() as u64 -} - -pub struct ActiveState { - /// The writer's reference to send messages to the counterparty - writer: WriterRef, - /// When we should send the next heartbeat message to the counterparty - heartbeat_deadline: Instant, - /// When the next message from the counterparty is expected at the latest - peer_deadline: Instant, - /// The ID of the test request we sent on peer timer expiry - sent_test_request_id: Option, -} - -/// Session state we're in while processing messages we requested to be resent. -pub struct AwaitingResendState { - /// The reference to the writer loop. - pub(crate) writer: WriterRef, - /// The beginning of the gap we're waiting for the target to resend. - pub(crate) begin_seq_number: u64, - /// The end of the gap we're waiting for the target to resend. - pub(crate) end_seq_number: u64, - /// Inbound messages we receive while processing the resend. - pub(crate) inbound_queue: VecDeque, - /// The number of times we've attempted to ask the counterparty to resend the gap. - pub(crate) resend_attempts: usize, -} - -impl AwaitingResendState { - fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { - Self { - writer, - begin_seq_number, - end_seq_number, - inbound_queue: Default::default(), - resend_attempts: 1, - } - } - - fn update( - &mut self, - begin_seq_number: u64, - end_seq_number: u64, - ) -> AwaitingResendTransitionOutcome { - let resend_attempts = if self.begin_seq_number == begin_seq_number { - if self.resend_attempts + 1 > MAX_RESEND_ATTEMPTS { - return AwaitingResendTransitionOutcome::AttemptsExceeded; - } - self.resend_attempts + 1 - } else if begin_seq_number < self.begin_seq_number { - return AwaitingResendTransitionOutcome::BeginSeqNumberTooLow; - } else { - 1 - }; - - self.resend_attempts = resend_attempts; - self.begin_seq_number = begin_seq_number; - self.end_seq_number = end_seq_number; - - AwaitingResendTransitionOutcome::Success - } -} - -pub struct DisconnectedState { - reconnect: bool, - session_awaiter: Option>, - reason: String, -} - -impl DisconnectedState { - fn new(reconnect: bool, reason: &str) -> Self { - Self { - reconnect, - session_awaiter: None, - reason: reason.to_string(), - } - } - - fn set_session_awaiter(&mut self, responder: oneshot::Sender) { - self.session_awaiter = Some(responder); - } - - fn has_session_awaiter(&self) -> bool { - self.session_awaiter.is_some() - } - - fn take_session_awaiter(&mut self) -> Option> { - self.session_awaiter.take() - } -} - -pub enum AwaitingResendTransitionOutcome { - Success, - InvalidState(String), - BeginSeqNumberTooLow, - AttemptsExceeded, -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::sync::mpsc; - - #[test] - fn test_awaiting_resend_transition_begin_seq_number_too_low() { - let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - let result = state.try_transition_to_awaiting_resend(0, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::BeginSeqNumberTooLow - )); - } - - #[test] - fn test_awaiting_resend_transition_attempts_exceeded() { - let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - - // we can transition twice more without hitting the limit - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - - // the fourth time we'd get into an AwaitingResendState with the same begin seq number, we get an error - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::AttemptsExceeded - )); - } - - #[test] - fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { - let mut state = SessionState::AwaitingLogout { - writer: create_writer_ref(), - logout_timeout: Instant::now(), - reconnect: false, - }; - - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::InvalidState(_) - )); - } - - fn create_writer_ref() -> WriterRef { - let (sender, _) = mpsc::channel(10); - WriterRef::new(sender) - } -} diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs new file mode 100644 index 00000000..b0828bbe --- /dev/null +++ b/crates/hotfix/src/session/state/active.rs @@ -0,0 +1,19 @@ +use crate::session::state::TestRequestId; +use crate::transport::writer::WriterRef; +use tokio::time::Instant; + +pub(crate) struct ActiveState { + /// The writer's reference to send messages to the counterparty + pub(crate) writer: WriterRef, + /// When we should send the next heartbeat message to the counterparty + pub(crate) heartbeat_deadline: Instant, + /// When the next message from the counterparty is expected at the latest + pub(crate) peer_deadline: Instant, + /// The ID of the test request we sent on peer timer expiry + pub(crate) sent_test_request_id: Option, +} + +#[inline] +pub(crate) fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { + (heartbeat_interval as f64 * super::TEST_REQUEST_THRESHOLD).round() as u64 +} diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs new file mode 100644 index 00000000..64a6758c --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -0,0 +1,11 @@ +use crate::transport::writer::WriterRef; +use tokio::time::Instant; + +pub(crate) struct AwaitingLogonState { + /// The writer's reference to send messages to the counterparty + pub(crate) writer: WriterRef, + /// Indicates whether we have sent Logon - safeguards against accidental double sends + pub(crate) logon_sent: bool, + /// When we are expecting the Logon response at the latest + pub(crate) logon_timeout: Instant, +} diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs new file mode 100644 index 00000000..0dc5d329 --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -0,0 +1,11 @@ +use crate::transport::writer::WriterRef; +use tokio::time::Instant; + +pub(crate) struct AwaitingLogoutState { + /// The writer's reference to send messages to the counterparty + pub(crate) writer: WriterRef, + /// When we are expecting the Logout response at the latest + pub(crate) logout_timeout: Instant, + /// Indicates whether we should attempt to reconnect after we've fully logged out + pub(crate) reconnect: bool, +} diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs new file mode 100644 index 00000000..ede8a2ed --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -0,0 +1,121 @@ +use crate::transport::writer::WriterRef; +use hotfix_message::message::Message; +use std::collections::VecDeque; + +const MAX_RESEND_ATTEMPTS: usize = 3; + +/// Session state we're in while processing messages we requested to be resent. +pub(crate) struct AwaitingResendState { + /// The reference to the writer loop. + pub(crate) writer: WriterRef, + /// The beginning of the gap we're waiting for the target to resend. + pub(crate) begin_seq_number: u64, + /// The end of the gap we're waiting for the target to resend. + pub(crate) end_seq_number: u64, + /// Inbound messages we receive while processing the resend. + pub(crate) inbound_queue: VecDeque, + /// The number of times we've attempted to ask the counterparty to resend the gap. + pub(crate) resend_attempts: usize, +} + +impl AwaitingResendState { + pub(crate) fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { + Self { + writer, + begin_seq_number, + end_seq_number, + inbound_queue: Default::default(), + resend_attempts: 1, + } + } + + pub(crate) fn update( + &mut self, + begin_seq_number: u64, + end_seq_number: u64, + ) -> AwaitingResendTransitionOutcome { + let resend_attempts = if self.begin_seq_number == begin_seq_number { + if self.resend_attempts + 1 > MAX_RESEND_ATTEMPTS { + return AwaitingResendTransitionOutcome::AttemptsExceeded; + } + self.resend_attempts + 1 + } else if begin_seq_number < self.begin_seq_number { + return AwaitingResendTransitionOutcome::BeginSeqNumberTooLow; + } else { + 1 + }; + + self.resend_attempts = resend_attempts; + self.begin_seq_number = begin_seq_number; + self.end_seq_number = end_seq_number; + + AwaitingResendTransitionOutcome::Success + } +} + +pub(crate) enum AwaitingResendTransitionOutcome { + Success, + InvalidState(String), + BeginSeqNumberTooLow, + AttemptsExceeded, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::session::state::SessionState; + use tokio::sync::mpsc; + use tokio::time::Instant; + + #[test] + fn test_awaiting_resend_transition_begin_seq_number_too_low() { + let writer = create_writer_ref(); + let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); + let result = state.try_transition_to_awaiting_resend(0, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::BeginSeqNumberTooLow + )); + } + + #[test] + fn test_awaiting_resend_transition_attempts_exceeded() { + let writer = create_writer_ref(); + let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); + + // we can transition twice more without hitting the limit + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + + // the fourth time we'd get into an AwaitingResendState with the same begin seq number, we get an error + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::AttemptsExceeded + )); + } + + #[test] + fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { + use crate::session::state::AwaitingLogoutState; + + let mut state = SessionState::AwaitingLogout(AwaitingLogoutState { + writer: create_writer_ref(), + logout_timeout: Instant::now(), + reconnect: false, + }); + + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::InvalidState(_) + )); + } + + fn create_writer_ref() -> WriterRef { + let (sender, _) = mpsc::channel(10); + WriterRef::new(sender) + } +} diff --git a/crates/hotfix/src/session/state/disconnected.rs b/crates/hotfix/src/session/state/disconnected.rs new file mode 100644 index 00000000..2d527ace --- /dev/null +++ b/crates/hotfix/src/session/state/disconnected.rs @@ -0,0 +1,34 @@ +use crate::session::event::ScheduleResponse; +use tokio::sync::oneshot; + +pub(crate) struct DisconnectedState { + /// Indicates whether we should attempt to reconnect + pub(crate) reconnect: bool, + /// The channel for notifying the session loop when trading hours resume + /// as indicated by the schedule + schedule_awaiter: Option>, + /// The reason we are disconnected + pub(crate) reason: String, +} + +impl DisconnectedState { + pub(crate) fn new(reconnect: bool, reason: &str) -> Self { + Self { + reconnect, + schedule_awaiter: None, + reason: reason.to_string(), + } + } + + pub(crate) fn set_schedule_awaiter(&mut self, responder: oneshot::Sender) { + self.schedule_awaiter = Some(responder); + } + + pub(crate) fn has_schedule_awaiter(&self) -> bool { + self.schedule_awaiter.is_some() + } + + pub(crate) fn take_schedule_awaiter(&mut self) -> Option> { + self.schedule_awaiter.take() + } +}