@@ -152,6 +152,9 @@ func (r *baseRegistry) List(_ context.Context) ([]capabilities.BaseCapability, e
152152}
153153
154154func (r * baseRegistry ) Add (ctx context.Context , c capabilities.BaseCapability ) error {
155+ if c == nil {
156+ return errors .New ("cannot add a nil capability to the registry" )
157+ }
155158 r .mu .Lock ()
156159 defer r .mu .Unlock ()
157160 info , err := c .Info (ctx )
@@ -175,11 +178,11 @@ func (r *baseRegistry) Add(ctx context.Context, c capabilities.BaseCapability) e
175178 var ac atomicBaseCapability
176179 switch info .CapabilityType {
177180 case capabilities .CapabilityTypeTrigger :
178- ac = & atomicTriggerCapability {}
181+ ac = newAtomicTriggerCapability ()
179182 case capabilities .CapabilityTypeAction , capabilities .CapabilityTypeConsensus , capabilities .CapabilityTypeTarget :
180183 ac = & atomicExecuteCapability {}
181184 case capabilities .CapabilityTypeCombined :
182- ac = & atomicExecuteAndTriggerCapability {}
185+ ac = newAtomicExecuteAndTriggerCapability ()
183186 default :
184187 return fmt .Errorf ("unknown capability type: %s" , info .CapabilityType )
185188 }
@@ -206,23 +209,180 @@ func (r *baseRegistry) Remove(_ context.Context, id string) error {
206209 return nil
207210}
208211
212+ type triggerRegistration struct {
213+ request capabilities.TriggerRegistrationRequest
214+ outCh chan capabilities.TriggerResponse
215+ cancel context.CancelFunc // used to shut down the forwarding goroutine when the trigger is unregistered
216+ }
217+
218+ // Caches all trigger registrations and replays them when the underlying capability is updated.
219+ // Owns channels passed to the higher layer (Engine or Don2Don) and goroutines forwarding events
220+ // from the underlying capability.
221+ type triggerRegistrationManager struct {
222+ mu sync.RWMutex
223+ regs map [string ]* triggerRegistration
224+ }
225+
226+ func newTriggerRegistrationManager () * triggerRegistrationManager {
227+ return & triggerRegistrationManager {
228+ regs : make (map [string ]* triggerRegistration ),
229+ }
230+ }
231+
232+ func (m * triggerRegistrationManager ) register (ctx context.Context , exec capabilities.TriggerExecutable , req capabilities.TriggerRegistrationRequest ) (<- chan capabilities.TriggerResponse , error ) {
233+ if exec == nil {
234+ return nil , errors .New ("capability unavailable" )
235+ }
236+ in , err := exec .RegisterTrigger (ctx , req )
237+ if err != nil {
238+ return nil , err
239+ }
240+
241+ m .mu .Lock ()
242+ reg , ok := m .regs [req .TriggerID ]
243+ if ! ok {
244+ reg = & triggerRegistration {
245+ request : req ,
246+ outCh : make (chan capabilities.TriggerResponse ),
247+ }
248+ m .regs [req .TriggerID ] = reg
249+ } else {
250+ reg .request = req
251+ if reg .cancel != nil {
252+ reg .cancel ()
253+ }
254+ }
255+ ctxForward , cancel := context .WithCancel (context .Background ())
256+ reg .cancel = cancel
257+ out := reg .outCh
258+ m .mu .Unlock ()
259+
260+ go forwardTriggerResponses (ctxForward , in , out )
261+
262+ return out , nil
263+ }
264+
265+ func (m * triggerRegistrationManager ) unregister (ctx context.Context , exec capabilities.TriggerExecutable , req capabilities.TriggerRegistrationRequest ) error {
266+ if exec == nil {
267+ return errors .New ("capability unavailable" )
268+ }
269+
270+ var out chan capabilities.TriggerResponse
271+ m .mu .Lock ()
272+ if reg , ok := m .regs [req .TriggerID ]; ok {
273+ if reg .cancel != nil {
274+ reg .cancel ()
275+ }
276+ out = reg .outCh
277+ delete (m .regs , req .TriggerID )
278+ }
279+ m .mu .Unlock ()
280+
281+ if out != nil {
282+ close (out )
283+ }
284+ return exec .UnregisterTrigger (ctx , req )
285+ }
286+
287+ func (m * triggerRegistrationManager ) rebind (oldExec , newExec capabilities.TriggerExecutable ) error {
288+ m .mu .RLock ()
289+ regs := make ([]* triggerRegistration , 0 , len (m .regs ))
290+ for _ , reg := range m .regs {
291+ regs = append (regs , & triggerRegistration {
292+ request : reg .request ,
293+ outCh : reg .outCh ,
294+ cancel : reg .cancel ,
295+ })
296+ }
297+ m .mu .RUnlock ()
298+
299+ for _ , reg := range regs {
300+ if reg .cancel != nil {
301+ reg .cancel ()
302+ }
303+ if oldExec != nil {
304+ _ = oldExec .UnregisterTrigger (context .Background (), reg .request )
305+ }
306+ if newExec == nil {
307+ continue
308+ }
309+
310+ in , err := newExec .RegisterTrigger (context .Background (), reg .request )
311+ if err != nil {
312+ return fmt .Errorf ("failed to re-register trigger %s: %w" , reg .request .TriggerID , err )
313+ }
314+
315+ m .mu .Lock ()
316+ regInMap , ok := m .regs [reg .request .TriggerID ]
317+ if ! ok {
318+ regInMap = & triggerRegistration {
319+ request : reg .request ,
320+ outCh : reg .outCh ,
321+ }
322+ m .regs [reg .request .TriggerID ] = regInMap
323+ } else {
324+ regInMap .request = reg .request
325+ regInMap .outCh = reg .outCh
326+ if regInMap .cancel != nil {
327+ regInMap .cancel ()
328+ }
329+ }
330+ ctxForward , cancel := context .WithCancel (context .Background ())
331+ regInMap .cancel = cancel
332+ out := regInMap .outCh
333+ m .mu .Unlock ()
334+
335+ go forwardTriggerResponses (ctxForward , in , out )
336+ }
337+ return nil
338+ }
339+
340+ func forwardTriggerResponses (ctx context.Context , in <- chan capabilities.TriggerResponse , out chan <- capabilities.TriggerResponse ) {
341+ for {
342+ select {
343+ case <- ctx .Done ():
344+ return
345+ case resp , ok := <- in :
346+ if ! ok {
347+ return
348+ }
349+ select {
350+ case <- ctx .Done ():
351+ return
352+ case out <- resp :
353+ }
354+ }
355+ }
356+ }
357+
209358var _ capabilities.TriggerCapability = & atomicTriggerCapability {}
210359
211360type atomicTriggerCapability struct {
212361 atomic.Pointer [capabilities.TriggerCapability ]
362+ registrations * triggerRegistrationManager
363+ }
364+
365+ func newAtomicTriggerCapability () * atomicTriggerCapability {
366+ return & atomicTriggerCapability {
367+ registrations : newTriggerRegistrationManager (),
368+ }
213369}
214370
215371func (a * atomicTriggerCapability ) Update (c capabilities.BaseCapability ) error {
372+ var prev capabilities.TriggerExecutable
373+ if existing := a .Load (); existing != nil {
374+ prev = * existing
375+ }
216376 if c == nil {
217377 a .Store (nil )
218- return nil
378+ return a . registrations . rebind ( prev , nil )
219379 }
220380 tc , ok := c .(capabilities.TriggerCapability )
221381 if ! ok {
222382 return errors .New ("trigger capability does not satisfy TriggerCapability interface" )
223383 }
224384 a .Store (& tc )
225- return nil
385+ return a . registrations . rebind ( prev , tc )
226386}
227387
228388func (a * atomicTriggerCapability ) Info (ctx context.Context ) (capabilities.CapabilityInfo , error ) {
@@ -249,15 +409,15 @@ func (a *atomicTriggerCapability) RegisterTrigger(ctx context.Context, request c
249409 if c == nil {
250410 return nil , errors .New ("capability unavailable" )
251411 }
252- return ( * c ). RegisterTrigger (ctx , request )
412+ return a . registrations . register (ctx , * c , request )
253413}
254414
255415func (a * atomicTriggerCapability ) UnregisterTrigger (ctx context.Context , request capabilities.TriggerRegistrationRequest ) error {
256416 c := a .Load ()
257417 if c == nil {
258418 return errors .New ("capability unavailable" )
259419 }
260- return ( * c ). UnregisterTrigger (ctx , request )
420+ return a . registrations . unregister (ctx , * c , request )
261421}
262422
263423var _ capabilities.ExecutableCapability = & atomicExecuteCapability {}
@@ -326,19 +486,30 @@ var _ capabilities.ExecutableAndTriggerCapability = &atomicExecuteAndTriggerCapa
326486
327487type atomicExecuteAndTriggerCapability struct {
328488 atomic.Pointer [capabilities.ExecutableAndTriggerCapability ]
489+ registrations * triggerRegistrationManager
490+ }
491+
492+ func newAtomicExecuteAndTriggerCapability () * atomicExecuteAndTriggerCapability {
493+ return & atomicExecuteAndTriggerCapability {
494+ registrations : newTriggerRegistrationManager (),
495+ }
329496}
330497
331498func (a * atomicExecuteAndTriggerCapability ) Update (c capabilities.BaseCapability ) error {
499+ var prev capabilities.TriggerExecutable
500+ if existing := a .Load (); existing != nil {
501+ prev = * existing
502+ }
332503 if c == nil {
333504 a .Store (nil )
334- return nil
505+ return a . registrations . rebind ( prev , nil )
335506 }
336507 tc , ok := c .(capabilities.ExecutableAndTriggerCapability )
337508 if ! ok {
338509 return errors .New ("target capability does not satisfy ExecutableAndTriggerCapability interface" )
339510 }
340511 a .Store (& tc )
341- return nil
512+ return a . registrations . rebind ( prev , tc )
342513}
343514
344515func (a * atomicExecuteAndTriggerCapability ) Info (ctx context.Context ) (capabilities.CapabilityInfo , error ) {
@@ -365,15 +536,15 @@ func (a *atomicExecuteAndTriggerCapability) RegisterTrigger(ctx context.Context,
365536 if c == nil {
366537 return nil , errors .New ("capability unavailable" )
367538 }
368- return ( * c ). RegisterTrigger (ctx , request )
539+ return a . registrations . register (ctx , * c , request )
369540}
370541
371542func (a * atomicExecuteAndTriggerCapability ) UnregisterTrigger (ctx context.Context , request capabilities.TriggerRegistrationRequest ) error {
372543 c := a .Load ()
373544 if c == nil {
374545 return errors .New ("capability unavailable" )
375546 }
376- return ( * c ). UnregisterTrigger (ctx , request )
547+ return a . registrations . unregister (ctx , * c , request )
377548}
378549
379550func (a * atomicExecuteAndTriggerCapability ) RegisterToWorkflow (ctx context.Context , request capabilities.RegisterToWorkflowRequest ) error {
0 commit comments