Skip to content

Commit b11861f

Browse files
committed
fix[pipeline]: use existing context to propagate cancellation
1 parent dea447c commit b11861f

84 files changed

Lines changed: 148 additions & 135 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

internal/scheduling/cinder/filter_weigher_pipeline_controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (c *FilterWeigherPipelineController) process(ctx context.Context, decision
121121
return err
122122
}
123123

124-
result, err := pipeline.Run(request)
124+
result, err := pipeline.Run(ctx, request)
125125
if err != nil {
126126
log.Error(err, "failed to run pipeline")
127127
return err

internal/scheduling/lib/filter_monitor.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ func (fm *FilterMonitor[RequestType]) Validate(ctx context.Context, params v1alp
4343
}
4444

4545
// Run the filter and observe its execution.
46-
func (fm *FilterMonitor[RequestType]) Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
47-
return fm.monitor.RunWrapped(traceLog, request, fm.filter)
46+
func (fm *FilterMonitor[RequestType]) Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
47+
return fm.monitor.RunWrapped(ctx, traceLog, request, fm.filter)
4848
}

internal/scheduling/lib/filter_monitor_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func TestMonitorFilter(t *testing.T) {
2323
InitFunc: func(ctx context.Context, cl client.Client, step v1alpha1.FilterSpec) error {
2424
return nil
2525
},
26-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
26+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
2727
return &FilterWeigherPipelineStepResult{
2828
Activations: map[string]float64{"host1": 0.5, "host2": 1.0},
2929
}, nil
@@ -77,7 +77,7 @@ func TestFilterMonitor_Init(t *testing.T) {
7777
func TestFilterMonitor_Run(t *testing.T) {
7878
runCalled := false
7979
mockFilter := &mockFilter[mockFilterWeigherPipelineRequest]{
80-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
80+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
8181
runCalled = true
8282
return &FilterWeigherPipelineStepResult{
8383
Activations: map[string]float64{"host1": 0.5, "host2": 1.0},
@@ -100,7 +100,7 @@ func TestFilterMonitor_Run(t *testing.T) {
100100
Weights: map[string]float64{"host1": 0.1, "host2": 0.2, "host3": 0.3},
101101
}
102102

103-
result, err := fm.Run(slog.Default(), request)
103+
result, err := fm.Run(t.Context(), slog.Default(), request)
104104
if err != nil {
105105
t.Errorf("expected no error, got %v", err)
106106
}

internal/scheduling/lib/filter_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
type mockFilter[RequestType FilterWeigherPipelineRequest] struct {
1717
InitFunc func(ctx context.Context, client client.Client, step v1alpha1.FilterSpec) error
1818
ValidateFunc func(ctx context.Context, params v1alpha1.Parameters) error
19-
RunFunc func(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
19+
RunFunc func(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
2020
}
2121

2222
func (m *mockFilter[RequestType]) Init(ctx context.Context, client client.Client, step v1alpha1.FilterSpec) error {
@@ -31,11 +31,11 @@ func (m *mockFilter[RequestType]) Validate(ctx context.Context, params v1alpha1.
3131
}
3232
return m.ValidateFunc(ctx, params)
3333
}
34-
func (m *mockFilter[RequestType]) Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
34+
func (m *mockFilter[RequestType]) Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
3535
if m.RunFunc == nil {
3636
return &FilterWeigherPipelineStepResult{}, nil
3737
}
38-
return m.RunFunc(traceLog, request)
38+
return m.RunFunc(ctx, traceLog, request)
3939
}
4040

4141
// filterTestOptions implements FilterWeigherPipelineStepOpts for testing.

internal/scheduling/lib/filter_validation.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ func validateFilter[RequestType FilterWeigherPipelineRequest](filter Filter[Requ
3535
}
3636

3737
// Run the filter and validate what happens.
38-
func (s *FilterValidator[RequestType]) Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
39-
result, err := s.Filter.Run(traceLog, request)
38+
func (s *FilterValidator[RequestType]) Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error) {
39+
result, err := s.Filter.Run(ctx, traceLog, request)
4040
if err != nil {
4141
return nil, err
4242
}

internal/scheduling/lib/filter_validation_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func TestFilterValidator_Run(t *testing.T) {
146146
for _, tt := range tests {
147147
t.Run(tt.name, func(t *testing.T) {
148148
filter := &mockFilter[mockFilterWeigherPipelineRequest]{
149-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
149+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
150150
return tt.runResult, tt.runError
151151
},
152152
}
@@ -156,7 +156,7 @@ func TestFilterValidator_Run(t *testing.T) {
156156
}
157157
traceLog := slog.Default()
158158

159-
result, err := validator.Run(traceLog, request)
159+
result, err := validator.Run(t.Context(), traceLog, request)
160160

161161
if tt.expectError && err == nil {
162162
t.Error("expected error but got nil")

internal/scheduling/lib/filter_weigher_pipeline.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
type FilterWeigherPipeline[RequestType FilterWeigherPipelineRequest] interface {
2121
// Run the scheduling pipeline with the given request.
22-
Run(request RequestType) (v1alpha1.DecisionResult, error)
22+
Run(ctx context.Context, request RequestType) (v1alpha1.DecisionResult, error)
2323
}
2424

2525
// Pipeline of scheduler steps.
@@ -136,6 +136,7 @@ func InitNewFilterWeigherPipeline[RequestType FilterWeigherPipelineRequest](
136136
// During this process, the request is mutated to only include the
137137
// remaining hosts.
138138
func (p *filterWeigherPipeline[RequestType]) runFilters(
139+
ctx context.Context,
139140
log *slog.Logger,
140141
request RequestType,
141142
) (filteredRequest RequestType, stepResults []v1alpha1.StepResult) {
@@ -145,7 +146,7 @@ func (p *filterWeigherPipeline[RequestType]) runFilters(
145146
filter := p.filters[filterName]
146147
stepLog := log.With("filter", filterName)
147148
stepLog.Info("scheduler: running filter")
148-
result, err := filter.Run(stepLog, filteredRequest)
149+
result, err := filter.Run(ctx, stepLog, filteredRequest)
149150
if errors.Is(err, ErrStepSkipped) {
150151
stepLog.Info("scheduler: filter skipped")
151152
continue
@@ -168,6 +169,7 @@ func (p *filterWeigherPipeline[RequestType]) runFilters(
168169

169170
// Execute weighers and collect their activations by step name.
170171
func (p *filterWeigherPipeline[RequestType]) runWeighers(
172+
ctx context.Context,
171173
log *slog.Logger,
172174
filteredRequest RequestType,
173175
) map[string]map[string]float64 {
@@ -181,7 +183,7 @@ func (p *filterWeigherPipeline[RequestType]) runWeighers(
181183
wg.Go(func() {
182184
stepLog := log.With("weigher", weigherName)
183185
stepLog.Info("scheduler: running weigher")
184-
result, err := weigher.Run(stepLog, filteredRequest)
186+
result, err := weigher.Run(ctx, stepLog, filteredRequest)
185187
if errors.Is(err, ErrStepSkipped) {
186188
stepLog.Info("scheduler: weigher skipped")
187189
return
@@ -262,7 +264,7 @@ func (s *filterWeigherPipeline[RequestType]) sortHostsByWeights(weights map[stri
262264
}
263265

264266
// Evaluate the pipeline and return a list of hosts in order of preference.
265-
func (p *filterWeigherPipeline[RequestType]) Run(request RequestType) (v1alpha1.DecisionResult, error) {
267+
func (p *filterWeigherPipeline[RequestType]) Run(ctx context.Context, request RequestType) (v1alpha1.DecisionResult, error) {
266268
slogArgs := request.GetTraceLogArgs()
267269
slogArgsAny := make([]any, 0, len(slogArgs))
268270
for _, arg := range slogArgs {
@@ -279,7 +281,7 @@ func (p *filterWeigherPipeline[RequestType]) Run(request RequestType) (v1alpha1.
279281

280282
// Run filters first to reduce the number of hosts.
281283
// Any weights assigned to filtered out hosts are ignored.
282-
filteredRequest, filterStepResults := p.runFilters(traceLog, request)
284+
filteredRequest, filterStepResults := p.runFilters(ctx, traceLog, request)
283285
traceLog.Info(
284286
"scheduler: finished filters",
285287
"remainingHosts", filteredRequest.GetHosts(),
@@ -290,7 +292,7 @@ func (p *filterWeigherPipeline[RequestType]) Run(request RequestType) (v1alpha1.
290292
for _, host := range filteredRequest.GetHosts() {
291293
remainingWeights[host] = inWeights[host]
292294
}
293-
stepWeights := p.runWeighers(traceLog, filteredRequest)
295+
stepWeights := p.runWeighers(ctx, traceLog, filteredRequest)
294296
outWeights := p.applyWeights(traceLog, stepWeights, remainingWeights)
295297
traceLog.Info("scheduler: output weights", "weights", outWeights)
296298

internal/scheduling/lib/filter_weigher_pipeline_step.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type FilterWeigherPipelineStep[RequestType FilterWeigherPipelineRequest] interfa
3030
//
3131
// A traceLog is provided that contains the global request id and should
3232
// be used to log the step's execution.
33-
Run(traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
33+
Run(ctx context.Context, traceLog *slog.Logger, request RequestType) (*FilterWeigherPipelineStepResult, error)
3434
}
3535

3636
// Common base for all steps that provides some functionality

internal/scheduling/lib/filter_weigher_pipeline_step_monitor.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package lib
55

66
import (
7+
"context"
78
"fmt"
89
"log/slog"
910
"maps"
@@ -63,6 +64,7 @@ func monitorStep[RequestType FilterWeigherPipelineRequest](stepName string, m Fi
6364

6465
// Run the step and observe its execution.
6566
func (s *FilterWeigherPipelineStepMonitor[RequestType]) RunWrapped(
67+
ctx context.Context,
6668
traceLog *slog.Logger,
6769
request RequestType,
6870
step FilterWeigherPipelineStep[RequestType],
@@ -74,7 +76,7 @@ func (s *FilterWeigherPipelineStepMonitor[RequestType]) RunWrapped(
7476
}
7577

7678
inWeights := request.GetWeights()
77-
stepResult, err := step.Run(traceLog, request)
79+
stepResult, err := step.Run(ctx, traceLog, request)
7880
if err != nil {
7981
return nil, err
8082
}

internal/scheduling/lib/filter_weigher_pipeline_step_monitor_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package lib
55

66
import (
7+
"context"
78
"log/slog"
89
"os"
910
"testing"
@@ -28,7 +29,7 @@ func TestStepMonitorRun(t *testing.T) {
2829
removedHostsObserver: removedHostsObserver,
2930
}
3031
step := &mockWeigher[mockFilterWeigherPipelineRequest]{
31-
RunFunc: func(traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
32+
RunFunc: func(_ context.Context, traceLog *slog.Logger, request mockFilterWeigherPipelineRequest) (*FilterWeigherPipelineStepResult, error) {
3233
return &FilterWeigherPipelineStepResult{
3334
Activations: map[string]float64{"host1": 0.1, "host2": 1.0, "host3": 0.0},
3435
}, nil
@@ -38,7 +39,7 @@ func TestStepMonitorRun(t *testing.T) {
3839
Hosts: []string{"host1", "host2", "host3"},
3940
Weights: map[string]float64{"host1": 0.2, "host2": 0.1, "host3": 0.0},
4041
}
41-
if _, err := monitor.RunWrapped(slog.Default(), request, step); err != nil {
42+
if _, err := monitor.RunWrapped(t.Context(), slog.Default(), request, step); err != nil {
4243
t.Fatalf("Run() error = %v, want nil", err)
4344
}
4445
if len(removedHostsObserver.Observations) != 1 {

0 commit comments

Comments
 (0)