diff --git a/chasm/lib/activity/activity.go b/chasm/lib/activity/activity.go index 9605fde4d2e..d2ad6175b40 100644 --- a/chasm/lib/activity/activity.go +++ b/chasm/lib/activity/activity.go @@ -6,7 +6,9 @@ import ( "slices" "time" + "github.com/nexus-rpc/sdk-go/nexus" apiactivitypb "go.temporal.io/api/activity/v1" //nolint:importas + callbackpb "go.temporal.io/api/callback/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" @@ -18,10 +20,14 @@ import ( tokenspb "go.temporal.io/server/api/token/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/chasm/lib/activity/gen/activitypb/v1" + "go.temporal.io/server/chasm/lib/callback" + callbackspb "go.temporal.io/server/chasm/lib/callback/gen/callbackpb/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + commonnexus "go.temporal.io/server/common/nexus" + "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/payload" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/tqid" @@ -41,6 +47,7 @@ var ( ) var _ chasm.VisibilitySearchAttributesProvider = (*Activity)(nil) +var _ callback.CompletionSource = (*Activity)(nil) type ActivityStore interface { // RecordCompleted applies the provided function to record activity completion @@ -65,6 +72,10 @@ type Activity struct { // implements the ActivityStore interface). // TODO(saa-preview): figure out better naming. Store chasm.ParentPtr[ActivityStore] + + // Callbacks holds completion callbacks to be invoked when this standalone activity reaches a terminal state. Nil + // for workflow-embedded activities as the workflow handles its own callbacks. + Callbacks chasm.Map[string, *callback.Callback] } // WithToken wraps a request with its deserialized task token. @@ -256,8 +267,116 @@ func attemptScheduleTimeForRetry(attempt *activitypb.ActivityAttemptState) *time } // RecordCompleted applies the provided function to record activity completion. +// For standalone activities, it also triggers any registered completion callbacks. func (a *Activity) RecordCompleted(ctx chasm.MutableContext, applyFn func(ctx chasm.MutableContext) error) error { - return applyFn(ctx) + if err := applyFn(ctx); err != nil { + return err + } + return callback.ScheduleStandbyCallbacks(ctx, a.Callbacks) +} + +func (a *Activity) addCompletionCallbacks( + ctx chasm.MutableContext, + requestID string, + completionCallbacks []*commonpb.Callback, + maxCallbacks int, +) error { + if len(completionCallbacks) == 0 { + return nil + } + if a.LifecycleState(ctx).IsClosed() { + return serviceerror.NewFailedPrecondition("cannot attach callbacks to a closed activity") + } + + currentCount := len(a.Callbacks) + if len(completionCallbacks)+currentCount > maxCallbacks { + return serviceerror.NewFailedPreconditionf( + "cannot attach more than %d callbacks to an activity (%d callbacks already attached)", + maxCallbacks, + currentCount, + ) + } + + if a.Callbacks == nil { + a.Callbacks = make(chasm.Map[string, *callback.Callback], len(completionCallbacks)) + } + + registrationTime := timestamppb.New(ctx.Now(a)) + + for idx, cb := range completionCallbacks { + chasmCB := &callbackspb.Callback{ + Links: cb.GetLinks(), + } + switch variant := cb.Variant.(type) { + case *commonpb.Callback_Nexus_: + chasmCB.Variant = &callbackspb.Callback_Nexus_{ + Nexus: &callbackspb.Callback_Nexus{ + Url: variant.Nexus.GetUrl(), + Header: variant.Nexus.GetHeader(), + }, + } + default: + return serviceerror.NewInvalidArgumentf("unsupported callback variant: %T", variant) + } + + // requestID (unique per API call) + idx (position within the request) ensures unique,idempotent callback IDs. + id := fmt.Sprintf("%s-%d", requestID, idx) + callbackObj := callback.NewCallback(requestID, registrationTime, &callbackspb.CallbackState{}, chasmCB) + a.Callbacks[id] = chasm.NewComponentField(ctx, callbackObj) + } + return nil +} + +// GetNexusCompletion returns the activity's completion data in the format required by the Nexus callback invocation. +// Implements callback.CompletionSource. +func (a *Activity) GetNexusCompletion(ctx chasm.Context, _ string) (nexusrpc.CompleteOperationOptions, error) { + if !a.LifecycleState(ctx).IsClosed() { + return nexusrpc.CompleteOperationOptions{}, serviceerror.NewInternal("activity has not completed yet") + } + + opts := nexusrpc.CompleteOperationOptions{ + StartTime: a.GetScheduleTime().AsTime(), + CloseTime: ctx.ExecutionInfo().CloseTime, + } + + outcome := a.Outcome.Get(ctx) + if successful := outcome.GetSuccessful(); successful != nil { + // Successful completion: return the first output payload as the result as Nexus supports only a single payload + var p *commonpb.Payload + if payloads := successful.GetOutput().GetPayloads(); len(payloads) > 0 { + p = payloads[0] + } + opts.Result = p + return opts, nil + } + + failure := a.terminalFailure(ctx) + if failure != nil { + state := nexus.OperationStateFailed + message := "operation failed" + if a.Status == activitypb.ACTIVITY_EXECUTION_STATUS_CANCELED { + state = nexus.OperationStateCanceled + message = "operation canceled" + } + + nf, err := commonnexus.TemporalFailureToNexusFailure(failure) + if err != nil { + return nexusrpc.CompleteOperationOptions{}, serviceerror.NewInternalf("failed to convert failure: %v", err) + } + + opErr := &nexus.OperationError{ + State: state, + Message: message, + Cause: &nexus.FailureError{Failure: nf}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nexusrpc.CompleteOperationOptions{}, err + } + opts.Error = opErr + return opts, nil + } + + return nexusrpc.CompleteOperationOptions{}, serviceerror.NewInternalf("activity in status %v has no outcome", a.Status) } // HandleCompleted updates the activity on activity completion. @@ -716,11 +835,17 @@ func (a *Activity) buildDescribeActivityExecutionResponse( input = a.RequestData.Get(ctx).GetInput() } + callbackInfos, err := a.buildCallbackInfos(ctx) + if err != nil { + return nil, err + } + response := &workflowservice.DescribeActivityExecutionResponse{ Info: info, RunId: ctx.ExecutionKey().RunID, Input: input, LongPollToken: token, + Callbacks: callbackInfos, } if request.GetIncludeOutcome() { @@ -732,6 +857,56 @@ func (a *Activity) buildDescribeActivityExecutionResponse( }, nil } +func (a *Activity) buildCallbackInfos(ctx chasm.Context) ([]*apiactivitypb.CallbackInfo, error) { + if len(a.Callbacks) == 0 { + return nil, nil + } + + cbInfos := make([]*apiactivitypb.CallbackInfo, 0, len(a.Callbacks)) + for _, field := range a.Callbacks { + cb := field.Get(ctx) + + cbSpec, err := cb.ToAPICallback() + if err != nil { + return nil, err + } + + var state enumspb.CallbackState + switch cb.Status { + case callbackspb.CALLBACK_STATUS_UNSPECIFIED: + return nil, serviceerror.NewInternal("callback with UNSPECIFIED state") + case callbackspb.CALLBACK_STATUS_STANDBY: + state = enumspb.CALLBACK_STATE_STANDBY + case callbackspb.CALLBACK_STATUS_SCHEDULED: + state = enumspb.CALLBACK_STATE_SCHEDULED + case callbackspb.CALLBACK_STATUS_BACKING_OFF: + state = enumspb.CALLBACK_STATE_BACKING_OFF + case callbackspb.CALLBACK_STATUS_FAILED: + state = enumspb.CALLBACK_STATE_FAILED + case callbackspb.CALLBACK_STATUS_SUCCEEDED: + state = enumspb.CALLBACK_STATE_SUCCEEDED + default: + return nil, serviceerror.NewInternalf("unknown callback state: %v", cb.Status) + } + + cbInfos = append(cbInfos, &apiactivitypb.CallbackInfo{ + Trigger: &apiactivitypb.CallbackInfo_Trigger{ + Variant: &apiactivitypb.CallbackInfo_Trigger_ActivityClosed{}, + }, + Info: &callbackpb.CallbackInfo{ + Callback: cbSpec, + RegistrationTime: cb.RegistrationTime, + State: state, + Attempt: cb.Attempt, + LastAttemptCompleteTime: cb.LastAttemptCompleteTime, + LastAttemptFailure: cb.LastAttemptFailure, + NextAttemptScheduleTime: cb.NextAttemptScheduleTime, + }, + }) + } + return cbInfos, nil +} + func (a *Activity) buildPollActivityExecutionResponse( ctx chasm.Context, ) *activitypb.PollActivityExecutionResponse { @@ -755,15 +930,23 @@ func (a *Activity) outcome(ctx chasm.Context) *apiactivitypb.ActivityExecutionOu Value: &apiactivitypb.ActivityExecutionOutcome_Result{Result: successful.GetOutput()}, } } - if failure := activityOutcome.GetFailed().GetFailure(); failure != nil { + if failure := a.terminalFailure(ctx); failure != nil { return &apiactivitypb.ActivityExecutionOutcome{ Value: &apiactivitypb.ActivityExecutionOutcome_Failure{Failure: failure}, } } + return nil +} + +// terminalFailure returns the failure for a closed activity. The failure may be stored in Outcome.Failed +// (terminated, canceled, timed out) or in LastAttempt.LastFailureDetails (failed after exhausting retries). +// Returns nil if no failure is found. +func (a *Activity) terminalFailure(ctx chasm.Context) *failurepb.Failure { + if f := a.Outcome.Get(ctx).GetFailed(); f != nil { + return f.GetFailure() + } if details := a.LastAttempt.Get(ctx).GetLastFailureDetails(); details != nil { - return &apiactivitypb.ActivityExecutionOutcome{ - Value: &apiactivitypb.ActivityExecutionOutcome_Failure{Failure: details.GetFailure()}, - } + return details.GetFailure() } return nil } diff --git a/chasm/lib/activity/config.go b/chasm/lib/activity/config.go index 763b030b910..7336f253143 100644 --- a/chasm/lib/activity/config.go +++ b/chasm/lib/activity/config.go @@ -3,6 +3,7 @@ package activity import ( "time" + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/retrypolicy" ) @@ -37,6 +38,7 @@ type Config struct { LongPollBuffer dynamicconfig.DurationPropertyFnWithNamespaceFilter LongPollTimeout dynamicconfig.DurationPropertyFnWithNamespaceFilter MaxIDLengthLimit dynamicconfig.IntPropertyFn + MaxCallbacksPerExecution dynamicconfig.IntPropertyFnWithNamespaceFilter DefaultActivityRetryPolicy dynamicconfig.TypedPropertyFnWithNamespaceFilter[retrypolicy.DefaultRetrySettings] VisibilityMaxPageSize dynamicconfig.IntPropertyFnWithNamespaceFilter } @@ -51,6 +53,7 @@ func ConfigProvider(dc *dynamicconfig.Collection) *Config { LongPollBuffer: LongPollBuffer.Get(dc), LongPollTimeout: LongPollTimeout.Get(dc), MaxIDLengthLimit: dynamicconfig.MaxIDLengthLimit.Get(dc), + MaxCallbacksPerExecution: callback.MaxPerExecution.Get(dc), VisibilityMaxPageSize: dynamicconfig.FrontendVisibilityMaxPageSize.Get(dc), } } diff --git a/chasm/lib/activity/frontend.go b/chasm/lib/activity/frontend.go index f2633672037..aa62c77bb1b 100644 --- a/chasm/lib/activity/frontend.go +++ b/chasm/lib/activity/frontend.go @@ -11,6 +11,7 @@ import ( "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/chasm/lib/activity/gen/activitypb/v1" + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/common" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" @@ -37,6 +38,7 @@ var ErrStandaloneActivityDisabled = serviceerror.NewUnimplemented("Standalone ac type frontendHandler struct { FrontendHandler + callbackValidator *callback.Validator client activitypb.ActivityServiceClient config *Config logger log.Logger @@ -48,6 +50,7 @@ type frontendHandler struct { // NewFrontendHandler creates a new FrontendHandler instance for processing activity frontend requests. func NewFrontendHandler( + callbackValidator *callback.Validator, client activitypb.ActivityServiceClient, config *Config, logger log.Logger, @@ -57,6 +60,7 @@ func NewFrontendHandler( saValidator *searchattribute.Validator, ) FrontendHandler { return &frontendHandler{ + callbackValidator: callbackValidator, client: client, config: config, logger: logger, @@ -392,6 +396,12 @@ func (h *frontendHandler) validateAndPopulateStartRequest( return nil, err } + if cbs := req.GetCompletionCallbacks(); len(cbs) > 0 { + if err := h.callbackValidator.Validate(req.GetNamespace(), cbs); err != nil { + return nil, err + } + } + return req, nil } diff --git a/chasm/lib/activity/handler.go b/chasm/lib/activity/handler.go index f97a0f1a9a0..378776ebca4 100644 --- a/chasm/lib/activity/handler.go +++ b/chasm/lib/activity/handler.go @@ -4,6 +4,7 @@ import ( "context" "errors" + commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/api/workflowservice/v1" @@ -61,11 +62,13 @@ func (h *handler) StartActivityExecution(ctx context.Context, req *activitypb.St return nil, serviceerror.NewInvalidArgumentf("unsupported ID conflict policy: %v", frontendReq.GetIdConflictPolicy()) } + maxCallbacks := h.config.MaxCallbacksPerExecution(frontendReq.GetNamespace()) + result, err := chasm.StartExecution( ctx, chasm.ExecutionKey{ NamespaceID: req.GetNamespaceId(), - BusinessID: req.GetFrontendRequest().GetActivityId(), + BusinessID: frontendReq.GetActivityId(), }, func(mutableContext chasm.MutableContext, request *workflowservice.StartActivityExecutionRequest) (*Activity, error) { newActivity, err := NewStandaloneActivity(mutableContext, request) @@ -73,6 +76,12 @@ func (h *handler) StartActivityExecution(ctx context.Context, req *activitypb.St return nil, err } + if cbs := request.GetCompletionCallbacks(); len(cbs) > 0 { + if err := newActivity.addCompletionCallbacks(mutableContext, request.GetRequestId(), cbs, maxCallbacks); err != nil { + return nil, err + } + } + err = TransitionScheduled.Apply(newActivity, mutableContext, nil) if err != nil { return nil, err @@ -80,8 +89,8 @@ func (h *handler) StartActivityExecution(ctx context.Context, req *activitypb.St return newActivity, nil }, - req.GetFrontendRequest(), - chasm.WithRequestID(req.GetFrontendRequest().GetRequestId()), + frontendReq, + chasm.WithRequestID(frontendReq.GetRequestId()), chasm.WithBusinessIDPolicy(reusePolicy, conflictPolicy), ) @@ -94,10 +103,38 @@ func (h *handler) StartActivityExecution(ctx context.Context, req *activitypb.St return nil, err } + // Attach callbacks to an existing activity when on_conflict_options.attach_completion_callbacks is set. + // TODO: Use chasm.UpdateWithStartExecution to avoid a second transaction once the engine supports BusinessIDConflictPolicyFail in the updateFn path. + cbs := frontendReq.GetCompletionCallbacks() + if !result.Created && frontendReq.GetOnConflictOptions().GetAttachCompletionCallbacks() && len(cbs) > 0 { + requestID := frontendReq.GetRequestId() + ref := chasm.NewComponentRef[*Activity](result.ExecutionKey) + _, _, err := chasm.UpdateComponent( + ctx, + ref, + func(a *Activity, ctx chasm.MutableContext, _ any) (any, error) { + return nil, a.addCompletionCallbacks(ctx, requestID, cbs, maxCallbacks) + }, + nil, + ) + if err != nil { + return nil, err + } + } + return &activitypb.StartActivityExecutionResponse{ FrontendResponse: &workflowservice.StartActivityExecutionResponse{ RunId: result.ExecutionKey.RunID, Started: result.Created, + Link: &commonpb.Link{ + Variant: &commonpb.Link_Activity_{ + Activity: &commonpb.Link_Activity{ + Namespace: frontendReq.GetNamespace(), + ActivityId: frontendReq.GetActivityId(), + RunId: result.ExecutionKey.RunID, + }, + }, + }, // EagerTask: TODO when supported, need to call the same code that would handle the HandleStarted API }, }, nil diff --git a/chasm/lib/callback/component.go b/chasm/lib/callback/component.go index 0f64150f687..c017a7d2f30 100644 --- a/chasm/lib/callback/component.go +++ b/chasm/lib/callback/component.go @@ -163,3 +163,19 @@ func (c *Callback) ToAPICallback() (*commonpb.Callback, error) { // This should not happen as CHASM only supports Nexus callbacks currently return nil, serviceerror.NewInternal("unsupported CHASM callback type") } + +// ScheduleStandbyCallbacks transitions all STANDBY callbacks to SCHEDULED state, +// triggering their invocation. Used by both workflows and standalone activities +// when the execution reaches a terminal state. +func ScheduleStandbyCallbacks(ctx chasm.MutableContext, callbacks chasm.Map[string, *Callback]) error { + for _, field := range callbacks { + cb := field.Get(ctx) + if cb.Status != callbackspb.CALLBACK_STATUS_STANDBY { + continue + } + if err := TransitionScheduled.Apply(cb, ctx, EventScheduled{}); err != nil { + return err + } + } + return nil +} diff --git a/chasm/lib/callback/config.go b/chasm/lib/callback/config.go index c0b4aadf485..844add8d671 100644 --- a/chasm/lib/callback/config.go +++ b/chasm/lib/callback/config.go @@ -14,6 +14,12 @@ import ( "google.golang.org/grpc/status" ) +var MaxPerExecution = dynamicconfig.NewNamespaceIntSetting( + "callback.maxPerExecution", + 2000, + `MaxPerExecution is the maximum number of callbacks that can be attached to an execution (workflow or standalone activity).`, +) + var RequestTimeout = dynamicconfig.NewDestinationDurationSetting( "callback.request.timeout", time.Second*10, diff --git a/chasm/lib/callback/validator.go b/chasm/lib/callback/validator.go new file mode 100644 index 00000000000..6a1c58ec46e --- /dev/null +++ b/chasm/lib/callback/validator.go @@ -0,0 +1,79 @@ +package callback + +import ( + "fmt" + "strings" + + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/api/serviceerror" + "go.temporal.io/server/common/dynamicconfig" + "google.golang.org/grpc/status" +) + +// Validator validates completion callbacks attached to executions (workflows and standalone activities). +type Validator struct { + maxCallbacksPerExecution dynamicconfig.IntPropertyFnWithNamespaceFilter + urlMaxLength dynamicconfig.IntPropertyFnWithNamespaceFilter + headerMaxSize dynamicconfig.IntPropertyFnWithNamespaceFilter + endpointRules dynamicconfig.TypedPropertyFnWithNamespaceFilter[AddressMatchRules] +} + +func NewValidator( + maxCallbacksPerExecution dynamicconfig.IntPropertyFnWithNamespaceFilter, + urlMaxLength dynamicconfig.IntPropertyFnWithNamespaceFilter, + headerMaxSize dynamicconfig.IntPropertyFnWithNamespaceFilter, + endpointRules dynamicconfig.TypedPropertyFnWithNamespaceFilter[AddressMatchRules], +) *Validator { + return &Validator{ + maxCallbacksPerExecution: maxCallbacksPerExecution, + urlMaxLength: urlMaxLength, + headerMaxSize: headerMaxSize, + endpointRules: endpointRules, + } +} + +// Validate validates completion callbacks: count, URL length, endpoint allowlist, header size, and normalizes header +// keys to lowercase. +func (v *Validator) Validate(namespaceName string, cbs []*commonpb.Callback) error { + if len(cbs) > v.maxCallbacksPerExecution(namespaceName) { + return serviceerror.NewInvalidArgumentf( + "cannot attach more than %d callbacks to an execution", v.maxCallbacksPerExecution(namespaceName), + ) + } + + for _, cb := range cbs { + switch variant := cb.GetVariant().(type) { + case *commonpb.Callback_Nexus_: + rawURL := variant.Nexus.GetUrl() + if len(rawURL) > v.urlMaxLength(namespaceName) { + return serviceerror.NewInvalidArgumentf( + "invalid url: url length longer than max length allowed of %d", v.urlMaxLength(namespaceName), + ) + } + if err := v.endpointRules(namespaceName).Validate(rawURL); err != nil { + if s, ok := status.FromError(err); ok { + return serviceerror.NewInvalidArgument(s.Message()) + } + return serviceerror.NewInvalidArgument(err.Error()) + } + + headerSize := 0 + lowerCaseHeaders := make(map[string]string, len(variant.Nexus.GetHeader())) + for k, val := range variant.Nexus.GetHeader() { + headerSize += len(k) + len(val) + lowerCaseHeaders[strings.ToLower(k)] = val + } + if headerSize > v.headerMaxSize(namespaceName) { + return serviceerror.NewInvalidArgumentf( + "invalid header: header size longer than max allowed size of %d", v.headerMaxSize(namespaceName), + ) + } + variant.Nexus.Header = lowerCaseHeaders + case *commonpb.Callback_Internal_: + continue + default: + return serviceerror.NewUnimplemented(fmt.Sprintf("unknown callback variant: %T", variant)) + } + } + return nil +} diff --git a/chasm/lib/callback/validator_test.go b/chasm/lib/callback/validator_test.go new file mode 100644 index 00000000000..55ae7dd5fd9 --- /dev/null +++ b/chasm/lib/callback/validator_test.go @@ -0,0 +1,152 @@ +package callback + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/api/serviceerror" +) + +func TestValidateCallbacks(t *testing.T) { + allowAll := AddressMatchRules{ + Rules: []AddressMatchRule{ + {Regexp: regexp.MustCompile(`.*`), AllowInsecure: true}, + }, + } + v := NewValidator( + func(string) int { return 10 }, + func(string) int { return 1000 }, + func(string) int { return 4096 }, + func(string) AddressMatchRules { return allowAll }, + ) + + t.Run("ValidNexusCallback", func(t *testing.T) { + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{ + Nexus: &commonpb.Callback_Nexus{ + Url: "http://localhost:8080/callback", + Header: map[string]string{"Content-Type": "application/json"}, + }, + }}, + } + err := v.Validate("ns", cbs) + require.NoError(t, err) + }) + + t.Run("TooManyCallbacks", func(t *testing.T) { + v := NewValidator( + func(string) int { return 1 }, + func(string) int { return 1000 }, + func(string) int { return 4096 }, + func(string) AddressMatchRules { return allowAll }, + ) + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/cb1"}}}, + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/cb2"}}}, + } + err := v.Validate("ns", cbs) + var invalidArgErr *serviceerror.InvalidArgument + require.ErrorAs(t, err, &invalidArgErr) + require.Contains(t, err.Error(), "cannot attach more than 1 callbacks") + }) + + t.Run("URLTooLong", func(t *testing.T) { + v := NewValidator( + func(string) int { return 10 }, + func(string) int { return 50 }, + func(string) int { return 4096 }, + func(string) AddressMatchRules { return allowAll }, + ) + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{ + Nexus: &commonpb.Callback_Nexus{ + Url: "http://localhost/" + string(make([]byte, 51)), + }, + }}, + } + err := v.Validate("ns", cbs) + var invalidArgErr *serviceerror.InvalidArgument + require.ErrorAs(t, err, &invalidArgErr) + require.Contains(t, err.Error(), "url length longer than max length allowed") + }) + + t.Run("HeaderTooLarge", func(t *testing.T) { + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{ + Nexus: &commonpb.Callback_Nexus{ + Url: "http://localhost:8080/callback", + Header: map[string]string{"X-Large": string(make([]byte, 5000))}, + }, + }}, + } + err := v.Validate("ns", cbs) + var invalidArgErr *serviceerror.InvalidArgument + require.ErrorAs(t, err, &invalidArgErr) + require.Contains(t, err.Error(), "header size longer than max allowed size") + }) + + t.Run("HeaderKeysNormalizedToLowercase", func(t *testing.T) { + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{ + Nexus: &commonpb.Callback_Nexus{ + Url: "http://localhost:8080/callback", + Header: map[string]string{"Content-Type": "application/json", "X-Custom": "value"}, + }, + }}, + } + err := v.Validate("ns", cbs) + require.NoError(t, err) + nexus := cbs[0].GetNexus() + require.Equal(t, "application/json", nexus.Header["content-type"]) + require.Equal(t, "value", nexus.Header["x-custom"]) + _, hasMixed := nexus.Header["Content-Type"] + require.False(t, hasMixed) + }) + + t.Run("URLNotInAllowlist", func(t *testing.T) { + v := NewValidator( + func(string) int { return 10 }, + func(string) int { return 1000 }, + func(string) int { return 4096 }, + func(string) AddressMatchRules { return AddressMatchRules{} }, + ) + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{ + Nexus: &commonpb.Callback_Nexus{ + Url: "http://localhost:8080/callback", + }, + }}, + } + err := v.Validate("ns", cbs) + var invalidArgErr *serviceerror.InvalidArgument + require.ErrorAs(t, err, &invalidArgErr) + require.Contains(t, err.Error(), "does not match any configured callback address") + }) + + t.Run("UnsupportedVariant", func(t *testing.T) { + cbs := []*commonpb.Callback{ + {Variant: nil}, + } + err := v.Validate("ns", cbs) + var unimplementedErr *serviceerror.Unimplemented + require.ErrorAs(t, err, &unimplementedErr) + require.Contains(t, err.Error(), "unknown callback variant") + }) + + t.Run("EmptyCallbacksNoError", func(t *testing.T) { + err := v.Validate("ns", nil) + require.NoError(t, err) + }) + + t.Run("InternalCallbackSkipped", func(t *testing.T) { + cbs := []*commonpb.Callback{ + {Variant: &commonpb.Callback_Internal_{ + Internal: &commonpb.Callback_Internal{}, + }}, + } + err := v.Validate("ns", cbs) + require.NoError(t, err) + }) +} diff --git a/chasm/lib/workflow/workflow.go b/chasm/lib/workflow/workflow.go index 9094c6d2d88..7bbd305e837 100644 --- a/chasm/lib/workflow/workflow.go +++ b/chasm/lib/workflow/workflow.go @@ -70,24 +70,6 @@ func (w *Workflow) Terminate( return chasm.TerminateComponentResponse{}, serviceerror.NewInternal("workflow root Terminate should not be called") } -// ProcessCloseCallbacks triggers "WorkflowClosed" callbacks using the CHASM implementation. -// It iterates through all callbacks and schedules WorkflowClosed ones that are in STANDBY state. -func (w *Workflow) ProcessCloseCallbacks(ctx chasm.MutableContext) error { - // Iterate through all callbacks and schedule WorkflowClosed ones - for _, field := range w.Callbacks { - cb := field.Get(ctx) - // Only process callbacks in STANDBY state (not already triggered) - if cb.Status != callbackspb.CALLBACK_STATUS_STANDBY { - continue - } - // Trigger the callback by transitioning to SCHEDULED state - if err := callback.TransitionScheduled.Apply(cb, ctx, callback.EventScheduled{}); err != nil { - return err - } - } - return nil -} - // AddCompletionCallbacks creates completion callbacks using the CHASM implementation. // maxCallbacksPerWorkflow is the configured maximum number of callbacks allowed per workflow. func (w *Workflow) AddCompletionCallbacks( @@ -129,6 +111,9 @@ func (w *Workflow) AddCompletionCallbacks( return serviceerror.NewInvalidArgumentf("unsupported callback variant: %T", variant) } + // requestID (unique per API call) + idx (position within the request) ensures unique, idempotent callback IDs. + // Unlike HSM callbacks, CHASM replicates entire trees rather than replaying events, so deterministic + // cross-cluster IDs based on event version are not needed. id := fmt.Sprintf("%s-%d", requestID, idx) // Create and add callback diff --git a/cmd/tools/getproto/files.go b/cmd/tools/getproto/files.go index 6ef97ed180e..e333a06f297 100644 --- a/cmd/tools/getproto/files.go +++ b/cmd/tools/getproto/files.go @@ -1,4 +1,3 @@ - // Code generated by getproto. DO NOT EDIT. // If you get build errors in this file, just delete it. It will be regenerated. @@ -9,6 +8,7 @@ import ( activity "go.temporal.io/api/activity/v1" batch "go.temporal.io/api/batch/v1" + callback "go.temporal.io/api/callback/v1" command "go.temporal.io/api/command/v1" common "go.temporal.io/api/common/v1" compute "go.temporal.io/api/compute/v1" @@ -49,6 +49,7 @@ func init() { importMap["google/protobuf/wrappers.proto"] = wrapperspb.File_google_protobuf_wrappers_proto importMap["temporal/api/activity/v1/message.proto"] = activity.File_temporal_api_activity_v1_message_proto importMap["temporal/api/batch/v1/message.proto"] = batch.File_temporal_api_batch_v1_message_proto + importMap["temporal/api/callback/v1/message.proto"] = callback.File_temporal_api_callback_v1_message_proto importMap["temporal/api/command/v1/message.proto"] = command.File_temporal_api_command_v1_message_proto importMap["temporal/api/common/v1/message.proto"] = common.File_temporal_api_common_v1_message_proto importMap["temporal/api/compute/v1/config.proto"] = compute.File_temporal_api_compute_v1_config_proto diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index eee31ed89f2..4e596234782 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -1010,13 +1010,6 @@ so forwarding by endpoint ID will not work out of the box.`, 32, `MaxCallbacksPerWorkflow is the maximum number of callbacks that can be attached to a workflow.`, ) - // NOTE (seankane): MaxCHASMCallbacksPerWorkflow is temporary, this will be removed and replaced with MaxCallbacksPerWorkflow - // once CHASM is fully enabled - MaxCHASMCallbacksPerWorkflow = NewNamespaceIntSetting( - "system.maxCHASMCallbacksPerWorkflow", - 2000, - `MaxCHASMCallbacksPerWorkflow is the maximum number of callbacks that can be attached to a workflow when using the CHASM implementation.`, - ) FrontendLinkMaxSize = NewNamespaceIntSetting( "frontend.linkMaxSize", 4000, // Links may include a workflow ID and namespace name, both of which are limited to a length of 1000. diff --git a/components/callbacks/config.go b/components/callbacks/config.go index 6c6c5b9258a..8b6a52e5239 100644 --- a/components/callbacks/config.go +++ b/components/callbacks/config.go @@ -68,7 +68,7 @@ type AddressMatchRules struct { Rules []AddressMatchRule } -func (a AddressMatchRules) Validate(rawURL string) error { +func (a AddressMatchRules) validate(rawURL string) error { // Exact match only; no path, query, or fragment allowed for system URL if rawURL == nexus.SystemCallbackURL || rawURL == chasm.NexusCompletionHandlerURL { return nil @@ -84,7 +84,7 @@ func (a AddressMatchRules) Validate(rawURL string) error { return status.Errorf(codes.InvalidArgument, "invalid url: missing host") } for _, rule := range a.Rules { - allow, err := rule.Allow(u) + allow, err := rule.allow(u) if err != nil { return err } @@ -105,7 +105,7 @@ type AddressMatchRule struct { // for the given rule. // 2. false, nil if the URL does not match the rule. // 3. It false, error if there is a match and the URL fails validation -func (a AddressMatchRule) Allow(u *url.URL) (bool, error) { +func (a AddressMatchRule) allow(u *url.URL) (bool, error) { if !a.Regexp.MatchString(u.Host) { return false, nil } diff --git a/components/callbacks/config_test.go b/components/callbacks/config_test.go index 0c0043d875c..5f209f9f4ae 100644 --- a/components/callbacks/config_test.go +++ b/components/callbacks/config_test.go @@ -324,7 +324,7 @@ func TestAddressMatchRules_Validate(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { rules := AddressMatchRules{Rules: tt.args.rules} - tt.validateErr(t, rules.Validate(tt.args.rawURL)) + tt.validateErr(t, rules.validate(tt.args.rawURL)) }) } } diff --git a/go.mod b/go.mod index 6ded7a3143d..64f91c65f76 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( go.opentelemetry.io/otel/sdk v1.40.0 go.opentelemetry.io/otel/sdk/metric v1.40.0 go.opentelemetry.io/otel/trace v1.40.0 - go.temporal.io/api v1.62.8 + go.temporal.io/api v1.62.10-0.20260415205944-dbe8c077fbf1 go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2 go.temporal.io/sdk v1.41.1 go.uber.org/fx v1.24.0 diff --git a/go.sum b/go.sum index 597b47e7c3f..6b0c1677e65 100644 --- a/go.sum +++ b/go.sum @@ -442,8 +442,8 @@ go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZY go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= -go.temporal.io/api v1.62.8 h1:g8RAZmdebYODoNa2GLA4M4TsXNe1096WV3n26C4+fdw= -go.temporal.io/api v1.62.8/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/api v1.62.10-0.20260415205944-dbe8c077fbf1 h1:YM1RgEu5BHI5HnuRoAsZ+3UDbreTdmPu0oXTIl2PKEs= +go.temporal.io/api v1.62.10-0.20260415205944-dbe8c077fbf1/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2 h1:1hKeH3GyR6YD6LKMHGCZ76t6h1Sgha0hXVQBxWi3dlQ= go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2/go.mod h1:T8dnzVPeO+gaUTj9eDgm/lT2lZH4+JXNvrGaQGyVi50= go.temporal.io/sdk v1.41.1 h1:yOpvsHyDD1lNuwlGBv/SUodCPhjv9nDeC9lLHW/fJUA= diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 893f5d479e3..8946f0954ba 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -8,6 +8,7 @@ import ( "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/chasm/lib/activity" + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1" "go.temporal.io/server/client" "go.temporal.io/server/common" @@ -40,6 +41,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + hsmcallbacks "go.temporal.io/server/components/callbacks" nexusfrontend "go.temporal.io/server/components/nexusoperations/frontend" "go.temporal.io/server/service" "go.temporal.io/server/service/frontend/configs" @@ -101,6 +103,7 @@ var Module = fx.Options( fx.Provide(AuthorizationInterceptorProvider), fx.Provide(NamespaceCheckerProvider), fx.Provide(func(so GrpcServerOptions) *grpc.Server { return grpc.NewServer(so.Options...) }), + fx.Provide(callbackValidatorProvider), fx.Provide(HandlerProvider), fx.Provide(AdminHandlerProvider), fx.Provide(NamespaceDLQHandlerProvider), @@ -787,6 +790,26 @@ func OperatorHandlerProvider( return NewOperatorHandlerImpl(args) } +// callbackValidatorProvider creates a callback Validator using the production dynamic config keys +// so that existing operator configurations (component.callbacks.allowedAddresses) are honored. +// TODO: Once HSM callbacks (components/callbacks) are removed, move this provider into +// chasm/lib/callback/fx.go and read directly from callback.AllowedAddresses. +func callbackValidatorProvider(dc *dynamicconfig.Collection) *callback.Validator { + return callback.NewValidator( + callback.MaxPerExecution.Get(dc), + dynamicconfig.FrontendCallbackURLMaxLength.Get(dc), + dynamicconfig.FrontendCallbackHeaderMaxSize.Get(dc), + func(ns string) callback.AddressMatchRules { + hsmRules := hsmcallbacks.AllowedAddresses.Get(dc)(ns) + chasmRules := make([]callback.AddressMatchRule, len(hsmRules.Rules)) + for i, r := range hsmRules.Rules { + chasmRules[i] = callback.AddressMatchRule{Regexp: r.Regexp, AllowInsecure: r.AllowInsecure} + } + return callback.AddressMatchRules{Rules: chasmRules} + }, + ) +} + func HandlerProvider( cfg *config.Config, serviceName primitives.ServiceName, @@ -820,6 +843,7 @@ func HandlerProvider( healthInterceptor *interceptor.HealthInterceptor, scheduleSpecBuilder *scheduler.SpecBuilder, activityHandler activity.FrontendHandler, + callbackValidator *callback.Validator, registry *chasm.Registry, frontendServiceResolver membership.ServiceResolver, ) Handler { @@ -830,6 +854,7 @@ func HandlerProvider( ) wfHandler := NewWorkflowHandler( + callbackValidator, serviceConfig, namespaceReplicationQueue, visibilityMgr, diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 458b3a6a36e..8857fe6a893 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -36,6 +36,7 @@ import ( taskqueuespb "go.temporal.io/server/api/taskqueue/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/chasm/lib/activity" + "go.temporal.io/server/chasm/lib/callback" chasmscheduler "go.temporal.io/server/chasm/lib/scheduler" "go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1" "go.temporal.io/server/client/frontend" @@ -81,10 +82,8 @@ import ( "go.temporal.io/server/service/worker/dummy" "go.temporal.io/server/service/worker/scheduler" "go.temporal.io/server/service/worker/workerdeployment" - "google.golang.org/grpc/codes" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" @@ -117,6 +116,7 @@ type ( status int32 + callbackValidator *callback.Validator tokenSerializer *tasktoken.Serializer config *Config versionChecker headers.VersionChecker @@ -295,6 +295,7 @@ func (wh *WorkflowHandler) ValidateWorkerDeploymentVersionComputeConfig( // NewWorkflowHandler creates a gRPC handler for workflowservice func NewWorkflowHandler( + callbackValidator *callback.Validator, config *Config, namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityMgr manager.VisibilityManager, @@ -325,11 +326,12 @@ func NewWorkflowHandler( workerDeploymentReadRateLimiter quotas.RequestRateLimiter, ) *WorkflowHandler { handler := &WorkflowHandler{ - FrontendHandler: activityHandler, - status: common.DaemonStatusInitialized, - config: config, - tokenSerializer: tasktoken.NewSerializer(), - versionChecker: headers.NewDefaultVersionChecker(), + FrontendHandler: activityHandler, + status: common.DaemonStatusInitialized, + callbackValidator: callbackValidator, + config: config, + tokenSerializer: tasktoken.NewSerializer(), + versionChecker: headers.NewDefaultVersionChecker(), namespaceHandler: newNamespaceHandler( logger, persistenceMetadataManager, @@ -668,8 +670,10 @@ func (wh *WorkflowHandler) prepareStartWorkflowRequest( return nil, err } - if err := wh.validateWorkflowCompletionCallbacks(namespaceName, request.GetCompletionCallbacks()); err != nil { - return nil, err + if cbs := request.GetCompletionCallbacks(); len(cbs) > 0 { + if err := wh.callbackValidator.Validate(namespaceName.String(), cbs); err != nil { + return nil, err + } } request.Links = dedupLinksFromCallbacks(request.GetLinks(), request.GetCompletionCallbacks()) @@ -6350,61 +6354,6 @@ func (wh *WorkflowHandler) validateLinks( return nil } -func (wh *WorkflowHandler) validateWorkflowCompletionCallbacks( - ns namespace.Name, - callbacks []*commonpb.Callback, -) error { - if len(callbacks) > wh.config.MaxCallbacksPerWorkflow(ns.String()) { - return status.Error( - codes.InvalidArgument, - fmt.Sprintf( - "cannot attach more than %d callbacks to a workflow", - wh.config.MaxCallbacksPerWorkflow(ns.String()), - ), - ) - } - - for _, callback := range callbacks { - switch cb := callback.GetVariant().(type) { - case *commonpb.Callback_Nexus_: - if err := wh.validateCallbackURL(ns, cb.Nexus.GetUrl()); err != nil { - return err - } - - headerSize := 0 - lowerCaseHeaders := make(map[string]string, len(cb.Nexus.GetHeader())) - for k, v := range cb.Nexus.GetHeader() { - headerSize += len(k) + len(v) - lowerCaseHeaders[strings.ToLower(k)] = v - } - if headerSize > wh.config.CallbackHeaderMaxSize(ns.String()) { - return status.Error( - codes.InvalidArgument, - fmt.Sprintf( - "invalid header: header size longer than max allowed size of %d", - wh.config.CallbackHeaderMaxSize(ns.String()), - ), - ) - } - cb.Nexus.Header = lowerCaseHeaders - case *commonpb.Callback_Internal_: - // TODO(Tianyu): For now, there is nothing to validate given that this is an internal field. - continue - default: - return status.Error(codes.Unimplemented, fmt.Sprintf("unknown callback variant: %T", cb)) - } - } - return nil -} - -func (wh *WorkflowHandler) validateCallbackURL(ns namespace.Name, rawURL string) error { - if len(rawURL) > wh.config.CallbackURLMaxLength(ns.String()) { - return status.Errorf(codes.InvalidArgument, "invalid url: url length longer than max length allowed of %d", wh.config.CallbackURLMaxLength(ns.String())) - } - rules := wh.config.CallbackEndpointConfigs(ns.String()) - return rules.Validate(rawURL) -} - type buildIdAndFlag interface { GetBuildId() string GetUseVersioning() bool diff --git a/service/frontend/workflow_handler_test.go b/service/frontend/workflow_handler_test.go index 3d64ea0a492..5f644ddb49b 100644 --- a/service/frontend/workflow_handler_test.go +++ b/service/frontend/workflow_handler_test.go @@ -37,6 +37,7 @@ import ( "go.temporal.io/server/api/matchingservicemock/v1" persistencespb "go.temporal.io/server/api/persistence/v1" taskqueuespb "go.temporal.io/server/api/taskqueue/v1" + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/common" "go.temporal.io/server/common/archiver" "go.temporal.io/server/common/archiver/provider" @@ -167,7 +168,20 @@ func (s *WorkflowHandlerSuite) getWorkflowHandler(config *Config) *WorkflowHandl s.mockVisibilityMgr.EXPECT().GetIndexName().Return(esIndexName).AnyTimes() healthInterceptor := interceptor.NewHealthInterceptor() healthInterceptor.SetHealthy(true) + cbValidator := callback.NewValidator( + func(string) int { return 2000 }, + config.CallbackURLMaxLength, + config.CallbackHeaderMaxSize, + func(string) callback.AddressMatchRules { + return callback.AddressMatchRules{ + Rules: []callback.AddressMatchRule{ + {Regexp: regexp.MustCompile(`.*`), AllowInsecure: true}, + }, + } + }, + ) return NewWorkflowHandler( + cbValidator, config, s.mockProducer, s.mockResource.GetVisibilityManager(), diff --git a/service/history/configs/config.go b/service/history/configs/config.go index 61ecc2d99bc..c7d911df239 100644 --- a/service/history/configs/config.go +++ b/service/history/configs/config.go @@ -1,6 +1,7 @@ package configs import ( + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/namespace" @@ -63,7 +64,7 @@ type Config struct { EnableUpdateWorkflowModeIgnoreCurrent dynamicconfig.BoolPropertyFn EnableTransitionHistory dynamicconfig.BoolPropertyFnWithNamespaceFilter MaxCallbacksPerWorkflow dynamicconfig.IntPropertyFnWithNamespaceFilter - MaxCHASMCallbacksPerWorkflow dynamicconfig.IntPropertyFnWithNamespaceFilter + MaxCallbacksPerExecution dynamicconfig.IntPropertyFnWithNamespaceFilter EnableChasm dynamicconfig.BoolPropertyFnWithNamespaceFilter EnableCHASMCallbacks dynamicconfig.BoolPropertyFnWithNamespaceFilter ChasmMaxInMemoryPureTasks dynamicconfig.IntPropertyFn @@ -483,7 +484,7 @@ func NewConfig( EnableUpdateWorkflowModeIgnoreCurrent: dynamicconfig.EnableUpdateWorkflowModeIgnoreCurrent.Get(dc), EnableTransitionHistory: dynamicconfig.EnableTransitionHistory.Get(dc), MaxCallbacksPerWorkflow: dynamicconfig.MaxCallbacksPerWorkflow.Get(dc), - MaxCHASMCallbacksPerWorkflow: dynamicconfig.MaxCHASMCallbacksPerWorkflow.Get(dc), + MaxCallbacksPerExecution: callback.MaxPerExecution.Get(dc), EnableChasm: dynamicconfig.EnableChasm.Get(dc), ChasmMaxInMemoryPureTasks: dynamicconfig.ChasmMaxInMemoryPureTasks.Get(dc), diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index 48125fb0161..f7dcee427a6 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -35,6 +35,7 @@ import ( tokenspb "go.temporal.io/server/api/token/v1" workflowspb "go.temporal.io/server/api/workflow/v1" "go.temporal.io/server/chasm" + "go.temporal.io/server/chasm/lib/callback" chasmworkflow "go.temporal.io/server/chasm/lib/workflow" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" @@ -3217,7 +3218,7 @@ func (ms *MutableStateImpl) addCompletionCallbacksChasm( return err } - maxCallbacksPerWorkflow := ms.config.MaxCHASMCallbacksPerWorkflow(ms.GetNamespaceEntry().Name().String()) + maxCallbacksPerWorkflow := ms.config.MaxCallbacksPerExecution(ms.GetNamespaceEntry().Name().String()) return wf.AddCompletionCallbacks(ctx, event.EventTime, requestID, completionCallbacks, maxCallbacksPerWorkflow) } @@ -6720,7 +6721,7 @@ func (ms *MutableStateImpl) processCloseCallbacksChasm() error { return err } - return wf.ProcessCloseCallbacks(ctx) + return callback.ScheduleStandbyCallbacks(ctx, wf.Callbacks) } func (ms *MutableStateImpl) AddTasks( diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index d53158e5965..789ee2321e4 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -21,6 +21,7 @@ import ( "go.temporal.io/sdk/client" "go.temporal.io/sdk/worker" "go.temporal.io/sdk/workflow" + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/testing/protoassert" @@ -106,7 +107,7 @@ func (s *CallbacksSuite) TestWorkflowCallbacks_InvalidArgument() { { name: "too many callbacks", urls: []string{"http://url-1", "http://url-2", "http://url-3"}, - message: "cannot attach more than 2 callbacks to a workflow", + message: "cannot attach more than 2 callbacks to an execution", }, { name: "url not configured", @@ -123,6 +124,7 @@ func (s *CallbacksSuite) TestWorkflowCallbacks_InvalidArgument() { s.OverrideDynamicConfig(dynamicconfig.FrontendCallbackURLMaxLength, 50) s.OverrideDynamicConfig(dynamicconfig.FrontendCallbackHeaderMaxSize, 6) s.OverrideDynamicConfig(dynamicconfig.MaxCallbacksPerWorkflow, 2) + s.OverrideDynamicConfig(callback.MaxPerExecution, 2) s.OverrideDynamicConfig( callbacks.AllowedAddresses, []any{map[string]any{"Pattern": "some-ignored-address", "AllowInsecure": true}, map[string]any{"Pattern": "some-secure-address", "AllowInsecure": false}}, diff --git a/tests/standalone_activity_test.go b/tests/standalone_activity_test.go index 95cb57393ba..9ef9b3c4948 100644 --- a/tests/standalone_activity_test.go +++ b/tests/standalone_activity_test.go @@ -4,10 +4,13 @@ import ( "context" "errors" "fmt" + "io" + "net/http/httptest" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" activitypb "go.temporal.io/api/activity/v1" @@ -19,14 +22,19 @@ import ( "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/temporal" "go.temporal.io/server/chasm/lib/activity" + "go.temporal.io/server/chasm/lib/callback" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" + commonnexus "go.temporal.io/server/common/nexus" + "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/payload" "go.temporal.io/server/common/payloads" "go.temporal.io/server/common/tasktoken" "go.temporal.io/server/common/testing/protorequire" "go.temporal.io/server/common/testing/testvars" + "go.temporal.io/server/components/callbacks" "go.temporal.io/server/tests/testcore" "google.golang.org/grpc/codes" "google.golang.org/protobuf/testing/protocmp" @@ -217,14 +225,14 @@ func (s *standaloneActivityTestSuite) TestIDConflictPolicy() { }) t.Run("UseExisting", func(t *testing.T) { - activityID := testcore.RandomizeStr(t.Name()) + originalActivityID := testcore.RandomizeStr(t.Name()) taskQueue := testcore.RandomizeStr(t.Name()) - firstStartResp := s.startAndValidateActivity(ctx, t, activityID, taskQueue) + firstStartResp := s.startAndValidateActivity(ctx, t, originalActivityID, taskQueue) startWithUseExisting := func(requestID string) (*workflowservice.StartActivityExecutionResponse, error) { return s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ Namespace: s.Namespace().String(), - ActivityId: activityID, + ActivityId: originalActivityID, ActivityType: s.tv.ActivityType(), Identity: s.tv.WorkerIdentity(), Input: defaultInput, @@ -242,6 +250,13 @@ func (s *standaloneActivityTestSuite) TestIDConflictPolicy() { require.NoError(t, err) require.Equal(t, firstStartResp.RunId, resp.RunId) require.False(t, resp.GetStarted()) + + // Link should point to the existing activity run. + link := resp.GetLink().GetActivity() + require.NotNil(t, link) + require.Equal(t, s.Namespace().String(), link.Namespace) + require.Equal(t, originalActivityID, link.ActivityId) + require.Equal(t, firstStartResp.RunId, link.RunId) }) t.Run("SecondStartWithSameRequestIdReturnsExistingRun", func(t *testing.T) { resp, err := startWithUseExisting(s.tv.RequestID()) @@ -250,8 +265,132 @@ func (s *standaloneActivityTestSuite) TestIDConflictPolicy() { require.False(t, resp.GetStarted()) }) + t.Run("OnConflictOptions", func(t *testing.T) { + s.OverrideDynamicConfig( + callbacks.AllowedAddresses, + []any{map[string]any{"Pattern": "*", "AllowInsecure": true}}, + ) + + onConflictOpts := &commonpb.OnConflictOptions{ + AttachRequestId: true, + AttachCompletionCallbacks: true, + AttachLinks: true, + } + + t.Run("AttachesToNewActivity", func(t *testing.T) { + newActivityID := testcore.RandomizeStr(t.Name()) + newTaskQueue := testcore.RandomizeStr(t.Name()) + + resp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: newActivityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: newTaskQueue, + }, + StartToCloseTimeout: durationpb.New(1 * time.Minute), + IdConflictPolicy: enumspb.ACTIVITY_ID_CONFLICT_POLICY_USE_EXISTING, + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/new-activity-cb"}}}, + }, + OnConflictOptions: onConflictOpts, + }) + require.NoError(t, err) + require.True(t, resp.GetStarted()) + + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: newActivityID, + RunId: resp.RunId, + }) + require.NoError(t, err) + require.Len(t, descResp.Callbacks, 1) + require.Equal(t, "http://localhost/new-activity-cb", descResp.Callbacks[0].GetInfo().GetCallback().GetNexus().GetUrl()) + }) + + t.Run("AttachesToExistingActivity", func(t *testing.T) { + resp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: originalActivityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(1 * time.Minute), + IdConflictPolicy: enumspb.ACTIVITY_ID_CONFLICT_POLICY_USE_EXISTING, + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/existing-activity-cb"}}}, + }, + OnConflictOptions: onConflictOpts, + }) + require.NoError(t, err) + require.False(t, resp.GetStarted()) + require.Equal(t, firstStartResp.RunId, resp.RunId) + + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: originalActivityID, + RunId: firstStartResp.RunId, + }) + require.NoError(t, err) + require.Len(t, descResp.Callbacks, 1) + require.Equal(t, "http://localhost/existing-activity-cb", descResp.Callbacks[0].GetInfo().GetCallback().GetNexus().GetUrl()) + }) + + t.Run("IdempotentWithSameRequestId", func(t *testing.T) { + idempotentActivityID := testcore.RandomizeStr(t.Name()) + idempotentTaskQueue := testcore.RandomizeStr(t.Name()) + idempotentStartResp := s.startAndValidateActivity(ctx, t, idempotentActivityID, idempotentTaskQueue) + + requestID := s.tv.Any().String() + startReq := &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: idempotentActivityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: idempotentTaskQueue, + }, + StartToCloseTimeout: durationpb.New(1 * time.Minute), + IdConflictPolicy: enumspb.ACTIVITY_ID_CONFLICT_POLICY_USE_EXISTING, + RequestId: requestID, + CompletionCallbacks: []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/idempotent-cb"}}}, + }, + OnConflictOptions: onConflictOpts, + } + + // First call attaches the callback. + resp1, err := s.FrontendClient().StartActivityExecution(ctx, startReq) + require.NoError(t, err) + require.False(t, resp1.GetStarted()) + + // Second call with the same request ID should not duplicate the callback. + resp2, err := s.FrontendClient().StartActivityExecution(ctx, startReq) + require.NoError(t, err) + require.False(t, resp2.GetStarted()) + + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: idempotentActivityID, + RunId: idempotentStartResp.RunId, + }) + require.NoError(t, err) + // Only 1 callback: the second call with the same request ID should not add another. + require.Len(t, descResp.Callbacks, 1) + require.Equal(t, "http://localhost/idempotent-cb", descResp.Callbacks[0].GetInfo().GetCallback().GetNexus().GetUrl()) + }) + }) + t.Run("DoesNotApplyToCompletedActivity", func(t *testing.T) { - pollTaskResp := s.pollActivityTaskAndValidate(ctx, t, activityID, taskQueue, firstStartResp.RunId) + pollTaskResp := s.pollActivityTaskAndValidate(ctx, t, originalActivityID, taskQueue, firstStartResp.RunId) _, err := s.FrontendClient().RespondActivityTaskCompleted(ctx, &workflowservice.RespondActivityTaskCompletedRequest{ Namespace: s.Namespace().String(), TaskToken: pollTaskResp.TaskToken, @@ -521,6 +660,35 @@ func (s *standaloneActivityTestSuite) TestStart() { require.ErrorAs(t, err, &invalidArgErr) }) }) + + t.Run("ResponseFields", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + resp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + }) + require.NoError(t, err) + + require.True(t, resp.Started) + require.NotEmpty(t, resp.RunId) + + // Verify link points to the started activity. + link := resp.GetLink().GetActivity() + require.NotNil(t, link) + require.Equal(t, s.Namespace().String(), link.Namespace) + require.Equal(t, activityID, link.ActivityId) + require.Equal(t, resp.RunId, link.RunId) + }) } func (s *standaloneActivityTestSuite) TestComplete() { @@ -4950,3 +5118,507 @@ func (s *standaloneActivityTestSuite) startActivityWithType(ctx context.Context, RequestId: s.tv.RequestID(), }) } + +func (s *standaloneActivityTestSuite) runNexusCompletionHTTPServer(t *testing.T, h *completionHandler) string { + hh := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{Handler: h}) + srv := httptest.NewServer(hh) + t.Cleanup(func() { + srv.Close() + }) + return srv.URL +} + +func (s *standaloneActivityTestSuite) TestCallbacks() { + t := s.T() + ctx, cancel := context.WithTimeout(t.Context(), 15*time.Second) + defer cancel() + + s.OverrideDynamicConfig( + callbacks.AllowedAddresses, + []any{map[string]any{"Pattern": "*", "AllowInsecure": true}}, + ) + + t.Run("AcceptedOnStart", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + resp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{{ + Variant: &commonpb.Callback_Nexus_{ + Nexus: &commonpb.Callback_Nexus{ + Url: "http://localhost/callback", + }, + }, + }}, + }) + require.NoError(t, err) + require.True(t, resp.Started) + require.NotEmpty(t, resp.RunId) + }) + + t.Run("MultipleCallbacksAccepted", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + resp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/callback1"}}}, + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/callback2"}}}, + }, + }) + require.NoError(t, err) + require.True(t, resp.Started) + require.NotEmpty(t, resp.RunId) + }) + + t.Run("DescribeIncludesCallbackInfo", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + callbackURL := "http://localhost/describe-callback" + startResp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackURL}}}, + }, + }) + require.NoError(t, err) + + describeResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + RunId: startResp.RunId, + }) + require.NoError(t, err) + + require.Len(t, describeResp.Callbacks, 1) + cbInfo := describeResp.Callbacks[0] + require.NotNil(t, cbInfo.GetTrigger().GetActivityClosed()) + require.Equal(t, callbackURL, cbInfo.GetInfo().GetCallback().GetNexus().GetUrl()) + require.Equal(t, enumspb.CALLBACK_STATE_STANDBY, cbInfo.GetInfo().GetState()) + require.NotNil(t, cbInfo.GetInfo().GetRegistrationTime()) + }) + + t.Run("ExceedsMaxCallbacksLimit", func(t *testing.T) { + maxCallbacks := 1 + s.OverrideDynamicConfig( + callback.MaxPerExecution, + maxCallbacks, + ) + + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + _, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + // Two callbacks when overridden max dynamic config is 1, so should error. + CompletionCallbacks: []*commonpb.Callback{ + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/callback1"}}}, + {Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/callback2"}}}, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, fmt.Sprintf("cannot attach more than %d callbacks", maxCallbacks)) + }) + + t.Run("CompletesWithCallbacks", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + ch := &completionHandler{ + requestCh: make(chan *nexusrpc.CompletionRequest, 1), + requestCompleteCh: make(chan error, 1), + } + defer func() { + close(ch.requestCh) + close(ch.requestCompleteCh) + }() + callbackAddress := s.runNexusCompletionHTTPServer(t, ch) + + _, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{{ + Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackAddress}}, + }}, + }) + require.NoError(t, err) + + pollResp, err := s.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: s.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + Identity: s.tv.WorkerIdentity(), + }) + require.NoError(t, err) + + _, err = s.FrontendClient().RespondActivityTaskCompleted(ctx, &workflowservice.RespondActivityTaskCompletedRequest{ + Namespace: s.Namespace().String(), + TaskToken: pollResp.TaskToken, + Result: defaultResult, + Identity: defaultIdentity, + }) + require.NoError(t, err) + + // Verify the callback was actually delivered with the correct result. + select { + case completion := <-ch.requestCh: + require.Equal(t, nexus.OperationStateSucceeded, completion.State) + require.False(t, completion.StartTime.IsZero()) + require.False(t, completion.CloseTime.IsZero()) + body, readErr := io.ReadAll(completion.HTTPRequest.Body) + _ = completion.HTTPRequest.Body.Close() + require.NoError(t, readErr) + require.JSONEq(t, string(defaultResult.Payloads[0].Data), string(body)) + // Unblock CompleteOperation so it returns 200 OK to the callback library + ch.requestCompleteCh <- nil + case <-ctx.Done(): + require.Fail(t, "timed out waiting for completion callback") + } + + // Verify the activity is in completed state. + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + }) + require.NoError(t, err) + require.Equal(t, enumspb.ACTIVITY_EXECUTION_STATUS_COMPLETED, descResp.GetInfo().GetStatus()) + }) + + t.Run("FailsWithCallbacks", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + ch := &completionHandler{ + requestCh: make(chan *nexusrpc.CompletionRequest, 1), + requestCompleteCh: make(chan error, 1), + } + defer func() { + close(ch.requestCh) + close(ch.requestCompleteCh) + }() + callbackAddress := s.runNexusCompletionHTTPServer(t, ch) + + _, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{{ + Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackAddress}}, + }}, + }) + require.NoError(t, err) + + pollResp, err := s.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: s.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + Identity: s.tv.WorkerIdentity(), + }) + require.NoError(t, err) + + _, err = s.FrontendClient().RespondActivityTaskFailed(ctx, &workflowservice.RespondActivityTaskFailedRequest{ + Namespace: s.Namespace().String(), + TaskToken: pollResp.TaskToken, + Failure: defaultFailure, + Identity: defaultIdentity, + }) + require.NoError(t, err) + + // Verify the callback was actually delivered with failure state. + select { + case completion := <-ch.requestCh: + require.Equal(t, nexus.OperationStateFailed, completion.State) + require.False(t, completion.StartTime.IsZero()) + require.False(t, completion.CloseTime.IsZero()) + var failureErr *nexus.FailureError + require.ErrorAs(t, completion.Error.Cause, &failureErr) + tFailure, convErr := commonnexus.NexusFailureToTemporalFailure(failureErr.Failure) + require.NoError(t, convErr) + sdkErr := temporal.GetDefaultFailureConverter().FailureToError(tFailure) + var appErr *temporal.ApplicationError + require.ErrorAs(t, sdkErr, &appErr) + require.Equal(t, defaultFailure.Message, appErr.Message()) + ch.requestCompleteCh <- nil + case <-ctx.Done(): + require.Fail(t, "timed out waiting for completion callback") + } + + // Verify the activity is in failed state. + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + }) + require.NoError(t, err) + require.Equal(t, enumspb.ACTIVITY_EXECUTION_STATUS_FAILED, descResp.GetInfo().GetStatus()) + }) + + t.Run("TerminatedWithCallbacks", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + ch := &completionHandler{ + requestCh: make(chan *nexusrpc.CompletionRequest, 1), + requestCompleteCh: make(chan error, 1), + } + defer func() { + close(ch.requestCh) + close(ch.requestCompleteCh) + }() + callbackAddress := s.runNexusCompletionHTTPServer(t, ch) + + startResp, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{{ + Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackAddress}}, + }}, + }) + require.NoError(t, err) + runID := startResp.RunId + + _, err = s.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: s.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + Identity: s.tv.WorkerIdentity(), + }) + require.NoError(t, err) + + reason := "Test Termination" + _, err = s.FrontendClient().TerminateActivityExecution(ctx, &workflowservice.TerminateActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + RunId: runID, + Reason: reason, + Identity: "terminator", + }) + require.NoError(t, err) + + // Verify the callback was delivered with failure state (terminated maps to failed). + select { + case completion := <-ch.requestCh: + require.Equal(t, nexus.OperationStateFailed, completion.State) + require.False(t, completion.StartTime.IsZero()) + require.False(t, completion.CloseTime.IsZero()) + var failureErr *nexus.FailureError + require.ErrorAs(t, completion.Error.Cause, &failureErr) + tFailure, convErr := commonnexus.NexusFailureToTemporalFailure(failureErr.Failure) + require.NoError(t, convErr) + sdkErr := temporal.GetDefaultFailureConverter().FailureToError(tFailure) + var termErr *temporal.TerminatedError + require.ErrorAs(t, sdkErr, &termErr) + ch.requestCompleteCh <- nil + case <-ctx.Done(): + require.Fail(t, "timed out waiting for completion callback") + } + + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + }) + require.NoError(t, err) + require.Equal(t, enumspb.ACTIVITY_EXECUTION_STATUS_TERMINATED, descResp.GetInfo().GetStatus()) + }) + + t.Run("CanceledWithCallbacks", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + ch := &completionHandler{ + requestCh: make(chan *nexusrpc.CompletionRequest, 1), + requestCompleteCh: make(chan error, 1), + } + defer func() { + close(ch.requestCh) + close(ch.requestCompleteCh) + }() + callbackAddress := s.runNexusCompletionHTTPServer(t, ch) + + _, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{{ + Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackAddress}}, + }}, + }) + require.NoError(t, err) + + pollResp, err := s.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: s.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + Identity: s.tv.WorkerIdentity(), + }) + require.NoError(t, err) + + _, err = s.FrontendClient().RequestCancelActivityExecution(ctx, &workflowservice.RequestCancelActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + Identity: "cancelling-worker", + RequestId: s.tv.Any().String(), + Reason: "Test Cancellation", + }) + require.NoError(t, err) + + _, err = s.FrontendClient().RespondActivityTaskCanceled(ctx, &workflowservice.RespondActivityTaskCanceledRequest{ + Namespace: s.Namespace().String(), + TaskToken: pollResp.TaskToken, + Identity: defaultIdentity, + }) + require.NoError(t, err) + + // Verify the callback was delivered with canceled state. + select { + case completion := <-ch.requestCh: + require.Equal(t, nexus.OperationStateCanceled, completion.State) + require.False(t, completion.StartTime.IsZero()) + require.False(t, completion.CloseTime.IsZero()) + var failureErr *nexus.FailureError + require.ErrorAs(t, completion.Error.Cause, &failureErr) + tFailure, convErr := commonnexus.NexusFailureToTemporalFailure(failureErr.Failure) + require.NoError(t, convErr) + sdkErr := temporal.GetDefaultFailureConverter().FailureToError(tFailure) + var canceledErr *temporal.CanceledError + require.ErrorAs(t, sdkErr, &canceledErr) + ch.requestCompleteCh <- nil + case <-ctx.Done(): + require.Fail(t, "timed out waiting for completion callback") + } + + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + }) + require.NoError(t, err) + require.Equal(t, enumspb.ACTIVITY_EXECUTION_STATUS_CANCELED, descResp.GetInfo().GetStatus()) + }) + + // This test covers the timeout callback path using schedule-to-start, but the callback behavior + // is the same for all timeout types (schedule-to-close, start-to-close, heartbeat). + t.Run("TimeoutWithCallbacks", func(t *testing.T) { + activityID := testcore.RandomizeStr(t.Name()) + taskQueue := testcore.RandomizeStr(t.Name()) + + ch := &completionHandler{ + requestCh: make(chan *nexusrpc.CompletionRequest, 1), + requestCompleteCh: make(chan error, 1), + } + defer func() { + close(ch.requestCh) + close(ch.requestCompleteCh) + }() + callbackAddress := s.runNexusCompletionHTTPServer(t, ch) + + _, err := s.FrontendClient().StartActivityExecution(ctx, &workflowservice.StartActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + ActivityType: s.tv.ActivityType(), + Identity: s.tv.WorkerIdentity(), + Input: defaultInput, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: taskQueue, + }, + StartToCloseTimeout: durationpb.New(1 * time.Minute), + ScheduleToStartTimeout: durationpb.New(1 * time.Second), + RequestId: s.tv.Any().String(), + CompletionCallbacks: []*commonpb.Callback{{ + Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: callbackAddress}}, + }}, + }) + require.NoError(t, err) + + // No worker polls — activity will time out waiting to be started. + + // Verify the callback is delivered with failure state and non-zero CloseTime. + select { + case completion := <-ch.requestCh: + require.Equal(t, nexus.OperationStateFailed, completion.State) + var failureErr *nexus.FailureError + require.ErrorAs(t, completion.Error.Cause, &failureErr) + tFailure, convErr := commonnexus.NexusFailureToTemporalFailure(failureErr.Failure) + require.NoError(t, convErr) + sdkErr := temporal.GetDefaultFailureConverter().FailureToError(tFailure) + var timeoutErr *temporal.TimeoutError + require.ErrorAs(t, sdkErr, &timeoutErr) + require.False(t, completion.StartTime.IsZero()) + require.False(t, completion.CloseTime.IsZero()) + ch.requestCompleteCh <- nil + case <-ctx.Done(): + require.Fail(t, "timed out waiting for completion callback") + } + + // Verify the activity is in timed-out state. + descResp, err := s.FrontendClient().DescribeActivityExecution(ctx, &workflowservice.DescribeActivityExecutionRequest{ + Namespace: s.Namespace().String(), + ActivityId: activityID, + }) + require.NoError(t, err) + require.Equal(t, enumspb.ACTIVITY_EXECUTION_STATUS_TIMED_OUT, descResp.GetInfo().GetStatus()) + }) +}