Skip to content

Commit 337d578

Browse files
authored
Merge pull request #4 from StudioLambda/fix/high-security-vulnerabilities
Fix high-severity security vulnerabilities
2 parents 9edd6a0 + b779ef2 commit 337d578

25 files changed

Lines changed: 603 additions & 150 deletions

contract/request/body.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,26 @@ func LimitedBytes(r *http.Request, maxSize int64) ([]byte, error) {
4444
// Prefer [LimitedString] or apply [http.MaxBytesReader] in a
4545
// middleware to prevent memory exhaustion from oversized requests.
4646
func String(r *http.Request) (string, error) {
47-
b, err := Bytes(r)
47+
body, err := Bytes(r)
4848

4949
if err != nil {
5050
return "", err
5151
}
5252

53-
return string(b), nil
53+
return string(body), nil
5454
}
5555

5656
// LimitedString reads the request body up to maxSize bytes and
5757
// returns it as a string. If the body exceeds maxSize, the result
5858
// is truncated. Pass -1 to use [DefaultMaxBodySize].
5959
func LimitedString(r *http.Request, maxSize int64) (string, error) {
60-
b, err := LimitedBytes(r, maxSize)
60+
body, err := LimitedBytes(r, maxSize)
6161

6262
if err != nil {
6363
return "", err
6464
}
6565

66-
return string(b), nil
66+
return string(body), nil
6767
}
6868

6969
// JSON decodes JSON data from the request body into a value of type T.

contract/request/cookie.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import "net/http"
88
//
99
// Parameters:
1010
// - r: The HTTP request to search for the cookie
11-
// - k: The name of the cookie to retrieve
11+
// - name: The name of the cookie to retrieve
1212
//
1313
// Returns the cookie object or nil if not found.
14-
func Cookie(r *http.Request, k string) *http.Cookie {
15-
if cookie, err := r.Cookie(k); err == nil {
14+
func Cookie(r *http.Request, name string) *http.Cookie {
15+
if cookie, err := r.Cookie(name); err == nil {
1616
return cookie
1717
}
1818

@@ -25,12 +25,12 @@ func Cookie(r *http.Request, k string) *http.Cookie {
2525
//
2626
// Parameters:
2727
// - r: The HTTP request to search for the cookie
28-
// - k: The name of the cookie whose value to retrieve
28+
// - name: The name of the cookie whose value to retrieve
2929
//
3030
// Returns the cookie value as a string, or empty string if not found.
31-
func CookieValue(r *http.Request, k string) string {
32-
if c := Cookie(r, k); c != nil {
33-
return c.Value
31+
func CookieValue(r *http.Request, name string) string {
32+
if cookie := Cookie(r, name); cookie != nil {
33+
return cookie.Value
3434
}
3535

3636
return ""
@@ -42,14 +42,14 @@ func CookieValue(r *http.Request, k string) string {
4242
//
4343
// Parameters:
4444
// - r: The HTTP request to search for the cookie
45-
// - k: The name of the cookie whose value to retrieve
46-
// - d: The default value to return if the cookie is not found or empty
45+
// - name: The name of the cookie whose value to retrieve
46+
// - fallback: The default value to return if the cookie is not found or empty
4747
//
4848
// Returns the cookie value if found and non-empty, otherwise the default value.
49-
func CookieValueOr(r *http.Request, k string, d string) string {
50-
if v := CookieValue(r, k); v != "" {
51-
return v
49+
func CookieValueOr(r *http.Request, name string, fallback string) string {
50+
if value := CookieValue(r, name); value != "" {
51+
return value
5252
}
5353

54-
return d
54+
return fallback
5555
}

contract/request/header.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ func Header(r *http.Request, key string) string {
1717

1818
// HeaderOr returns the first value for the given header key, falling
1919
// back to the provided default value if the header is missing or empty.
20-
func HeaderOr(r *http.Request, key string, def string) string {
21-
if h := Header(r, key); h != "" {
22-
return h
20+
func HeaderOr(r *http.Request, key string, fallback string) string {
21+
if value := Header(r, key); value != "" {
22+
return value
2323
}
2424

25-
return def
25+
return fallback
2626
}
2727

2828
// HeaderValues returns all values associated with the given header

contract/request/hooks.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,24 @@ var ErrNoHooksMiddleware = problem.Problem{
1818
// Hooks returns the hooks instance used to attach callbacks
1919
// to lifecycle events. It panics if the hooks middleware has
2020
// not been applied to the request's context.
21+
//
22+
// WARNING: This function panics when hooks are missing. Use
23+
// [TryHooks] for a non-panicking alternative, or ensure the
24+
// [framework.Recover] middleware is in place.
2125
func Hooks(r *http.Request) contract.Hooks {
2226
if hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks); ok {
2327
return hooks
2428
}
2529

2630
panic(ErrNoHooksMiddleware)
2731
}
32+
33+
// TryHooks retrieves the hooks instance from the request
34+
// context without panicking. The boolean return value indicates
35+
// whether hooks were found. This is the safe alternative to
36+
// [Hooks] for use outside the framework handler chain.
37+
func TryHooks(r *http.Request) (contract.Hooks, bool) {
38+
hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks)
39+
40+
return hooks, ok
41+
}

contract/request/param.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import "net/http"
88
//
99
// Parameters:
1010
// - r: The HTTP request containing the path parameters
11-
// - k: The name of the path parameter to retrieve
11+
// - name: The name of the path parameter to retrieve
1212
//
1313
// Returns the parameter value as a string, or empty string if not found.
14-
func Param(r *http.Request, k string) string {
15-
return r.PathValue(k)
14+
func Param(r *http.Request, name string) string {
15+
return r.PathValue(name)
1616
}
1717

1818
// ParamOr retrieves a path parameter value by name, returning a default
@@ -22,14 +22,14 @@ func Param(r *http.Request, k string) string {
2222
//
2323
// Parameters:
2424
// - r: The HTTP request containing the path parameters
25-
// - k: The name of the path parameter to retrieve
26-
// - d: The default value to return if the parameter is not found or empty
25+
// - name: The name of the path parameter to retrieve
26+
// - fallback: The default value to return if the parameter is not found or empty
2727
//
2828
// Returns the parameter value if found and non-empty, otherwise the default value.
29-
func ParamOr(r *http.Request, k string, d string) string {
30-
if p := Param(r, k); p != "" {
31-
return p
29+
func ParamOr(r *http.Request, name string, fallback string) string {
30+
if value := Param(r, name); value != "" {
31+
return value
3232
}
3333

34-
return d
34+
return fallback
3535
}

contract/request/query.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import "net/http"
88
//
99
// Parameters:
1010
// - r: The HTTP request containing the URL with query parameters
11-
// - k: The name of the query parameter to retrieve
11+
// - name: The name of the query parameter to retrieve
1212
//
1313
// Returns the first value associated with the key, or empty string if not found.
14-
func Query(r *http.Request, k string) string {
15-
return r.URL.Query().Get(k)
14+
func Query(r *http.Request, name string) string {
15+
return r.URL.Query().Get(name)
1616
}
1717

1818
// HasQuery checks if a query parameter exists in the HTTP request URL,
@@ -21,11 +21,11 @@ func Query(r *http.Request, k string) string {
2121
//
2222
// Parameters:
2323
// - r: The HTTP request containing the URL with query parameters
24-
// - k: The name of the query parameter to check for
24+
// - name: The name of the query parameter to check for
2525
//
2626
// Returns true if the parameter exists in the query string, false otherwise.
27-
func HasQuery(r *http.Request, k string) bool {
28-
return r.URL.Query().Has(k)
27+
func HasQuery(r *http.Request, name string) bool {
28+
return r.URL.Query().Has(name)
2929
}
3030

3131
// QueryOr retrieves a query parameter value by name, returning a default
@@ -35,14 +35,14 @@ func HasQuery(r *http.Request, k string) bool {
3535
//
3636
// Parameters:
3737
// - r: The HTTP request containing the URL with query parameters
38-
// - k: The name of the query parameter to retrieve
39-
// - d: The default value to return if the parameter doesn't exist
38+
// - name: The name of the query parameter to retrieve
39+
// - fallback: The default value to return if the parameter doesn't exist
4040
//
4141
// Returns the parameter value if it exists, otherwise the default value.
42-
func QueryOr(r *http.Request, k string, d string) string {
43-
if HasQuery(r, k) {
44-
return Query(r, k)
42+
func QueryOr(r *http.Request, name string, fallback string) string {
43+
if HasQuery(r, name) {
44+
return Query(r, name)
4545
}
4646

47-
return d
47+
return fallback
4848
}

contract/request/session.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ func SessionKeyed(r *http.Request, key any) (contract.Session, bool) {
3434
// MustSessionKeyed retrieves the session from the request context
3535
// using the provided key. It panics with [ErrSessionNotFound] if
3636
// no session is found.
37+
//
38+
// WARNING: This function panics when the session is missing. Use
39+
// [SessionKeyed] for a non-panicking alternative, or ensure the
40+
// [framework.Recover] middleware is in place.
3741
func MustSessionKeyed(r *http.Request, key any) contract.Session {
3842
if s, ok := SessionKeyed(r, key); ok {
3943
return s
@@ -45,6 +49,10 @@ func MustSessionKeyed(r *http.Request, key any) contract.Session {
4549
// MustSession retrieves the session from the request context using
4650
// the default [contract.SessionKey]. It panics with [ErrSessionNotFound]
4751
// if no session is found.
52+
//
53+
// WARNING: This function panics when the session is missing. Use
54+
// [Session] for a non-panicking alternative, or ensure the
55+
// [framework.Recover] middleware is in place.
4856
func MustSession(r *http.Request) contract.Session {
4957
return MustSessionKeyed(r, contract.SessionKey)
5058
}

contract/response/static.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package response
33
import (
44
"encoding/json"
55
"encoding/xml"
6+
"errors"
67
htmltemplate "html/template"
78
"net/http"
9+
"net/url"
10+
"strings"
811
"text/template"
912
)
1013

@@ -159,6 +162,9 @@ func XML(w http.ResponseWriter, status int, data any) error {
159162
// This is a generic redirect function that allows you to specify any redirect status code.
160163
// The Location header is set to the provided URL and the appropriate status code is returned.
161164
//
165+
// WARNING: This function does not validate the redirect URL. If the URL comes
166+
// from user input, use [SafeRedirect] instead to prevent open redirect attacks.
167+
//
162168
// Common redirect status codes:
163169
// - 301: Moved Permanently
164170
// - 302: Found (temporary redirect)
@@ -175,3 +181,52 @@ func Redirect(w http.ResponseWriter, status int, url string) error {
175181

176182
return Status(w, status)
177183
}
184+
185+
// ErrUnsafeRedirect is returned by [SafeRedirect] when the target
186+
// URL is not a safe relative path. This prevents open redirect
187+
// attacks where an attacker tricks users into visiting a malicious
188+
// external site via your application's redirect endpoint.
189+
var ErrUnsafeRedirect = errors.New("unsafe redirect URL: must be a relative path")
190+
191+
// SafeRedirect sends an HTTP redirect response only if the target
192+
// URL is a safe relative path (starts with "/" and does not contain
193+
// a scheme, host, or protocol-relative prefix "//"). This prevents
194+
// open redirect vulnerabilities when the redirect target comes from
195+
// user input such as query parameters or form fields.
196+
//
197+
// Returns [ErrUnsafeRedirect] if the URL fails validation.
198+
//
199+
// Parameters:
200+
// - w: The HTTP response writer
201+
// - status: The HTTP redirect status code to set
202+
// - rawURL: The URL to redirect the user to (must be a relative path)
203+
func SafeRedirect(w http.ResponseWriter, status int, rawURL string) error {
204+
if !isRelativePath(rawURL) {
205+
return ErrUnsafeRedirect
206+
}
207+
208+
return Redirect(w, status, rawURL)
209+
}
210+
211+
// isRelativePath validates that a URL is a safe relative path
212+
// and not an absolute URL, protocol-relative URL, javascript:
213+
// URI, or data: URI that could be used in an open redirect attack.
214+
func isRelativePath(rawURL string) bool {
215+
if !strings.HasPrefix(rawURL, "/") {
216+
return false
217+
}
218+
219+
// Reject protocol-relative URLs like "//evil.com"
220+
if strings.HasPrefix(rawURL, "//") {
221+
return false
222+
}
223+
224+
parsed, err := url.Parse(rawURL)
225+
226+
if err != nil {
227+
return false
228+
}
229+
230+
// Reject if a scheme or host is present.
231+
return parsed.Scheme == "" && parsed.Host == ""
232+
}

framework/cache/redis.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func NewRedisFrom(client *redis.Client) *RedisClient {
3535
// Get retrieves a value by key. Returns contract.ErrCacheKeyNotFound
3636
// wrapped with the key name when the key does not exist.
3737
func (client *RedisClient) Get(ctx context.Context, key string) (any, error) {
38-
v, err := (*redis.Client)(client).Get(ctx, key).Result()
38+
value, err := (*redis.Client)(client).Get(ctx, key).Result()
3939

4040
if errors.Is(err, redis.Nil) {
4141
return nil, fmt.Errorf("%w: %s", contract.ErrCacheKeyNotFound, key)
@@ -45,7 +45,7 @@ func (client *RedisClient) Get(ctx context.Context, key string) (any, error) {
4545
return nil, err
4646
}
4747

48-
return v, nil
48+
return value, nil
4949
}
5050

5151
// Put stores a value with the given TTL. A zero TTL means the key
@@ -61,18 +61,18 @@ func (client *RedisClient) Delete(ctx context.Context, key string) error {
6161

6262
// Has reports whether the key exists in Redis.
6363
func (client *RedisClient) Has(ctx context.Context, key string) (bool, error) {
64-
n, err := (*redis.Client)(client).Exists(ctx, key).Result()
64+
count, err := (*redis.Client)(client).Exists(ctx, key).Result()
6565

6666
if err != nil {
6767
return false, err
6868
}
6969

70-
return n > 0, nil
70+
return count > 0, nil
7171
}
7272

7373
// Pull atomically retrieves and deletes a key using Redis GETDEL.
7474
// The stored value is JSON-decoded into the return value.
75-
func (client *RedisClient) Pull(ctx context.Context, key string) (v any, e error) {
75+
func (client *RedisClient) Pull(ctx context.Context, key string) (value any, err error) {
7676
encoded, err := (*redis.Client)(client).GetDel(ctx, key).Result()
7777

7878
if errors.Is(err, redis.Nil) {
@@ -83,11 +83,11 @@ func (client *RedisClient) Pull(ctx context.Context, key string) (v any, e error
8383
return nil, err
8484
}
8585

86-
if err := json.Unmarshal([]byte(encoded), &v); err != nil {
86+
if err := json.Unmarshal([]byte(encoded), &value); err != nil {
8787
return nil, err
8888
}
8989

90-
return v, nil
90+
return value, nil
9191
}
9292

9393
// Forever stores a value with no expiration.

0 commit comments

Comments
 (0)