Skip to content

Commit 6bdb86d

Browse files
committed
chore: more tests
Signed-off-by: Valery Piashchynski <piashchynski.valery@gmail.com>
1 parent 290b234 commit 6bdb86d

2 files changed

Lines changed: 216 additions & 0 deletions

File tree

handler/handler_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,123 @@ func TestServeHTTP_PoolExecError_Returns500(t *testing.T) {
217217
}
218218
}
219219

220+
// ── Group B′: FetchIP edge cases ─────────────────────────────────────────────
221+
222+
func TestFetchIP_EdgeCases(t *testing.T) {
223+
tests := []struct {
224+
name string
225+
input string
226+
want string
227+
}{
228+
{"empty string", "", ""},
229+
{"ipv4 no port", "10.0.0.1", "10.0.0.1"},
230+
{"ipv6 bracketed with port", "[::1]:8080", "::1"},
231+
{"ipv6 full address bare", "2001:db8::1", "2001:db8::1"},
232+
{"garbage with colons", "not:a:valid:thing", ""},
233+
{"port only", ":8080", ""},
234+
{"ipv4 with empty port", "127.0.0.1:", "127.0.0.1"},
235+
{"ipv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
236+
}
237+
238+
log := zap.NewNop()
239+
for _, tt := range tests {
240+
t.Run(tt.name, func(t *testing.T) {
241+
got := FetchIP(tt.input, log)
242+
if got != tt.want {
243+
t.Errorf("FetchIP(%q) = %q, want %q", tt.input, got, tt.want)
244+
}
245+
})
246+
}
247+
}
248+
249+
// ── Group B″: URI edge cases ─────────────────────────────────────────────────
250+
251+
func TestURI_EdgeCases(t *testing.T) {
252+
tests := []struct {
253+
name string
254+
setup func() *http.Request
255+
want string
256+
}{
257+
{
258+
name: "empty host",
259+
setup: func() *http.Request {
260+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil)
261+
r.Host = ""
262+
return r
263+
},
264+
want: "http:///",
265+
},
266+
{
267+
name: "host with port",
268+
setup: func() *http.Request {
269+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/p", nil)
270+
r.Host = "example.com:8080"
271+
return r
272+
},
273+
want: "http://example.com:8080/p",
274+
},
275+
{
276+
name: "url already has host set",
277+
setup: func() *http.Request {
278+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/x", nil)
279+
r.URL.Host = "other.com"
280+
return r
281+
},
282+
want: "//other.com/x",
283+
},
284+
{
285+
name: "root path only",
286+
setup: func() *http.Request {
287+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil)
288+
r.Host = "example.com"
289+
return r
290+
},
291+
want: "http://example.com/",
292+
},
293+
{
294+
name: "query but no path",
295+
setup: func() *http.Request {
296+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil)
297+
r.Host = "example.com"
298+
r.URL.Path = ""
299+
r.URL.RawQuery = "a=1"
300+
return r
301+
},
302+
want: "http://example.com?a=1",
303+
},
304+
{
305+
name: "encoded CRLF in path preserved",
306+
setup: func() *http.Request {
307+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/foo%0D%0Abar", nil)
308+
r.Host = "example.com"
309+
return r
310+
},
311+
want: "http://example.com/foo%0D%0Abar",
312+
},
313+
{
314+
name: "tab in query not stripped",
315+
setup: func() *http.Request {
316+
r := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil)
317+
r.Host = "example.com"
318+
r.URL.RawQuery = "x=1\tX-Bad: true"
319+
return r
320+
},
321+
want: "http://example.com/?x=1\tX-Bad: true",
322+
},
323+
}
324+
325+
for _, tt := range tests {
326+
t.Run(tt.name, func(t *testing.T) {
327+
got := URI(tt.setup())
328+
if got != tt.want {
329+
t.Errorf("URI() = %q, want %q", got, tt.want)
330+
}
331+
})
332+
}
333+
}
334+
335+
// ── Group C: mockPool tests ───────────────────────────────────────────────────
336+
220337
func TestServeHTTP_NoFreeWorkers_SetsHeader(t *testing.T) {
221338
mp := &mockPool{execErr: errors.E(errors.NoFreeWorkers)}
222339
h := newTestHandler(t, defaultCfg(), mp)

middleware/redirect_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package middleware
22

33
import (
4+
"net/http"
5+
"net/http/httptest"
46
"testing"
57

68
"github.com/stretchr/testify/assert"
@@ -31,6 +33,12 @@ func TestTLSAddr(t *testing.T) {
3133
// IPv6 without port (bare bracketed form from r.Host)
3234
{"ipv6 no port, default ssl port", "[::1]", false, 443, "[::1]"},
3335
{"ipv6 no port, force port", "[::1]", true, 443, "[::1]:443"},
36+
37+
// Edge cases — degenerate inputs
38+
{"empty host", "", false, 443, ""},
39+
{"empty host forced port", "", true, 443, ":443"},
40+
{"host with trailing colon", "example.com:", false, 443, "example.com"},
41+
{"ip literal with zone", "[fe80::1%25eth0]:80", false, 443, "[fe80::1%eth0]"},
3442
}
3543

3644
for _, tt := range tests {
@@ -40,3 +48,94 @@ func TestTLSAddr(t *testing.T) {
4048
})
4149
}
4250
}
51+
52+
func TestRedirect(t *testing.T) {
53+
tests := []struct {
54+
name string
55+
method string
56+
host string
57+
path string
58+
query string
59+
sslPort int
60+
wantLoc string
61+
wantCode int
62+
}{
63+
{
64+
name: "basic http to https",
65+
method: http.MethodGet,
66+
host: "example.com",
67+
path: "/page",
68+
sslPort: 443,
69+
wantLoc: "https://example.com/page",
70+
wantCode: http.StatusPermanentRedirect,
71+
},
72+
{
73+
name: "preserves query string",
74+
method: http.MethodGet,
75+
host: "example.com",
76+
path: "/search",
77+
query: "q=hello&lang=en",
78+
sslPort: 443,
79+
wantLoc: "https://example.com/search?q=hello&lang=en",
80+
wantCode: http.StatusPermanentRedirect,
81+
},
82+
{
83+
name: "non-default ssl port",
84+
method: http.MethodGet,
85+
host: "example.com",
86+
path: "/",
87+
sslPort: 8443,
88+
wantLoc: "https://example.com:8443/",
89+
wantCode: http.StatusPermanentRedirect,
90+
},
91+
{
92+
name: "POST gets 308 not 301",
93+
method: http.MethodPost,
94+
host: "example.com",
95+
path: "/api",
96+
sslPort: 443,
97+
wantLoc: "https://example.com/api",
98+
wantCode: http.StatusPermanentRedirect,
99+
},
100+
{
101+
name: "IPv6 host",
102+
method: http.MethodGet,
103+
host: "[::1]:8080",
104+
path: "/",
105+
sslPort: 443,
106+
wantLoc: "https://[::1]/",
107+
wantCode: http.StatusPermanentRedirect,
108+
},
109+
{
110+
name: "empty path",
111+
method: http.MethodGet,
112+
host: "example.com",
113+
path: "",
114+
sslPort: 443,
115+
wantLoc: "https://example.com",
116+
wantCode: http.StatusPermanentRedirect,
117+
},
118+
}
119+
120+
for _, tt := range tests {
121+
t.Run(tt.name, func(t *testing.T) {
122+
req := httptest.NewRequestWithContext(t.Context(), tt.method, "/", nil)
123+
req.Host = tt.host
124+
req.URL.Path = tt.path
125+
req.URL.RawQuery = tt.query
126+
127+
called := false
128+
inner := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
129+
called = true
130+
})
131+
132+
rr := httptest.NewRecorder()
133+
Redirect(inner, tt.sslPort).ServeHTTP(rr, req)
134+
135+
assert.Equal(t, tt.wantCode, rr.Code)
136+
assert.Equal(t, tt.wantLoc, rr.Header().Get("Location"))
137+
assert.NotEmpty(t, rr.Header().Get("Strict-Transport-Security"), "STS header must be set")
138+
assert.False(t, called, "inner handler should not be called")
139+
})
140+
}
141+
}

0 commit comments

Comments
 (0)