|
| 1 | +package org.jruby.ext.openssl; |
| 2 | + |
| 3 | +import java.io.IOException; |
| 4 | +import java.io.ByteArrayOutputStream; |
| 5 | +import java.io.InputStream; |
| 6 | +import java.nio.ByteBuffer; |
| 7 | +import java.nio.charset.StandardCharsets; |
| 8 | + |
| 9 | +import org.jruby.Ruby; |
| 10 | +import org.jruby.RubyArray; |
| 11 | +import org.jruby.RubyFixnum; |
| 12 | +import org.jruby.RubyInteger; |
| 13 | +import org.jruby.RubyString; |
| 14 | +import org.jruby.exceptions.RaiseException; |
| 15 | +import org.jruby.runtime.ThreadContext; |
| 16 | +import org.jruby.runtime.builtin.IRubyObject; |
| 17 | + |
| 18 | +import org.junit.After; |
| 19 | +import org.junit.Before; |
| 20 | +import org.junit.Test; |
| 21 | +import static org.junit.Assert.*; |
| 22 | + |
| 23 | +public class SSLSocketTest { |
| 24 | + |
| 25 | + private Ruby runtime; |
| 26 | + |
| 27 | + /** Loads the ssl_pair.rb script that creates a connected SSL socket pair. */ |
| 28 | + private String start_ssl_server_rb() { return readResource("/start_ssl_server.rb"); } |
| 29 | + |
| 30 | + @Before |
| 31 | + public void setUp() { |
| 32 | + runtime = Ruby.newInstance(); |
| 33 | + // prepend lib/ so openssl.rb + jopenssl/ are loaded instead of the ones bundled in jruby-stdlib |
| 34 | + String libDir = new java.io.File("lib").getAbsolutePath(); |
| 35 | + runtime.evalScriptlet("$LOAD_PATH.unshift '" + libDir + "'"); |
| 36 | + runtime.evalScriptlet("require 'openssl'"); |
| 37 | + } |
| 38 | + |
| 39 | + @After |
| 40 | + public void tearDown() { |
| 41 | + if (runtime != null) { |
| 42 | + runtime.tearDown(false); |
| 43 | + runtime = null; |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + /** |
| 48 | + * Real-world scenario: {@code gem push} sends a large POST body via {@code syswrite_nonblock}, |
| 49 | + * then reads the HTTP response via {@code sysread}. |
| 50 | + * |
| 51 | + * Approximates the {@code gem push} scenario: |
| 52 | + * <ol> |
| 53 | + * <li>Write 256KB via {@code syswrite_nonblock} in a loop (the net/http POST pattern)</li> |
| 54 | + * <li>Server reads via {@code sysread} and counts bytes</li> |
| 55 | + * <li>Assert: server received exactly what client sent</li> |
| 56 | + * </ol> |
| 57 | + * |
| 58 | + * With the old {@code clear()} bug, encrypted bytes were silently |
| 59 | + * discarded during partial non-blocking writes, so the server would |
| 60 | + * receive fewer bytes than sent. |
| 61 | + */ |
| 62 | + @Test |
| 63 | + public void syswriteNonblockDataIntegrity() throws Exception { |
| 64 | + final RubyArray pair = (RubyArray) runtime.evalScriptlet(start_ssl_server_rb()); |
| 65 | + SSLSocket client = (SSLSocket) pair.entry(0).toJava(SSLSocket.class); |
| 66 | + SSLSocket server = (SSLSocket) pair.entry(1).toJava(SSLSocket.class); |
| 67 | + |
| 68 | + try { |
| 69 | + // Server: read all data in a background thread, counting bytes |
| 70 | + final long[] serverReceived = { 0 }; |
| 71 | + Thread serverReader = startServerReader(server, serverReceived); |
| 72 | + |
| 73 | + // Client: write 256KB in 4KB chunks via syswrite_nonblock |
| 74 | + byte[] chunk = new byte[4096]; |
| 75 | + java.util.Arrays.fill(chunk, (byte) 'P'); // P for POST body |
| 76 | + RubyString payload = RubyString.newString(runtime, chunk); |
| 77 | + |
| 78 | + long totalSent = 0; |
| 79 | + for (int i = 0; i < 64; i++) { // 64 * 4KB = 256KB |
| 80 | + try { |
| 81 | + IRubyObject written = client.syswrite_nonblock(currentContext(), payload); |
| 82 | + totalSent += ((RubyInteger) written).getLongValue(); |
| 83 | + } catch (RaiseException e) { |
| 84 | + String rubyClass = e.getException().getMetaClass().getName(); |
| 85 | + if (rubyClass.contains("WaitWritable")) { |
| 86 | + // Expected: non-blocking write would block — retry as blocking |
| 87 | + IRubyObject written = client.syswrite(currentContext(), payload); |
| 88 | + totalSent += ((RubyInteger) written).getLongValue(); |
| 89 | + } else { |
| 90 | + System.err.println("syswrite_nonblock unexpected: " + rubyClass + ": " + e.getMessage()); |
| 91 | + throw e; |
| 92 | + } |
| 93 | + } |
| 94 | + } |
| 95 | + assertTrue("should have sent data", totalSent > 0); |
| 96 | + |
| 97 | + // Close client to signal EOF, let server finish reading |
| 98 | + client.callMethod(currentContext(), "close"); |
| 99 | + serverReader.join(10_000); |
| 100 | + |
| 101 | + assertEquals( |
| 102 | + "server must receive exactly what client sent — mismatch means encrypted bytes were lost!", |
| 103 | + totalSent, serverReceived[0] |
| 104 | + ); |
| 105 | + } finally { |
| 106 | + closeQuietly(pair); |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + private Thread startServerReader(final SSLSocket server, final long[] serverReceived) { |
| 111 | + Thread serverReader = new Thread(() -> { |
| 112 | + try { |
| 113 | + RubyFixnum len = RubyFixnum.newFixnum(runtime, 8192); |
| 114 | + while (true) { |
| 115 | + IRubyObject data = server.sysread(currentContext(), len); |
| 116 | + serverReceived[0] += ((RubyString) data).getByteList().getRealSize(); |
| 117 | + } |
| 118 | + } catch (RaiseException e) { |
| 119 | + String rubyClass = e.getException().getMetaClass().getName(); |
| 120 | + // EOFError or IOError expected when client closes the connection |
| 121 | + if (!rubyClass.equals("EOFError") && !rubyClass.equals("IOError")) { |
| 122 | + System.err.println("server reader unexpected: " + rubyClass + ": " + e.getMessage()); |
| 123 | + e.printStackTrace(System.err); |
| 124 | + } |
| 125 | + } |
| 126 | + }); |
| 127 | + serverReader.start(); |
| 128 | + return serverReader; |
| 129 | + } |
| 130 | + |
| 131 | + /** |
| 132 | + * After saturating the TCP send buffer with {@code syswrite_nonblock}, |
| 133 | + * inspect {@code netWriteData} to verify the buffer is consistent. |
| 134 | + */ |
| 135 | + @Test |
| 136 | + public void syswriteNonblockNetWriteDataConsistency() { |
| 137 | + final RubyArray pair = (RubyArray) runtime.evalScriptlet(start_ssl_server_rb()); |
| 138 | + SSLSocket client = (SSLSocket) pair.entry(0).toJava(SSLSocket.class); |
| 139 | + |
| 140 | + try { |
| 141 | + assertNotNull("netWriteData initialized after handshake", client.netWriteData); |
| 142 | + |
| 143 | + // Saturate: server is not reading yet, so backpressure builds |
| 144 | + byte[] chunk = new byte[16384]; |
| 145 | + java.util.Arrays.fill(chunk, (byte) 'S'); |
| 146 | + RubyString payload = RubyString.newString(runtime, chunk); |
| 147 | + |
| 148 | + int successfulWrites = 0; |
| 149 | + for (int i = 0; i < 200; i++) { |
| 150 | + try { |
| 151 | + client.syswrite_nonblock(currentContext(), payload); |
| 152 | + successfulWrites++; |
| 153 | + } catch (RaiseException e) { |
| 154 | + String rubyClass = e.getException().getMetaClass().getName(); |
| 155 | + if (rubyClass.contains("WaitWritable") || rubyClass.equals("IOError")) { |
| 156 | + break; // buffer saturated — expected |
| 157 | + } |
| 158 | + System.err.println("saturate loop unexpected: " + rubyClass + ": " + e.getMessage()); |
| 159 | + throw e; |
| 160 | + } |
| 161 | + } |
| 162 | + assertTrue("at least one write should succeed", successfulWrites > 0); |
| 163 | + |
| 164 | + // Inspect netWriteData directly |
| 165 | + ByteBuffer netWriteData = client.netWriteData; |
| 166 | + assertTrue("position <= limit", netWriteData.position() <= netWriteData.limit()); |
| 167 | + assertTrue("limit <= capacity", netWriteData.limit() <= netWriteData.capacity()); |
| 168 | + |
| 169 | + // If there are unflushed bytes, compact() preserved them |
| 170 | + if (netWriteData.remaining() > 0) { |
| 171 | + // The bytes should be valid TLS record data, not zeroed memory |
| 172 | + byte b = netWriteData.get(netWriteData.position()); |
| 173 | + assertNotEquals("preserved bytes should be TLS data, not zeroed", 0, b); |
| 174 | + } |
| 175 | + |
| 176 | + } finally { |
| 177 | + closeQuietly(pair); |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + private ThreadContext currentContext() { |
| 182 | + return runtime.getCurrentContext(); |
| 183 | + } |
| 184 | + |
| 185 | + private void closeQuietly(final RubyArray sslPair) { |
| 186 | + for (int i = 0; i < sslPair.getLength(); i++) { |
| 187 | + try { sslPair.entry(i).callMethod(currentContext(), "close"); } |
| 188 | + catch (RaiseException e) { /* already closed */ } |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + static String readResource(final String resource) { |
| 193 | + int n; |
| 194 | + try (InputStream in = SSLSocketTest.class.getResourceAsStream(resource)) { |
| 195 | + if (in == null) throw new IllegalArgumentException(resource + " not found on classpath"); |
| 196 | + |
| 197 | + ByteArrayOutputStream out = new ByteArrayOutputStream(); |
| 198 | + byte[] buf = new byte[8192]; |
| 199 | + while ((n = in.read(buf)) != -1) out.write(buf, 0, n); |
| 200 | + return new String(out.toByteArray(), StandardCharsets.UTF_8); |
| 201 | + } catch (IOException e) { |
| 202 | + throw new IllegalStateException("failed to load" + resource, e); |
| 203 | + } |
| 204 | + } |
| 205 | +} |
0 commit comments