@@ -152,6 +152,9 @@ func (r *baseRegistry) List(_ context.Context) ([]capabilities.BaseCapability, e
152152}
153153
154154func (r * baseRegistry ) Add (ctx context.Context , c capabilities.BaseCapability ) error {
155+ if c == nil {
156+ return errors .New ("cannot add a nil capability to the registry" )
157+ }
155158 r .mu .Lock ()
156159 defer r .mu .Unlock ()
157160 info , err := c .Info (ctx )
@@ -175,11 +178,11 @@ func (r *baseRegistry) Add(ctx context.Context, c capabilities.BaseCapability) e
175178 var ac atomicBaseCapability
176179 switch info .CapabilityType {
177180 case capabilities .CapabilityTypeTrigger :
178- ac = & atomicTriggerCapability {}
181+ ac = newAtomicTriggerCapability ()
179182 case capabilities .CapabilityTypeAction , capabilities .CapabilityTypeConsensus , capabilities .CapabilityTypeTarget :
180183 ac = & atomicExecuteCapability {}
181184 case capabilities .CapabilityTypeCombined :
182- ac = & atomicExecuteAndTriggerCapability {}
185+ ac = newAtomicExecuteAndTriggerCapability ()
183186 default :
184187 return fmt .Errorf ("unknown capability type: %s" , info .CapabilityType )
185188 }
@@ -206,23 +209,161 @@ func (r *baseRegistry) Remove(_ context.Context, id string) error {
206209 return nil
207210}
208211
212+ type triggerRegistration struct {
213+ request capabilities.TriggerRegistrationRequest
214+ outCh chan capabilities.TriggerResponse
215+ cancel context.CancelFunc // used to shut down the forwarding goroutine when the trigger is unregistered
216+ }
217+
218+ // Caches all trigger registrations and replays them when the underlying capability is updated.
219+ // Owns channels passed to the higher layer (Engine or Don2Don) and goroutines forwarding events
220+ // from the underlying capability.
221+ type triggerRegistrationManager struct {
222+ mu sync.RWMutex
223+ regs map [string ]* triggerRegistration
224+ bindMu sync.Mutex
225+ }
226+
227+ func newTriggerRegistrationManager () * triggerRegistrationManager {
228+ return & triggerRegistrationManager {
229+ regs : make (map [string ]* triggerRegistration ),
230+ }
231+ }
232+
233+ func (m * triggerRegistrationManager ) register (ctx context.Context , loadExec func () (capabilities.TriggerExecutable , error ), req capabilities.TriggerRegistrationRequest ) (<- chan capabilities.TriggerResponse , error ) {
234+ m .bindMu .Lock ()
235+ defer m .bindMu .Unlock ()
236+ exec , err := loadExec ()
237+ if err != nil {
238+ return nil , err
239+ }
240+ in , err := exec .RegisterTrigger (ctx , req )
241+ if err != nil {
242+ return nil , err
243+ }
244+
245+ return m .upsertRegistration (req , nil , in ), nil
246+ }
247+
248+ func (m * triggerRegistrationManager ) unregister (ctx context.Context , exec capabilities.TriggerExecutable , req capabilities.TriggerRegistrationRequest ) error {
249+ m .bindMu .Lock ()
250+ defer m .bindMu .Unlock ()
251+ var out chan capabilities.TriggerResponse
252+ m .mu .Lock ()
253+ if reg , ok := m .regs [req .TriggerID ]; ok {
254+ if reg .cancel != nil {
255+ reg .cancel ()
256+ }
257+ out = reg .outCh
258+ delete (m .regs , req .TriggerID )
259+ }
260+ m .mu .Unlock ()
261+
262+ if out != nil {
263+ close (out )
264+ }
265+ return exec .UnregisterTrigger (ctx , req )
266+ }
267+
268+ func (m * triggerRegistrationManager ) upsertRegistration (req capabilities.TriggerRegistrationRequest , outCh chan capabilities.TriggerResponse , in <- chan capabilities.TriggerResponse ) chan capabilities.TriggerResponse {
269+ m .mu .Lock ()
270+ defer m .mu .Unlock ()
271+ regInMap , ok := m .regs [req .TriggerID ]
272+ if ! ok {
273+ if outCh == nil {
274+ outCh = make (chan capabilities.TriggerResponse )
275+ }
276+ regInMap = & triggerRegistration {
277+ request : req ,
278+ outCh : outCh ,
279+ }
280+ m .regs [req .TriggerID ] = regInMap
281+ } else {
282+ regInMap .request = req
283+ if outCh != nil {
284+ regInMap .outCh = outCh
285+ }
286+ if regInMap .cancel != nil {
287+ regInMap .cancel () // shuts down the previous forwarding goroutine
288+ }
289+ }
290+ ctxForward , cancel := context .WithCancel (context .Background ())
291+ regInMap .cancel = cancel
292+ go forwardTriggerResponses (ctxForward , in , regInMap .outCh )
293+ return regInMap .outCh
294+ }
295+
296+ func (m * triggerRegistrationManager ) rebind (newExec capabilities.TriggerExecutable ) error {
297+ m .bindMu .Lock ()
298+ defer m .bindMu .Unlock ()
299+ m .mu .RLock ()
300+ regs := make ([]* triggerRegistration , 0 , len (m .regs ))
301+ for _ , reg := range m .regs {
302+ regs = append (regs , & triggerRegistration {
303+ request : reg .request ,
304+ outCh : reg .outCh ,
305+ cancel : reg .cancel ,
306+ })
307+ }
308+ m .mu .RUnlock ()
309+
310+ for _ , reg := range regs {
311+ if reg .cancel != nil {
312+ reg .cancel ()
313+ }
314+ if newExec != nil {
315+ in , err := newExec .RegisterTrigger (context .Background (), reg .request )
316+ if err != nil {
317+ return fmt .Errorf ("failed to re-register trigger %s: %w" , reg .request .TriggerID , err )
318+ }
319+ _ = m .upsertRegistration (reg .request , reg .outCh , in )
320+ }
321+ }
322+ return nil
323+ }
324+
325+ func forwardTriggerResponses (ctx context.Context , in <- chan capabilities.TriggerResponse , out chan <- capabilities.TriggerResponse ) {
326+ for {
327+ select {
328+ case <- ctx .Done ():
329+ return
330+ case resp , ok := <- in :
331+ if ! ok {
332+ return
333+ }
334+ select {
335+ case <- ctx .Done ():
336+ return
337+ case out <- resp :
338+ }
339+ }
340+ }
341+ }
342+
209343var _ capabilities.TriggerCapability = & atomicTriggerCapability {}
210344
211345type atomicTriggerCapability struct {
212346 atomic.Pointer [capabilities.TriggerCapability ]
347+ registrations * triggerRegistrationManager
348+ }
349+
350+ func newAtomicTriggerCapability () * atomicTriggerCapability {
351+ return & atomicTriggerCapability {
352+ registrations : newTriggerRegistrationManager (),
353+ }
213354}
214355
215356func (a * atomicTriggerCapability ) Update (c capabilities.BaseCapability ) error {
216357 if c == nil {
217358 a .Store (nil )
218- return nil
359+ return a . registrations . rebind ( nil )
219360 }
220361 tc , ok := c .(capabilities.TriggerCapability )
221362 if ! ok {
222363 return errors .New ("trigger capability does not satisfy TriggerCapability interface" )
223364 }
224365 a .Store (& tc )
225- return nil
366+ return a . registrations . rebind ( tc )
226367}
227368
228369func (a * atomicTriggerCapability ) Info (ctx context.Context ) (capabilities.CapabilityInfo , error ) {
@@ -245,19 +386,22 @@ func (a *atomicTriggerCapability) GetState() connectivity.State {
245386}
246387
247388func (a * atomicTriggerCapability ) RegisterTrigger (ctx context.Context , request capabilities.TriggerRegistrationRequest ) (<- chan capabilities.TriggerResponse , error ) {
248- c := a .Load ()
249- if c == nil {
250- return nil , errors .New ("capability unavailable" )
389+ loadExec := func () (capabilities.TriggerExecutable , error ) {
390+ c := a .Load ()
391+ if c == nil {
392+ return nil , errors .New ("capability unavailable" )
393+ }
394+ return * c , nil
251395 }
252- return ( * c ). RegisterTrigger (ctx , request )
396+ return a . registrations . register (ctx , loadExec , request )
253397}
254398
255399func (a * atomicTriggerCapability ) UnregisterTrigger (ctx context.Context , request capabilities.TriggerRegistrationRequest ) error {
256400 c := a .Load ()
257401 if c == nil {
258402 return errors .New ("capability unavailable" )
259403 }
260- return ( * c ). UnregisterTrigger (ctx , request )
404+ return a . registrations . unregister (ctx , * c , request )
261405}
262406
263407var _ capabilities.ExecutableCapability = & atomicExecuteCapability {}
@@ -326,19 +470,26 @@ var _ capabilities.ExecutableAndTriggerCapability = &atomicExecuteAndTriggerCapa
326470
327471type atomicExecuteAndTriggerCapability struct {
328472 atomic.Pointer [capabilities.ExecutableAndTriggerCapability ]
473+ registrations * triggerRegistrationManager
474+ }
475+
476+ func newAtomicExecuteAndTriggerCapability () * atomicExecuteAndTriggerCapability {
477+ return & atomicExecuteAndTriggerCapability {
478+ registrations : newTriggerRegistrationManager (),
479+ }
329480}
330481
331482func (a * atomicExecuteAndTriggerCapability ) Update (c capabilities.BaseCapability ) error {
332483 if c == nil {
333484 a .Store (nil )
334- return nil
485+ return a . registrations . rebind ( nil )
335486 }
336487 tc , ok := c .(capabilities.ExecutableAndTriggerCapability )
337488 if ! ok {
338489 return errors .New ("target capability does not satisfy ExecutableAndTriggerCapability interface" )
339490 }
340491 a .Store (& tc )
341- return nil
492+ return a . registrations . rebind ( tc )
342493}
343494
344495func (a * atomicExecuteAndTriggerCapability ) Info (ctx context.Context ) (capabilities.CapabilityInfo , error ) {
@@ -361,19 +512,22 @@ func (a *atomicExecuteAndTriggerCapability) GetState() connectivity.State {
361512}
362513
363514func (a * atomicExecuteAndTriggerCapability ) RegisterTrigger (ctx context.Context , request capabilities.TriggerRegistrationRequest ) (<- chan capabilities.TriggerResponse , error ) {
364- c := a .Load ()
365- if c == nil {
366- return nil , errors .New ("capability unavailable" )
515+ loadExec := func () (capabilities.TriggerExecutable , error ) {
516+ c := a .Load ()
517+ if c == nil {
518+ return nil , errors .New ("capability unavailable" )
519+ }
520+ return * c , nil
367521 }
368- return ( * c ). RegisterTrigger (ctx , request )
522+ return a . registrations . register (ctx , loadExec , request )
369523}
370524
371525func (a * atomicExecuteAndTriggerCapability ) UnregisterTrigger (ctx context.Context , request capabilities.TriggerRegistrationRequest ) error {
372526 c := a .Load ()
373527 if c == nil {
374528 return errors .New ("capability unavailable" )
375529 }
376- return ( * c ). UnregisterTrigger (ctx , request )
530+ return a . registrations . unregister (ctx , * c , request )
377531}
378532
379533func (a * atomicExecuteAndTriggerCapability ) RegisterToWorkflow (ctx context.Context , request capabilities.RegisterToWorkflowRequest ) error {
0 commit comments