|
1 | 1 | # frozen_string_literal: true |
2 | 2 |
|
3 | | -require "webrick" |
4 | | -require "webrick/ssl" |
| 3 | +require "socket" |
| 4 | +require "openssl" |
5 | 5 |
|
6 | | -require "support/black_hole" |
7 | 6 | require "support/dummy_server/servlet" |
8 | | -require "support/servers/config" |
9 | 7 | require "support/servers/runner" |
10 | 8 | require "support/ssl_helper" |
11 | 9 |
|
12 | | -class DummyServer < WEBrick::HTTPServer |
13 | | - include ServerConfig |
14 | | - |
15 | | - CONFIG = { |
16 | | - BindAddress: "127.0.0.1", |
17 | | - Port: 0, |
18 | | - AccessLog: BlackHole, |
19 | | - Logger: BlackHole |
20 | | - }.freeze |
21 | | - |
22 | | - SSL_CONFIG = CONFIG.merge( |
23 | | - SSLEnable: true, |
24 | | - SSLStartImmediately: true |
25 | | - ).freeze |
26 | | - |
| 10 | +class DummyServer |
27 | 11 | def initialize(options = {}) |
28 | | - super(options[:ssl] ? SSL_CONFIG : CONFIG) |
29 | | - @memo = {} |
30 | | - mount("/", Servlet, @memo) |
| 12 | + @ssl = options[:ssl] |
| 13 | + @tcp_server = TCPServer.new("127.0.0.1", 0) |
| 14 | + @port = @tcp_server.addr[1] |
| 15 | + @memo = {} |
| 16 | + @servlet = Servlet.new(self, @memo) |
| 17 | + @running = false |
| 18 | + end |
| 19 | + |
| 20 | + def addr |
| 21 | + "127.0.0.1" |
31 | 22 | end |
32 | 23 |
|
| 24 | + attr_reader :port |
| 25 | + |
33 | 26 | def endpoint |
34 | 27 | "#{scheme}://#{addr}:#{port}" |
35 | 28 | end |
36 | 29 |
|
37 | 30 | def scheme |
38 | | - config[:SSLEnable] ? "https" : "http" |
| 31 | + @ssl ? "https" : "http" |
| 32 | + end |
| 33 | + |
| 34 | + def start |
| 35 | + server = @ssl ? ssl_server : @tcp_server |
| 36 | + @running = true |
| 37 | + |
| 38 | + while @running |
| 39 | + client = server.accept |
| 40 | + Thread.new(client) { |c| handle_connection(c) } |
| 41 | + end |
| 42 | + rescue IOError, Errno::EBADF |
| 43 | + # Server socket closed during shutdown |
| 44 | + end |
| 45 | + |
| 46 | + def shutdown |
| 47 | + @running = false |
| 48 | + @tcp_server.close |
| 49 | + rescue |
| 50 | + nil |
39 | 51 | end |
40 | 52 |
|
41 | 53 | def ssl_context |
42 | 54 | @ssl_context ||= SSLHelper.server_context |
43 | 55 | end |
| 56 | + |
| 57 | + private |
| 58 | + |
| 59 | + def ssl_server |
| 60 | + OpenSSL::SSL::SSLServer.new(@tcp_server, ssl_context) |
| 61 | + end |
| 62 | + |
| 63 | + def handle_connection(client) |
| 64 | + loop do |
| 65 | + request = read_request(client) |
| 66 | + break unless request |
| 67 | + |
| 68 | + Thread.pass |
| 69 | + respond(client, request) |
| 70 | + end |
| 71 | + rescue IOError, Errno::ECONNRESET, Errno::EPIPE, Errno::EPROTOTYPE, OpenSSL::SSL::SSLError |
| 72 | + # Connection closed or SSL error |
| 73 | + ensure |
| 74 | + client.close rescue nil # rubocop:disable Style/RescueModifier |
| 75 | + end |
| 76 | + |
| 77 | + def respond(client, request) |
| 78 | + response = Response.new |
| 79 | + @servlet.dispatch(request, response) |
| 80 | + client.write(response.serialize(head_request: request.request_method == "HEAD")) |
| 81 | + end |
| 82 | + |
| 83 | + def read_request(client) |
| 84 | + line = client.gets |
| 85 | + return unless line |
| 86 | + |
| 87 | + method, uri, = line.split(" ", 3) |
| 88 | + return bad_request(client) unless uri.ascii_only? |
| 89 | + |
| 90 | + raw_path, query_string = uri.split("?", 2) |
| 91 | + headers = read_headers(client) |
| 92 | + |
| 93 | + Request.new({ |
| 94 | + request_method: method, request_path: percent_decode(raw_path), |
| 95 | + query_string: query_string, headers: headers, |
| 96 | + body: read_body(client, headers), socket: client, unparsed_uri: uri |
| 97 | + }) |
| 98 | + end |
| 99 | + |
| 100 | + def read_headers(client) |
| 101 | + headers = {} |
| 102 | + while (header_line = client.gets) |
| 103 | + break if header_line == "\r\n" |
| 104 | + |
| 105 | + key, value = header_line.split(": ", 2) |
| 106 | + headers[key.downcase] = value.strip |
| 107 | + end |
| 108 | + headers |
| 109 | + end |
| 110 | + |
| 111 | + def read_body(client, headers) |
| 112 | + content_length = headers["content-length"] |
| 113 | + client.read(content_length.to_i) if content_length |
| 114 | + end |
| 115 | + |
| 116 | + def bad_request(client) |
| 117 | + client.write("HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") |
| 118 | + nil |
| 119 | + end |
| 120 | + |
| 121 | + def percent_decode(str) |
| 122 | + str.b.gsub(/%([0-9A-Fa-f]{2})/) { [::Regexp.last_match(1)].pack("H2") } |
| 123 | + end |
44 | 124 | end |
0 commit comments