Skip to content

Commit f36be4c

Browse files
committed
fix data lose issue during SSL communication
1 parent 426fb5e commit f36be4c

2 files changed

Lines changed: 168 additions & 0 deletions

File tree

rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) {
160160
case NEED_WRAP: {
161161
log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted);
162162
packetWriter.accept(SslWrapRequestNetworkPacket.getInstance());
163+
if (networkBuffer.hasRemaining()) {
164+
return decryptAndRead(networkBuffer);
165+
}
163166
NetworkUtils.cleanNetworkBuffer(networkBuffer);
164167
return SKIP_READ_PACKETS;
165168
}
@@ -204,6 +207,10 @@ protected int decryptAndRead(ByteBuffer receivedBuffer) {
204207
}
205208
switch (result.getStatus()) {
206209
case OK: {
210+
if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) {
211+
log.debug(remoteAddress, "[%s] No progress during decryption, stop processing"::formatted);
212+
return SKIP_READ_PACKETS;
213+
}
207214
sslDataBuffer.flip();
208215
logDataAfterDecrypt(remoteAddress, sslDataBuffer);
209216
total += readPackets(sslDataBuffer, sslDataPendingBuffer);
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package javasabr.rlib.network.packet.impl;
2+
3+
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
4+
import static org.mockito.ArgumentMatchers.any;
5+
import static org.mockito.ArgumentMatchers.anyInt;
6+
import static org.mockito.Mockito.doReturn;
7+
import static org.mockito.Mockito.mock;
8+
import static org.mockito.Mockito.spy;
9+
import static org.mockito.Mockito.times;
10+
import static org.mockito.Mockito.verify;
11+
import static org.mockito.Mockito.when;
12+
13+
import java.nio.ByteBuffer;
14+
import java.time.Duration;
15+
import java.util.function.Consumer;
16+
import javasabr.rlib.network.BufferAllocator;
17+
import javasabr.rlib.network.Network;
18+
import javasabr.rlib.network.NetworkConfig;
19+
import javasabr.rlib.network.UnsafeConnection;
20+
import javasabr.rlib.network.impl.DefaultBufferAllocator;
21+
import javasabr.rlib.network.packet.ReadableNetworkPacket;
22+
import javasabr.rlib.network.packet.WritableNetworkPacket;
23+
import javax.net.ssl.SSLEngine;
24+
import javax.net.ssl.SSLEngineResult;
25+
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
26+
import javax.net.ssl.SSLEngineResult.Status;
27+
import javax.net.ssl.SSLSession;
28+
import org.junit.jupiter.api.BeforeEach;
29+
import org.junit.jupiter.api.Test;
30+
import org.junit.jupiter.api.extension.ExtendWith;
31+
import org.mockito.Mock;
32+
import org.mockito.junit.jupiter.MockitoExtension;
33+
import org.mockito.junit.jupiter.MockitoSettings;
34+
import org.mockito.quality.Strictness;
35+
36+
/**
37+
* The tests of SSL packet reader
38+
*
39+
* @author crazyrokr
40+
*/
41+
@ExtendWith(MockitoExtension.class)
42+
@MockitoSettings(strictness = Strictness.LENIENT)
43+
public class AbstractSslNetworkPacketReaderTest {
44+
45+
@Mock
46+
private TestConnection connection;
47+
48+
@Mock
49+
private Network<TestConnection> network;
50+
51+
@Mock
52+
private SSLEngine sslEngine;
53+
54+
@Mock
55+
private SSLSession sslSession;
56+
57+
@Mock
58+
private Consumer<ReadableNetworkPacket<TestConnection>> packetHandler;
59+
60+
@Mock
61+
private Consumer<WritableNetworkPacket<TestConnection>> packetWriter;
62+
63+
DefaultSslNetworkPacketReader<ReadableNetworkPacket<TestConnection>, TestConnection> reader;
64+
65+
private BufferAllocator bufferAllocator;
66+
67+
private interface TestConnection extends UnsafeConnection<TestConnection> {}
68+
69+
@BeforeEach
70+
void setUp() {
71+
bufferAllocator = new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT);
72+
when(connection.bufferAllocator()).thenReturn(bufferAllocator);
73+
when(connection.network()).thenReturn((Network) network);
74+
when(connection.remoteAddress()).thenReturn("test-address");
75+
when(network.config()).thenReturn(NetworkConfig.DEFAULT_CLIENT);
76+
when(sslEngine.getSession()).thenReturn(sslSession);
77+
when(sslSession.getApplicationBufferSize()).thenReturn(1024);
78+
when(sslSession.getPacketBufferSize()).thenReturn(1024);
79+
reader = spy(new DefaultSslNetworkPacketReader<ReadableNetworkPacket<TestConnection>, TestConnection>(
80+
connection,
81+
() -> {},
82+
packetHandler,
83+
packetHandler,
84+
len -> mock(ReadableNetworkPacket.class),
85+
sslEngine,
86+
packetWriter,
87+
1,
88+
100) {
89+
});
90+
doReturn(1).when(reader).readFullPacketLength(any(ByteBuffer.class));
91+
}
92+
93+
@Test
94+
void testShouldNotLoseDataOnNeedWrapDuringHandshake() throws Exception {
95+
// given
96+
// Initial state: NEED_UNWRAP
97+
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_UNWRAP);
98+
99+
// First unwrap will result in NEED_WRAP and status OK, consuming some data.
100+
// Simulate a single network buffer containing 5 bytes of handshake data followed by
101+
// 5 bytes of application data, so the remaining bytes can still be processed afterward.
102+
ByteBuffer networkData = ByteBuffer.allocate(10);
103+
networkData.put(new byte[10]);
104+
networkData.flip();
105+
106+
// doHandshake calls unwrap in NEED_UNWRAP, consumes first 5 bytes, then returns OK
107+
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer[].class))).thenAnswer(invocation -> {
108+
ByteBuffer in = invocation.getArgument(0);
109+
in.position(in.position() + 5); // consume 5 bytes of handshake
110+
// Change status to NEED_WRAP for next getHandshakeStatus() call
111+
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP);
112+
return new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 5, 0);
113+
});
114+
115+
// decryptAndRead calls unwrap, consumes the remaining 5 bytes, then return FINISHED or NOT_HANDSHAKING
116+
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenAnswer(invocation -> {
117+
ByteBuffer in = invocation.getArgument(0);
118+
ByteBuffer out = invocation.getArgument(1);
119+
int remaining = in.remaining();
120+
in.position(in.limit()); // consume all
121+
out.put(new byte[remaining]); // put decrypted data (mocked)
122+
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NOT_HANDSHAKING);
123+
return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, remaining, remaining);
124+
});
125+
126+
// when
127+
reader.readPackets(networkData);
128+
129+
// then
130+
// readPackets should have been called for the remaining 5 bytes,
131+
// since each packet is 1 byte, it should have read 5 packets
132+
verify(reader, times(5)).createPacketFor(any(ByteBuffer.class), anyInt(), anyInt(), anyInt());
133+
verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class));
134+
}
135+
136+
@Test
137+
void testShouldNotDeadLoopWhenNeedWrapAndNoProgress() throws Exception {
138+
// given
139+
// Initial state: NEED_WRAP
140+
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP);
141+
142+
// Network buffer has data
143+
ByteBuffer networkData = ByteBuffer.allocate(10);
144+
networkData.put(new byte[10]);
145+
networkData.flip();
146+
147+
// Mock unwrap in decryptAndRead to return OK with 0 progress
148+
// This happens if engine is in NEED_WRAP and can't decrypt application data
149+
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class)))
150+
.thenReturn(new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 0, 0));
151+
152+
// when
153+
// We expect this NOT to hang indefinitely.
154+
// If it dead-loops, the test will fail by timeout.
155+
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> reader.readPackets(networkData));
156+
157+
// then
158+
// Should have requested wrap
159+
verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class));
160+
}
161+
}

0 commit comments

Comments
 (0)