Skip to content

Commit da5bccc

Browse files
committed
Send full requirements to modules that accept them
1 parent e578b91 commit da5bccc

5 files changed

Lines changed: 117 additions & 30 deletions

File tree

pkg/workflows/host/module.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ type Module interface {
2323
Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error)
2424
}
2525

26+
type RequirementEnforcingModule interface {
27+
Module
28+
29+
// SetRequirements must respect the requirements for the execution until it completes
30+
SetRequirements(executionId string, requirements *sdkpb.Requirements)
31+
}
32+
2633
// ExecutionHelper Implemented by those running the host, for example the Workflow Engine
2734
type ExecutionHelper interface {
2835
// CallCapability blocking call to the Workflow Engine

pkg/workflows/host/requirement_selecting_module.go

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,64 +27,64 @@ func (l *lazyModule) ensureStarted() {
2727
})
2828
}
2929

30+
// NewRequirementSelectingModule creates a module that routes trigger executions
31+
// based on subscription requirements. main is prepended as modules[0]; additional
32+
// modules follow. Subscribe always runs on modules[0].
3033
func NewRequirementSelectingModule(main ModuleAndHandler, additional []ModuleAndHandler) Module {
31-
wrapped := make([]*lazyModule, len(additional))
32-
for i := range additional {
33-
wrapped[i] = &lazyModule{ModuleAndHandler: additional[i]}
34-
}
35-
return &requirementSelectingModule{
36-
main: main,
37-
additional: wrapped,
34+
modules := make([]*lazyModule, 1+len(additional))
35+
modules[0] = &lazyModule{ModuleAndHandler: main}
36+
for i, a := range additional {
37+
modules[1+i] = &lazyModule{ModuleAndHandler: a}
3838
}
39+
return &requirementSelectingModule{modules: modules}
40+
}
41+
42+
type triggerInfo struct {
43+
moduleIdx int
44+
requirements *sdk.Requirements
3945
}
4046

4147
type requirementSelectingModule struct {
42-
main ModuleAndHandler
43-
additional []*lazyModule
44-
// triggerID → index into additional
48+
modules []*lazyModule
49+
// triggerID → triggerInfo
4550
cache sync.Map
4651
}
4752

4853
func (r *requirementSelectingModule) Start() {
49-
r.main.Start()
54+
r.modules[0].ensureStarted()
5055
}
5156

5257
func (r *requirementSelectingModule) Close() {
53-
r.main.Close()
54-
for _, m := range r.additional {
58+
for _, m := range r.modules {
5559
if m.started {
5660
m.Close()
5761
}
5862
}
5963
}
6064

6165
func (r *requirementSelectingModule) IsLegacyDAG() bool {
62-
return r.main.IsLegacyDAG()
66+
return r.modules[0].IsLegacyDAG()
6367
}
6468

6569
func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) {
66-
if triggerID, ok := extractTriggerID(request); ok {
67-
if idx, cached := r.cache.Load(triggerID); cached {
68-
return r.additional[idx.(int)].Execute(ctx, request, handler)
69-
}
70-
return r.main.Execute(ctx, request, handler)
70+
if request.GetTrigger() == nil {
71+
return r.subscribe(ctx, request, handler)
7172
}
73+
return r.trigger(ctx, request, handler)
74+
}
7275

73-
// Subscribe: run main, then build triggerID→module cache from subscription requirements
74-
result, err := r.main.Execute(ctx, request, handler)
76+
func (r *requirementSelectingModule) subscribe(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) {
77+
result, err := r.modules[0].Execute(ctx, request, handler)
7578
if err != nil {
7679
return nil, err
7780
}
7881

7982
for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() {
80-
if sub.Requirements == nil || CheckRequirements(ctx, r.main.RequirementsHandler, sub.Requirements) {
81-
continue
82-
}
8383
matched := false
84-
for j, m := range r.additional {
84+
for j, m := range r.modules {
8585
if CheckRequirements(ctx, m.RequirementsHandler, sub.Requirements) {
8686
m.ensureStarted()
87-
r.cache.Store(uint64(i), j)
87+
r.cache.Store(uint64(i), triggerInfo{moduleIdx: j, requirements: sub.Requirements})
8888
matched = true
8989
break
9090
}
@@ -97,11 +97,18 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E
9797
return result, nil
9898
}
9999

100-
func extractTriggerID(req *sdk.ExecuteRequest) (uint64, bool) {
101-
if t := req.GetTrigger(); t != nil {
102-
return t.Id, true
100+
func (r *requirementSelectingModule) trigger(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) {
101+
trigger := request.GetTrigger()
102+
if val, cached := r.cache.Load(trigger.Id); cached {
103+
info := val.(triggerInfo)
104+
m := r.modules[info.moduleIdx]
105+
if rem, ok := m.Module.(RequirementEnforcingModule); ok && info.requirements != nil {
106+
rem.SetRequirements(handler.GetWorkflowExecutionID(), info.requirements)
107+
}
108+
109+
return m.Execute(ctx, request, handler)
103110
}
104-
return 0, false
111+
return r.modules[0].Execute(ctx, request, handler)
105112
}
106113

107114
var _ Module = &requirementSelectingModule{}

pkg/workflows/host/requirement_selecting_module_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h Exe
2727
return s.executeFn(ctx, req, h)
2828
}
2929

30+
type requirementEnforcingStub struct {
31+
*stubModule
32+
setRequirementsFn func(string, *sdk.Requirements)
33+
}
34+
35+
func (s *requirementEnforcingStub) SetRequirements(executionID string, requirements *sdk.Requirements) {
36+
s.setRequirementsFn(executionID, requirements)
37+
}
38+
3039
func noop() {}
3140
func noopClose() {}
3241

@@ -362,6 +371,62 @@ func TestRequirementSelectingModule_Execute(t *testing.T) {
362371
assert.Equal(t, want, got)
363372
assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls), "trigger should run on main")
364373
})
374+
375+
t.Run("cached trigger sets requirements before execute", func(t *testing.T) {
376+
teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}}
377+
want := &sdk.ExecutionResult{}
378+
executionID := "wf-exec-1"
379+
380+
main := ModuleAndHandler{
381+
Module: &stubModule{
382+
startFn: noop,
383+
executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) {
384+
return subscribeResult(subWithReqs(teeReqs)), nil
385+
},
386+
},
387+
RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }},
388+
}
389+
390+
var calls []string
391+
var gotReqs *sdk.Requirements
392+
var gotExecutionID string
393+
enforcingAdd := &requirementEnforcingStub{
394+
stubModule: &stubModule{
395+
startFn: noop,
396+
closeFn: noopClose,
397+
executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) {
398+
calls = append(calls, "execute")
399+
return want, nil
400+
},
401+
},
402+
setRequirementsFn: func(id string, requirements *sdk.Requirements) {
403+
calls = append(calls, "set")
404+
gotExecutionID = id
405+
gotReqs = requirements
406+
},
407+
}
408+
add := ModuleAndHandler{
409+
Module: enforcingAdd,
410+
RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }},
411+
}
412+
413+
m := NewRequirementSelectingModule(main, []ModuleAndHandler{add})
414+
m.Start()
415+
416+
helper := &MockExecutionHelper{}
417+
helper.On("GetWorkflowExecutionID").Return(executionID).Once()
418+
419+
_, err := m.Execute(t.Context(), subscribeRequest(), nil)
420+
require.NoError(t, err)
421+
422+
got, err := m.Execute(t.Context(), triggerRequest(0), helper)
423+
require.NoError(t, err)
424+
assert.Equal(t, want, got)
425+
assert.Equal(t, []string{"set", "execute"}, calls)
426+
assert.Equal(t, executionID, gotExecutionID)
427+
assert.Same(t, teeReqs, gotReqs)
428+
helper.AssertExpectations(t)
429+
})
365430
}
366431

367432
func TestRequirementSelectingModule_TriggerCache(t *testing.T) {

pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ type RequirementsHandler struct {
1414
// non-nil field in req, returning false if any are false, or if the handler is nil.
1515
// Unknown fields on the proto also result in a false return value.
1616
func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool {
17+
if req == nil {
18+
return true
19+
}
20+
1721
if len(req.ProtoReflect().GetUnknown()) != 0 {
1822
return false
1923
}

pkg/workflows/host/requirements_helper_gen.go

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)