2929import java .net .InetAddress ;
3030import java .net .ServerSocket ;
3131import java .net .Socket ;
32+ import java .net .SocketException ;
3233import java .nio .charset .StandardCharsets ;
3334import java .security .KeyPair ;
3435import java .security .KeyPairGenerator ;
5556
5657public class ProtocolHandlerTest {
5758
59+ public static final byte [] SERVER_DATA = "hello" .getBytes (StandardCharsets .UTF_8 );
60+ public static final byte [] CLIENT_DATA = "from client" .getBytes (StandardCharsets .UTF_8 );
61+ public static final byte [] MDX_REQUEST_DATA ;
62+ public static final byte [] WANT_FULL_REQUEST_BYTES ;
63+
64+ static {
65+ byte [] wantClientMessage = CLIENT_DATA ;
66+ MDX_REQUEST_DATA = wantRequestBytes ();
67+ byte [] wantRequest = new byte [wantClientMessage .length + MDX_REQUEST_DATA .length ];
68+ System .arraycopy (MDX_REQUEST_DATA , 0 , wantRequest , 0 , MDX_REQUEST_DATA .length );
69+ System .arraycopy (
70+ wantClientMessage , 0 , wantRequest , MDX_REQUEST_DATA .length , wantClientMessage .length );
71+ WANT_FULL_REQUEST_BYTES = wantRequest ;
72+ }
73+
5874 @ Test
5975 public void testSendMdx () throws IOException {
6076 ByteArrayOutputStream out = new ByteArrayOutputStream ();
@@ -85,35 +101,42 @@ public void testReadMdx_WithMdxResponse() throws IOException {
85101 public void testMdxSocket_clientWritesFirst_noMdxResponse () throws Exception {
86102 // 1. The client connects, writes to the server, sending an MDX request,
87103 // reads from the server but receives no MDX response
104+ AtomicReference <byte []> requestBytes = new AtomicReference <>();
88105
89106 // Setup SSL server that expects an MDX request, but does not send an MDX response.
90107 SslServer server =
91108 new SslServer (
92109 (in , out ) -> {
93110 // Server reads the client's MDX request.
94- byte [] req = new byte [wantRequestBytes () .length ];
111+ byte [] req = new byte [CLIENT_DATA .length ];
95112 new DataInputStream (in ).readFully (req );
96- assertThat ( req ). isEqualTo ( wantRequestBytes () );
113+ requestBytes . set ( req );
97114
98115 // Server writes a non-MDX response.
99116 out .write ("hello" .getBytes (StandardCharsets .UTF_8 ));
117+ out .flush ();
100118 });
101119 SslServer .SslServerParams p = server .start ();
102120
103121 // Setup client socket
104- Socket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
122+ MdxSocket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
105123
106124 // Client writes, which should trigger the MDX exchange.
107- socket .getOutputStream ().write ("from client" .getBytes (StandardCharsets .UTF_8 ));
125+ socket .getOutputStream ().write (CLIENT_DATA );
126+ socket .getOutputStream ().flush ();
108127
109128 // The server should not have sent an MDX response, so the client should
110129 // read the raw "hello" from the server.
111130 byte [] fromServer = new byte [5 ];
112131 new DataInputStream (socket .getInputStream ()).readFully (fromServer );
113- assertThat (fromServer ).isEqualTo ("hello" .getBytes (StandardCharsets .UTF_8 ));
132+ while (requestBytes .get () == null ) {
133+ Thread .sleep (10 );
134+ }
114135
115136 socket .close ();
116137 server .stop ();
138+ assertThat (fromServer ).isEqualTo (SERVER_DATA );
139+ assertThat (socket .getMdxResponse ()).isNull ();
117140 }
118141
119142 @ Test
@@ -127,35 +150,42 @@ public void testMdxSocket_clientWritesFirst_receivesMdxResponse() throws Excepti
127150 new SslServer (
128151 (in , out ) -> {
129152 // Server reads the client's MDX request.
130- byte [] req = new byte [wantRequestBytes () .length ];
153+ byte [] req = new byte [WANT_FULL_REQUEST_BYTES .length ];
131154 new DataInputStream (in ).readFully (req );
132155 requestBytes .set (req );
133156 // Server writes an MDX response.
134157 out .write (wantResponseBytes ("hello" .getBytes (StandardCharsets .UTF_8 )));
158+ out .flush ();
135159 });
136160
137161 SslServer .SslServerParams p = server .start ();
138162
139163 // Setup client socket
140- Socket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
164+ MdxSocket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
141165
142166 // Client writes, which should trigger the MDX exchange.
143- socket .getOutputStream ().write ("from client" .getBytes (StandardCharsets .UTF_8 ));
167+ socket .getOutputStream ().write (CLIENT_DATA );
168+ socket .getOutputStream ().flush ();
144169
145170 // The server should have sent an MDX response, so the client should
146171 // read the "hello" from the server.
147172 byte [] fromServer = new byte [5 ];
148173 new DataInputStream (socket .getInputStream ()).readFully (fromServer );
149- assertThat ( fromServer ). isEqualTo ( "hello" . getBytes ( StandardCharsets . UTF_8 ));
150-
151- assertThat ( requestBytes . get ()). isEqualTo ( wantRequestBytes ());
174+ while ( requestBytes . get () == null ) {
175+ Thread . sleep ( 10 );
176+ }
152177
153178 socket .close ();
154179 server .stop ();
180+
181+ assertThat (fromServer ).isEqualTo ("hello" .getBytes (StandardCharsets .UTF_8 ));
182+ assertThat (requestBytes .get ()).isEqualTo (WANT_FULL_REQUEST_BYTES );
183+ assertThat (socket .getMdxResponse ()).isNotNull ();
155184 }
156185
157186 @ Test
158187 public void testMdxSocket_clientReadsFirst_noMdxResponse () throws Exception {
188+
159189 // 3. The client connects, reads from the server but receives no MDX response,
160190 // then writes to the server, sending an MDX request
161191 AtomicReference <byte []> requestBytes = new AtomicReference <>();
@@ -166,36 +196,43 @@ public void testMdxSocket_clientReadsFirst_noMdxResponse() throws Exception {
166196 (in , out ) -> {
167197 // Server writes a non-MDX response.
168198 out .write ("hello" .getBytes (StandardCharsets .UTF_8 ));
199+ out .flush ();
169200
170201 // Server reads the client's MDX request.
171- byte [] req = new byte [wantRequestBytes () .length ];
202+ byte [] req = new byte [WANT_FULL_REQUEST_BYTES .length ];
172203 new DataInputStream (in ).readFully (req );
173204 requestBytes .set (req );
174205 });
175206 SslServer .SslServerParams p = server .start ();
176207
177208 // Setup client socket
178- Socket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
209+ MdxSocket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
179210
180211 // The server should not have sent an MDX response, so the client should
181212 // read the raw "hello" from the server.
182213 byte [] fromServer = new byte [5 ];
183214 new DataInputStream (socket .getInputStream ()).readFully (fromServer );
184- assertThat (fromServer ).isEqualTo ("hello" .getBytes (StandardCharsets .UTF_8 ));
185- assertThat (requestBytes .get ()).isNull ();
186215
187216 // Client writes, which should trigger the MDX exchange.
188- socket .getOutputStream ().write ("from client" .getBytes (StandardCharsets .UTF_8 ));
217+ socket .getOutputStream ().write (CLIENT_DATA );
218+ socket .getOutputStream ().flush ();
219+ while (requestBytes .get () == null ) {
220+ Thread .sleep (10 );
221+ }
189222
190223 socket .close ();
191224 server .stop ();
225+
226+ assertThat (fromServer ).isEqualTo ("hello" .getBytes (StandardCharsets .UTF_8 ));
227+ assertThat (requestBytes .get ()).isEqualTo (WANT_FULL_REQUEST_BYTES );
228+ assertThat (socket .getMdxResponse ()).isNull ();
192229 }
193230
194231 @ Test
195- public void testMdxSocket_clientReadsFirst_receivesMdxResponse () throws Exception {
196- // 4. The client connects, reads from the server and receives an MDX response,
232+ public void testMdxSocket_clientReadsFirst_receivesMdxResponse_writesMdxRequest ()
233+ throws Exception {
234+ // 4. The client connects, reads from the server and receives no MDX response,
197235 // then writes to the server, sending an MDX request
198-
199236 AtomicReference <byte []> requestBytes = new AtomicReference <>();
200237
201238 // Setup SSL server that expects an MDX request and sends an MDX response.
@@ -204,32 +241,39 @@ public void testMdxSocket_clientReadsFirst_receivesMdxResponse() throws Exceptio
204241 (in , out ) -> {
205242 // Server writes a non-MDX response.
206243 out .write ("hello" .getBytes (StandardCharsets .UTF_8 ));
244+ out .flush ();
207245
208246 // Server reads the client's MDX request.
209- byte [] req = new byte [wantRequestBytes () .length ];
247+ byte [] req = new byte [WANT_FULL_REQUEST_BYTES .length ];
210248 new DataInputStream (in ).readFully (req );
211249 requestBytes .set (req );
212250 });
213251 SslServer .SslServerParams p = server .start ();
214252
215253 // Setup client socket
216- Socket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
254+ MdxSocket socket = new ProtocolHandler ("ua" ).connect (p .getSocket (), "tls" );
217255
218- // The server should have sent an MDX response, so the client should
256+ // The server should not send MDX response, so the client should
219257 // read the "hello" from the server.
220258 byte [] fromServer = new byte [5 ];
221259 new DataInputStream (socket .getInputStream ()).readFully (fromServer );
222- assertThat (fromServer ).isEqualTo ("hello" .getBytes (StandardCharsets .UTF_8 ));
223260
224261 // Client writes, which should trigger the MDX exchange.
225- socket .getOutputStream ().write ("from client" .getBytes (StandardCharsets .UTF_8 ));
226- assertThat (requestBytes .get ()).isNull ();
262+ socket .getOutputStream ().write (CLIENT_DATA );
263+ socket .getOutputStream ().flush ();
264+ while (requestBytes .get () == null ) {
265+ Thread .sleep (10 );
266+ }
227267
228268 socket .close ();
229269 server .stop ();
270+
271+ assertThat (fromServer ).isEqualTo ("hello" .getBytes (StandardCharsets .UTF_8 ));
272+ assertThat (requestBytes .get ()).isEqualTo (WANT_FULL_REQUEST_BYTES );
273+ assertThat (socket .getMdxResponse ()).isNull ();
230274 }
231275
232- private static byte [] wantRequestBytes () throws IOException {
276+ private static byte [] wantRequestBytes () {
233277 ByteArrayOutputStream wantOut = new ByteArrayOutputStream ();
234278 MetadataExchange .MetadataExchangeRequest req =
235279 MetadataExchange .MetadataExchangeRequest .newBuilder ()
@@ -238,39 +282,46 @@ private static byte[] wantRequestBytes() throws IOException {
238282 .build ();
239283 int size = req .getSerializedSize ();
240284
241- // Write the protocoal header
242- wantOut .write ("CSQLMDEX" .getBytes (StandardCharsets .UTF_8 ));
243- // Write the uint32 size
244- wantOut .write ((size >>> 24 ) & 0xFF );
245- wantOut .write ((size >>> 16 ) & 0xFF );
246- wantOut .write ((size >>> 8 ) & 0xFF );
247- wantOut .write (size & 0xFF );
248- // Write the protobuf
249- req .writeTo (wantOut );
250- wantOut .flush ();
285+ try {
286+ // Write the protocoal header
287+ wantOut .write ("CSQLMDEX" .getBytes (StandardCharsets .UTF_8 ));
288+ // Write the uint32 size
289+ wantOut .write ((byte ) ((size >>> 24 ) & 0xFF ));
290+ wantOut .write ((byte ) ((size >>> 16 ) & 0xFF ));
291+ wantOut .write ((byte ) ((size >>> 8 ) & 0xFF ));
292+ wantOut .write ((byte ) (size & 0xFF ));
293+ // Write the protobuf
294+ req .writeTo (wantOut );
295+ } catch (IOException e ) {
296+ throw new RuntimeException (e );
297+ }
298+
251299 return wantOut .toByteArray ();
252300 }
253301
254- private static byte [] wantResponseBytes (byte [] data ) throws IOException {
302+ private static byte [] wantResponseBytes (byte [] data ) {
255303 ByteArrayOutputStream wantOut = new ByteArrayOutputStream ();
256304 MetadataExchange .MetadataExchangeResponse res =
257305 MetadataExchange .MetadataExchangeResponse .newBuilder ()
258306 .setResponseStatusCode (MetadataExchange .MetadataExchangeResponse .ResponseStatusCode .OK )
259307 .build ();
260-
261308 int size = res .getSerializedSize ();
262309
263- // Write the protocoal header
264- wantOut .write ("CSQLMDEX" .getBytes (StandardCharsets .UTF_8 ));
265- // Write the uint32 size
266- wantOut .write ((size >>> 24 ) & 0xFF );
267- wantOut .write ((size >>> 16 ) & 0xFF );
268- wantOut .write ((size >>> 8 ) & 0xFF );
269- wantOut .write (size & 0xFF );
270- // Write the protobuf
271- res .writeTo (wantOut );
272- wantOut .write (data );
273- wantOut .flush ();
310+ try {
311+ // Write the protocoal header
312+ wantOut .write ("CSQLMDEX" .getBytes (StandardCharsets .UTF_8 ));
313+ // Write the uint32 size
314+ wantOut .write ((size >>> 24 ) & 0xFF );
315+ wantOut .write ((size >>> 16 ) & 0xFF );
316+ wantOut .write ((size >>> 8 ) & 0xFF );
317+ wantOut .write (size & 0xFF );
318+ // Write the protobuf
319+ res .writeTo (wantOut );
320+ wantOut .write (data );
321+ wantOut .flush ();
322+ } catch (IOException e ) {
323+ throw new RuntimeException (e );
324+ }
274325 return wantOut .toByteArray ();
275326 }
276327
@@ -312,6 +363,8 @@ private SslServerParams start() throws Exception {
312363 }
313364 handler .accept (s .getInputStream (), s .getOutputStream ());
314365 }
366+ } catch (SocketException e ) {
367+ // do nothing, we don't care if the socket was closed.
315368 } catch (Exception e ) {
316369 throw new RuntimeException (e );
317370 }
0 commit comments