Skip to content

Commit cf2fac1

Browse files
committed
Add better test coverage for initiator and fix bug with shut down session trying to reconnect
1 parent 1fe5adf commit cf2fac1

2 files changed

Lines changed: 158 additions & 4 deletions

File tree

crates/hotfix/src/initiator.rs

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,22 @@ async fn establish_connection<Outbound: OutboundMessage>(
152152
completion_tx.send_replace(true);
153153
}
154154

155-
#[cfg(all(test, feature = "fix44"))]
155+
#[cfg(test)]
156+
#[allow(clippy::expect_used)]
156157
mod tests {
157158
use super::*;
158159
use crate::application::{Application, InboundDecision, OutboundDecision};
159-
use crate::message::InboundMessage;
160+
use crate::message::logon::{Logon, ResetSeqNumConfig};
161+
use crate::message::logout::Logout;
162+
use crate::message::parser::Parser;
163+
use crate::message::{InboundMessage, generate_message};
160164
use crate::store::in_memory::InMemoryMessageStore;
165+
use hotfix_message::Part;
161166
use hotfix_message::message::Message;
167+
use hotfix_message::session_fields::MSG_TYPE;
162168
use std::time::Duration;
163-
use tokio::net::TcpListener;
169+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
170+
use tokio::net::{TcpListener, TcpStream};
164171

165172
// Minimal message type for tests
166173
#[derive(Clone)]
@@ -194,6 +201,90 @@ mod tests {
194201
async fn on_logon(&mut self) {}
195202
}
196203

204+
/// A minimal FIX counterparty for testing the Initiator over TCP.
205+
struct TestCounterparty {
206+
stream: TcpStream,
207+
parser: Parser,
208+
seq_num: u64,
209+
// Counterparty's view: sender is TEST-TARGET, target is TEST-SENDER
210+
sender_comp_id: String,
211+
target_comp_id: String,
212+
}
213+
214+
impl TestCounterparty {
215+
async fn accept(listener: &TcpListener, config: &SessionConfig) -> Self {
216+
let (stream, _) = tokio::time::timeout(Duration::from_secs(2), listener.accept())
217+
.await
218+
.expect("timeout waiting for connection")
219+
.expect("failed to accept connection");
220+
221+
Self {
222+
stream,
223+
parser: Parser::default(),
224+
seq_num: 1,
225+
// Swap sender/target for counterparty perspective
226+
sender_comp_id: config.target_comp_id.clone(),
227+
target_comp_id: config.sender_comp_id.clone(),
228+
}
229+
}
230+
231+
async fn read_message(&mut self) -> Message {
232+
let mut buf = [0u8; 4096];
233+
loop {
234+
let n = self.stream.read(&mut buf).await.expect("read failed");
235+
if n == 0 {
236+
panic!("connection closed before receiving complete message");
237+
}
238+
let messages = self.parser.parse(&buf[..n]);
239+
if let Some(raw_msg) = messages.into_iter().next() {
240+
let builder = hotfix_message::MessageBuilder::new(
241+
hotfix_message::dict::Dictionary::fix44(),
242+
hotfix_message::message::Config::default(),
243+
)
244+
.expect("failed to create message builder");
245+
match builder.build(raw_msg.as_bytes()) {
246+
hotfix_message::parsed_message::ParsedMessage::Valid(msg) => return msg,
247+
_ => panic!("received invalid FIX message"),
248+
}
249+
}
250+
}
251+
}
252+
253+
async fn expect_message(&mut self, expected_type: &str) -> Message {
254+
let msg = tokio::time::timeout(Duration::from_secs(2), self.read_message())
255+
.await
256+
.expect("timeout waiting for message");
257+
let msg_type: &str = msg.header().get(MSG_TYPE).expect("missing MSG_TYPE");
258+
assert_eq!(msg_type, expected_type, "unexpected message type");
259+
msg
260+
}
261+
262+
async fn send_logon(&mut self, heartbeat_interval: u64) {
263+
let logon = Logon::new(heartbeat_interval, ResetSeqNumConfig::NoReset(None));
264+
self.send_message(logon).await;
265+
}
266+
267+
async fn send_logout(&mut self) {
268+
self.send_message(Logout::default()).await;
269+
}
270+
271+
async fn send_message(&mut self, message: impl OutboundMessage) {
272+
let raw = generate_message(
273+
"FIX.4.4",
274+
&self.sender_comp_id,
275+
&self.target_comp_id,
276+
self.seq_num,
277+
message,
278+
)
279+
.expect("failed to generate message");
280+
self.seq_num += 1;
281+
self.stream
282+
.write_all(&raw)
283+
.await
284+
.expect("failed to send message");
285+
}
286+
}
287+
197288
fn create_test_config(host: &str, port: u16) -> SessionConfig {
198289
SessionConfig {
199290
begin_string: "FIX.4.4".to_string(),
@@ -212,6 +303,27 @@ mod tests {
212303
}
213304
}
214305

306+
async fn given_logged_on_initiator() -> (Initiator<DummyMessage>, TestCounterparty) {
307+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
308+
let port = listener.local_addr().unwrap().port();
309+
let config = create_test_config("127.0.0.1", port);
310+
311+
let initiator = Initiator::start(config.clone(), NoOpApp, InMemoryMessageStore::default())
312+
.await
313+
.unwrap();
314+
315+
let mut counterparty = TestCounterparty::accept(&listener, &config).await;
316+
317+
// Complete the logon handshake
318+
counterparty.expect_message("A").await; // Receive Logon
319+
counterparty.send_logon(30).await; // Send Logon response
320+
321+
// Give the session a moment to process the logon
322+
sleep(Duration::from_millis(50)).await;
323+
324+
(initiator, counterparty)
325+
}
326+
215327
#[tokio::test]
216328
async fn test_start_creates_initiator_successfully() {
217329
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@@ -321,4 +433,41 @@ mod tests {
321433
let result = initiator.send_forget(DummyMessage).await;
322434
assert!(result.is_ok());
323435
}
436+
437+
#[tokio::test]
438+
async fn test_session_handle_returns_working_handle() {
439+
use crate::session::error::SendOutcome;
440+
441+
let (initiator, mut counterparty) = given_logged_on_initiator().await;
442+
443+
// Get the session handle and use it to send a message
444+
let handle = initiator.session_handle();
445+
let result = handle.send(DummyMessage).await;
446+
447+
assert!(matches!(result, Ok(SendOutcome::Sent { .. })));
448+
449+
// Verify counterparty received the message (msg type "0" = Heartbeat)
450+
counterparty.expect_message("0").await;
451+
}
452+
453+
#[tokio::test]
454+
async fn test_shutdown_with_logout_handshake() {
455+
let (initiator, mut counterparty) = given_logged_on_initiator().await;
456+
457+
assert!(!initiator.is_shutdown());
458+
459+
// Spawn shutdown in background - it sends Logout and waits for response
460+
let shutdown_handle = tokio::spawn(async move { initiator.shutdown(false).await });
461+
462+
// Counterparty receives Logout and responds
463+
counterparty.expect_message("5").await; // Logout
464+
counterparty.send_logout().await;
465+
466+
// Close the TCP connection - this completes the disconnect
467+
drop(counterparty);
468+
469+
// Shutdown should complete successfully
470+
let result = shutdown_handle.await.expect("shutdown task panicked");
471+
assert!(result.is_ok(), "Shutdown should complete, got {:?}", result);
472+
}
324473
}

crates/hotfix/src/session.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,12 @@ where
363363
match self.state {
364364
// if the session is already disconnected, we have nothing else to do
365365
SessionState::Disconnected(..) => {}
366-
// otherwise set the state to disconnected and assume it makes sense to try to reconnect
366+
// if we initiated the logout, preserve the reconnect flag
367+
SessionState::AwaitingLogout { reconnect, .. } => {
368+
self.state.disconnect_writer().await;
369+
self.state = SessionState::new_disconnected(reconnect, "logout completed");
370+
}
371+
// otherwise assume it makes sense to try to reconnect
367372
_ => {
368373
self.state.disconnect_writer().await;
369374
self.state = SessionState::new_disconnected(true, "peer has logged us out")

0 commit comments

Comments
 (0)