diff --git a/quickfixj-core/src/test/java/quickfix/mina/HttpProxyServer.java b/quickfixj-core/src/test/java/quickfix/mina/HttpProxyServer.java new file mode 100644 index 0000000000..9317b1acaf --- /dev/null +++ b/quickfixj-core/src/test/java/quickfix/mina/HttpProxyServer.java @@ -0,0 +1,307 @@ +package quickfix.mina; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.channel.MultiThreadIoEventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioIoHandler; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.base64.Base64; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslContext; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.internal.SocketUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingDeque; + +/** + * Http proxy server implementation with basic authentication support. The implementation is modified implementation of + * Netty's HttpProxyServer samples. + * + *
+ * - only HTTP 1.1 is supported
+ * - invalid requests are not challenged with a 407 Proxy Authentication Required response
+ * 
+ * + *
+ * io.netty.handler.proxy.ProxyServer
+ * io.netty.handler.proxy.HttpProxyServer
+ * 
+ */ +public class HttpProxyServer { + + private static final Logger LOGGER = LoggerFactory.getLogger(HttpProxyServer.class); + + private final ServerSocketChannel ch; + private final Deque recordedExceptions = new LinkedBlockingDeque<>(); + private final String username; + private final String password; + private final InetSocketAddress destination; + + public HttpProxyServer(int port, InetSocketAddress destination, String username, String password) { + this(null, port, destination, username, password); + } + + public HttpProxyServer(SslContext sslContext, int port, InetSocketAddress destination, String username, String password) { + this.destination = destination; + this.username = username; + this.password = password; + + ServerBootstrap b = new ServerBootstrap(); + b.channel(NioServerSocketChannel.class); + b.group(new MultiThreadIoEventLoopGroup(3, new DefaultThreadFactory("proxy", true), NioIoHandler.newFactory())); + b.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline p = ch.pipeline(); + + if (sslContext != null) { + p.addLast(sslContext.newHandler(ch.alloc())); + } + + configure(ch); + } + }); + + ch = (ServerSocketChannel) b.bind(NetUtil.LOCALHOST, port).syncUninterruptibly().channel(); + } + + public int getPort() { + return ch.localAddress().getPort(); + } + + public InetSocketAddress getDestination() { + return destination; + } + + public Deque getRecordedExceptions() { + return recordedExceptions; + } + + protected void configure(SocketChannel ch) { + ChannelPipeline p = ch.pipeline(); + + p.addLast(new HttpServerCodec()); + p.addLast(new HttpObjectAggregator(1)); + p.addLast(new HttpIntermediaryHandler()); + } + + @SuppressWarnings("BooleanMethodIsAlwaysInverted") + private boolean authenticate(ChannelHandlerContext ctx, FullHttpRequest req) { + if (!req.method().equals(HttpMethod.CONNECT)) { + throw new IllegalArgumentException("Only HTTP CONNECT method is supported"); + } + + ctx.pipeline().remove(HttpObjectAggregator.class); + ctx.pipeline().get(HttpServerCodec.class).removeInboundHandler(); + + boolean authzSuccess = false; + if (username != null) { + CharSequence authz = req.headers().get(HttpHeaderNames.PROXY_AUTHORIZATION); + if (authz != null) { + String[] authzParts = authz.toString().split(" ", 2); + ByteBuf authzBuf64 = Unpooled.copiedBuffer(authzParts[1], CharsetUtil.US_ASCII); + ByteBuf authzBuf = Base64.decode(authzBuf64); + + String expectedAuthz = username + ':' + password; + authzSuccess = "Basic".equals(authzParts[0]) && + expectedAuthz.equals(authzBuf.toString(CharsetUtil.US_ASCII)); + + authzBuf64.release(); + authzBuf.release(); + } + } else { + authzSuccess = true; + } + + return authzSuccess; + } + + private void recordException(Throwable t) { + LOGGER.warn("Unexpected exception from proxy server", t); + recordedExceptions.add(t); + } + + public void stop() { + ch.close(); + } + + protected abstract class IntermediaryHandler extends SimpleChannelInboundHandler { + + private final Queue received = new ArrayDeque<>(); + + private boolean finished; + private Channel backend; + + @Override + protected final void channelRead0(final ChannelHandlerContext ctx, Object msg) throws Exception { + if (finished) { + received.add(ReferenceCountUtil.retain(msg)); + flush(); + return; + } + + boolean finished = handleProxyProtocol(ctx, msg); + if (finished) { + this.finished = true; + ChannelFuture f = connectToDestination(ctx.channel().eventLoop(), new BackendHandler(ctx)); + f.addListener((ChannelFutureListener) future -> { + if (!future.isSuccess()) { + recordException(future.cause()); + ctx.close(); + } else { + backend = future.channel(); + flush(); + } + }); + } + } + + private void flush() { + if (backend != null) { + boolean wrote = false; + for (; ; ) { + Object msg = received.poll(); + if (msg == null) { + break; + } + backend.write(msg); + wrote = true; + } + + if (wrote) { + backend.flush(); + } + } + } + + protected abstract boolean handleProxyProtocol(ChannelHandlerContext ctx, Object msg) throws Exception; + + protected abstract SocketAddress intermediaryDestination(); + + private ChannelFuture connectToDestination(EventLoop loop, ChannelHandler handler) { + Bootstrap b = new Bootstrap(); + b.channel(NioSocketChannel.class); + b.group(loop); + b.handler(handler); + return b.connect(intermediaryDestination()); + } + + @Override + public final void channelReadComplete(ChannelHandlerContext ctx) { + ctx.flush(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + if (backend != null) { + backend.close(); + } + } + + @Override + public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + recordException(cause); + ctx.close(); + } + + private final class BackendHandler extends ChannelInboundHandlerAdapter { + + private final ChannelHandlerContext frontend; + + BackendHandler(ChannelHandlerContext frontend) { + this.frontend = frontend; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + frontend.write(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + frontend.flush(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + frontend.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + recordException(cause); + ctx.close(); + } + } + } + + private final class HttpIntermediaryHandler extends IntermediaryHandler { + + private SocketAddress intermediaryDestination; + + @Override + protected boolean handleProxyProtocol(ChannelHandlerContext ctx, Object msg) { + FullHttpRequest req = (FullHttpRequest) msg; + FullHttpResponse res; + if (!authenticate(ctx, req)) { + res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0); + } else { + res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + String uri = req.uri(); + int lastColonPos = uri.lastIndexOf(':'); + + if (lastColonPos <= 0) { + throw new IllegalArgumentException("Invalid URI: " + uri); + } + + intermediaryDestination = SocketUtils.socketAddress( + uri.substring(0, lastColonPos), Integer.parseInt(uri.substring(lastColonPos + 1))); + } + + System.out.println("Responding to proxy request with: " + res); + + ctx.write(res); + ctx.pipeline().get(HttpServerCodec.class).removeOutboundHandler(); + return true; + } + + @Override + protected SocketAddress intermediaryDestination() { + return intermediaryDestination; + } + } +} \ No newline at end of file diff --git a/quickfixj-core/src/test/java/quickfix/mina/HttpProxyTest.java b/quickfixj-core/src/test/java/quickfix/mina/HttpProxyTest.java new file mode 100644 index 0000000000..96884f92a7 --- /dev/null +++ b/quickfixj-core/src/test/java/quickfix/mina/HttpProxyTest.java @@ -0,0 +1,278 @@ +package quickfix.mina; + +import org.apache.mina.util.AvailablePortFinder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import quickfix.Acceptor; +import quickfix.ApplicationAdapter; +import quickfix.ConfigError; +import quickfix.DefaultMessageFactory; +import quickfix.FixVersions; +import quickfix.Initiator; +import quickfix.MemoryStoreFactory; +import quickfix.MessageFactory; +import quickfix.MessageStoreFactory; +import quickfix.Session; +import quickfix.SessionFactory; +import quickfix.SessionID; +import quickfix.SessionSettings; +import quickfix.ThreadedSocketAcceptor; +import quickfix.ThreadedSocketInitiator; +import quickfix.mina.ssl.SSLSupport; +import quickfix.test.util.SSLUtil; +import quickfix.test.util.SessionUtil; + +import java.math.BigInteger; +import java.net.InetSocketAddress; +import java.util.HashMap; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Performs end-to-end tests against HTTP proxy server. + */ +public class HttpProxyTest { + + private static final String PROXY_USERNAME = "proxy-username"; + private static final String PROXY_PASSWORD = "proxy-password"; + + private HttpProxyServer proxyServer; + + @Before + public void setUp() { + InetSocketAddress destination = new InetSocketAddress("127.0.0.1", AvailablePortFinder.getNextAvailable()); + proxyServer = new HttpProxyServer(AvailablePortFinder.getNextAvailable(), destination, PROXY_USERNAME, PROXY_PASSWORD); + } + + @After + public void tearDown() { + if (proxyServer != null) { + proxyServer.stop(); + } + } + + @Test + public void shouldLoginBasicAuth() throws ConfigError { + int port = proxyServer.getDestination().getPort(); + SessionConnector acceptor = createAcceptor(port, false, null); + + try { + acceptor.start(); + + SessionConnector initiator = createInitiator(proxyServer.getPort(), port, PROXY_USERNAME, PROXY_PASSWORD, false, null); + + try { + initiator.start(); + SessionUtil.assertLoggedOn(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB")); + SessionUtil.assertLoggedOn(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE")); + SSLUtil.assertNotAuthenticated(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB"), false); + SSLUtil.assertNotAuthenticated(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE"), false); + assertTrue(proxyServer.getRecordedExceptions().isEmpty()); + } finally { + initiator.stop(); + } + } finally { + acceptor.stop(); + } + } + + @Test + public void shouldLoginBasicAuthWithSsl() throws ConfigError { + int port = proxyServer.getDestination().getPort(); + SessionConnector acceptor = createAcceptor(port, true, "single-session/server.keystore"); + + try { + acceptor.start(); + + SessionConnector initiator = createInitiator(proxyServer.getPort(), port, PROXY_USERNAME, PROXY_PASSWORD, true, "single-session/client.truststore"); + + try { + initiator.start(); + SessionUtil.assertLoggedOn(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB")); + SessionUtil.assertLoggedOn(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE")); + SSLUtil.assertNotAuthenticated(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB"), false); + SSLUtil.assertAuthenticated(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE"), new BigInteger("1448538842")); + assertTrue(proxyServer.getRecordedExceptions().isEmpty()); + } finally { + initiator.stop(); + } + } finally { + acceptor.stop(); + } + } + + @Test + public void shouldFailLoginBasicAuthWhenServerIsUntrusted() throws ConfigError { + int port = proxyServer.getDestination().getPort(); + SessionConnector acceptor = createAcceptor(port, true, "single-session/server.keystore"); + + try { + acceptor.start(); + + SessionConnector initiator = createInitiator(proxyServer.getPort(), port, PROXY_USERNAME, PROXY_PASSWORD, true, "single-session/empty.keystore"); + + try { + initiator.start(); + SessionUtil.assertNotLoggedOn(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB")); + SessionUtil.assertNotLoggedOn(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE")); + SSLUtil.assertNotAuthenticated(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB"), true); + SSLUtil.assertNotAuthenticated(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE"), true); + assertTrue(proxyServer.getRecordedExceptions().isEmpty()); + } finally { + initiator.stop(); + } + } finally { + acceptor.stop(); + } + } + + @Test + public void shouldFailBasicAuthWhenInvalidCredentials() throws ConfigError { + int port = proxyServer.getDestination().getPort(); + SessionConnector acceptor = createAcceptor(port, false, null); + + try { + acceptor.start(); + + SessionConnector initiator = createInitiator(proxyServer.getPort(), port, "a", "b", false, null); + + try { + initiator.start(); + SessionUtil.assertNotLoggedOn(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB")); + SessionUtil.assertNotLoggedOn(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE")); + SSLUtil.assertNotAuthenticated(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB"), false); + SSLUtil.assertNotAuthenticated(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE"), false); + assertFalse(proxyServer.getRecordedExceptions().isEmpty()); + } finally { + initiator.stop(); + } + } finally { + acceptor.stop(); + } + } + + @Test + public void shouldFailBasicAuthWhenInvalidCredentialsWithSsl() throws ConfigError { + int port = proxyServer.getDestination().getPort(); + SessionConnector acceptor = createAcceptor(port, true, "single-session/server.keystore"); + + try { + acceptor.start(); + + SessionConnector initiator = createInitiator(proxyServer.getPort(), port, "a", "b", true, "single-session/client.truststore"); + + try { + initiator.start(); + SessionUtil.assertNotLoggedOn(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB")); + SessionUtil.assertNotLoggedOn(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE")); + SSLUtil.assertNotAuthenticated(acceptor, new SessionID(FixVersions.BEGINSTRING_FIX44, "ALICE", "BOB"), true); + SSLUtil.assertNotAuthenticated(initiator, new SessionID(FixVersions.BEGINSTRING_FIX44, "BOB", "ALICE"), true); + assertFalse(proxyServer.getRecordedExceptions().isEmpty()); + } finally { + initiator.stop(); + } + } finally { + acceptor.stop(); + } + } + + private SessionConnector createAcceptor(int port, boolean useSsl, String keyStoreName) throws ConfigError { + MessageStoreFactory messageStoreFactory = new MemoryStoreFactory(); + MessageFactory messageFactory = new DefaultMessageFactory(); + SessionSettings acceptorSettings = createAcceptorSettings("ALICE", "BOB", port, useSsl, keyStoreName); + return new ThreadedSocketAcceptor(new ApplicationAdapter(), messageStoreFactory, acceptorSettings, messageFactory); + } + + private SessionConnector createInitiator(int proxyPort, int port, String proxyUsername, String proxyPassword, + boolean useSsl, String trustStoreName) throws ConfigError { + MessageStoreFactory messageStoreFactory = new MemoryStoreFactory(); + MessageFactory messageFactory = new DefaultMessageFactory(); + SessionSettings initiatorSettings = createInitiatorSettings("BOB", "ALICE", proxyPort, port, proxyUsername, proxyPassword, useSsl, trustStoreName); + return new ThreadedSocketInitiator(new ApplicationAdapter(), messageStoreFactory, initiatorSettings, messageFactory); + } + + private SessionSettings createAcceptorSettings(String senderId, String targetId, int port, boolean useSsl, String keyStoreName) { + HashMap defaults = new HashMap<>(); + defaults.put(SessionFactory.SETTING_CONNECTION_TYPE, "acceptor"); + defaults.put(Acceptor.SETTING_SOCKET_ACCEPT_PORT, Integer.toString(port)); + defaults.put(Session.SETTING_START_TIME, "00:00:00"); + defaults.put(Session.SETTING_END_TIME, "00:00:00"); + defaults.put(Session.SETTING_HEARTBTINT, "30"); + + if (useSsl) { + defaults.put(SSLSupport.SETTING_USE_SSL, "Y"); + } else { + defaults.put(SSLSupport.SETTING_USE_SSL, "N"); + } + + if (keyStoreName != null) { + defaults.put(SSLSupport.SETTING_KEY_STORE_NAME, keyStoreName); + defaults.put(SSLSupport.SETTING_KEY_STORE_PWD, "password"); + defaults.put(SSLSupport.SETTING_KEY_STORE_TYPE, "JCEKS"); + + } + + SessionID sessionID = new SessionID(FixVersions.BEGINSTRING_FIX44, senderId, targetId); + + SessionSettings sessionSettings = new SessionSettings(); + sessionSettings.set(defaults); + sessionSettings.setString(sessionID, "BeginString", FixVersions.BEGINSTRING_FIX44); + sessionSettings.setString(sessionID, "DataDictionary", "FIX44.xml"); + sessionSettings.setString(sessionID, "SenderCompID", senderId); + sessionSettings.setString(sessionID, "TargetCompID", targetId); + + return sessionSettings; + } + + private SessionSettings createInitiatorSettings(String senderId, String targetId, int proxyPort, int port, + String proxyUsername, String proxyPassword, + boolean useSsl, String trustStoreName) { + HashMap defaults = new HashMap<>(); + defaults.put(SessionFactory.SETTING_CONNECTION_TYPE, "initiator"); + defaults.put(Initiator.SETTING_SOCKET_CONNECT_PROTOCOL, ProtocolFactory.getTypeString(ProtocolFactory.SOCKET)); + defaults.put(Initiator.SETTING_SOCKET_CONNECT_HOST, "localhost"); + defaults.put(Initiator.SETTING_SOCKET_CONNECT_PORT, Integer.toString(port)); + defaults.put(Initiator.SETTING_RECONNECT_INTERVAL, "2"); + defaults.put(Initiator.SETTING_PROXY_HOST, "localhost"); + defaults.put(Initiator.SETTING_PROXY_PORT, Integer.toString(proxyPort)); + defaults.put(Initiator.SETTING_PROXY_TYPE, "http"); + defaults.put(Initiator.SETTING_PROXY_VERSION, "1.1"); + + if (proxyUsername != null) { + defaults.put(Initiator.SETTING_PROXY_USER, proxyUsername); + } + + if (proxyPassword != null) { + defaults.put(Initiator.SETTING_PROXY_PASSWORD, proxyPassword); + } + + if (useSsl) { + defaults.put(SSLSupport.SETTING_USE_SSL, "Y"); + } else { + defaults.put(SSLSupport.SETTING_USE_SSL, "N"); + } + + if (trustStoreName != null) { + defaults.put(SSLSupport.SETTING_TRUST_STORE_NAME, trustStoreName); + defaults.put(SSLSupport.SETTING_TRUST_STORE_PWD, "password"); + defaults.put(SSLSupport.SETTING_TRUST_STORE_TYPE, "JCEKS"); + } + + defaults.put(Session.SETTING_START_TIME, "00:00:00"); + defaults.put(Session.SETTING_END_TIME, "00:00:00"); + defaults.put(Session.SETTING_HEARTBTINT, "30"); + + SessionID sessionID = new SessionID(FixVersions.BEGINSTRING_FIX44, senderId, targetId); + + SessionSettings sessionSettings = new SessionSettings(); + sessionSettings.set(defaults); + sessionSettings.setString(sessionID, "BeginString", FixVersions.BEGINSTRING_FIX44); + sessionSettings.setString(sessionID, "DataDictionary", "FIX44.xml"); + sessionSettings.setString(sessionID, "SenderCompID", senderId); + sessionSettings.setString(sessionID, "TargetCompID", targetId); + + return sessionSettings; + } +} diff --git a/quickfixj-core/src/test/java/quickfix/mina/SocksProxyTest.java b/quickfixj-core/src/test/java/quickfix/mina/SocksProxyTest.java index 1403da9bd7..88459bb599 100644 --- a/quickfixj-core/src/test/java/quickfix/mina/SocksProxyTest.java +++ b/quickfixj-core/src/test/java/quickfix/mina/SocksProxyTest.java @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit; /** - * Performs end to end tests against SOCKS proxy server. + * Performs end-to-end tests against SOCKS proxy server. */ public class SocksProxyTest { @@ -43,7 +43,9 @@ public void setUp() { @After public void tearDown() { - proxyServer.stop(); + if (proxyServer != null) { + proxyServer.stop(); + } } @Test @@ -173,5 +175,4 @@ private SessionSettings createInitiatorSettings(String senderId, String targetId return sessionSettings; } - } diff --git a/quickfixj-core/src/test/java/quickfix/mina/ssl/SSLCertificateTest.java b/quickfixj-core/src/test/java/quickfix/mina/ssl/SSLCertificateTest.java index ee5e13fb93..9df4907f7e 100644 --- a/quickfixj-core/src/test/java/quickfix/mina/ssl/SSLCertificateTest.java +++ b/quickfixj-core/src/test/java/quickfix/mina/ssl/SSLCertificateTest.java @@ -21,9 +21,11 @@ import org.apache.mina.core.filterchain.IoFilterAdapter; import org.apache.mina.core.session.IoSession; +import org.apache.mina.util.AvailablePortFinder; import org.burningwave.tools.net.DefaultHostResolver; import org.burningwave.tools.net.HostResolutionRequestInterceptor; import org.burningwave.tools.net.MappedHostResolver; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -49,6 +51,9 @@ import quickfix.ThreadedSocketInitiator; import quickfix.mina.ProtocolFactory; import quickfix.mina.SessionConnector; +import quickfix.mina.SocksProxyServer; +import quickfix.test.util.SSLUtil; +import quickfix.test.util.SessionUtil; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; @@ -57,7 +62,6 @@ import java.math.BigInteger; import java.security.Principal; import java.security.cert.Certificate; -import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -66,10 +70,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import org.apache.mina.util.AvailablePortFinder; -import org.junit.After; -import quickfix.mina.SocksProxyServer; -import quickfix.test.util.SSLUtil; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; @@ -1036,121 +1036,24 @@ public void exceptionCaught(NextFilter nextFilter, IoSession session, Throwable public abstract SessionConnector createConnector(SessionSettings sessionSettings) throws ConfigError; - private Session findSession(SessionID sessionID) { - for (Session session : connector.getManagedSessions()) { - if (session.getSessionID().equals(sessionID)) - return session; - } - - return null; - } - public void assertAuthenticated(SessionID sessionID, BigInteger serialNumber) { - Session session = findSession(sessionID); - SSLSession sslSession = SSLUtil.findSSLSession(session); - - if (sslSession == null) { - throw new AssertionError("No SSL session found: " + sessionID); - } - - Certificate[] peerCertificates = SSLUtil.getPeerCertificates(sslSession); - - if (peerCertificates == null || peerCertificates.length == 0) { - throw new AssertionError("Session was not authenticated: " + sslSession); - } - - for (Certificate peerCertificate : peerCertificates) { - if (!(peerCertificate instanceof X509Certificate)) { - continue; - } - - if (((X509Certificate)peerCertificate).getSerialNumber().compareTo(serialNumber) == 0) { - return; - } - } - - throw new AssertionError("Certificate with serial number " + serialNumber + " was not authenticated"); + SSLUtil.assertAuthenticated(connector, sessionID, serialNumber); } public void assertNotAuthenticated(SessionID sessionID) { assertNotAuthenticated(sessionID, true); } - /** - * Asserts that the session associated with the given {@code sessionID} is not authenticated. - * The behavior of this method depends on the {@code authOn} parameter: - * - *
    - *
  • If {@code authOn} is {@code true}, the method checks if the SSL session associated - * with the given session ID is still alive. If the SSL session persists beyond the - * specified timeout period, an {@link AssertionError} is thrown.
  • - *
  • If {@code authOn} is {@code false}, the method checks if there are any peer certificates - * associated with the SSL session. If peer certificates are found, an {@link AssertionError} - * is thrown, indicating that the session was authenticated.
  • - *
- * - * @param sessionID the session ID to check for authentication status - * @param authOn a flag indicating whether authentication is currently enabled - * @throws AssertionError if the session is still authenticated after the timeout period - * (when {@code authOn} is {@code true}) or if peer certificates are found - * (when {@code authOn} is {@code false}) - */ public void assertNotAuthenticated(SessionID sessionID, boolean authOn) { - Session session = findSession(sessionID); - SSLSession sslSession = SSLUtil.findSSLSession(session); - - if (sslSession == null) { - return; - } - - if (authOn) { - // when authentication is on, the SSL session maybe still be alive (invalid) for some time - long startTime = System.nanoTime(); - - while (SSLUtil.findSSLSession(session) != null) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Thread was interrupted", e); - } - - if (TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - startTime) >= TIMEOUT_SECONDS) { - throw new AssertionError("SSL session still exists for session: " + sessionID); - } - } - } else { - // when authentication is off, there must be no peer certificates - Certificate[] peerCertificates = SSLUtil.getPeerCertificates(sslSession); - - if (peerCertificates != null && peerCertificates.length > 0) { - throw new AssertionError("Certificate was authenticated"); - } - } + SSLUtil.assertNotAuthenticated(connector, sessionID, authOn); } public void assertLoggedOn(SessionID sessionID) { - Session session = findSession(sessionID); - - if (session == null) { - throw new AssertionError("No session found: " + sessionID); - } - - if (!session.isLoggedOn()) { - throw new AssertionError("Session is not logged on: " + session); - } + SessionUtil.assertLoggedOn(connector, sessionID); } public void assertNotLoggedOn(SessionID sessionID) { - Session session = findSession(sessionID); - - if (session == null) { - throw new AssertionError("No session found: " + sessionID); - } - - if (session.isLoggedOn()) { - throw new AssertionError("Session is logged on: " + session); - } + SessionUtil.assertNotLoggedOn(connector, sessionID); } public void assertSslExceptionThrown() throws Exception { @@ -1186,7 +1089,7 @@ public void assertNoSslExceptionThrown() throws Exception { } public void assertSNIHostName(SessionID sessionID, String expectedSniHostName) { - Session session = findSession(sessionID); + Session session = SessionUtil.findSession(connector, sessionID); SSLSession sslSession = SSLUtil.findSSLSession(session); if (sslSession == null) { @@ -1203,7 +1106,7 @@ public void assertSNIHostName(SessionID sessionID, String expectedSniHostName) { } public void assertNoSNIHostName(SessionID sessionID) { - Session session = findSession(sessionID); + Session session = SessionUtil.findSession(connector, sessionID); SSLSession sslSession = SSLUtil.findSSLSession(session); if (sslSession == null) { @@ -1236,7 +1139,7 @@ private void logSSLInfo() { LOGGER.info("All session IDs: {}", sessionsIDs); for (SessionID sessionID : sessionsIDs) { - Session session = findSession(sessionID); + Session session = SessionUtil.findSession(connector, sessionID); if (session == null) { LOGGER.info("No session found for ID: {}", sessionID); diff --git a/quickfixj-core/src/test/java/quickfix/test/util/SSLUtil.java b/quickfixj-core/src/test/java/quickfix/test/util/SSLUtil.java index bbc59537c6..093c60e4cc 100644 --- a/quickfixj-core/src/test/java/quickfix/test/util/SSLUtil.java +++ b/quickfixj-core/src/test/java/quickfix/test/util/SSLUtil.java @@ -6,7 +6,9 @@ import org.apache.mina.filter.ssl.SslFilter; import org.apache.mina.filter.ssl.SslHandler; import quickfix.Session; +import quickfix.SessionID; import quickfix.mina.IoSessionResponder; +import quickfix.mina.SessionConnector; import quickfix.mina.ssl.SSLSupport; import javax.net.ssl.ExtendedSSLSession; @@ -15,17 +17,19 @@ import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import java.lang.reflect.Field; +import java.math.BigInteger; import java.net.IDN; -import java.net.InetAddress; -import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.security.Principal; import java.security.cert.Certificate; +import java.security.cert.X509Certificate; import java.util.List; +import java.util.concurrent.TimeUnit; /** - * A utility class for working with SSL/TLS sessions and retrieving SSL-related information from a {@link Session}. - * This class provides methods to find the underlying {@link SSLSession}, retrieve peer certificates, and get the peer principal etc. + * A utility class for working with SSL/TLS sessions and retrieving SSL-related information from a {@link Session}. This + * class provides methods to find the underlying {@link SSLSession}, retrieve peer certificates, and get the peer + * principal etc. */ public final class SSLUtil { @@ -35,6 +39,7 @@ public final class SSLUtil { private static final Field IO_SESSION_FIELD; private static final Field SSL_ENGINE_FIELD; private static final int SNI_HOST_NAME_TYPE = 0; + private static final long DEFAULT_TIMEOUT_SECONDS = 5L; static { try { @@ -99,14 +104,14 @@ public static IoSession findIoSession(Session session) { } /** - * Retrieves the {@link SslHandler} associated with the given {@link Session}. - * This method first finds the corresponding {@link IoSession} for the provided session, - * then retrieves the {@link SslFilter} from the session's filter chain. - * If the filter is found, it returns the {@link SslHandler} stored as an attribute in the {@link IoSession}. + * Retrieves the {@link SslHandler} associated with the given {@link Session}. This method first finds the + * corresponding {@link IoSession} for the provided session, then retrieves the {@link SslFilter} from the session's + * filter chain. If the filter is found, it returns the {@link SslHandler} stored as an attribute in the + * {@link IoSession}. * * @param session The session for which to retrieve the {@link SslHandler}. - * @return The {@link SslHandler} associated with the session, or {@code null} if either - * the {@link IoSession} or the {@link SslFilter} is not found. + * @return The {@link SslHandler} associated with the session, or {@code null} if either the {@link IoSession} or + * the {@link SslFilter} is not found. */ public static SslHandler getSSLHandler(Session session) { IoSession ioSession = findIoSession(session); @@ -126,14 +131,13 @@ public static SslHandler getSSLHandler(Session session) { } /** - * Retrieves the {@link SSLEngine} associated with the given {@link Session}. - * This method first retrieves the {@link SslHandler} using {@link #getSSLHandler(Session)}, - * and then attempts to access the {@link SSLEngine} stored within the {@link SslHandler} - * using reflection. + * Retrieves the {@link SSLEngine} associated with the given {@link Session}. This method first retrieves the + * {@link SslHandler} using {@link #getSSLHandler(Session)}, and then attempts to access the {@link SSLEngine} + * stored within the {@link SslHandler} using reflection. * * @param session The session for which to retrieve the {@link SSLEngine}. - * @return The {@link SSLEngine} associated with the session, or {@code null} if the - * {@link SslHandler} is not found. + * @return The {@link SSLEngine} associated with the session, or {@code null} if the {@link SslHandler} is not + * found. */ public static SSLEngine getSSLEngine(Session session) { SslHandler sslHandler = getSSLHandler(session); @@ -153,7 +157,8 @@ public static SSLEngine getSSLEngine(Session session) { * Retrieves the peer certificates from the given {@link SSLSession}. * * @param sslSession the SSL session from which to retrieve the peer certificates. - * @return an array of {@link Certificate} objects representing the peer certificates, or {@code null} if the peer is unverified. + * @return an array of {@link Certificate} objects representing the peer certificates, or {@code null} if the peer + * is unverified. */ public static Certificate[] getPeerCertificates(SSLSession sslSession) { try { @@ -200,4 +205,171 @@ public static String getSniHostName(SSLSession sslSession) { return null; } + + /** + * Checks if the session associated with the given {@code sessionID} is authenticated with a certificate matching + * the specified serial number. + * + * @param connector the {@link SessionConnector} used to retrieve the session information + * @param sessionID the session ID to check for authentication status + * @param certificateSerialNumber the expected serial number of the peer certificate + * @return {@code true} if the session is authenticated with a certificate matching the specified serial number, + * {@code false} otherwise + */ + public static boolean isAuthenticated(SessionConnector connector, SessionID sessionID, BigInteger certificateSerialNumber) { + Session session = SessionUtil.findSession(connector, sessionID); + SSLSession sslSession = SSLUtil.findSSLSession(session); + + if (sslSession == null) { + return false; + } + + Certificate[] peerCertificates = SSLUtil.getPeerCertificates(sslSession); + + if (peerCertificates == null || peerCertificates.length == 0) { + return false; + } + + for (Certificate peerCertificate : peerCertificates) { + if (!(peerCertificate instanceof X509Certificate)) { + continue; + } + + if (((X509Certificate) peerCertificate).getSerialNumber().compareTo(certificateSerialNumber) == 0) { + return true; + } + } + + return false; + } + + /** + * Asserts that the session associated with the given {@code sessionID} is authenticated with a certificate matching + * the specified serial number. + * + * @param connector the {@link SessionConnector} used to retrieve the session information + * @param sessionID the session ID to check for authentication status + * @param serialNumber the expected serial number of the peer certificate + * @throws AssertionError if the certificate with the specified serial number is not authenticated within the + * default timeout period + */ + public static void assertAuthenticated(SessionConnector connector, SessionID sessionID, BigInteger serialNumber) { + assertAuthenticated(connector, sessionID, serialNumber, DEFAULT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } + + /** + * Asserts that the session associated with the given {@code sessionID} is authenticated with a certificate matching + * the specified serial number. This method polls the session until either the certificate is authenticated or the + * specified timeout is reached. + * + * @param connector the {@link SessionConnector} used to retrieve the session information + * @param sessionID the session ID to check for authentication status + * @param serialNumber the expected serial number of the peer certificate + * @param timeout the maximum time to wait for authentication + * @param unit the time unit of the {@code timeout} parameter + * @throws AssertionError if the certificate with the specified serial number is not authenticated within the + * specified timeout period + */ + public static void assertAuthenticated( + SessionConnector connector, SessionID sessionID, BigInteger serialNumber, + long timeout, TimeUnit unit) { + long deadlineNs = System.nanoTime() + unit.toNanos(timeout); + + while (System.nanoTime() < deadlineNs) { + if (isAuthenticated(connector, sessionID, serialNumber)) { + return; + } + + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Thread was interrupted", e); + } + } + + throw new AssertionError("Certificate with serial number " + serialNumber + " was not authenticated"); + } + + /** + * Asserts that the session associated with the given {@code sessionID} is not authenticated. The behavior of this + * method depends on the {@code authOn} parameter: + * + *
    + *
  • If {@code authOn} is {@code true}, the method checks if the SSL session associated + * with the given session ID is still alive. If the SSL session persists beyond the + * specified timeout period, an {@link AssertionError} is thrown.
  • + *
  • If {@code authOn} is {@code false}, the method checks if there are any peer certificates + * associated with the SSL session. If peer certificates are found, an {@link AssertionError} + * is thrown, indicating that the session was authenticated.
  • + *
+ * + * @param connector the {@link SessionConnector} used to retrieve the session information + * @param sessionID the session ID to check for authentication status + * @param authOn a flag indicating whether authentication is currently enabled + * @throws AssertionError if the session is still authenticated after the timeout period (when {@code authOn} is + * {@code true}) or if peer certificates are found (when {@code authOn} is {@code false}) + */ + public static void assertNotAuthenticated(SessionConnector connector, SessionID sessionID, boolean authOn) { + assertNotAuthenticated(connector, sessionID, authOn, DEFAULT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } + + /** + * Asserts that the session associated with the given {@code sessionID} is not authenticated. The behavior of this + * method depends on the {@code authOn} parameter: + * + *
    + *
  • If {@code authOn} is {@code true}, the method checks if the SSL session associated + * with the given session ID is still alive. If the SSL session persists beyond the + * specified timeout period, an {@link AssertionError} is thrown.
  • + *
  • If {@code authOn} is {@code false}, the method checks if there are any peer certificates + * associated with the SSL session. If peer certificates are found, an {@link AssertionError} + * is thrown, indicating that the session was authenticated.
  • + *
+ * + * @param connector the {@link SessionConnector} used to retrieve the session information + * @param sessionID the session ID to check for authentication status + * @param authOn a flag indicating whether authentication is currently enabled + * @param timeout the maximum time to wait for the session to become unauthenticated + * @param unit the time unit of the {@code timeout} parameter + * @throws AssertionError if the session is still authenticated after the timeout period (when {@code authOn} is + * {@code true}) or if peer certificates are found (when {@code authOn} is {@code false}) + */ + public static void assertNotAuthenticated( + SessionConnector connector, SessionID sessionID, boolean authOn, + long timeout, TimeUnit unit) { + Session session = SessionUtil.findSession(connector, sessionID); + SSLSession sslSession = findSSLSession(session); + + if (sslSession == null) { + return; + } + + if (authOn) { + long deadlineNs = System.nanoTime() + unit.toNanos(timeout); + + // when authentication is on, the SSL session maybe still be alive (invalid) for some time + while (System.nanoTime() < deadlineNs) { + if (findSSLSession(session) == null) { + return; + } + + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Thread was interrupted", e); + } + } + + throw new AssertionError("SSL session still exists for session: " + sessionID); + } else { + // when authentication is off, there must be no peer certificates + Certificate[] peerCertificates = getPeerCertificates(sslSession); + + if (peerCertificates != null && peerCertificates.length > 0) { + throw new AssertionError("Certificate was authenticated"); + } + } + } } diff --git a/quickfixj-core/src/test/java/quickfix/test/util/SessionUtil.java b/quickfixj-core/src/test/java/quickfix/test/util/SessionUtil.java new file mode 100644 index 0000000000..33ac64c49c --- /dev/null +++ b/quickfixj-core/src/test/java/quickfix/test/util/SessionUtil.java @@ -0,0 +1,133 @@ +package quickfix.test.util; + +import quickfix.Session; +import quickfix.SessionID; +import quickfix.mina.SessionConnector; + +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * Utility class for managing and asserting FIX session states in tests. + */ +public final class SessionUtil { + + private static final long DEFAULT_TIMEOUT_SECONDS = 5L; + + private SessionUtil() { + } + + /** + * Finds a FIX session by its session identifier. + * + * @param connector the session connector that manages sessions + * @param sessionID the target session identifier + * @return the matching {@link Session}, or {@code null} if no managed session matches + */ + public static Session findSession(SessionConnector connector, SessionID sessionID) { + List managedSessions = connector.getManagedSessions(); + + for (Session session : managedSessions) { + if (session.getSessionID().equals(sessionID)) { + return session; + } + } + + return null; + } + + /** + * Checks if a session is currently logged on. + * + * @param connector the session connector + * @param sessionID the session identifier + * @return true if the session is logged on, false otherwise + */ + public static boolean isLoggedOn(SessionConnector connector, SessionID sessionID) { + Session session = findSession(connector, sessionID); + + if (session == null) { + return false; + } + + return session.isLoggedOn(); + } + + /** + * Asserts that a session is logged on within the default timeout period. + * + * @param connector the session connector + * @param sessionID the session identifier + * @throws AssertionError if the session is not logged on within the timeout + */ + public static void assertLoggedOn(SessionConnector connector, SessionID sessionID) { + assertLoggedOn(connector, sessionID, DEFAULT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } + + /** + * Asserts that a session is logged on within the specified timeout period. + * + * @param connector the session connector + * @param sessionID the session identifier + * @param timeout the timeout duration + * @param unit the timeout unit + * @throws AssertionError if the session is not logged on within the timeout + * @throws RuntimeException if interrupted while waiting + */ + public static void assertLoggedOn(SessionConnector connector, SessionID sessionID, long timeout, TimeUnit unit) { + long deadlineNs = System.nanoTime() + unit.toNanos(timeout); + + while (System.nanoTime() < deadlineNs) { + if (isLoggedOn(connector, sessionID)) { + return; + } + + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted", e); + } + } + + throw new AssertionError("Session " + sessionID + " is not logged on"); + } + + /** + * Asserts that a session is not logged on within the default timeout period. + * + * @param connector the session connector + * @param sessionID the session identifier + * @throws AssertionError if the session is logged on within the timeout + */ + public static void assertNotLoggedOn(SessionConnector connector, SessionID sessionID) { + assertNotLoggedOn(connector, sessionID, DEFAULT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } + + /** + * Asserts that a session is not logged on within the specified timeout period. + * + * @param connector the session connector + * @param sessionID the session identifier + * @param timeout the timeout duration + * @param unit the timeout unit + * @throws AssertionError if the session is logged on within the timeout + * @throws RuntimeException if interrupted while waiting + */ + public static void assertNotLoggedOn(SessionConnector connector, SessionID sessionID, long timeout, TimeUnit unit) { + long deadlineNs = System.nanoTime() + unit.toNanos(timeout); + + while (System.nanoTime() < deadlineNs) { + if (isLoggedOn(connector, sessionID)) { + throw new AssertionError("Session " + sessionID + " is logged on"); + } + + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted", e); + } + } + } +}