Skip to content

Commit d71ac99

Browse files
committed
refactor(listener): use OnNotification callback
Replace the polling-based receiveLoop (100ms timeout sub-contexts, manual channel filtering, done channel) with pgx's OnNotification callback. The callback is registered on the connection config before connecting, and a simple WaitForNotification(ctx) loop drives reads. This eliminates: - The 100ms polling timeout and its per-iteration context allocation - The done channel (replaced by a cancellable context in Close) - Manual notification.Channel filtering (now in the callback)
1 parent 0bd76d6 commit d71ac99

1 file changed

Lines changed: 51 additions & 66 deletions

File tree

internal/database/listener.go

Lines changed: 51 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,29 @@ type Listener struct {
1818
channel string
1919
signals chan types.CoverageSignal
2020
errors chan error
21-
done chan struct{}
21+
cancel context.CancelFunc
2222
droppedSignals atomic.Int64
2323
}
2424

2525
// NewListener creates a new LISTEN/NOTIFY listener using the config from a pool.
2626
func NewListener(ctx context.Context, pool *pgxpool.Pool, channel string) (*Listener, error) {
27-
// Connect using the pool's connection config
28-
conn, err := pgx.ConnectConfig(ctx, pool.Config().ConnConfig.Copy())
27+
connCfg := pool.Config().ConnConfig.Copy()
28+
29+
listener := &Listener{
30+
channel: channel,
31+
signals: make(chan types.CoverageSignal, 1000), // Buffered to avoid blocking
32+
errors: make(chan error, 10),
33+
}
34+
35+
// Register the OnNotification callback before connecting so that
36+
// every notification is dispatched to our channel automatically.
37+
connCfg.OnNotification = listener.handleNotification
38+
39+
conn, err := pgx.ConnectConfig(ctx, connCfg)
2940
if err != nil {
3041
return nil, fmt.Errorf("failed to connect for LISTEN: %w", err)
3142
}
43+
listener.conn = conn
3244

3345
// Start listening on channel
3446
_, err = conn.Exec(ctx, fmt.Sprintf("LISTEN %s", channel))
@@ -37,80 +49,53 @@ func NewListener(ctx context.Context, pool *pgxpool.Pool, channel string) (*List
3749
return nil, fmt.Errorf("failed to execute LISTEN: %w", err)
3850
}
3951

40-
listener := &Listener{
41-
conn: conn,
42-
channel: channel,
43-
signals: make(chan types.CoverageSignal, 1000), // Buffered to avoid blocking
44-
errors: make(chan error, 10),
45-
done: make(chan struct{}),
46-
}
52+
// Derive a cancellable context so Close can interrupt the receive loop.
53+
loopCtx, cancel := context.WithCancel(ctx)
54+
listener.cancel = cancel
4755

48-
// Start background goroutine to receive notifications
49-
go listener.receiveLoop(ctx)
56+
// Start background goroutine to drive reads (required for OnNotification to fire).
57+
go listener.receiveLoop(loopCtx)
5058

5159
return listener, nil
5260
}
5361

54-
// receiveLoop continuously receives notifications from PostgreSQL
62+
// handleNotification is the pgx OnNotification callback. It is invoked
63+
// synchronously during WaitForNotification whenever a NOTIFY arrives.
64+
func (l *Listener) handleNotification(_ *pgconn.PgConn, n *pgconn.Notification) {
65+
if n.Channel != l.channel {
66+
return
67+
}
68+
69+
signal := types.CoverageSignal{
70+
SignalID: n.Payload,
71+
Timestamp: time.Now(),
72+
}
73+
74+
select {
75+
case l.signals <- signal:
76+
default:
77+
// Buffer full — increment counter so the caller can
78+
// detect and report lost signals after test execution.
79+
l.droppedSignals.Add(1)
80+
}
81+
}
82+
83+
// receiveLoop blocks on WaitForNotification so that the pgx connection
84+
// continuously reads from the server and dispatches OnNotification callbacks.
5585
func (l *Listener) receiveLoop(ctx context.Context) {
5686
defer close(l.signals)
5787
defer close(l.errors)
5888

5989
for {
60-
select {
61-
case <-ctx.Done():
62-
return
63-
case <-l.done:
64-
return
65-
default:
66-
// Wait for notification with short timeout to allow checking done/ctx
67-
waitCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
68-
notification, err := l.conn.WaitForNotification(waitCtx)
69-
cancel()
70-
71-
if err != nil {
72-
// Check if context was cancelled
73-
if ctx.Err() != nil {
74-
return
75-
}
76-
77-
// Check if connection is closed
78-
if l.conn.IsClosed() {
79-
select {
80-
case l.errors <- fmt.Errorf("connection closed"):
81-
default:
82-
}
83-
return
84-
}
85-
86-
// Timeout is expected, just continue
87-
if waitCtx.Err() == context.DeadlineExceeded {
88-
continue
89-
}
90-
91-
// Send error but continue
92-
select {
93-
case l.errors <- fmt.Errorf("notification error: %w", err):
94-
default:
95-
}
96-
continue
90+
_, err := l.conn.WaitForNotification(ctx)
91+
if err != nil {
92+
if ctx.Err() != nil || l.conn.IsClosed() {
93+
return
9794
}
9895

99-
if notification != nil && notification.Channel == l.channel {
100-
// Create coverage signal
101-
signal := types.CoverageSignal{
102-
SignalID: notification.Payload,
103-
Timestamp: time.Now(),
104-
}
105-
106-
// Send signal (non-blocking)
107-
select {
108-
case l.signals <- signal:
109-
default:
110-
// Buffer full — increment counter so the caller can
111-
// detect and report lost signals after test execution.
112-
l.droppedSignals.Add(1)
113-
}
96+
select {
97+
case l.errors <- fmt.Errorf("notification error: %w", err):
98+
default:
11499
}
115100
}
116101
}
@@ -135,7 +120,7 @@ func (l *Listener) DroppedSignals() int64 {
135120

136121
// Close stops the listener and closes the connection
137122
func (l *Listener) Close(ctx context.Context) error {
138-
close(l.done)
123+
l.cancel() // interrupt receiveLoop's WaitForNotification
139124

140125
// Unlisten
141126
if l.conn != nil && !l.conn.IsClosed() {

0 commit comments

Comments
 (0)