Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ jspecify = "1.0.0"
mockito = "5.20.0"
# https://mvnrepository.com/artifact/org.assertj/assertj-core
assertj-core = "4.0.0-M1"
# https://mvnrepository.com/artifact/org.awaitility/awaitility
awaitility = "4.3.0"

[libraries]
project-reactor-core = { module = "io.projectreactor:reactor-core", version.ref = "project-reactor" }
Expand All @@ -37,6 +39,7 @@ testcontainers = { module = "org.testcontainers:testcontainers", version.ref = "
mockito-core = { module = "org.mockito:mockito-core", version.ref = "mockito" }
mockito-junit-jupiter = { module = "org.mockito:mockito-junit-jupiter", version.ref = "mockito" }
assertj-core = { module = "org.assertj:assertj-core", version.ref = "assertj-core" }
awaitility = { module = "org.awaitility:awaitility", version.ref = "awaitility" }

[bundles]
mail = ["jakarta-mail-api", "angus-mail"]

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions rlib-network/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ dependencies {
testRuntimeOnly projects.rlibLoggerImpl
loadTestRuntimeOnly projects.rlibLoggerImpl
testImplementation testFixtures(projects.rlibCommon)
testImplementation libs.awaitility
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) {
case NEED_WRAP: {
log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted);
packetWriter.accept(SslWrapRequestNetworkPacket.getInstance());
if (networkBuffer.hasRemaining()) {
return decryptAndRead(networkBuffer);
}
NetworkUtils.cleanNetworkBuffer(networkBuffer);
return SKIP_READ_PACKETS;
}
Expand Down Expand Up @@ -204,6 +207,11 @@ protected int decryptAndRead(ByteBuffer receivedBuffer) {
}
switch (result.getStatus()) {
case OK: {
if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) {
log.debug(remoteAddress, "[%s] No progress during decryption, skip read packets"::formatted);
NetworkUtils.cleanNetworkBuffer(receivedBuffer);
return SKIP_READ_PACKETS;
}
sslDataBuffer.flip();
logDataAfterDecrypt(remoteAddress, sslDataBuffer);
total += readPackets(sslDataBuffer, sslDataPendingBuffer);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package javasabr.rlib.network;

import static java.util.function.Predicate.isEqual;
import static javasabr.rlib.network.util.NetworkUtils.createAllTrustedClientSslContext;
import static javasabr.rlib.network.util.NetworkUtils.createSslContext;
import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;

import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javasabr.rlib.common.util.AwaitUtils;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.impl.AbstractConnection;
import javasabr.rlib.network.impl.DefaultConnection;
Expand Down Expand Up @@ -80,14 +81,16 @@ void shouldCloseServerConnectionWhenClientClosesTcpChannelAbruptly() {

// when
clientConnection.channel().close();
assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, clientConnection::closed))
.as("Client connection should be closed prior server side verification")
.isTrue();

// then
assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, serverConnection::closed))
.as("Server connection should be closed after receiving EOF from abruptly closed client channel")
.isTrue();
await()
.alias("Client connection should be closed prior server side verification")
.atMost(5, TimeUnit.SECONDS)
.until(clientConnection::closed, isEqual(true));
await()
.alias("Server connection should be closed after receiving EOF from abruptly closed client channel")
.atMost(5, TimeUnit.SECONDS)
.until(serverConnection::closed, isEqual(true));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package javasabr.rlib.network.packet.impl;

import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.function.Consumer;
import javasabr.rlib.network.BufferAllocator;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.NetworkConfig;
import javasabr.rlib.network.UnsafeConnection;
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
import javasabr.rlib.network.packet.WritableNetworkPacket;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLSession;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;

/**
* The tests of SSL packet reader
*
* @author crazyrokr
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
class AbstractSslNetworkPacketReaderTest {

@Mock
private TestConnection connection;

@Mock
private Network<TestConnection> network;

@Mock
private SSLEngine sslEngine;

@Mock
private SSLSession sslSession;

@Mock
private Consumer<ReadableNetworkPacket<TestConnection>> packetHandler;

@Mock
private Consumer<WritableNetworkPacket<TestConnection>> packetWriter;

DefaultSslNetworkPacketReader<ReadableNetworkPacket<TestConnection>, TestConnection> reader;

private BufferAllocator bufferAllocator;

private interface TestConnection extends UnsafeConnection<TestConnection> {}

@BeforeEach
void setUp() {
bufferAllocator = new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT);
when(connection.bufferAllocator()).thenReturn(bufferAllocator);
when(connection.network()).thenReturn((Network) network);
when(connection.remoteAddress()).thenReturn("test-address");
when(network.config()).thenReturn(NetworkConfig.DEFAULT_CLIENT);
when(sslEngine.getSession()).thenReturn(sslSession);
when(sslSession.getApplicationBufferSize()).thenReturn(1024);
when(sslSession.getPacketBufferSize()).thenReturn(1024);
reader = spy(new DefaultSslNetworkPacketReader<ReadableNetworkPacket<TestConnection>, TestConnection>(
connection,
() -> {},
packetHandler,
packetHandler,
len -> mock(ReadableNetworkPacket.class),
sslEngine,
packetWriter,
1,
100) {
});
doReturn(1).when(reader).readFullPacketLength(any(ByteBuffer.class));
}

@Test
void shouldNotLoseDataOnNeedWrapDuringHandshake() throws Exception {
// given
// Initial state: NEED_UNWRAP
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_UNWRAP);

// First unwrap will result in NEED_WRAP and status OK, consuming some data.
// Simulate a single network buffer containing 5 bytes of handshake data followed by
// 5 bytes of application data, so the remaining bytes can still be processed afterward.
ByteBuffer networkData = ByteBuffer.allocate(10);
networkData.put(new byte[10]);
networkData.flip();

// doHandshake calls unwrap in NEED_UNWRAP, consumes first 5 bytes, then returns OK
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer[].class))).thenAnswer(invocation -> {
ByteBuffer in = invocation.getArgument(0);
in.position(in.position() + 5); // consume 5 bytes of handshake
// Change status to NEED_WRAP for next getHandshakeStatus() call
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP);
return new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 5, 0);
});

// decryptAndRead calls unwrap, consumes the remaining 5 bytes, then return FINISHED or NOT_HANDSHAKING
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenAnswer(invocation -> {
ByteBuffer in = invocation.getArgument(0);
ByteBuffer out = invocation.getArgument(1);
int remaining = in.remaining();
in.position(in.limit()); // consume all
out.put(new byte[remaining]); // put decrypted data (mocked)
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NOT_HANDSHAKING);
return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, remaining, remaining);
});

// when
reader.readPackets(networkData);

// then
// readPackets should have been called for the remaining 5 bytes,
// since each packet is 1 byte, it should have read 5 packets
verify(reader, times(5)).createPacketFor(any(ByteBuffer.class), anyInt(), anyInt(), anyInt());
verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class));
}

@Test
void testShouldNotDeadLoopWhenNeedWrapAndNoProgress() throws Exception {
// given
// Initial state: NEED_WRAP
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP);

// Network buffer has data
ByteBuffer networkData = ByteBuffer.allocate(10);
networkData.put(new byte[10]);
networkData.flip();

// Mock unwrap in decryptAndRead to return OK with 0 progress
// This happens if engine is in NEED_WRAP and can't decrypt application data
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class)))
.thenReturn(new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 0, 0));

// when
// We expect this NOT to hang indefinitely.
// If it dead-loops, the test will fail by timeout.
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> reader.readPackets(networkData));

// then
// Should have requested wrap
verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class));
}
}
Loading