Skip to content

Commit e528f66

Browse files
committed
fix[pipeline]: use existing context to propagate cancellation
1 parent d52cb3c commit e528f66

79 files changed

Lines changed: 138 additions & 124 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

internal/scheduling/cinder/filter_weigher_pipeline_controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (c *FilterWeigherPipelineController) process(ctx context.Context, decision
129129
return err
130130
}
131131

132-
result, err := pipeline.Run(request)
132+
result, err := pipeline.Run(ctx, request)
133133
if err != nil {
134134
log.Error(err, "failed to run pipeline")
135135
return err

internal/scheduling/lib/filter_monitor.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ func (fm *FilterMonitor[RequestType]) Init(ctx context.Context, client client.Cl
3838
}
3939

4040
// Run the filter and observe its execution.
41-
func (fm *FilterMonitor[RequestType]) Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
42-
return fm.monitor.RunWrapped(traceLog, request, fm.filter)
41+
func (fm *FilterMonitor[RequestType]) Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
42+
return fm.monitor.RunWrapped(ctx, traceLog, request, fm.filter)
4343
}

internal/scheduling/lib/filter_monitor_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func TestMonitorFilter(t *testing.T) {
2323
InitFunc: func(ctx context.Context, cl client.Client, step v1alpha1.FilterSpec) error {
2424
return nil
2525
},
26-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
26+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
2727
return &FilterWeigherPipelineStepResult{
2828
Activations: map[string]float64{"host1": 0.5, "host2": 1.0},
2929
}, nil
@@ -79,7 +79,7 @@ func TestFilterMonitor_Init(t *testing.T) {
7979
func TestFilterMonitor_Run(t *testing.T) {
8080
runCalled := false
8181
mockFilter := &mockFilter[mockFilterWeigherPipelineRequest]{
82-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
82+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
8383
runCalled = true
8484
return &FilterWeigherPipelineStepResult{
8585
Activations: map[string]float64{"host1": 0.5, "host2": 1.0},
@@ -102,7 +102,7 @@ func TestFilterMonitor_Run(t *testing.T) {
102102
Weights: map[string]float64{"host1": 0.1, "host2": 0.2, "host3": 0.3},
103103
}
104104

105-
result, err := fm.Run(slog.Default(), request)
105+
result, err := fm.Run(t.Context(), slog.Default(), request)
106106
if err != nil {
107107
t.Errorf("expected no error, got %v", err)
108108
}

internal/scheduling/lib/filter_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616

1717
type mockFilter[RequestType FilterWeigherPipelineRequest] struct {
1818
InitFunc func(ctx context.Context, client client.Client, step v1alpha1.FilterSpec) error
19-
RunFunc func(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
19+
RunFunc func(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
2020
}
2121

2222
func (m *mockFilter[RequestType]) Init(ctx context.Context, client client.Client, step v1alpha1.FilterSpec) error {
@@ -25,11 +25,11 @@ func (m *mockFilter[RequestType]) Init(ctx context.Context, client client.Client
2525
}
2626
return m.InitFunc(ctx, client, step)
2727
}
28-
func (m *mockFilter[RequestType]) Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
28+
func (m *mockFilter[RequestType]) Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
2929
if m.RunFunc == nil {
3030
return &FilterWeigherPipelineStepResult{}, nil
3131
}
32-
return m.RunFunc(traceLog, request)
32+
return m.RunFunc(ctx, traceLog, request)
3333
}
3434

3535
// filterTestOptions implements FilterWeigherPipelineStepOpts for testing.

internal/scheduling/lib/filter_validation.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ func validateFilter[RequestType FilterWeigherPipelineRequest](filter Filter[Requ
3030
}
3131

3232
// Run the filter and validate what happens.
33-
func (s *FilterValidator[RequestType]) Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
34-
result, err := s.Filter.Run(traceLog, request)
33+
func (s *FilterValidator[RequestType]) Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
34+
result, err := s.Filter.Run(ctx, traceLog, request)
3535
if err != nil {
3636
return nil, err
3737
}

internal/scheduling/lib/filter_validation_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func TestFilterValidator_Run(t *testing.T) {
151151
for _, tt := range tests {
152152
t.Run(tt.name, func(t *testing.T) {
153153
filter := &mockFilter[mockFilterWeigherPipelineRequest]{
154-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
154+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
155155
return tt.runResult, tt.runError
156156
},
157157
}
@@ -161,7 +161,7 @@ func TestFilterValidator_Run(t *testing.T) {
161161
}
162162
traceLog := slog.Default()
163163

164-
result, err := validator.Run(traceLog, request)
164+
result, err := validator.Run(t.Context(), traceLog, request)
165165

166166
if tt.expectError && err == nil {
167167
t.Error("expected error but got nil")

internal/scheduling/lib/filter_weigher_pipeline.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
type FilterWeigherPipeline[RequestType FilterWeigherPipelineRequest] interface {
2121
// Run the scheduling pipeline with the given request.
22-
Run(request RequestType) (v1alpha1.DecisionResult, error)
22+
Run(ctx context.Context, request RequestType) (v1alpha1.DecisionResult, error)
2323
}
2424

2525
// Pipeline of scheduler steps.
@@ -132,6 +132,7 @@ func InitNewFilterWeigherPipeline[RequestType FilterWeigherPipelineRequest](
132132
// During this process, the request is mutated to only include the
133133
// remaining hosts.
134134
func (p *filterWeigherPipeline[RequestType]) runFilters(
135+
ctx context.Context,
135136
log *slog.Logger,
136137
request RequestType,
137138
) (filteredRequest RequestType) {
@@ -141,7 +142,7 @@ func (p *filterWeigherPipeline[RequestType]) runFilters(
141142
filter := p.filters[filterName]
142143
stepLog := log.With("filter", filterName)
143144
stepLog.Info("scheduler: running filter")
144-
result, err := filter.Run(stepLog, filteredRequest)
145+
result, err := filter.Run(ctx, stepLog, filteredRequest)
145146
if errors.Is(err, ErrStepSkipped) {
146147
stepLog.Info("scheduler: filter skipped")
147148
continue
@@ -160,6 +161,7 @@ func (p *filterWeigherPipeline[RequestType]) runFilters(
160161

161162
// Execute weighers and collect their activations by step name.
162163
func (p *filterWeigherPipeline[RequestType]) runWeighers(
164+
ctx context.Context,
163165
log *slog.Logger,
164166
filteredRequest RequestType,
165167
) map[string]map[string]float64 {
@@ -173,7 +175,7 @@ func (p *filterWeigherPipeline[RequestType]) runWeighers(
173175
wg.Go(func() {
174176
stepLog := log.With("weigher", weigherName)
175177
stepLog.Info("scheduler: running weigher")
176-
result, err := weigher.Run(stepLog, filteredRequest)
178+
result, err := weigher.Run(ctx, stepLog, filteredRequest)
177179
if errors.Is(err, ErrStepSkipped) {
178180
stepLog.Info("scheduler: weigher skipped")
179181
return
@@ -243,7 +245,7 @@ func (s *filterWeigherPipeline[RequestType]) sortHostsByWeights(weights map[stri
243245
}
244246

245247
// Evaluate the pipeline and return a list of hosts in order of preference.
246-
func (p *filterWeigherPipeline[RequestType]) Run(request RequestType) (v1alpha1.DecisionResult, error) {
248+
func (p *filterWeigherPipeline[RequestType]) Run(ctx context.Context, request RequestType) (v1alpha1.DecisionResult, error) {
247249
slogArgs := request.GetTraceLogArgs()
248250
slogArgsAny := make([]any, 0, len(slogArgs))
249251
for _, arg := range slogArgs {
@@ -260,7 +262,7 @@ func (p *filterWeigherPipeline[RequestType]) Run(request RequestType) (v1alpha1.
260262

261263
// Run filters first to reduce the number of hosts.
262264
// Any weights assigned to filtered out hosts are ignored.
263-
filteredRequest := p.runFilters(traceLog, request)
265+
filteredRequest := p.runFilters(ctx, traceLog, request)
264266
traceLog.Info(
265267
"scheduler: finished filters",
266268
"remainingHosts", filteredRequest.GetHosts(),
@@ -271,7 +273,7 @@ func (p *filterWeigherPipeline[RequestType]) Run(request RequestType) (v1alpha1.
271273
for _, host := range filteredRequest.GetHosts() {
272274
remainingWeights[host] = inWeights[host]
273275
}
274-
stepWeights := p.runWeighers(traceLog, filteredRequest)
276+
stepWeights := p.runWeighers(ctx, traceLog, filteredRequest)
275277
outWeights := p.applyWeights(stepWeights, remainingWeights)
276278
traceLog.Info("scheduler: output weights", "weights", outWeights)
277279

internal/scheduling/lib/filter_weigher_pipeline_step.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type FilterWeigherPipelineStep[RequestType FilterWeigherPipelineRequest] interfa
2929
//
3030
// A traceLog is provided that contains the global request id and should
3131
// be used to log the step's execution.
32-
Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
32+
Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
3333
}
3434

3535
// Common base for all steps that provides some functionality

internal/scheduling/lib/filter_weigher_pipeline_step_monitor.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package lib
55

66
import (
7+
"context"
78
"fmt"
89
"log/slog"
910
"maps"
@@ -63,6 +64,7 @@ func monitorStep[RequestType FilterWeigherPipelineRequest](stepName string, m Fi
6364

6465
// Run the step and observe its execution.
6566
func (s *FilterWeigherPipelineStepMonitor[RequestType]) RunWrapped(
67+
ctx context.Context,
6668
traceLog *slog.Logger,
6769
request RequestType,
6870
step FilterWeigherPipelineStep[RequestType],
@@ -74,7 +76,7 @@ func (s *FilterWeigherPipelineStepMonitor[RequestType]) RunWrapped(
7476
}
7577

7678
inWeights := request.GetWeights()
77-
stepResult, err := step.Run(traceLog, request)
79+
stepResult, err := step.Run(ctx, traceLog, request)
7880
if err != nil {
7981
return nil, err
8082
}

internal/scheduling/lib/filter_weigher_pipeline_step_monitor_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package lib
55

66
import (
7+
"context"
78
"log/slog"
89
"os"
910
"testing"
@@ -28,7 +29,7 @@ func TestStepMonitorRun(t *testing.T) {
2829
removedHostsObserver: removedHostsObserver,
2930
}
3031
step := &mockWeigher[mockFilterWeigherPipelineRequest]{
31-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
32+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
3233
return &FilterWeigherPipelineStepResult{
3334
Activations: map[string]float64{"host1": 0.1, "host2": 1.0, "host3": 0.0},
3435
}, nil
@@ -38,7 +39,7 @@ func TestStepMonitorRun(t *testing.T) {
3839
Hosts: []string{"host1", "host2", "host3"},
3940
Weights: map[string]float64{"host1": 0.2, "host2": 0.1, "host3": 0.0},
4041
}
41-
if _, err := monitor.RunWrapped(slog.Default(), request, step); err != nil {
42+
if _, err := monitor.RunWrapped(t.Context(), slog.Default(), request, step); err != nil {
4243
t.Fatalf("Run() error = %v, want nil", err)
4344
}
4445
if len(removedHostsObserver.Observations) != 1 {

0 commit comments

Comments
 (0)