@@ -18,17 +18,29 @@ type Listener struct {
1818 channel string
1919 signals chan types.CoverageSignal
2020 errors chan error
21- done chan struct {}
21+ cancel context. CancelFunc
2222 droppedSignals atomic.Int64
2323}
2424
2525// NewListener creates a new LISTEN/NOTIFY listener using the config from a pool.
2626func NewListener (ctx context.Context , pool * pgxpool.Pool , channel string ) (* Listener , error ) {
27- // Connect using the pool's connection config
28- conn , err := pgx .ConnectConfig (ctx , pool .Config ().ConnConfig .Copy ())
27+ connCfg := pool .Config ().ConnConfig .Copy ()
28+
29+ listener := & Listener {
30+ channel : channel ,
31+ signals : make (chan types.CoverageSignal , 1000 ), // Buffered to avoid blocking
32+ errors : make (chan error , 10 ),
33+ }
34+
35+ // Register the OnNotification callback before connecting so that
36+ // every notification is dispatched to our channel automatically.
37+ connCfg .OnNotification = listener .handleNotification
38+
39+ conn , err := pgx .ConnectConfig (ctx , connCfg )
2940 if err != nil {
3041 return nil , fmt .Errorf ("failed to connect for LISTEN: %w" , err )
3142 }
43+ listener .conn = conn
3244
3345 // Start listening on channel
3446 _ , err = conn .Exec (ctx , fmt .Sprintf ("LISTEN %s" , channel ))
@@ -37,80 +49,53 @@ func NewListener(ctx context.Context, pool *pgxpool.Pool, channel string) (*List
3749 return nil , fmt .Errorf ("failed to execute LISTEN: %w" , err )
3850 }
3951
40- listener := & Listener {
41- conn : conn ,
42- channel : channel ,
43- signals : make (chan types.CoverageSignal , 1000 ), // Buffered to avoid blocking
44- errors : make (chan error , 10 ),
45- done : make (chan struct {}),
46- }
52+ // Derive a cancellable context so Close can interrupt the receive loop.
53+ loopCtx , cancel := context .WithCancel (ctx )
54+ listener .cancel = cancel
4755
48- // Start background goroutine to receive notifications
49- go listener .receiveLoop (ctx )
56+ // Start background goroutine to drive reads (required for OnNotification to fire).
57+ go listener .receiveLoop (loopCtx )
5058
5159 return listener , nil
5260}
5361
54- // receiveLoop continuously receives notifications from PostgreSQL
62+ // handleNotification is the pgx OnNotification callback. It is invoked
63+ // synchronously during WaitForNotification whenever a NOTIFY arrives.
64+ func (l * Listener ) handleNotification (_ * pgconn.PgConn , n * pgconn.Notification ) {
65+ if n .Channel != l .channel {
66+ return
67+ }
68+
69+ signal := types.CoverageSignal {
70+ SignalID : n .Payload ,
71+ Timestamp : time .Now (),
72+ }
73+
74+ select {
75+ case l .signals <- signal :
76+ default :
77+ // Buffer full — increment counter so the caller can
78+ // detect and report lost signals after test execution.
79+ l .droppedSignals .Add (1 )
80+ }
81+ }
82+
83+ // receiveLoop blocks on WaitForNotification so that the pgx connection
84+ // continuously reads from the server and dispatches OnNotification callbacks.
5585func (l * Listener ) receiveLoop (ctx context.Context ) {
5686 defer close (l .signals )
5787 defer close (l .errors )
5888
5989 for {
60- select {
61- case <- ctx .Done ():
62- return
63- case <- l .done :
64- return
65- default :
66- // Wait for notification with short timeout to allow checking done/ctx
67- waitCtx , cancel := context .WithTimeout (ctx , 100 * time .Millisecond )
68- notification , err := l .conn .WaitForNotification (waitCtx )
69- cancel ()
70-
71- if err != nil {
72- // Check if context was cancelled
73- if ctx .Err () != nil {
74- return
75- }
76-
77- // Check if connection is closed
78- if l .conn .IsClosed () {
79- select {
80- case l .errors <- fmt .Errorf ("connection closed" ):
81- default :
82- }
83- return
84- }
85-
86- // Timeout is expected, just continue
87- if waitCtx .Err () == context .DeadlineExceeded {
88- continue
89- }
90-
91- // Send error but continue
92- select {
93- case l .errors <- fmt .Errorf ("notification error: %w" , err ):
94- default :
95- }
96- continue
90+ _ , err := l .conn .WaitForNotification (ctx )
91+ if err != nil {
92+ if ctx .Err () != nil || l .conn .IsClosed () {
93+ return
9794 }
9895
99- if notification != nil && notification .Channel == l .channel {
100- // Create coverage signal
101- signal := types.CoverageSignal {
102- SignalID : notification .Payload ,
103- Timestamp : time .Now (),
104- }
105-
106- // Send signal (non-blocking)
107- select {
108- case l .signals <- signal :
109- default :
110- // Buffer full — increment counter so the caller can
111- // detect and report lost signals after test execution.
112- l .droppedSignals .Add (1 )
113- }
96+ select {
97+ case l .errors <- fmt .Errorf ("notification error: %w" , err ):
98+ default :
11499 }
115100 }
116101 }
@@ -135,7 +120,7 @@ func (l *Listener) DroppedSignals() int64 {
135120
136121// Close stops the listener and closes the connection
137122func (l * Listener ) Close (ctx context.Context ) error {
138- close ( l . done )
123+ l . cancel () // interrupt receiveLoop's WaitForNotification
139124
140125 // Unlisten
141126 if l .conn != nil && ! l .conn .IsClosed () {
0 commit comments