Skip to content

Commit 11bced1

Browse files
committed
Add worker
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent 7fc915f commit 11bced1

File tree

2 files changed

+724
-0
lines changed

2 files changed

+724
-0
lines changed

worker/worker.go

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
// Copyright (c) 2024 Bryan Frimin <bryan@frimin.fr>.
2+
//
3+
// Permission to use, copy, modify, and/or distribute this software
4+
// for any purpose with or without fee is hereby granted, provided
5+
// that the above copyright notice and this permission notice appear
6+
// in all copies.
7+
//
8+
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
9+
// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
10+
// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
11+
// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
12+
// CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
13+
// OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
14+
// NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
15+
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16+
17+
package worker
18+
19+
import (
20+
"context"
21+
"errors"
22+
"sync"
23+
"time"
24+
25+
"github.com/prometheus/client_golang/prometheus"
26+
"go.gearno.de/kit/internal/version"
27+
"go.gearno.de/kit/log"
28+
"go.opentelemetry.io/otel"
29+
"go.opentelemetry.io/otel/attribute"
30+
"go.opentelemetry.io/otel/codes"
31+
"go.opentelemetry.io/otel/trace"
32+
)
33+
34+
// ErrNoTask is returned by Handler.Claim when there are no tasks
35+
// available for processing. The worker uses this sentinel to
36+
// distinguish "no work" from real errors.
37+
var ErrNoTask = errors.New("worker: no task available")
38+
39+
const (
40+
tracerName = "go.gearno.de/kit/worker"
41+
)
42+
43+
type (
44+
// Handler defines the operations a worker needs to claim,
45+
// process, and manage tasks of type T.
46+
Handler[T any] interface {
47+
// Claim acquires the next available task.
48+
// Implementations must return ErrNoTask when no work is
49+
// available.
50+
Claim(ctx context.Context) (T, error)
51+
52+
// Process performs the actual work on a claimed task.
53+
// Implementations are responsible for handling their own
54+
// failures (e.g. updating status, retrying, logging).
55+
Process(ctx context.Context, task T) error
56+
}
57+
58+
// StaleRecoverer is an optional interface that a Handler can
59+
// implement to recover tasks stuck in a processing state. When
60+
// the handler implements this interface, RecoverStale is called
61+
// at the beginning of each polling cycle.
62+
StaleRecoverer interface {
63+
RecoverStale(ctx context.Context) error
64+
}
65+
66+
// Option configures a Worker.
67+
Option func(*options)
68+
69+
options struct {
70+
interval time.Duration
71+
maxConcurrency int
72+
tracerProvider trace.TracerProvider
73+
registerer prometheus.Registerer
74+
}
75+
76+
// Worker polls for tasks using a Handler and processes them
77+
// concurrently up to a configurable limit.
78+
Worker[T any] struct {
79+
name string
80+
handler Handler[T]
81+
logger *log.Logger
82+
tracer trace.Tracer
83+
interval time.Duration
84+
maxConcurrency int
85+
86+
pollCyclesTotal *prometheus.CounterVec
87+
claimErrorsTotal *prometheus.CounterVec
88+
claimDuration *prometheus.HistogramVec
89+
recoverStaleDuration *prometheus.HistogramVec
90+
tasksTotal *prometheus.CounterVec
91+
taskDuration *prometheus.HistogramVec
92+
}
93+
)
94+
95+
// WithInterval sets the polling interval between work cycles.
96+
// Default is 10 seconds.
97+
func WithInterval(d time.Duration) Option {
98+
return func(o *options) { o.interval = d }
99+
}
100+
101+
// WithMaxConcurrency sets the maximum number of tasks processed
102+
// concurrently. Values less than 1 are ignored. Default is 5.
103+
func WithMaxConcurrency(n int) Option {
104+
return func(o *options) {
105+
if n > 0 {
106+
o.maxConcurrency = n
107+
}
108+
}
109+
}
110+
111+
// WithTracerProvider configures OpenTelemetry tracing with the
112+
// provided tracer provider.
113+
func WithTracerProvider(tp trace.TracerProvider) Option {
114+
return func(o *options) {
115+
o.tracerProvider = tp
116+
}
117+
}
118+
119+
// WithRegisterer sets a custom Prometheus registerer for metrics.
120+
func WithRegisterer(r prometheus.Registerer) Option {
121+
return func(o *options) {
122+
o.registerer = r
123+
}
124+
}
125+
126+
// New creates a Worker named name that uses handler to claim and
127+
// process tasks. The name identifies this worker in metrics, logs,
128+
// and traces.
129+
func New[T any](name string, handler Handler[T], logger *log.Logger, opts ...Option) *Worker[T] {
130+
o := options{
131+
interval: 10 * time.Second,
132+
maxConcurrency: 5,
133+
tracerProvider: otel.GetTracerProvider(),
134+
registerer: prometheus.DefaultRegisterer,
135+
}
136+
137+
for _, opt := range opts {
138+
opt(&o)
139+
}
140+
141+
workerLabel := []string{"worker"}
142+
metricLabels := []string{"worker", "status"}
143+
144+
pollCyclesTotal := prometheus.NewCounterVec(
145+
prometheus.CounterOpts{
146+
Subsystem: "worker",
147+
Name: "poll_cycles_total",
148+
Help: "Total number of polling cycles executed.",
149+
},
150+
workerLabel,
151+
)
152+
if err := o.registerer.Register(pollCyclesTotal); err != nil {
153+
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
154+
pollCyclesTotal = are.ExistingCollector.(*prometheus.CounterVec)
155+
} else {
156+
panic(err)
157+
}
158+
}
159+
160+
claimErrorsTotal := prometheus.NewCounterVec(
161+
prometheus.CounterOpts{
162+
Subsystem: "worker",
163+
Name: "claim_errors_total",
164+
Help: "Total number of task claim errors.",
165+
},
166+
workerLabel,
167+
)
168+
if err := o.registerer.Register(claimErrorsTotal); err != nil {
169+
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
170+
claimErrorsTotal = are.ExistingCollector.(*prometheus.CounterVec)
171+
} else {
172+
panic(err)
173+
}
174+
}
175+
176+
claimDuration := prometheus.NewHistogramVec(
177+
prometheus.HistogramOpts{
178+
Subsystem: "worker",
179+
Name: "claim_duration_seconds",
180+
Help: "Duration of task claim operations in seconds.",
181+
Buckets: prometheus.DefBuckets,
182+
},
183+
workerLabel,
184+
)
185+
if err := o.registerer.Register(claimDuration); err != nil {
186+
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
187+
claimDuration = are.ExistingCollector.(*prometheus.HistogramVec)
188+
} else {
189+
panic(err)
190+
}
191+
}
192+
193+
recoverStaleDuration := prometheus.NewHistogramVec(
194+
prometheus.HistogramOpts{
195+
Subsystem: "worker",
196+
Name: "recover_stale_duration_seconds",
197+
Help: "Duration of stale task recovery operations in seconds.",
198+
Buckets: prometheus.DefBuckets,
199+
},
200+
workerLabel,
201+
)
202+
if err := o.registerer.Register(recoverStaleDuration); err != nil {
203+
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
204+
recoverStaleDuration = are.ExistingCollector.(*prometheus.HistogramVec)
205+
} else {
206+
panic(err)
207+
}
208+
}
209+
210+
tasksTotal := prometheus.NewCounterVec(
211+
prometheus.CounterOpts{
212+
Subsystem: "worker",
213+
Name: "tasks_total",
214+
Help: "Total number of tasks processed.",
215+
},
216+
metricLabels,
217+
)
218+
if err := o.registerer.Register(tasksTotal); err != nil {
219+
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
220+
tasksTotal = are.ExistingCollector.(*prometheus.CounterVec)
221+
} else {
222+
panic(err)
223+
}
224+
}
225+
226+
taskDuration := prometheus.NewHistogramVec(
227+
prometheus.HistogramOpts{
228+
Subsystem: "worker",
229+
Name: "task_duration_seconds",
230+
Help: "Duration of task processing in seconds.",
231+
Buckets: prometheus.DefBuckets,
232+
},
233+
metricLabels,
234+
)
235+
if err := o.registerer.Register(taskDuration); err != nil {
236+
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
237+
taskDuration = are.ExistingCollector.(*prometheus.HistogramVec)
238+
} else {
239+
panic(err)
240+
}
241+
}
242+
243+
return &Worker[T]{
244+
name: name,
245+
handler: handler,
246+
logger: logger.Named(name),
247+
tracer: o.tracerProvider.Tracer(
248+
tracerName,
249+
trace.WithInstrumentationVersion(
250+
version.New(0).Alpha(1),
251+
),
252+
),
253+
interval: o.interval,
254+
maxConcurrency: o.maxConcurrency,
255+
pollCyclesTotal: pollCyclesTotal,
256+
claimErrorsTotal: claimErrorsTotal,
257+
claimDuration: claimDuration,
258+
recoverStaleDuration: recoverStaleDuration,
259+
tasksTotal: tasksTotal,
260+
taskDuration: taskDuration,
261+
}
262+
}
263+
264+
// Run starts the worker loop. It blocks until ctx is cancelled, then
265+
// waits for all in-flight tasks to complete before returning.
266+
func (w *Worker[T]) Run(ctx context.Context) error {
267+
var (
268+
wg sync.WaitGroup
269+
sem = make(chan struct{}, w.maxConcurrency)
270+
ticker = time.NewTicker(w.interval)
271+
)
272+
273+
defer ticker.Stop()
274+
defer wg.Wait()
275+
276+
for {
277+
select {
278+
case <-ctx.Done():
279+
return context.Cause(ctx)
280+
case <-ticker.C:
281+
w.pollCyclesTotal.WithLabelValues(w.name).Inc()
282+
283+
nonCancelableCtx := context.WithoutCancel(ctx)
284+
if sr, ok := w.handler.(StaleRecoverer); ok {
285+
recoverStart := time.Now()
286+
if err := sr.RecoverStale(nonCancelableCtx); err != nil {
287+
w.logger.ErrorCtx(nonCancelableCtx, "cannot recover stale tasks", log.Error(err))
288+
}
289+
w.recoverStaleDuration.WithLabelValues(w.name).Observe(time.Since(recoverStart).Seconds())
290+
}
291+
292+
for {
293+
if err := w.processNext(ctx, sem, &wg); err != nil {
294+
if !errors.Is(err, ErrNoTask) {
295+
w.claimErrorsTotal.WithLabelValues(w.name).Inc()
296+
w.logger.ErrorCtx(nonCancelableCtx, "cannot claim task", log.Error(err))
297+
}
298+
break
299+
}
300+
}
301+
}
302+
}
303+
}
304+
305+
func (w *Worker[T]) processNext(ctx context.Context, sem chan struct{}, wg *sync.WaitGroup) error {
306+
select {
307+
case sem <- struct{}{}:
308+
case <-ctx.Done():
309+
return context.Cause(ctx)
310+
}
311+
312+
nonCancelableCtx := context.WithoutCancel(ctx)
313+
314+
claimStart := time.Now()
315+
task, err := w.handler.Claim(nonCancelableCtx)
316+
w.claimDuration.WithLabelValues(w.name).Observe(time.Since(claimStart).Seconds())
317+
if err != nil {
318+
<-sem
319+
return err
320+
}
321+
322+
wg.Go(
323+
func() {
324+
defer func() { <-sem }()
325+
326+
processCtx, span := w.tracer.Start(
327+
nonCancelableCtx,
328+
"worker.process",
329+
trace.WithSpanKind(trace.SpanKindInternal),
330+
trace.WithAttributes(
331+
attribute.String("worker.name", w.name),
332+
),
333+
)
334+
defer span.End()
335+
336+
start := time.Now()
337+
err := w.handler.Process(processCtx, task)
338+
duration := time.Since(start)
339+
340+
status := "succeeded"
341+
if err != nil {
342+
status = "failed"
343+
span.RecordError(err)
344+
span.SetStatus(codes.Error, err.Error())
345+
w.logger.ErrorCtx(
346+
processCtx,
347+
"task processing failed",
348+
log.Error(err),
349+
log.Duration("duration", duration),
350+
)
351+
} else {
352+
w.logger.InfoCtx(
353+
processCtx,
354+
"task processing succeeded",
355+
log.Duration("duration", duration),
356+
)
357+
}
358+
359+
w.tasksTotal.WithLabelValues(w.name, status).Inc()
360+
w.taskDuration.WithLabelValues(w.name, status).Observe(duration.Seconds())
361+
},
362+
)
363+
364+
return nil
365+
}

0 commit comments

Comments
 (0)