@@ -27,64 +27,64 @@ func (l *lazyModule) ensureStarted() {
2727 })
2828}
2929
30+ // NewRequirementSelectingModule creates a module that routes trigger executions
31+ // based on subscription requirements. main is prepended as modules[0]; additional
32+ // modules follow. Subscribe always runs on modules[0].
3033func NewRequirementSelectingModule (main ModuleAndHandler , additional []ModuleAndHandler ) Module {
31- wrapped := make ([]* lazyModule , len (additional ))
32- for i := range additional {
33- wrapped [i ] = & lazyModule {ModuleAndHandler : additional [i ]}
34- }
35- return & requirementSelectingModule {
36- main : main ,
37- additional : wrapped ,
34+ modules := make ([]* lazyModule , 1 + len (additional ))
35+ modules [0 ] = & lazyModule {ModuleAndHandler : main }
36+ for i , a := range additional {
37+ modules [1 + i ] = & lazyModule {ModuleAndHandler : a }
3838 }
39+ return & requirementSelectingModule {modules : modules }
40+ }
41+
42+ type triggerInfo struct {
43+ moduleIdx int
44+ requirements * sdk.Requirements
3945}
4046
4147type requirementSelectingModule struct {
42- main ModuleAndHandler
43- additional []* lazyModule
44- // triggerID → index into additional
48+ modules []* lazyModule
49+ // triggerID → triggerInfo
4550 cache sync.Map
4651}
4752
4853func (r * requirementSelectingModule ) Start () {
49- r .main . Start ()
54+ r .modules [ 0 ]. ensureStarted ()
5055}
5156
5257func (r * requirementSelectingModule ) Close () {
53- r .main .Close ()
54- for _ , m := range r .additional {
58+ for _ , m := range r .modules {
5559 if m .started {
5660 m .Close ()
5761 }
5862 }
5963}
6064
6165func (r * requirementSelectingModule ) IsLegacyDAG () bool {
62- return r .main .IsLegacyDAG ()
66+ return r .modules [ 0 ] .IsLegacyDAG ()
6367}
6468
6569func (r * requirementSelectingModule ) Execute (ctx context.Context , request * sdk.ExecuteRequest , handler ExecutionHelper ) (* sdk.ExecutionResult , error ) {
66- if triggerID , ok := extractTriggerID (request ); ok {
67- if idx , cached := r .cache .Load (triggerID ); cached {
68- return r .additional [idx .(int )].Execute (ctx , request , handler )
69- }
70- return r .main .Execute (ctx , request , handler )
70+ if request .GetTrigger () == nil {
71+ return r .subscribe (ctx , request , handler )
7172 }
73+ return r .trigger (ctx , request , handler )
74+ }
7275
73- // Subscribe: run main, then build triggerID→module cache from subscription requirements
74- result , err := r .main .Execute (ctx , request , handler )
76+ func ( r * requirementSelectingModule ) subscribe ( ctx context. Context , request * sdk. ExecuteRequest , handler ExecutionHelper ) ( * sdk. ExecutionResult , error ) {
77+ result , err := r .modules [ 0 ] .Execute (ctx , request , handler )
7578 if err != nil {
7679 return nil , err
7780 }
7881
7982 for i , sub := range result .GetTriggerSubscriptions ().GetSubscriptions () {
80- if sub .Requirements == nil || CheckRequirements (ctx , r .main .RequirementsHandler , sub .Requirements ) {
81- continue
82- }
8383 matched := false
84- for j , m := range r .additional {
84+ for j , m := range r .modules {
8585 if CheckRequirements (ctx , m .RequirementsHandler , sub .Requirements ) {
8686 m .ensureStarted ()
87- r .cache .Store (uint64 (i ), j )
87+ r .cache .Store (uint64 (i ), triggerInfo { moduleIdx : j , requirements : sub . Requirements } )
8888 matched = true
8989 break
9090 }
@@ -97,11 +97,18 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E
9797 return result , nil
9898}
9999
100- func extractTriggerID (req * sdk.ExecuteRequest ) (uint64 , bool ) {
101- if t := req .GetTrigger (); t != nil {
102- return t .Id , true
100+ func (r * requirementSelectingModule ) trigger (ctx context.Context , request * sdk.ExecuteRequest , handler ExecutionHelper ) (* sdk.ExecutionResult , error ) {
101+ trigger := request .GetTrigger ()
102+ if val , cached := r .cache .Load (trigger .Id ); cached {
103+ info := val .(triggerInfo )
104+ m := r .modules [info .moduleIdx ]
105+ if rem , ok := m .Module .(RequirementEnforcingModule ); ok && info .requirements != nil {
106+ rem .SetRequirements (handler .GetWorkflowExecutionID (), info .requirements )
107+ }
108+
109+ return m .Execute (ctx , request , handler )
103110 }
104- return 0 , false
111+ return r . modules [ 0 ]. Execute ( ctx , request , handler )
105112}
106113
107114var _ Module = & requirementSelectingModule {}
0 commit comments