diff --git a/router/server.go b/router/server.go index d56e20ce8..a1c79cf05 100644 --- a/router/server.go +++ b/router/server.go @@ -280,9 +280,24 @@ func countNotificationTargets(notification *notify.PushNotification) int { return count } +// withEitherCancel returns a context that is cancelled when either ctx1 or ctx2 is done. +// This is useful for merging an HTTP request context with a queue-task context so that +// a push notification is aborted when the caller disconnects OR when the queue shuts down. +func withEitherCancel(ctx1, ctx2 context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx1) + go func() { + select { + case <-ctx2.Done(): + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} + // HandleNotification add notification to queue list. func handleNotification( - _ context.Context, + ctx context.Context, cfg *config.ConfYaml, req notify.RequestPush, q *queue.Queue, @@ -297,6 +312,7 @@ func handleNotification( var ( count int wg sync.WaitGroup + mu sync.Mutex logs = make([]logx.LogPushEntry, 0) ) @@ -306,19 +322,30 @@ func handleNotification( } if isLocalSync { - func(msg *notify.PushNotification, cfg *config.ConfYaml) { - if err := q.QueueTask(func(ctx context.Context) error { + func(msg *notify.PushNotification, cfg *config.ConfYaml, reqCtx context.Context) { + if err := q.QueueTask(func(queueCtx context.Context) error { defer wg.Done() - resp, err := notify.SendNotification(ctx, msg, cfg) + // Merge the HTTP request context with the queue context so that + // the notification is cancelled when either the client disconnects + // or the queue shuts down. See: + // https://github.com/appleboy/gorush/issues/422 + mergedCtx, cancel := withEitherCancel(reqCtx, queueCtx) + defer cancel() + resp, err := notify.SendNotification(mergedCtx, msg, cfg) if err != nil { return err } + mu.Lock() logs = append(logs, resp.Logs...) + mu.Unlock() return nil }); err != nil { + if cfg.Core.Sync { + wg.Done() + } logx.LogError.Error(err) } - }(notification, cfg) + }(notification, cfg, ctx) } else if err := q.Queue(notification); err != nil { resp := markFailedNotification(cfg, notification, "max capacity reached") logs = append(logs, resp...) diff --git a/router/server_test.go b/router/server_test.go index 59641fcee..c9068d77b 100644 --- a/router/server_test.go +++ b/router/server_test.go @@ -853,3 +853,166 @@ func TestCountNotificationTargets(t *testing.T) { }) } } + +// TestWithEitherCancel_Ctx1Cancel verifies that the derived context is cancelled +// when ctx1 (the first parent, e.g. the HTTP request context) is cancelled. +func TestWithEitherCancel_Ctx1Cancel(t *testing.T) { + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2 := t.Context() + + merged, cancelMerged := withEitherCancel(ctx1, ctx2) + defer cancelMerged() + + // ctx1 has not been cancelled yet - merged should be alive + select { + case <-merged.Done(): + t.Fatal("merged context should not be done yet") + default: + } + + // Cancel ctx1 (simulates HTTP client disconnect) + cancel1() + + select { + case <-merged.Done(): + // expected + case <-time.After(100 * time.Millisecond): + t.Fatal("merged context should have been cancelled when ctx1 was cancelled") + } +} + +// TestWithEitherCancel_Ctx2Cancel verifies that the derived context is cancelled +// when ctx2 (the second parent, e.g. the queue-task context) is cancelled. +func TestWithEitherCancel_Ctx2Cancel(t *testing.T) { + ctx1 := t.Context() + ctx2, cancel2 := context.WithCancel(context.Background()) + + merged, cancelMerged := withEitherCancel(ctx1, ctx2) + defer cancelMerged() + + // Neither parent cancelled yet - merged should be alive + select { + case <-merged.Done(): + t.Fatal("merged context should not be done yet") + default: + } + + // Cancel ctx2 (simulates queue shutdown) + cancel2() + + select { + case <-merged.Done(): + // expected + case <-time.After(100 * time.Millisecond): + t.Fatal("merged context should have been cancelled when ctx2 was cancelled") + } +} + +// TestWithEitherCancel_ExplicitCancel verifies that calling the returned +// CancelFunc directly cancels the merged context without affecting either parent. +func TestWithEitherCancel_ExplicitCancel(t *testing.T) { + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + merged, cancelMerged := withEitherCancel(ctx1, ctx2) + + // Explicitly cancel the merged context + cancelMerged() + + select { + case <-merged.Done(): + // expected + case <-time.After(100 * time.Millisecond): + t.Fatal("merged context should have been cancelled by explicit cancelMerged()") + } + + // Parents should still be alive + select { + case <-ctx1.Done(): + t.Fatal("ctx1 should not be cancelled") + default: + } + select { + case <-ctx2.Done(): + t.Fatal("ctx2 should not be cancelled") + default: + } +} + +// TestWithEitherCancel_NoGoroutineLeak verifies that the internal goroutine +// spawned by withEitherCancel exits when the merged context is cancelled, +// preventing goroutine leaks. +func TestWithEitherCancel_NoGoroutineLeak(t *testing.T) { + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2 := t.Context() + + goroutinesBefore := runtime.NumGoroutine() + + merged, cancelMerged := withEitherCancel(ctx1, ctx2) + + // Give the internal goroutine time to start + time.Sleep(10 * time.Millisecond) + goroutinesDuring := runtime.NumGoroutine() + assert.GreaterOrEqual(t, goroutinesDuring, goroutinesBefore, + "at least one new goroutine should exist while merged context is live") + + // Cancel via ctx1 - this should trigger the internal goroutine to exit + cancel1() + cancelMerged() // also call cancelMerged to release resources + + // Give the goroutine time to clean up + time.Sleep(50 * time.Millisecond) + goroutinesAfter := runtime.NumGoroutine() + + // Allow a ±1 goroutine variance due to runtime scheduling + assert.LessOrEqual(t, goroutinesAfter, goroutinesBefore+1, + "goroutine count should be near baseline after cancellation") + + // merged context must be done + select { + case <-merged.Done(): + // expected + default: + t.Fatal("merged context should be done after cancel1() and cancelMerged()") + } +} + +// TestWithEitherCancel_AlreadyCancelledCtx1 verifies that if ctx1 is already +// cancelled before withEitherCancel is called, the merged context is immediately done. +func TestWithEitherCancel_AlreadyCancelledCtx1(t *testing.T) { + ctx1, cancel1 := context.WithCancel(context.Background()) + cancel1() // already cancelled + + ctx2 := t.Context() + + merged, cancelMerged := withEitherCancel(ctx1, ctx2) + defer cancelMerged() + + select { + case <-merged.Done(): + // expected - ctx1 was already done, so merged should be immediately done + case <-time.After(100 * time.Millisecond): + t.Fatal("merged context should be immediately done when ctx1 is already cancelled") + } +} + +// TestWithEitherCancel_AlreadyCancelledCtx2 verifies that if ctx2 is already +// cancelled before withEitherCancel is called, the merged context is cancelled promptly. +func TestWithEitherCancel_AlreadyCancelledCtx2(t *testing.T) { + ctx1 := t.Context() + + ctx2, cancel2 := context.WithCancel(context.Background()) + cancel2() // already cancelled + + merged, cancelMerged := withEitherCancel(ctx1, ctx2) + defer cancelMerged() + + select { + case <-merged.Done(): + // expected + case <-time.After(100 * time.Millisecond): + t.Fatal("merged context should be cancelled promptly when ctx2 is already cancelled") + } +}