Skip to content

Commit 5a70b4e

Browse files
Merge branch 'uptrace:master' into jeffreydwalter-fix-migration-sort
2 parents 60f07c3 + e135221 commit 5a70b4e

11 files changed

Lines changed: 159 additions & 29 deletions

File tree

driver/pgdriver/listener.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,21 +237,30 @@ type Notification struct {
237237

238238
type ChannelOption func(c *channel)
239239

240+
type channelOverflowHandler func(n Notification)
241+
240242
func WithChannelSize(size int) ChannelOption {
241243
return func(c *channel) {
242244
c.size = size
243245
}
244246
}
245247

248+
func WithChannelOverflowHandler(handler channelOverflowHandler) ChannelOption {
249+
return func(c *channel) {
250+
c.overflowHandler = handler
251+
}
252+
}
253+
246254
type channel struct {
247255
ctx context.Context
248256
ln *Listener
249257

250258
size int
251259
pingTimeout time.Duration
252260

253-
ch chan Notification
254-
pingCh chan struct{}
261+
ch chan Notification
262+
pingCh chan struct{}
263+
overflowHandler channelOverflowHandler
255264
}
256265

257266
func newChannel(ln *Listener, opts []ChannelOption) *channel {
@@ -310,6 +319,9 @@ func (c *channel) startReceive() {
310319
case c.ch <- Notification{channel, payload}:
311320
default:
312321
Logger.Printf(c.ctx, "pgdriver: Listener buffer is full (message is dropped)")
322+
if c.overflowHandler != nil {
323+
c.overflowHandler(Notification{channel, payload})
324+
}
313325
}
314326
}
315327
}

driver/pgdriver/listener_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package pgdriver
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestWithChannelOverflowHandler(t *testing.T) {
11+
// Create a test handler
12+
testHandler := func(n Notification) {}
13+
14+
// Create a channel instance
15+
c := &channel{}
16+
17+
// Apply the option
18+
opt := WithChannelOverflowHandler(testHandler)
19+
opt(c)
20+
21+
// Verify the handler was set correctly
22+
assert.NotNil(t, c.overflowHandler, "overflow handler should be set")
23+
assert.Equal(t,
24+
reflect.ValueOf(testHandler).Pointer(),
25+
reflect.ValueOf(c.overflowHandler).Pointer(),
26+
"overflow handler should match the provided handler",
27+
)
28+
}

example/migrate/main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ func main() {
3030
bundebug.FromEnv(),
3131
))
3232

33+
templateData := map[string]string{
34+
"Prefix": "example_",
35+
}
3336
app := &cli.App{
3437
Name: "bun",
3538

3639
Commands: []*cli.Command{
37-
newDBCommand(migrate.NewMigrator(db, migrations.Migrations)),
40+
newDBCommand(migrate.NewMigrator(db, migrations.Migrations, migrate.WithTemplateData(templateData))),
3841
},
3942
}
4043
if err := app.Run(os.Args); err != nil {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
DROP TABLE IF EXISTS test;
1+
DROP TABLE IF EXISTS {{.Prefix}}test;
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
CREATE TABLE test (
1+
CREATE TABLE {{.Prefix}}test (
22
id bigint PRIMARY KEY
33
);
44

55
--bun:split
66

7-
ALTER TABLE test ADD COLUMN name varchar(100);
7+
ALTER TABLE {{.Prefix}}test ADD COLUMN name varchar(100);

internal/dbtest/listener_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package dbtest_test
22

33
import (
44
"context"
5+
"fmt"
6+
"sync/atomic"
57
"testing"
68
"time"
79

@@ -84,3 +86,45 @@ func TestListenerChannel(t *testing.T) {
8486
return !ok
8587
}, 3*time.Second, 100*time.Millisecond)
8688
}
89+
90+
func TestListenerChannelOverflowHandler(t *testing.T) {
91+
ctx := context.Background()
92+
channelSize := 1
93+
overflowMessagesCount := channelSize * 3
94+
95+
db := pg(t)
96+
defer db.Close()
97+
98+
ln := pgdriver.NewListener(db)
99+
defer ln.Close()
100+
101+
var overflowCount atomic.Int32
102+
103+
// Create channel with small buffer and overflow handler
104+
ch := ln.Channel(
105+
pgdriver.WithChannelSize(channelSize),
106+
pgdriver.WithChannelOverflowHandler(func(n pgdriver.Notification) {
107+
overflowCount.Add(1)
108+
}),
109+
)
110+
111+
err := ln.Listen(ctx, "test_channel")
112+
require.NoError(t, err)
113+
114+
// Fill the channel buffer
115+
_, err = db.ExecContext(ctx, "NOTIFY test_channel, ?", "msg1")
116+
require.NoError(t, err)
117+
118+
// Wait for the first message to be received
119+
<-ch
120+
121+
// Send more messages to trigger overflow
122+
for i := 0; i < overflowMessagesCount; i++ {
123+
_, err = db.ExecContext(ctx, "NOTIFY test_channel, ?", fmt.Sprintf("msg%d", i+2))
124+
require.NoError(t, err)
125+
}
126+
127+
require.Eventually(t, func() bool {
128+
return overflowCount.Load() > 0
129+
}, time.Second, 10*time.Millisecond, "overflow handler should have been called")
130+
}

internal/dbtest/migrate_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,22 @@ func testMigrateUpAndDown(t *testing.T, db *bun.DB) {
7575
migrations := migrate.NewMigrations()
7676
migrations.Add(migrate.Migration{
7777
Name: "20060102150405",
78-
Up: func(ctx context.Context, db *bun.DB) error {
78+
Up: func(ctx context.Context, db *bun.DB, templateData any) error {
7979
history = append(history, "up1")
8080
return nil
8181
},
82-
Down: func(ctx context.Context, db *bun.DB) error {
82+
Down: func(ctx context.Context, db *bun.DB, templateData any) error {
8383
history = append(history, "down1")
8484
return nil
8585
},
8686
})
8787
migrations.Add(migrate.Migration{
8888
Name: "20060102160405",
89-
Up: func(ctx context.Context, db *bun.DB) error {
89+
Up: func(ctx context.Context, db *bun.DB, templateData any) error {
9090
history = append(history, "up2")
9191
return nil
9292
},
93-
Down: func(ctx context.Context, db *bun.DB) error {
93+
Down: func(ctx context.Context, db *bun.DB, templateData any) error {
9494
history = append(history, "down2")
9595
return nil
9696
},
@@ -125,33 +125,33 @@ func testMigrateUpError(t *testing.T, db *bun.DB) {
125125
migrations := migrate.NewMigrations()
126126
migrations.Add(migrate.Migration{
127127
Name: "20060102150405",
128-
Up: func(ctx context.Context, db *bun.DB) error {
128+
Up: func(ctx context.Context, db *bun.DB, templateData any) error {
129129
history = append(history, "up1")
130130
return nil
131131
},
132-
Down: func(ctx context.Context, db *bun.DB) error {
132+
Down: func(ctx context.Context, db *bun.DB, templateData any) error {
133133
history = append(history, "down1")
134134
return nil
135135
},
136136
})
137137
migrations.Add(migrate.Migration{
138138
Name: "20060102160405",
139-
Up: func(ctx context.Context, db *bun.DB) error {
139+
Up: func(ctx context.Context, db *bun.DB, templateData any) error {
140140
history = append(history, "up2")
141141
return errors.New("failed")
142142
},
143-
Down: func(ctx context.Context, db *bun.DB) error {
143+
Down: func(ctx context.Context, db *bun.DB, templateData any) error {
144144
history = append(history, "down2")
145145
return nil
146146
},
147147
})
148148
migrations.Add(migrate.Migration{
149149
Name: "20060102170405",
150-
Up: func(ctx context.Context, db *bun.DB) error {
150+
Up: func(ctx context.Context, db *bun.DB, templateData any) error {
151151
history = append(history, "up3")
152152
return errors.New("failed")
153153
},
154-
Down: func(ctx context.Context, db *bun.DB) error {
154+
Down: func(ctx context.Context, db *bun.DB, templateData any) error {
155155
history = append(history, "down3")
156156
return nil
157157
},

migrate/auto.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional b
268268
migrations := NewMigrations(am.migrationsOpts...)
269269
migrations.Add(Migration{
270270
Name: name,
271-
Up: changes.Up(am.dbMigrator),
272-
Down: changes.Down(am.dbMigrator),
271+
Up: wrapMigrationFunc(changes.Up(am.dbMigrator)),
272+
Down: wrapMigrationFunc(changes.Down(am.dbMigrator)),
273273
Comment: "Changes detected by bun.AutoMigrator",
274274
})
275275

migrate/migration.go

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sort"
1111
"strconv"
1212
"strings"
13+
"text/template"
1314
"time"
1415

1516
"github.com/uptrace/bun"
@@ -24,8 +25,8 @@ type Migration struct {
2425
GroupID int64
2526
MigratedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"`
2627

27-
Up MigrationFunc `bun:"-"`
28-
Down MigrationFunc `bun:"-"`
28+
Up internalMigrationFunc `bun:"-"`
29+
Down internalMigrationFunc `bun:"-"`
2930
}
3031

3132
func (m Migration) String() string {
@@ -36,23 +37,57 @@ func (m Migration) IsApplied() bool {
3637
return m.ID > 0
3738
}
3839

40+
type internalMigrationFunc func(ctx context.Context, db *bun.DB, templateData any) error
41+
3942
type MigrationFunc func(ctx context.Context, db *bun.DB) error
4043

41-
func NewSQLMigrationFunc(fsys fs.FS, name string) MigrationFunc {
42-
return func(ctx context.Context, db *bun.DB) error {
44+
func NewSQLMigrationFunc(fsys fs.FS, name string) internalMigrationFunc {
45+
return func(ctx context.Context, db *bun.DB, templateData any) error {
4346
f, err := fsys.Open(name)
4447
if err != nil {
4548
return err
4649
}
4750

4851
isTx := strings.HasSuffix(name, ".tx.up.sql") || strings.HasSuffix(name, ".tx.down.sql")
49-
return Exec(ctx, db, f, isTx)
52+
return Exec(ctx, db, f, templateData, isTx)
53+
}
54+
}
55+
56+
func wrapMigrationFunc(fn MigrationFunc) internalMigrationFunc {
57+
return func(ctx context.Context, db *bun.DB, templateData any) error {
58+
return fn(ctx, db)
59+
}
60+
}
61+
62+
func renderTemplate(contents []byte, templateData any) (*bytes.Buffer, error) {
63+
tmpl, err := template.New("migration").Parse(string(contents))
64+
if err != nil {
65+
return nil, fmt.Errorf("failed to parse template: %w", err)
66+
}
67+
68+
var rendered bytes.Buffer
69+
if err := tmpl.Execute(&rendered, templateData); err != nil {
70+
return nil, fmt.Errorf("failed to execute template: %w", err)
5071
}
72+
73+
return &rendered, nil
5174
}
5275

5376
// Exec reads and executes the SQL migration in the f.
54-
func Exec(ctx context.Context, db *bun.DB, f io.Reader, isTx bool) error {
55-
scanner := bufio.NewScanner(f)
77+
func Exec(ctx context.Context, db *bun.DB, f io.Reader, templateData any, isTx bool) error {
78+
contents, err := io.ReadAll(f)
79+
if err != nil {
80+
return err
81+
}
82+
var reader io.Reader = bytes.NewReader(contents)
83+
if templateData != nil {
84+
buf, err := renderTemplate(contents, templateData)
85+
if err != nil {
86+
return err
87+
}
88+
reader = buf
89+
}
90+
scanner := bufio.NewScanner(reader)
5691
var queries []string
5792

5893
var query []byte

migrate/migrations.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ func (m *Migrations) Register(up, down MigrationFunc) error {
5858
m.Add(Migration{
5959
Name: name,
6060
Comment: comment,
61-
Up: up,
62-
Down: down,
61+
Up: wrapMigrationFunc(up),
62+
Down: wrapMigrationFunc(down),
6363
})
6464

6565
return nil

0 commit comments

Comments
 (0)