diff --git a/pkg/services/stop.go b/pkg/services/stop.go index d7f95f658..7efeb1fa1 100644 --- a/pkg/services/stop.go +++ b/pkg/services/stop.go @@ -48,16 +48,3 @@ func (s StopRChan) Ctx(ctx context.Context) (context.Context, context.CancelFunc func (s StopRChan) CtxWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { return s.CtxCancel(context.WithTimeout(context.Background(), timeout)) } - -// CtxCancel cancels a [context.Context] when StopChan is closed. -// Returns ctx and cancel unmodified, for convenience. -func (s StopRChan) CtxCancel(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) { - go func() { - select { - case <-s: - cancel() - case <-ctx.Done(): - } - }() - return ctx, cancel -} diff --git a/pkg/services/stop_!race.go b/pkg/services/stop_!race.go new file mode 100644 index 000000000..b8732c4a0 --- /dev/null +++ b/pkg/services/stop_!race.go @@ -0,0 +1,18 @@ +//go:build !race + +package services + +import "context" + +// CtxCancel cancels a [context.Context] when StopChan is closed. +// Returns ctx and cancel unmodified, for convenience. +func (s StopRChan) CtxCancel(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) { + go func() { + select { + case <-s: + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} diff --git a/pkg/services/stop_race.go b/pkg/services/stop_race.go new file mode 100644 index 000000000..a97a8e316 --- /dev/null +++ b/pkg/services/stop_race.go @@ -0,0 +1,41 @@ +//go:build race + +package services + +import ( + "context" + "time" +) + +// CtxCancel cancels a [context.Context] when StopChan is closed. +// Returns ctx and cancel unmodified, for convenience. +func (s StopRChan) CtxCancel(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) { + go func() { + select { + case <-s: + cancel() + case <-ctx.Done(): + } + }() + return &syncCtx{ + deadline: ctx.Deadline, + done: ctx.Done, + value: ctx.Value, + err: ctx.Err, + }, cancel +} + +var _ context.Context = &syncCtx{} + +// syncCtx is a context.Context implementation that is safe to format via %#v, which mockery uses. +type syncCtx struct { + deadline func() (time.Time, bool) + done func() <-chan struct{} + value func(any) any + err func() error +} + +func (s *syncCtx) Deadline() (time.Time, bool) { return s.deadline() } +func (s *syncCtx) Done() <-chan struct{} { return s.done() } +func (s *syncCtx) Value(k any) any { return s.value(k) } +func (c *syncCtx) Err() error { return c.err() }