Skip to content

Commit 3b87c0a

Browse files
add delayed mysql requeuer (#52)
Co-authored-by: adrian.zajkowski <adrian.zajkowski@nordsec.com>
1 parent 4683792 commit 3b87c0a

3 files changed

Lines changed: 125 additions & 8 deletions

File tree

pkg/sql/delayed_mysql.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"fmt"
66
"strings"
7+
"time"
78

89
"github.com/ThreeDotsLabs/watermill"
910
"github.com/ThreeDotsLabs/watermill/components/delay"
@@ -87,7 +88,7 @@ func (c *DelayedMySQLSubscriberConfig) setDefaults() {
8788
func NewDelayedMySQLSubscriber(db Beginner, config DelayedMySQLSubscriberConfig) (message.Subscriber, error) {
8889
config.setDefaults()
8990

90-
where := "delayed_until <= NOW()"
91+
where := "delayed_until <= UTC_TIMESTAMP()"
9192

9293
if config.AllowNoDelay {
9394
where += " OR delayed_until IS NULL"
@@ -138,7 +139,7 @@ func (a delayedMySQLSchemaAdapter) SchemaInitializingQueries(params SchemaInitia
138139
` + "`acked`" + ` BOOLEAN NOT NULL DEFAULT FALSE,
139140
` + "`created_at`" + ` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
140141
` + "`delayed_until`" + ` TIMESTAMP NULL DEFAULT NULL,
141-
INDEX ` + "`delayed_until_idx`" + ` (` + "`delayed_until`" + `)
142+
INDEX ` + "`idx_acked_delayed`" + ` (` + "`acked`" + `, ` + "`delayed_until`" + `)
142143
);
143144
`
144145

@@ -186,11 +187,11 @@ func delayedMySQLInsertArgs(msgs message.Messages) ([]any, error) {
186187
if delayedUntilStr == "" {
187188
args = append(args, nil)
188189
} else {
189-
// Convert ISO 8601 to MySQL TIMESTAMP format: "2025-10-22T09:58:00Z" -> "2025-10-22 09:58:00"
190-
delayedUntilStr = strings.Replace(delayedUntilStr, "T", " ", 1)
191-
delayedUntilStr = strings.TrimSuffix(delayedUntilStr, "Z")
192-
193-
args = append(args, delayedUntilStr)
190+
delayedUntil, err := time.Parse(time.RFC3339, delayedUntilStr)
191+
if err != nil {
192+
return nil, fmt.Errorf("could not parse delayed_until timestamp %s: %w", delayedUntilStr, err)
193+
}
194+
args = append(args, delayedUntil)
194195
}
195196
}
196197

pkg/sql/delayed_requeuer.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,50 @@ func NewPostgreSQLDelayedRequeuer(config DelayedRequeuerConfig) (*DelayedRequeue
136136
requeuer: requeuer,
137137
}, nil
138138
}
139+
140+
// NewMySQLDelayedRequeuer creates a new DelayedRequeuer that uses MySQL as a storage.
141+
func NewMySQLDelayedRequeuer(config DelayedRequeuerConfig) (*DelayedRequeuer, error) {
142+
config.setDefaults()
143+
err := config.Validate()
144+
if err != nil {
145+
return nil, err
146+
}
147+
148+
publisher, err := NewDelayedMySQLPublisher(config.DB, DelayedMySQLPublisherConfig{
149+
Logger: config.Logger,
150+
})
151+
if err != nil {
152+
return nil, err
153+
}
154+
155+
subscriber, err := NewDelayedMySQLSubscriber(config.DB, DelayedMySQLSubscriberConfig{
156+
DeleteOnAck: true,
157+
Logger: config.Logger,
158+
})
159+
if err != nil {
160+
return nil, err
161+
}
162+
163+
poisonQueue, err := middleware.PoisonQueue(publisher, config.RequeueTopic)
164+
if err != nil {
165+
return nil, err
166+
}
167+
168+
requeuer, err := requeuer.NewRequeuer(requeuer.Config{
169+
Subscriber: subscriber,
170+
SubscribeTopic: config.RequeueTopic,
171+
Publisher: config.Publisher,
172+
GeneratePublishTopic: config.GeneratePublishTopic,
173+
}, config.Logger)
174+
if err != nil {
175+
return nil, err
176+
}
177+
178+
return &DelayedRequeuer{
179+
middleware: []message.HandlerMiddleware{
180+
poisonQueue,
181+
config.DelayOnError.Middleware,
182+
},
183+
requeuer: requeuer,
184+
}, nil
185+
}

pkg/sql/delayed_requeuer_test.go

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/ThreeDotsLabs/watermill/message"
1515
)
1616

17-
func TestDelayedRequeuer(t *testing.T) {
17+
func TestPostgreSQLDelayedRequeuer(t *testing.T) {
1818
t.Parallel()
1919

2020
db := newPostgreSQL(t)
@@ -81,3 +81,72 @@ func TestDelayedRequeuer(t *testing.T) {
8181
assert.Equal(t, []string{"1", "3"}, receivedMessages)
8282
}, 1*time.Second, 100*time.Millisecond)
8383
}
84+
85+
func TestMySQLDelayedRequeuer(t *testing.T) {
86+
t.Parallel()
87+
88+
db := newMySQL(t)
89+
schemaAdapter := sql.DefaultMySQLSchema{}
90+
offsetsAdapter := sql.DefaultMySQLOffsetsAdapter{}
91+
publisher, subscriber := newPubSub(t, db, "test", schemaAdapter, offsetsAdapter)
92+
93+
topic := watermill.NewUUID()
94+
95+
err := subscriber.(message.SubscribeInitializer).SubscribeInitialize(topic)
96+
require.NoError(t, err)
97+
98+
delayedRequeuer, err := sql.NewMySQLDelayedRequeuer(sql.DelayedRequeuerConfig{
99+
DB: db,
100+
RequeueTopic: watermill.NewUUID(),
101+
Publisher: publisher,
102+
Logger: logger,
103+
})
104+
require.NoError(t, err)
105+
106+
router := message.NewDefaultRouter(logger)
107+
router.AddMiddleware(delayedRequeuer.Middleware()...)
108+
109+
var receivedMessages []string
110+
111+
router.AddNoPublisherHandler(
112+
"test",
113+
topic,
114+
subscriber,
115+
func(msg *message.Message) error {
116+
payload := string(msg.Payload)
117+
// MySQL and PostgreSQL format JSON with spaces, so we need to check both variants
118+
if payload == `{"error":true}` || payload == `{"error": true}` {
119+
return fmt.Errorf("error")
120+
}
121+
122+
receivedMessages = append(receivedMessages, msg.UUID)
123+
124+
return nil
125+
},
126+
)
127+
128+
go func() {
129+
err := router.Run(context.Background())
130+
require.NoError(t, err)
131+
}()
132+
133+
<-router.Running()
134+
135+
go func() {
136+
err := delayedRequeuer.Run(context.Background())
137+
require.NoError(t, err)
138+
}()
139+
140+
err = publisher.Publish(topic, message.NewMessage("1", []byte(`{}`)))
141+
require.NoError(t, err)
142+
143+
err = publisher.Publish(topic, message.NewMessage("2", []byte(`{"error":true}`)))
144+
require.NoError(t, err)
145+
146+
err = publisher.Publish(topic, message.NewMessage("3", []byte(`{}`)))
147+
require.NoError(t, err)
148+
149+
assert.EventuallyWithT(t, func(t *assert.CollectT) {
150+
assert.Equal(t, []string{"1", "3"}, receivedMessages)
151+
}, 1*time.Second, 100*time.Millisecond)
152+
}

0 commit comments

Comments
 (0)