Skip to content

Commit a2c489b

Browse files
authored
Merge pull request #10 from StudioLambda/fix/critical-audit-findings
Fix all critical audit findings (router, database, events)
2 parents 9be9e44 + 8fc6c8f commit a2c489b

10 files changed

Lines changed: 482 additions & 38 deletions

File tree

framework/database/sql.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ func (database *SQL) FindNamed(ctx context.Context, query string, dest any, arg
174174
// original error and any rollback error are joined. If fn succeeds,
175175
// the transaction is committed. Nested transactions are not supported
176176
// and return contract.ErrDatabaseNestedTransaction.
177-
func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Database) error) error {
177+
//
178+
// If fn panics, the transaction is rolled back before the panic is
179+
// re-raised, preventing connection pool leaks.
180+
func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Database) error) (retErr error) {
178181
if _, ok := database.db.(*sqlx.Tx); ok {
179182
return contract.ErrDatabaseNestedTransaction
180183
}
@@ -185,10 +188,23 @@ func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Da
185188
return err
186189
}
187190

191+
defer func() {
192+
if p := recover(); p != nil {
193+
_ = tx.Rollback()
194+
panic(p)
195+
}
196+
197+
if retErr != nil {
198+
retErr = errors.Join(retErr, tx.Rollback())
199+
}
200+
}()
201+
188202
txWrapper := &SQL{db: tx, raw: database.raw}
189203

190204
if err := fn(txWrapper); err != nil {
191-
return errors.Join(err, tx.Rollback())
205+
retErr = err
206+
207+
return
192208
}
193209

194210
return tx.Commit()

framework/database/sql_test.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//go:build cgo
2+
3+
package database_test
4+
5+
import (
6+
"context"
7+
"errors"
8+
"testing"
9+
10+
_ "github.com/mattn/go-sqlite3"
11+
"github.com/stretchr/testify/require"
12+
"github.com/studiolambda/cosmos/contract"
13+
"github.com/studiolambda/cosmos/framework/database"
14+
)
15+
16+
func newTestDB(t *testing.T) *database.SQL {
17+
t.Helper()
18+
19+
db, err := database.NewSQL("sqlite3", ":memory:")
20+
21+
require.NoError(t, err)
22+
23+
t.Cleanup(func() {
24+
require.NoError(t, db.Close())
25+
})
26+
27+
return db
28+
}
29+
30+
func TestWithTransactionPanicRollsBack(t *testing.T) {
31+
t.Parallel()
32+
33+
db := newTestDB(t)
34+
ctx := context.Background()
35+
36+
_, err := db.Exec(ctx, "CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT)")
37+
38+
require.NoError(t, err)
39+
40+
require.Panics(t, func() {
41+
_ = db.WithTransaction(ctx, func(tx contract.Database) error {
42+
_, err := tx.Exec(ctx, "INSERT INTO items (name) VALUES (?)", "should-be-rolled-back")
43+
44+
if err != nil {
45+
return err
46+
}
47+
48+
panic("unexpected failure")
49+
})
50+
})
51+
52+
// The row must not exist because the transaction was rolled back.
53+
var count int
54+
err = db.Find(ctx, "SELECT COUNT(*) FROM items", &count)
55+
56+
require.NoError(t, err)
57+
require.Equal(t, 0, count)
58+
}
59+
60+
func TestWithTransactionNestedReturnsError(t *testing.T) {
61+
t.Parallel()
62+
63+
db := newTestDB(t)
64+
ctx := context.Background()
65+
66+
err := db.WithTransaction(ctx, func(tx contract.Database) error {
67+
return tx.WithTransaction(ctx, func(inner contract.Database) error {
68+
return nil
69+
})
70+
})
71+
72+
require.ErrorIs(t, err, contract.ErrDatabaseNestedTransaction)
73+
}
74+
75+
func TestWithTransactionCommitsOnSuccess(t *testing.T) {
76+
t.Parallel()
77+
78+
db := newTestDB(t)
79+
ctx := context.Background()
80+
81+
_, err := db.Exec(ctx, "CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT)")
82+
83+
require.NoError(t, err)
84+
85+
err = db.WithTransaction(ctx, func(tx contract.Database) error {
86+
_, err := tx.Exec(ctx, "INSERT INTO items (name) VALUES (?)", "committed")
87+
88+
return err
89+
})
90+
91+
require.NoError(t, err)
92+
93+
var count int
94+
err = db.Find(ctx, "SELECT COUNT(*) FROM items", &count)
95+
96+
require.NoError(t, err)
97+
require.Equal(t, 1, count)
98+
}
99+
100+
func TestWithTransactionRollsBackOnError(t *testing.T) {
101+
t.Parallel()
102+
103+
db := newTestDB(t)
104+
ctx := context.Background()
105+
106+
_, err := db.Exec(ctx, "CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT)")
107+
108+
require.NoError(t, err)
109+
110+
callbackErr := errors.New("something went wrong")
111+
112+
err = db.WithTransaction(ctx, func(tx contract.Database) error {
113+
_, err := tx.Exec(ctx, "INSERT INTO items (name) VALUES (?)", "should-be-rolled-back")
114+
115+
if err != nil {
116+
return err
117+
}
118+
119+
return callbackErr
120+
})
121+
122+
require.ErrorIs(t, err, callbackErr)
123+
124+
var count int
125+
err = db.Find(ctx, "SELECT COUNT(*) FROM items", &count)
126+
127+
require.NoError(t, err)
128+
require.Equal(t, 0, count)
129+
}
130+
131+
func TestWithTransactionPanicPreservesValue(t *testing.T) {
132+
t.Parallel()
133+
134+
db := newTestDB(t)
135+
ctx := context.Background()
136+
137+
var recovered any
138+
139+
func() {
140+
defer func() {
141+
recovered = recover()
142+
}()
143+
144+
_ = db.WithTransaction(ctx, func(tx contract.Database) error {
145+
panic("test-panic-value")
146+
})
147+
}()
148+
149+
require.Equal(t, "test-panic-value", recovered)
150+
}

framework/event/memory.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,26 +109,31 @@ func (broker *MemoryBroker) Publish(
109109
}
110110

111111
broker.mu.RLock()
112-
defer broker.mu.RUnlock()
112+
113+
var matched []contract.EventHandler
113114

114115
for pattern, patternHandlers := range broker.handlers {
115-
if !matchEvent(pattern, event) {
116-
continue
116+
if matchEvent(pattern, event) {
117+
for _, handler := range patternHandlers {
118+
matched = append(matched, handler)
119+
}
117120
}
121+
}
118122

119-
for _, handler := range patternHandlers {
120-
broker.wg.Add(1)
121-
broker.sem <- struct{}{}
123+
broker.mu.RUnlock()
122124

123-
go func(h contract.EventHandler) {
124-
defer func() {
125-
<-broker.sem
126-
broker.wg.Done()
127-
}()
125+
for _, handler := range matched {
126+
broker.wg.Add(1)
127+
broker.sem <- struct{}{}
128128

129-
broker.deliverToHandler(h, encoded)
130-
}(handler)
131-
}
129+
go func() {
130+
defer func() {
131+
<-broker.sem
132+
broker.wg.Done()
133+
}()
134+
135+
broker.deliverToHandler(handler, encoded)
136+
}()
132137
}
133138

134139
return nil

framework/event/memory_test.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestMemoryBrokerWildcardStar(t *testing.T) {
8484
require.Equal(t, int64(1), atomic.LoadInt64(&received))
8585
}
8686

87-
func TestMemoryBrokerWildcardStarDoesNotMatchMultipleTokens(t *testing.T) {
87+
func TestMemoryBrokerPublishDoesNotDeadlockWithSubscribeInHandler(t *testing.T) {
8888
t.Parallel()
8989

9090
ctx := context.Background()
@@ -94,25 +94,32 @@ func TestMemoryBrokerWildcardStarDoesNotMatchMultipleTokens(t *testing.T) {
9494
_ = broker.Close()
9595
})
9696

97-
var received int64
97+
done := make(chan struct{})
9898

9999
_, err := broker.Subscribe(
100-
ctx, "user.*.created", func(payload contract.EventPayload) {
101-
atomic.AddInt64(&received, 1)
100+
ctx, "test.event", func(payload contract.EventPayload) {
101+
// This Subscribe call requires broker.mu.Lock().
102+
// If Publish still holds broker.mu.RLock() during
103+
// dispatch, this will deadlock.
104+
_, _ = broker.Subscribe(
105+
ctx, "other.event", func(payload contract.EventPayload) {},
106+
)
107+
108+
close(done)
102109
},
103110
)
104111

105112
require.NoError(t, err)
106113

107-
err = broker.Publish(
108-
ctx, "user.123.456.created", "data",
109-
)
110-
114+
err = broker.Publish(ctx, "test.event", "data")
111115
require.NoError(t, err)
112116

113-
time.Sleep(50 * time.Millisecond)
114-
115-
require.Equal(t, int64(0), atomic.LoadInt64(&received))
117+
select {
118+
case <-done:
119+
// Handler completed without deadlock.
120+
case <-time.After(3 * time.Second):
121+
t.Fatal("deadlock detected: handler calling Subscribe blocked for 3 seconds")
122+
}
116123
}
117124

118125
func TestMemoryBrokerWildcardHash(t *testing.T) {

framework/event/mqtt.go

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"log/slog"
78
"net/url"
89
"strconv"
910
"strings"
@@ -196,11 +197,7 @@ func NewMQTTBrokerWith(options *MQTTBrokerOptions) (*MQTTBroker, error) {
196197
ClientConfig: paho.ClientConfig{
197198
ClientID: "",
198199
OnPublishReceived: []func(paho.PublishReceived) (bool, error){
199-
func(pr paho.PublishReceived) (bool, error) {
200-
broker.route(pr.Packet)
201-
202-
return true, nil
203-
},
200+
broker.HandlePublish,
204201
},
205202
},
206203
}
@@ -225,6 +222,19 @@ func NewMQTTBrokerWith(options *MQTTBrokerOptions) (*MQTTBroker, error) {
225222
// autopaho ConnectionManager and QoS level. This constructor is
226223
// useful for advanced scenarios where the user needs full control
227224
// over the MQTT connection configuration.
225+
//
226+
// Because autopaho.ConnectionManager does not allow post-creation
227+
// configuration changes, the caller must wire up message routing
228+
// when building the paho config. Use [MQTTBroker.HandlePublish] in
229+
// the ClientConfig.OnPublishReceived slice:
230+
//
231+
// broker := event.NewMQTTBrokerFrom(nil, 1)
232+
// cfg.ClientConfig.OnPublishReceived = []func(paho.PublishReceived) (bool, error){
233+
// broker.HandlePublish,
234+
// }
235+
// cm, err := autopaho.NewConnection(ctx, cfg)
236+
// // then set the client so Publish/Subscribe/Close work:
237+
// broker.SetClient(cm)
228238
func NewMQTTBrokerFrom(
229239
client *autopaho.ConnectionManager,
230240
qos byte,
@@ -237,25 +247,69 @@ func NewMQTTBrokerFrom(
237247
}
238248
}
239249

250+
// SetClient sets the underlying autopaho ConnectionManager. This is
251+
// intended for use with [NewMQTTBrokerFrom] when the broker must be
252+
// created before the ConnectionManager so that [MQTTBroker.HandlePublish]
253+
// can be wired into the paho config.
254+
func (broker *MQTTBroker) SetClient(client *autopaho.ConnectionManager) {
255+
broker.client = client
256+
}
257+
258+
// HandlePublish is a paho OnPublishReceived callback that routes
259+
// incoming MQTT messages to registered handlers. Callers using
260+
// [NewMQTTBrokerFrom] must include this method in the paho
261+
// ClientConfig.OnPublishReceived slice so that subscribed handlers
262+
// receive messages.
263+
func (broker *MQTTBroker) HandlePublish(pr paho.PublishReceived) (bool, error) {
264+
broker.route(pr.Packet)
265+
266+
return true, nil
267+
}
268+
240269
// route delivers an incoming MQTT message to all matching
241270
// handlers based on topic pattern matching. This implements
242271
// fan-out behavior where multiple handlers can receive the
243272
// same message if they subscribed to matching patterns.
273+
//
274+
// Each handler is called with panic recovery so that a
275+
// panicking handler does not crash the paho client goroutine
276+
// and tear down the entire MQTT connection.
244277
func (broker *MQTTBroker) route(pb *paho.Publish) {
245278
broker.mu.RLock()
246279
defer broker.mu.RUnlock()
247280

248281
for pattern, handlers := range broker.handlers {
249282
if matchTopic(pattern, pb.Topic) {
250283
for _, handler := range handlers {
251-
handler(func(dest any) error {
252-
return json.Unmarshal(pb.Payload, dest)
253-
})
284+
broker.deliverToHandler(handler, pb.Topic, pb.Payload)
254285
}
255286
}
256287
}
257288
}
258289

290+
// deliverToHandler invokes a single handler with panic recovery.
291+
// Recovered panics are logged via slog so they remain visible for
292+
// debugging without propagating to the caller.
293+
func (broker *MQTTBroker) deliverToHandler(
294+
handler contract.EventHandler,
295+
topic string,
296+
payload []byte,
297+
) {
298+
defer func() {
299+
if recovered := recover(); recovered != nil {
300+
slog.Error(
301+
"mqtt event handler panicked",
302+
"topic", topic,
303+
"error", recovered,
304+
)
305+
}
306+
}()
307+
308+
handler(func(dest any) error {
309+
return json.Unmarshal(payload, dest)
310+
})
311+
}
312+
259313
// Publish sends an event with the given name and payload to all
260314
// subscribers listening for that event. The payload is serialized
261315
// to JSON and the event name is converted to MQTT topic format.

0 commit comments

Comments
 (0)