Skip to content

Commit fb6e474

Browse files
authored
Merge pull request #12 from StudioLambda/fix/medium-severity-audit-findings
Fix medium-severity audit findings across all modules
2 parents e0f5eab + 3ec96dc commit fb6e474

50 files changed

Lines changed: 1860 additions & 282 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

contract/hash.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,28 @@ package contract
33
// Hasher defines the interface for hashing and verifying hashed values.
44
// Implementations of Hasher are responsible for generating cryptographic hashes
55
// and verifying that values match their corresponding hashes.
6+
//
7+
// Both Hash and Check zero the input value slice after use as a security
8+
// measure. Callers must not reuse the value slice after calling either method.
69
type Hasher interface {
710
// Hash computes a cryptographic hash of the given byte slice and returns the hash.
11+
// The input value is zeroed after hashing as a security measure.
12+
// Callers must not reuse the value slice after calling Hash.
813
// It returns an error if the hashing operation fails.
914
Hash(value []byte) ([]byte, error)
1015

1116
// Check verifies that the given value matches the provided hash.
17+
// The input value is zeroed after verification as a security measure.
18+
// Callers must not reuse the value slice after calling Check.
1219
// It returns true if the value and hash match, false otherwise.
1320
// It returns an error if the verification operation fails.
1421
Check(value []byte, hash []byte) (bool, error)
1522
}
23+
24+
// Rehashable extends [Hasher] with the ability to detect stale hash parameters.
25+
// Implementations should return true when the given hash was produced with
26+
// different parameters than the current configuration, indicating the value
27+
// should be re-hashed on the next successful authentication.
28+
type Rehashable interface {
29+
NeedsRehash(hash []byte) bool
30+
}

contract/request/query.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ func HasQuery(r *http.Request, name string) bool {
4848
//
4949
// Returns the parameter value if it exists, otherwise the default value.
5050
func QueryOr(r *http.Request, name string, fallback string) string {
51-
if HasQuery(r, name) {
52-
return Query(r, name)
51+
values := r.URL.Query()
52+
53+
if !values.Has(name) {
54+
return fallback
5355
}
5456

55-
return fallback
57+
return values.Get(name)
5658
}
5759

5860
// QueryInt retrieves a query parameter by name and parses

contract/request/query_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,33 @@ func TestQueryIntOrReturnsFallbackWhenNotInteger(t *testing.T) {
157157

158158
require.Equal(t, 1, result)
159159
}
160+
161+
func TestQueryOrParsesOnceReturnsValueWhenPresent(t *testing.T) {
162+
t.Parallel()
163+
164+
r := httptest.NewRequest(http.MethodGet, "/?key=value", nil)
165+
166+
result := request.QueryOr(r, "key", "default")
167+
168+
require.Equal(t, "value", result)
169+
}
170+
171+
func TestQueryOrParsesOnceReturnsFallbackWhenMissing(t *testing.T) {
172+
t.Parallel()
173+
174+
r := httptest.NewRequest(http.MethodGet, "/", nil)
175+
176+
result := request.QueryOr(r, "key", "default")
177+
178+
require.Equal(t, "default", result)
179+
}
180+
181+
func TestQueryOrParsesOnceReturnsEmptyWhenPresentButEmpty(t *testing.T) {
182+
t.Parallel()
183+
184+
r := httptest.NewRequest(http.MethodGet, "/?key=", nil)
185+
186+
result := request.QueryOr(r, "key", "default")
187+
188+
require.Equal(t, "", result)
189+
}

contract/response/static.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package response
22

33
import (
4+
"bytes"
45
"encoding/json"
56
"encoding/xml"
67
"errors"
@@ -87,10 +88,18 @@ func String(w http.ResponseWriter, status int, data string) error {
8788
// - tmpl: The text template to execute
8889
// - data: The data to pass to the template for execution
8990
func StringTemplate(w http.ResponseWriter, status int, tmpl template.Template, data any) error {
91+
var buf bytes.Buffer
92+
93+
if err := tmpl.Execute(&buf, data); err != nil {
94+
return err
95+
}
96+
9097
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
9198
w.WriteHeader(status)
9299

93-
return tmpl.Execute(w, data)
100+
_, err := w.Write(buf.Bytes())
101+
102+
return err
94103
}
95104

96105
// HTML writes HTML content to the response writer with the
@@ -121,10 +130,18 @@ func HTML(w http.ResponseWriter, status int, data string) error {
121130
// - tmpl: The HTML template to execute (must be html/template for XSS safety)
122131
// - data: The data to pass to the template for execution
123132
func HTMLTemplate(w http.ResponseWriter, status int, tmpl htmltemplate.Template, data any) error {
133+
var buf bytes.Buffer
134+
135+
if err := tmpl.Execute(&buf, data); err != nil {
136+
return err
137+
}
138+
124139
w.Header().Set("Content-Type", "text/html; charset=utf-8")
125140
w.WriteHeader(status)
126141

127-
return tmpl.Execute(w, data)
142+
_, err := w.Write(buf.Bytes())
143+
144+
return err
128145
}
129146

130147
// JSON serializes the given data to JSON format and writes it to the

contract/response/static_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,33 @@ func TestErrUnsafeRedirectMessage(t *testing.T) {
389389
response.ErrUnsafeRedirect.Error(),
390390
)
391391
}
392+
393+
func TestStringTemplateBuffersBeforeWritingStatus(t *testing.T) {
394+
t.Parallel()
395+
396+
w := httptest.NewRecorder()
397+
tmpl := template.Must(
398+
template.New("test").Parse("{{.Name}}"),
399+
)
400+
401+
err := response.StringTemplate(w, http.StatusOK, *tmpl, 42)
402+
403+
require.Error(t, err)
404+
require.Equal(t, http.StatusOK, w.Code)
405+
require.Empty(t, w.Body.String())
406+
}
407+
408+
func TestHTMLTemplateBuffersBeforeWritingStatus(t *testing.T) {
409+
t.Parallel()
410+
411+
w := httptest.NewRecorder()
412+
tmpl := htmltemplate.Must(
413+
htmltemplate.New("test").Parse("{{.Name}}"),
414+
)
415+
416+
err := response.HTMLTemplate(w, http.StatusOK, *tmpl, 42)
417+
418+
require.Error(t, err)
419+
require.Equal(t, http.StatusOK, w.Code)
420+
require.Empty(t, w.Body.String())
421+
}

framework/cache/memory.go

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@ package cache
22

33
import (
44
"context"
5-
"sync"
5+
"errors"
66
"time"
77

88
"github.com/patrickmn/go-cache"
99
"github.com/studiolambda/cosmos/contract"
1010
)
1111

12-
// Memory implements contract.Cache using an in-memory store backed
13-
// by patrickmn/go-cache. It is suitable for single-process
14-
// applications and testing scenarios where persistence across
15-
// restarts is not required.
12+
// Memory implements [contract.Cache] using an in-memory store backed by
13+
// patrickmn/go-cache. Note that this dependency is unmaintained since 2017;
14+
// consider migrating to a maintained alternative for production use.
15+
//
16+
// Memory is suitable for single-process applications and testing scenarios
17+
// where persistence across restarts is not required.
1618
type Memory struct {
17-
mux sync.Mutex
1819
store *cache.Cache
1920
}
2021

@@ -43,6 +44,10 @@ func (memory *Memory) Get(_ context.Context, key string) (any, error) {
4344

4445
// Put stores a value in the in-memory cache with the given TTL.
4546
// A zero TTL uses the default expiration configured at creation.
47+
//
48+
// WARNING: Values are stored by reference. Callers must not mutate
49+
// values after storing or after retrieval. For safety with mutable
50+
// types, store copies or use value types.
4651
func (memory *Memory) Put(_ context.Context, key string, value any, ttl time.Duration) error {
4752
memory.store.Set(key, value, ttl)
4853

@@ -64,12 +69,8 @@ func (memory *Memory) Has(_ context.Context, key string) (bool, error) {
6469
return found, nil
6570
}
6671

67-
// Pull atomically retrieves and removes the value for the given key.
68-
// It holds a mutex to prevent races between the get and delete steps.
72+
// Pull retrieves and removes the value for the given key.
6973
func (memory *Memory) Pull(ctx context.Context, key string) (any, error) {
70-
memory.mux.Lock()
71-
defer memory.mux.Unlock()
72-
7374
val, err := memory.Get(ctx, key)
7475

7576
if err != nil {
@@ -84,38 +85,40 @@ func (memory *Memory) Pull(ctx context.Context, key string) (any, error) {
8485
}
8586

8687
// Forever stores a value permanently with no expiration.
88+
//
89+
// WARNING: Values are stored by reference. Callers must not mutate
90+
// values after storing or after retrieval. For safety with mutable
91+
// types, store copies or use value types.
8792
func (memory *Memory) Forever(_ context.Context, key string, value any) error {
8893
memory.store.Set(key, value, cache.NoExpiration)
8994

9095
return nil
9196
}
9297

9398
// Increment atomically increases the integer value stored at key by
94-
// the given amount. Returns contract.ErrCacheKeyNotFound if the key
99+
// the given amount. Returns [contract.ErrCacheKeyNotFound] if the key
95100
// does not exist.
96-
func (memory *Memory) Increment(ctx context.Context, key string, by int64) (int64, error) {
97-
memory.mux.Lock()
98-
defer memory.mux.Unlock()
101+
func (memory *Memory) Increment(_ context.Context, key string, by int64) (int64, error) {
102+
result, err := memory.store.IncrementInt64(key, by)
99103

100-
if found, _ := memory.Has(ctx, key); !found {
104+
if err != nil {
101105
return 0, contract.ErrCacheKeyNotFound
102106
}
103107

104-
return memory.store.IncrementInt64(key, by)
108+
return result, nil
105109
}
106110

107111
// Decrement atomically decreases the integer value stored at key by
108-
// the given amount. Returns contract.ErrCacheKeyNotFound if the key
112+
// the given amount. Returns [contract.ErrCacheKeyNotFound] if the key
109113
// does not exist.
110-
func (memory *Memory) Decrement(ctx context.Context, key string, by int64) (int64, error) {
111-
memory.mux.Lock()
112-
defer memory.mux.Unlock()
114+
func (memory *Memory) Decrement(_ context.Context, key string, by int64) (int64, error) {
115+
result, err := memory.store.DecrementInt64(key, by)
113116

114-
if found, _ := memory.Has(ctx, key); !found {
117+
if err != nil {
115118
return 0, contract.ErrCacheKeyNotFound
116119
}
117120

118-
return memory.store.DecrementInt64(key, by)
121+
return result, nil
119122
}
120123

121124
// Remember retrieves the cached value for the given key, or
@@ -129,39 +132,43 @@ func (memory *Memory) Decrement(ctx context.Context, key string, by int64) (int6
129132
// use golang.org/x/sync/singleflight to deduplicate concurrent
130133
// calls for the same key.
131134
func (memory *Memory) Remember(ctx context.Context, key string, ttl time.Duration, compute func() (any, error)) (any, error) {
132-
val, err := memory.Get(ctx, key)
135+
value, err := memory.Get(ctx, key)
133136

134137
if err == nil {
135-
return val, nil
138+
return value, nil
139+
}
140+
141+
if !errors.Is(err, contract.ErrCacheKeyNotFound) {
142+
return nil, err
136143
}
137144

138-
val, err = compute()
145+
value, err = compute()
139146

140147
if err != nil {
141148
return nil, err
142149
}
143150

144-
_ = memory.Put(ctx, key, val, ttl)
145-
146-
return val, nil
151+
return value, memory.Put(ctx, key, value, ttl)
147152
}
148153

149154
// RememberForever retrieves the cached value for the given key, or
150155
// computes and stores it permanently if the key is not found.
151156
func (memory *Memory) RememberForever(ctx context.Context, key string, compute func() (any, error)) (any, error) {
152-
val, err := memory.Get(ctx, key)
157+
value, err := memory.Get(ctx, key)
153158

154159
if err == nil {
155-
return val, nil
160+
return value, nil
156161
}
157162

158-
val, err = compute()
163+
if !errors.Is(err, contract.ErrCacheKeyNotFound) {
164+
return nil, err
165+
}
166+
167+
value, err = compute()
159168

160169
if err != nil {
161170
return nil, err
162171
}
163172

164-
_ = memory.Forever(ctx, key, val)
165-
166-
return val, nil
173+
return value, memory.Forever(ctx, key, value)
167174
}

0 commit comments

Comments
 (0)